c3c2e901e299525d8a42b56169b36583551332b0
Marco Ricci Rename and regroup all test...

Marco Ricci authored 2 months ago

1) # SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info>
2) #
3) # SPDX-License-Identifier: MIT
4) 
5) """Test OpenSSH key loading and signing."""
6) 
7) from __future__ import annotations
8) 
9) import base64
10) import errno
11) import io
12) import os
13) import socket
14) import subprocess
15) from typing import Any
16) 
17) import click
18) import click.testing
19) import derivepassphrase
20) import derivepassphrase.cli
21) import pytest
22) import ssh_agent_client
23) import tests
24) 
25) class TestStaticFunctionality:
26) 
27)     @pytest.mark.parametrize(['public_key', 'public_key_data'],
28)                              [(val['public_key'], val['public_key_data'])
29)                               for val in tests.SUPPORTED_KEYS.values()])
30)     def test_100_key_decoding(self, public_key, public_key_data):
31)         keydata = base64.b64decode(public_key.split(None, 2)[1])
32)         assert (
33)             keydata == public_key_data
34)         ), "recorded public key data doesn't match"
35) 
36)     def test_200_constructor_no_running_agent(self, monkeypatch):
37)         monkeypatch.delenv('SSH_AUTH_SOCK', raising=False)
38)         sock = socket.socket(family=socket.AF_UNIX)
Marco Ricci Distinguish errors when con...

Marco Ricci authored 2 months ago

