e8b3ecf264495b6e5cf9b5f07889545ed242b64b
Marco Ricci Rename and regroup all test...

Marco Ricci authored 4 months ago

1) # SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info>
2) #
3) # SPDX-License-Identifier: MIT
4) 
5) """Test OpenSSH key loading and signing."""
6) 
7) from __future__ import annotations
8) 
9) import base64
10) import io
11) import os
12) import socket
13) import subprocess
Marco Ricci Support Python 3.10 and PyP...

Marco Ricci authored 3 months ago

14) from typing_extensions import Any
Marco Ricci Rename and regroup all test...

Marco Ricci authored 4 months ago

15) 
16) import click
17) import click.testing
18) import derivepassphrase
19) import derivepassphrase.cli
20) import pytest
21) import ssh_agent_client
22) import tests
23) 
24) class TestStaticFunctionality:
25) 
26)     @pytest.mark.parametrize(['public_key', 'public_key_data'],
27)                              [(val['public_key'], val['public_key_data'])
28)                               for val in tests.SUPPORTED_KEYS.values()])
29)     def test_100_key_decoding(self, public_key, public_key_data):
30)         keydata = base64.b64decode(public_key.split(None, 2)[1])
31)         assert (
32)             keydata == public_key_data
33)         ), "recorded public key data doesn't match"
34) 
35)     def test_200_constructor_no_running_agent(self, monkeypatch):
36)         monkeypatch.delenv('SSH_AUTH_SOCK', raising=False)
37)         sock = socket.socket(family=socket.AF_UNIX)
Marco Ricci Distinguish errors when con...

Marco Ricci authored 4 months ago

