# SPDX-FileCopyrightText: 2024 Marco Ricci # # SPDX-License-Identifier: MIT """Test OpenSSH key loading and signing.""" from __future__ import annotations import base64 import io import os import socket import subprocess from typing import TYPE_CHECKING import click import click.testing import pytest from typing_extensions import Any import tests from derivepassphrase import _types, cli, ssh_agent, vault if TYPE_CHECKING: from collections.abc import Iterator class TestStaticFunctionality: @pytest.mark.parametrize( ['public_key', 'public_key_data'], [ (val['public_key'], val['public_key_data']) for val in tests.SUPPORTED_KEYS.values() ], ) def test_100_key_decoding( self, public_key: bytes, public_key_data: bytes ) -> None: keydata = base64.b64decode(public_key.split(None, 2)[1]) assert ( keydata == public_key_data ), "recorded public key data doesn't match" def test_200_constructor_no_running_agent( self, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.delenv('SSH_AUTH_SOCK', raising=False) sock = socket.socket(family=socket.AF_UNIX) with pytest.raises( KeyError, match='SSH_AUTH_SOCK environment variable' ): ssh_agent.SSHAgentClient(socket=sock) @pytest.mark.parametrize( ['input', 'expected'], [ (16777216, b'\x01\x00\x00\x00'), ], ) def test_210_uint32(self, input: int, expected: bytes | bytearray) -> None: uint32 = ssh_agent.SSHAgentClient.uint32 assert uint32(input) == expected @pytest.mark.parametrize( ['input', 'expected'], [ (b'ssh-rsa', b'\x00\x00\x00\x07ssh-rsa'), (b'ssh-ed25519', b'\x00\x00\x00\x0bssh-ed25519'), ( ssh_agent.SSHAgentClient.string(b'ssh-ed25519'), b'\x00\x00\x00\x0f\x00\x00\x00\x0bssh-ed25519', ), ], ) def test_211_string( self, input: bytes | bytearray, expected: bytes | bytearray ) -> None: string = ssh_agent.SSHAgentClient.string assert bytes(string(input)) == expected @pytest.mark.parametrize( ['input', 'expected'], [ (b'\x00\x00\x00\x07ssh-rsa', b'ssh-rsa'), ( ssh_agent.SSHAgentClient.string(b'ssh-ed25519'), b'ssh-ed25519', ), ], ) def test_212_unstring( self, input: bytes | bytearray, expected: bytes | bytearray ) -> None: unstring = ssh_agent.SSHAgentClient.unstring unstring_prefix = ssh_agent.SSHAgentClient.unstring_prefix assert bytes(unstring(input)) == expected assert tuple(bytes(x) for x in unstring_prefix(input)) == ( expected, b'', ) @pytest.mark.parametrize( ['value', 'exc_type', 'exc_pattern'], [ (10000000000000000, OverflowError, 'int too big to convert'), (-1, OverflowError, "can't convert negative int to unsigned"), ], ) def test_310_uint32_exceptions( self, value: int, exc_type: type[Exception], exc_pattern: str ) -> None: uint32 = ssh_agent.SSHAgentClient.uint32 with pytest.raises(exc_type, match=exc_pattern): uint32(value) @pytest.mark.parametrize( ['input', 'exc_type', 'exc_pattern'], [ ('some string', TypeError, 'invalid payload type'), ], ) def test_311_string_exceptions( self, input: Any, exc_type: type[Exception], exc_pattern: str ) -> None: string = ssh_agent.SSHAgentClient.string with pytest.raises(exc_type, match=exc_pattern): string(input) @pytest.mark.parametrize( ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [ (b'ssh', ValueError, 'malformed SSH byte string', False, None), ( b'\x00\x00\x00\x08ssh-rsa', ValueError, 'malformed SSH byte string', False, None, ), ( b'\x00\x00\x00\x04XXX trailing text', ValueError, 'malformed SSH byte string', True, (b'XXX ', b'trailing text'), ), ], ) def test_312_unstring_exceptions( self, input: bytes | bytearray, exc_type: type[Exception], exc_pattern: str, has_trailer: bool, parts: tuple[bytes | bytearray, bytes | bytearray] | None, ) -> None: unstring = ssh_agent.SSHAgentClient.unstring unstring_prefix = ssh_agent.SSHAgentClient.unstring_prefix with pytest.raises(exc_type, match=exc_pattern): unstring(input) if has_trailer: assert tuple(bytes(x) for x in unstring_prefix(input)) == parts else: with pytest.raises(exc_type, match=exc_pattern): unstring_prefix(input) @tests.skip_if_no_agent class TestAgentInteraction: @pytest.mark.parametrize( ['keytype', 'data_dict'], list(tests.SUPPORTED_KEYS.items()) ) def test_200_sign_data_via_agent( self, keytype: str, data_dict: tests.SSHTestKey ) -> None: del keytype # Unused. private_key = data_dict['private_key'] try: _ = subprocess.run( ['ssh-add', '-t', '30', '-q', '-'], input=private_key, check=True, capture_output=True, ) except subprocess.CalledProcessError as e: pytest.skip( f'uploading test key: {e!r}, stdout={e.stdout!r}, ' f'stderr={e.stderr!r}' ) else: try: client = ssh_agent.SSHAgentClient() except OSError: # pragma: no cover pytest.skip('communication error with the SSH agent') with client: key_comment_pairs = { bytes(k): bytes(c) for k, c in client.list_keys() } public_key_data = data_dict['public_key_data'] expected_signature = data_dict['expected_signature'] derived_passphrase = data_dict['derived_passphrase'] if public_key_data not in key_comment_pairs: # pragma: no cover pytest.skip('prerequisite SSH key not loaded') signature = bytes( client.sign(payload=vault.Vault._UUID, key=public_key_data) ) assert signature == expected_signature, 'SSH signature mismatch' signature2 = bytes( client.sign(payload=vault.Vault._UUID, key=public_key_data) ) assert signature2 == expected_signature, 'SSH signature mismatch' assert ( vault.Vault.phrase_from_key(public_key_data) == derived_passphrase ), 'SSH signature mismatch' @pytest.mark.parametrize( ['keytype', 'data_dict'], list(tests.UNSUITABLE_KEYS.items()) ) def test_201_sign_data_via_agent_unsupported( self, keytype: str, data_dict: tests.SSHTestKey ) -> None: del keytype # Unused. private_key = data_dict['private_key'] try: _ = subprocess.run( ['ssh-add', '-t', '30', '-q', '-'], input=private_key, check=True, capture_output=True, ) except subprocess.CalledProcessError as e: # pragma: no cover pytest.skip( f'uploading test key: {e!r}, stdout={e.stdout!r}, ' f'stderr={e.stderr!r}' ) else: try: client = ssh_agent.SSHAgentClient() except OSError: # pragma: no cover pytest.skip('communication error with the SSH agent') with client: key_comment_pairs = { bytes(k): bytes(c) for k, c in client.list_keys() } public_key_data = data_dict['public_key_data'] _ = data_dict['expected_signature'] if public_key_data not in key_comment_pairs: # pragma: no cover pytest.skip('prerequisite SSH key not loaded') signature = bytes( client.sign(payload=vault.Vault._UUID, key=public_key_data) ) signature2 = bytes( client.sign(payload=vault.Vault._UUID, key=public_key_data) ) assert signature != signature2, 'SSH signature repeatable?!' with pytest.raises(ValueError, match='unsuitable SSH key'): vault.Vault.phrase_from_key(public_key_data) @staticmethod def _params() -> Iterator[tuple[bytes, bool]]: for value in tests.SUPPORTED_KEYS.values(): key = value['public_key_data'] yield (key, False) singleton_key = tests.list_keys_singleton()[0].key for value in tests.SUPPORTED_KEYS.values(): key = value['public_key_data'] if key == singleton_key: yield (key, True) @pytest.mark.parametrize(['key', 'single'], list(_params())) def test_210_ssh_key_selector( self, monkeypatch: pytest.MonkeyPatch, key: bytes, single: bool ) -> None: def key_is_suitable(key: bytes) -> bool: return key in { v['public_key_data'] for v in tests.SUPPORTED_KEYS.values() } if single: monkeypatch.setattr( ssh_agent.SSHAgentClient, 'list_keys', tests.list_keys_singleton, ) keys = [ pair.key for pair in tests.list_keys_singleton() if key_is_suitable(pair.key) ] index = '1' text = 'Use this key? yes\n' else: monkeypatch.setattr( ssh_agent.SSHAgentClient, 'list_keys', tests.list_keys ) keys = [ pair.key for pair in tests.list_keys() if key_is_suitable(pair.key) ] index = str(1 + keys.index(key)) n = len(keys) text = f'Your selection? (1-{n}, leave empty to abort): {index}\n' b64_key = base64.standard_b64encode(key).decode('ASCII') @click.command() def driver() -> None: key = cli._select_ssh_key() click.echo(base64.standard_b64encode(key).decode('ASCII')) runner = click.testing.CliRunner(mix_stderr=True) _result = runner.invoke( driver, [], input=('yes\n' if single else f'{index}\n'), catch_exceptions=True, ) result = tests.ReadableResult.parse(_result) for snippet in ('Suitable SSH keys:\n', text, f'\n{b64_key}\n'): assert result.clean_exit(output=snippet), 'expected clean exit' del _params def test_300_constructor_bad_running_agent( self, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setenv('SSH_AUTH_SOCK', os.environ['SSH_AUTH_SOCK'] + '~') sock = socket.socket(family=socket.AF_UNIX) with pytest.raises(OSError): # noqa: PT011 ssh_agent.SSHAgentClient(socket=sock) @pytest.mark.parametrize( 'response', [ b'\x00\x00', b'\x00\x00\x00\x1f some bytes missing', ], ) def test_310_truncated_server_response( self, monkeypatch: pytest.MonkeyPatch, response: bytes ) -> None: client = ssh_agent.SSHAgentClient() response_stream = io.BytesIO(response) class PseudoSocket: def sendall(self, *args: Any, **kwargs: Any) -> Any: # noqa: ARG002 return None def recv(self, *args: Any, **kwargs: Any) -> Any: return response_stream.read(*args, **kwargs) pseudo_socket = PseudoSocket() monkeypatch.setattr(client, '_connection', pseudo_socket) with pytest.raises(EOFError): client.request(255, b'') @tests.skip_if_no_agent @pytest.mark.parametrize( ['response_code', 'response', 'exc_type', 'exc_pattern'], [ ( _types.SSH_AGENT.FAILURE, b'', ssh_agent.SSHAgentFailedError, 'failed to complete the request', ), ( _types.SSH_AGENT.IDENTITIES_ANSWER, b'\x00\x00\x00\x01', EOFError, 'truncated response', ), ( _types.SSH_AGENT.IDENTITIES_ANSWER, b'\x00\x00\x00\x00abc', ssh_agent.TrailingDataError, 'Overlong response', ), ], ) def test_320_list_keys_error_responses( self, monkeypatch: pytest.MonkeyPatch, response_code: _types.SSH_AGENT, response: bytes | bytearray, exc_type: type[Exception], exc_pattern: str, ) -> None: client = ssh_agent.SSHAgentClient() monkeypatch.setattr( client, 'request', lambda *a, **kw: (response_code.value, response), # noqa: ARG005 ) with pytest.raises(exc_type, match=exc_pattern): client.list_keys() @tests.skip_if_no_agent @pytest.mark.parametrize( ['key', 'check', 'response', 'exc_type', 'exc_pattern'], [ ( b'invalid-key', True, (_types.SSH_AGENT.FAILURE, b''), KeyError, 'target SSH key not loaded into agent', ), ( tests.SUPPORTED_KEYS['ed25519']['public_key_data'], True, (_types.SSH_AGENT.FAILURE, b''), ssh_agent.SSHAgentFailedError, 'failed to complete the request', ), ], ) def test_330_sign_error_responses( self, monkeypatch: pytest.MonkeyPatch, key: bytes | bytearray, check: bool, response: tuple[_types.SSH_AGENT, bytes | bytearray], exc_type: type[Exception], exc_pattern: str, ) -> None: client = ssh_agent.SSHAgentClient() monkeypatch.setattr( client, 'request', lambda a, b: (response[0].value, response[1]), # noqa: ARG005 ) KeyCommentPair = _types.KeyCommentPair # noqa: N806 loaded_keys = [ KeyCommentPair(v['public_key_data'], b'no comment') for v in tests.SUPPORTED_KEYS.values() ] monkeypatch.setattr(client, 'list_keys', lambda: loaded_keys) with pytest.raises(exc_type, match=exc_pattern): client.sign(key, b'abc', check_if_key_loaded=check)