39)         with pytest.raises(KeyError,
Marco Ricci Rename and regroup all test...

Marco Ricci authored 2 months ago

40)                            match='SSH_AUTH_SOCK environment variable'):
41)             ssh_agent_client.SSHAgentClient(socket=sock)
42) 
43)     @pytest.mark.parametrize(['input', 'expected'], [
44)         (16777216, b'\x01\x00\x00\x00'),
45)     ])
46)     def test_210_uint32(self, input, expected):
47)         uint32 = ssh_agent_client.SSHAgentClient.uint32
48)         assert uint32(input) == expected
49) 
50)     @pytest.mark.parametrize(['input', 'expected'], [
51)         (b'ssh-rsa', b'\x00\x00\x00\x07ssh-rsa'),
52)         (b'ssh-ed25519', b'\x00\x00\x00\x0bssh-ed25519'),
53)         (
54)             ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'),
55)             b'\x00\x00\x00\x0f\x00\x00\x00\x0bssh-ed25519',
56)         ),
57)     ])
58)     def test_211_string(self, input, expected):
59)         string = ssh_agent_client.SSHAgentClient.string
60)         assert bytes(string(input)) == expected
61) 
62)     @pytest.mark.parametrize(['input', 'expected'], [
63)         (b'\x00\x00\x00\x07ssh-rsa', b'ssh-rsa'),
64)         (
65)             ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'),
66)             b'ssh-ed25519',
67)         ),
68)     ])
69)     def test_212_unstring(self, input, expected):
70)         unstring = ssh_agent_client.SSHAgentClient.unstring
71)         unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
72)         assert bytes(unstring(input)) == expected
73)         assert tuple(
74)             bytes(x) for x in unstring_prefix(input)
75)         ) == (expected, b'')
76) 
77)     @pytest.mark.parametrize(['value', 'exc_type', 'exc_pattern'], [
78)         (10000000000000000, OverflowError, 'int too big to convert'),
79)         (-1, OverflowError, "can't convert negative int to unsigned"),
80)     ])
81)     def test_310_uint32_exceptions(self, value, exc_type, exc_pattern):
82)         uint32 = ssh_agent_client.SSHAgentClient.uint32
83)         with pytest.raises(exc_type, match=exc_pattern):
84)             uint32(value)
85) 
86)     @pytest.mark.parametrize(['input', 'exc_type', 'exc_pattern'], [
87)         ('some string', TypeError, 'invalid payload type'),
88)     ])
89)     def test_311_string_exceptions(self, input, exc_type, exc_pattern):
90)         string = ssh_agent_client.SSHAgentClient.string
91)         with pytest.raises(exc_type, match=exc_pattern):
92)             string(input)
93) 
94)     @pytest.mark.parametrize(
95)         ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [
96)             (b'ssh', ValueError, 'malformed SSH byte string', False, None),
97)             (
98)                 b'\x00\x00\x00\x08ssh-rsa',
99)                 ValueError, 'malformed SSH byte string',
100)                 False, None,
101)             ),
102)             (
103)                 b'\x00\x00\x00\x04XXX trailing text',
104)                 ValueError, 'malformed SSH byte string',
105)                 True, (b'XXX ', b'trailing text'),
106)             ),
107)     ])
108)     def test_312_unstring_exceptions(self, input, exc_type, exc_pattern,
109)                                      has_trailer, parts):
110)         unstring = ssh_agent_client.SSHAgentClient.unstring
111)         unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
112)         with pytest.raises(exc_type, match=exc_pattern):
113)             unstring(input)
114)         if has_trailer:
115)             assert tuple(bytes(x) for x in unstring_prefix(input)) == parts
116)         else:
117)             with pytest.raises(exc_type, match=exc_pattern):
118)                 unstring_prefix(input)
119) 
120) @tests.skip_if_no_agent
121) class TestAgentInteraction:
122) 
123)     @pytest.mark.parametrize(['keytype', 'data_dict'],
124)                              list(tests.SUPPORTED_KEYS.items()))
125)     def test_200_sign_data_via_agent(self, keytype, data_dict):
126)         private_key = data_dict['private_key']
127)         try:
128)             result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'],
129)                                     input=private_key, check=True,
130)                                     capture_output=True)
131)         except subprocess.CalledProcessError as e:
132)             pytest.skip(
133)                 f"uploading test key: {e!r}, stdout={e.stdout!r}, "
134)                 f"stderr={e.stderr!r}"
135)             )
136)         else:
137)             try:
138)                 client = ssh_agent_client.SSHAgentClient()
139)             except OSError:  # pragma: no cover
140)                 pytest.skip('communication error with the SSH agent')
141)         with client:
142)             key_comment_pairs = {bytes(k): bytes(c)
143)                                  for k, c in client.list_keys()}
144)             public_key_data = data_dict['public_key_data']
145)             expected_signature = data_dict['expected_signature']
146)             derived_passphrase = data_dict['derived_passphrase']
147)             if public_key_data not in key_comment_pairs:  # pragma: no cover
148)                 pytest.skip('prerequisite SSH key not loaded')
149)             signature = bytes(client.sign(
150)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
151)             assert signature == expected_signature, 'SSH signature mismatch'
152)             signature2 = bytes(client.sign(
153)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
154)             assert signature2 == expected_signature, 'SSH signature mismatch'
155)             assert (
156)                 derivepassphrase.Vault.phrase_from_key(public_key_data) ==
157)                 derived_passphrase
158)             ), 'SSH signature mismatch'
159) 
160)     @pytest.mark.parametrize(['keytype', 'data_dict'],
161)                              list(tests.UNSUITABLE_KEYS.items()))
162)     def test_201_sign_data_via_agent_unsupported(self, keytype, data_dict):
163)         private_key = data_dict['private_key']
164)         try:
165)             result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'],
166)                                     input=private_key, check=True,
167)                                     capture_output=True)
168)         except subprocess.CalledProcessError as e:  # pragma: no cover
169)             pytest.skip(
170)                 f"uploading test key: {e!r}, stdout={e.stdout!r}, "
171)                 f"stderr={e.stderr!r}"
172)             )
173)         else:
174)             try:
175)                 client = ssh_agent_client.SSHAgentClient()
176)             except OSError:  # pragma: no cover
177)                 pytest.skip('communication error with the SSH agent')
178)         with client:
179)             key_comment_pairs = {bytes(k): bytes(c)
180)                                  for k, c in client.list_keys()}
181)             public_key_data = data_dict['public_key_data']
182)             expected_signature = data_dict['expected_signature']
183)             if public_key_data not in key_comment_pairs:  # pragma: no cover
184)                 pytest.skip('prerequisite SSH key not loaded')
185)             signature = bytes(client.sign(
186)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
187)             signature2 = bytes(client.sign(
188)                 payload=derivepassphrase.Vault._UUID, key=public_key_data))
189)             assert signature != signature2, 'SSH signature repeatable?!'
190)             with pytest.raises(ValueError, match='unsuitable SSH key'):
191)                 derivepassphrase.Vault.phrase_from_key(public_key_data)
192) 
193)     @staticmethod
194)     def _params():
195)         for value in tests.SUPPORTED_KEYS.values():
196)             key = value['public_key_data']
197)             yield (key, False)
198)         singleton_key = tests.list_keys_singleton()[0].key
199)         for value in tests.SUPPORTED_KEYS.values():
200)             key = value['public_key_data']
201)             if key == singleton_key:
202)                 yield (key, True)
203) 
204)     @pytest.mark.parametrize(['key', 'single'], list(_params()))
205)     def test_210_ssh_key_selector(self, monkeypatch, key, single):
206)         def key_is_suitable(key: bytes):
207)             return key in {v['public_key_data']
208)                            for v in tests.SUPPORTED_KEYS.values()}
209)         if single:
210)             monkeypatch.setattr(ssh_agent_client.SSHAgentClient,
211)                                 'list_keys', tests.list_keys_singleton)
212)             keys = [pair.key for pair in tests.list_keys_singleton()
213)                     if key_is_suitable(pair.key)]
214)             index = '1'
215)             text = f'Use this key? yes\n'
216)         else:
217)             monkeypatch.setattr(ssh_agent_client.SSHAgentClient,
218)                                 'list_keys', tests.list_keys)
219)             keys = [pair.key for pair in tests.list_keys()
220)                     if key_is_suitable(pair.key)]
221)             index = str(1 + keys.index(key))
222)             n = len(keys)
223)             text = f'Your selection? (1-{n}, leave empty to abort): {index}\n'
224)         b64_key = base64.standard_b64encode(key).decode('ASCII')
225) 
226)         @click.command()
227)         def driver():
228)             key = derivepassphrase.cli._select_ssh_key()
229)             click.echo(base64.standard_b64encode(key).decode('ASCII'))
230) 
231)         runner = click.testing.CliRunner(mix_stderr=True)
232)         result = runner.invoke(driver, [],
233)                                input=('yes\n' if single else f'{index}\n'),
234)                                catch_exceptions=True)
235)         assert result.stdout.startswith('Suitable SSH keys:\n'), (
236)             'missing expected output'
237)         )
238)         assert text in result.stdout, 'missing expected output'
239)         assert (
240)             result.stdout.endswith(f'\n{b64_key}\n')
241)         ), 'missing expected output'
242)         assert result.exit_code == 0, 'driver program failed?!'
243) 
244)     del _params
245) 
246)     def test_300_constructor_bad_running_agent(self, monkeypatch):
247)         monkeypatch.setenv('SSH_AUTH_SOCK',
248)                            os.environ['SSH_AUTH_SOCK'] + '~')
249)         sock = socket.socket(family=socket.AF_UNIX)
250)         with pytest.raises(OSError):
251)             ssh_agent_client.SSHAgentClient(socket=sock)
252) 
253)     @pytest.mark.parametrize(['response'], [
254)         (b'\x00\x00',),
255)         (b'\x00\x00\x00\x1f some bytes missing',),
256)     ])
257)     def test_310_truncated_server_response(self, monkeypatch, response):
258)         client = ssh_agent_client.SSHAgentClient()
259)         response_stream = io.BytesIO(response)
260)         class PseudoSocket(object):
261)             def sendall(self, *args: Any, **kwargs: Any) -> Any:
262)                 return None
263)             def recv(self, *args: Any, **kwargs: Any) -> Any:
264)                 return response_stream.read(*args, **kwargs)
265)         pseudo_socket = PseudoSocket()
266)         monkeypatch.setattr(client, '_connection', pseudo_socket)
267)         with pytest.raises(EOFError):
268)             client.request(255, b'')
269) 
270)     @tests.skip_if_no_agent
271)     @pytest.mark.parametrize(
272)         ['response_code', 'response', 'exc_type', 'exc_pattern'],
273)         [
274)             (255, b'', RuntimeError, 'error return from SSH agent:'),
275)             (12, b'\x00\x00\x00\x01', EOFError, 'truncated response'),
Marco Ricci Introduce TrailingDataError...

Marco Ricci authored 2 months ago

276)             (
277)                 12,
278)                 b'\x00\x00\x00\x00abc',
279)                 ssh_agent_client.TrailingDataError,
280)                 'overlong response',
281)             ),