38)         with pytest.raises(KeyError,
Marco Ricci Rename and regroup all test...

Marco Ricci authored 4 months ago

39)                            match='SSH_AUTH_SOCK environment variable'):
40)             ssh_agent_client.SSHAgentClient(socket=sock)
41) 
42)     @pytest.mark.parametrize(['input', 'expected'], [
43)         (16777216, b'\x01\x00\x00\x00'),
44)     ])
45)     def test_210_uint32(self, input, expected):
46)         uint32 = ssh_agent_client.SSHAgentClient.uint32
47)         assert uint32(input) == expected
48) 
49)     @pytest.mark.parametrize(['input', 'expected'], [
50)         (b'ssh-rsa', b'\x00\x00\x00\x07ssh-rsa'),
51)         (b'ssh-ed25519', b'\x00\x00\x00\x0bssh-ed25519'),
52)         (
53)             ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'),
54)             b'\x00\x00\x00\x0f\x00\x00\x00\x0bssh-ed25519',
55)         ),
56)     ])
57)     def test_211_string(self, input, expected):
58)         string = ssh_agent_client.SSHAgentClient.string
59)         assert bytes(string(input)) == expected
60) 
61)     @pytest.mark.parametrize(['input', 'expected'], [
62)         (b'\x00\x00\x00\x07ssh-rsa', b'ssh-rsa'),
63)         (
64)             ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'),
65)             b'ssh-ed25519',
66)         ),
67)     ])
68)     def test_212_unstring(self, input, expected):
69)         unstring = ssh_agent_client.SSHAgentClient.unstring
70)         unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
71)         assert bytes(unstring(input)) == expected
72)         assert tuple(
73)             bytes(x) for x in unstring_prefix(input)
74)         ) == (expected, b'')
75) 
76)     @pytest.mark.parametrize(['value', 'exc_type', 'exc_pattern'], [
77)         (10000000000000000, OverflowError, 'int too big to convert'),
78)         (-1, OverflowError, "can't convert negative int to unsigned"),
79)     ])
80)     def test_310_uint32_exceptions(self, value, exc_type, exc_pattern):
81)         uint32 = ssh_agent_client.SSHAgentClient.uint32
82)         with pytest.raises(exc_type, match=exc_pattern):
83)             uint32(value)
84) 
85)     @pytest.mark.parametrize(['input', 'exc_type', 'exc_pattern'], [
86)         ('some string', TypeError, 'invalid payload type'),
87)     ])
88)     def test_311_string_exceptions(self, input, exc_type, exc_pattern):
89)         string = ssh_agent_client.SSHAgentClient.string
90)         with pytest.raises(exc_type, match=exc_pattern):
91)             string(input)
92) 
93)     @pytest.mark.parametrize(
94)         ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [
95)             (b'ssh', ValueError, 'malformed SSH byte string', False, None),
96)             (
97)                 b'\x00\x00\x00\x08ssh-rsa',
98)                 ValueError, 'malformed SSH byte string',
99)                 False, None,
100)             ),
101)             (
102)                 b'\x00\x00\x00\x04XXX trailing text',
103)                 ValueError, 'malformed SSH byte string',
104)                 True, (b'XXX ', b'trailing text'),
105)             ),
106)     ])
107)     def test_312_unstring_exceptions(self, input, exc_type, exc_pattern,
108)                                      has_trailer, parts):
109)         unstring = ssh_agent_client.SSHAgentClient.unstring
110)         unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
111)         with pytest.raises(exc_type, match=exc_pattern):
112)             unstring(input)
113)         if has_trailer:
114)             assert tuple(bytes(x) for x in unstring_prefix(input)) == parts
115)         else:
116)             with pytest.raises(exc_type, match=exc_pattern):
117)                 unstring_prefix(input)
118) 
119) @tests.skip_if_no_agent
120) class TestAgentInteraction:
121) 
122)     @pytest.mark.parametrize(['keytype', 'data_dict'],
123)                              list(tests.SUPPORTED_KEYS.items()))
124)     def test_200_sign_data_via_agent(self, keytype, data_dict):
125)         private_key = data_dict['private_key']
126)         try:
127)             result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'],
128)                                     input=private_key, check=True,
129)                                     capture_output=True)
130)         except subprocess.CalledProcessError as e:
131)             pytest.skip(
132)                 f"uploading test key: {e!r}, stdout={e.stdout!r}, "
133)                 f"stderr={e.stderr!r}"
134)             )
135)         else:
136)             try:
137)                 client = ssh_agent_client.SSHAgentClient()
138)             except OSError:  # pragma: no cover
139)                 pytest.skip('communication error with the SSH agent')
140)         with client:
141)             key_comment_pairs = {bytes(k): bytes(c)
142)                                  for k, c in client.list_keys()}
143)             public_key_data = data_dict['public_key_data']
144)             expected_signature = data_dict['expected_signature']
145)             derived_passphrase = data_dict['derived_passphrase']
146)             if public_key_data not in key_comment_pairs:  # pragma: no cover
147)                 pytest.skip('prerequisite SSH key not loaded')
148)             signature = bytes(client.sign(
149)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
150)             assert signature == expected_signature, 'SSH signature mismatch'
151)             signature2 = bytes(client.sign(
152)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
153)             assert signature2 == expected_signature, 'SSH signature mismatch'
154)             assert (
155)                 derivepassphrase.Vault.phrase_from_key(public_key_data) ==
156)                 derived_passphrase
157)             ), 'SSH signature mismatch'
158) 
159)     @pytest.mark.parametrize(['keytype', 'data_dict'],
160)                              list(tests.UNSUITABLE_KEYS.items()))
161)     def test_201_sign_data_via_agent_unsupported(self, keytype, data_dict):
162)         private_key = data_dict['private_key']
163)         try:
164)             result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'],
165)                                     input=private_key, check=True,
166)                                     capture_output=True)
167)         except subprocess.CalledProcessError as e:  # pragma: no cover
168)             pytest.skip(
169)                 f"uploading test key: {e!r}, stdout={e.stdout!r}, "
170)                 f"stderr={e.stderr!r}"
171)             )
172)         else:
173)             try:
174)                 client = ssh_agent_client.SSHAgentClient()
175)             except OSError:  # pragma: no cover
176)                 pytest.skip('communication error with the SSH agent')
177)         with client:
178)             key_comment_pairs = {bytes(k): bytes(c)
179)                                  for k, c in client.list_keys()}
180)             public_key_data = data_dict['public_key_data']
181)             expected_signature = data_dict['expected_signature']
182)             if public_key_data not in key_comment_pairs:  # pragma: no cover
183)                 pytest.skip('prerequisite SSH key not loaded')
184)             signature = bytes(client.sign(
185)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
186)             signature2 = bytes(client.sign(
187)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
188)             assert signature != signature2, 'SSH signature repeatable?!'
189)             with pytest.raises(ValueError, match='unsuitable SSH key'):
190)                 derivepassphrase.Vault.phrase_from_key(public_key_data)
191) 
192)     @staticmethod
193)     def _params():
194)         for value in tests.SUPPORTED_KEYS.values():
195)             key = value['public_key_data']
196)             yield (key, False)
197)         singleton_key = tests.list_keys_singleton()[0].key
198)         for value in tests.SUPPORTED_KEYS.values():
199)             key = value['public_key_data']
200)             if key == singleton_key:
201)                 yield (key, True)
202) 
203)     @pytest.mark.parametrize(['key', 'single'], list(_params()))
204)     def test_210_ssh_key_selector(self, monkeypatch, key, single):
205)         def key_is_suitable(key: bytes):
206)             return key in {v['public_key_data']
207)                            for v in tests.SUPPORTED_KEYS.values()}
208)         if single:
209)             monkeypatch.setattr(ssh_agent_client.SSHAgentClient,
210)                                 'list_keys', tests.list_keys_singleton)
211)             keys = [pair.key for pair in tests.list_keys_singleton()
212)                     if key_is_suitable(pair.key)]
213)             index = '1'
214)             text = f'Use this key? yes\n'
215)         else:
216)             monkeypatch.setattr(ssh_agent_client.SSHAgentClient,
217)                                 'list_keys', tests.list_keys)
218)             keys = [pair.key for pair in tests.list_keys()
219)                     if key_is_suitable(pair.key)]
220)             index = str(1 + keys.index(key))
221)             n = len(keys)
222)             text = f'Your selection? (1-{n}, leave empty to abort): {index}\n'
223)         b64_key = base64.standard_b64encode(key).decode('ASCII')
224) 
225)         @click.command()
226)         def driver():
227)             key = derivepassphrase.cli._select_ssh_key()
228)             click.echo(base64.standard_b64encode(key).decode('ASCII'))
229) 
230)         runner = click.testing.CliRunner(mix_stderr=True)
231)         result = runner.invoke(driver, [],
232)                                input=('yes\n' if single else f'{index}\n'),
233)                                catch_exceptions=True)
234)         assert result.stdout.startswith('Suitable SSH keys:\n'), (
235)             'missing expected output'
236)         )
237)         assert text in result.stdout, 'missing expected output'
238)         assert (
239)             result.stdout.endswith(f'\n{b64_key}\n')
240)         ), 'missing expected output'
241)         assert result.exit_code == 0, 'driver program failed?!'
242) 
243)     del _params
244) 
245)     def test_300_constructor_bad_running_agent(self, monkeypatch):
246)         monkeypatch.setenv('SSH_AUTH_SOCK',
247)                            os.environ['SSH_AUTH_SOCK'] + '~')
248)         sock = socket.socket(family=socket.AF_UNIX)
249)         with pytest.raises(OSError):
250)             ssh_agent_client.SSHAgentClient(socket=sock)
251) 
252)     @pytest.mark.parametrize(['response'], [
253)         (b'\x00\x00',),
254)         (b'\x00\x00\x00\x1f some bytes missing',),
255)     ])
256)     def test_310_truncated_server_response(self, monkeypatch, response):
257)         client = ssh_agent_client.SSHAgentClient()
258)         response_stream = io.BytesIO(response)
259)         class PseudoSocket(object):
260)             def sendall(self, *args: Any, **kwargs: Any) -> Any:
261)                 return None
262)             def recv(self, *args: Any, **kwargs: Any) -> Any:
263)                 return response_stream.read(*args, **kwargs)
264)         pseudo_socket = PseudoSocket()
265)         monkeypatch.setattr(client, '_connection', pseudo_socket)
266)         with pytest.raises(EOFError):
267)             client.request(255, b'')
268) 
269)     @tests.skip_if_no_agent
270)     @pytest.mark.parametrize(
271)         ['response_code', 'response', 'exc_type', 'exc_pattern'],
272)         [
273)             (255, b'', RuntimeError, 'error return from SSH agent:'),
274)             (12, b'\x00\x00\x00\x01', EOFError, 'truncated response'),
Marco Ricci Introduce TrailingDataError...

Marco Ricci authored 4 months ago

275)             (
276)                 12,
277)                 b'\x00\x00\x00\x00abc',
278)                 ssh_agent_client.TrailingDataError,
279)                 'overlong response',
280)             ),