Introduce a stubbed SSH agent, for testing
Marco Ricci

Marco Ricci commited on 2025-08-02 13:22:55
Zeige 2 geänderte Dateien mit 681 Einfügungen und 1 Löschungen.


Introduce a stubbed SSH agent socket that simulates a real agent's
responses (for test keys only).  Include tests to verify the correct
workings of the agent.

The stub agent comes in three versions.  The basic version implements
all the necessary functionality, but intentionally disables SSH agent
extension support.  The "address" version additionally supports
generating errors upon construction by both requiring an address (in
`SSH_AUTH_SOCK`) and by raising specific errors for specific addresses.
The "deterministic DSA" version additionally supports the "query" and
the "list-extended@putty.projects.tartarus.org" extension, and yields
the recorded RFC 6979 deterministic DSA signatures for DSA and ECDSA
test keys.  The "address" and "deterministic DSA" versions are intended
for use in the test suite, as test doubles for real SSH agents,
depending on whether deterministic DSA signatures are required or not.
The "basic" version is intended for verification of the stubbed SSH
agent itself, whereever it can be used in place of more powerful stubbed
SSH agent versions.

Though tested *individually*, as a new piece of code added to the code
base, the stub agent (in all its variations) is not yet *integrated*
into the test suite as a "proper", spawnable SSH agent.  That we leave
to the following commits.
... ...
@@ -8,6 +8,7 @@ import base64
8 8
 import contextlib
9 9
 import copy
10 10
 import enum
11
+import errno
11 12
 import importlib.util
12 13
 import json
13 14
 import logging
... ...
@@ -30,6 +31,7 @@ from typing_extensions import NamedTuple, assert_never
30 31
 
31 32
 from derivepassphrase import _types, cli, ssh_agent, vault
32 33
 from derivepassphrase._internals import cli_helpers, cli_machinery
34
+from derivepassphrase.ssh_agent import socketprovider
33 35
 
34 36
 __all__ = ()
35 37
 
... ...
@@ -39,7 +41,7 @@ if TYPE_CHECKING:
39 41
     from contextlib import AbstractContextManager
40 42
     from typing import IO, NotRequired
41 43
 
42
-    from typing_extensions import Any
44
+    from typing_extensions import Any, Buffer, Self
43 45
 
44 46
 
45 47
 class SSHTestKeyDeterministicSignatureClass(str, enum.Enum):
... ...
@@ -607,6 +609,9 @@ class KnownSSHAgent(str, enum.Enum):
607 609
             The agent from Simon Tatham's PuTTY suite.
608 610
         OpenSSHAgent (str):
609 611
             The agent from OpenBSD's OpenSSH suite.
612
+        StubbedSSHAgent (str):
613
+            The stubbed, fake agent pseudo-socket defined in this test
614
+            suite.
610 615
 
611 616
     """
612 617
 
... ...
@@ -616,6 +621,8 @@ class KnownSSHAgent(str, enum.Enum):
616 621
     """"""
617 622
     OpenSSHAgent = 'OpenSSHAgent'
618 623
     """"""
624
+    StubbedSSHAgent = 'StubbedSSHAgent'
625
+    """"""
619 626
 
620 627
 
621 628
 class SpawnedSSHAgentInfo(NamedTuple):
... ...
@@ -1851,6 +1858,361 @@ def xfail_on_the_annoying_os(
1851 1858
     return mark if f is None else mark(f)
1852 1859
 
1853 1860
 
1861
+@socketprovider.SocketProvider.register('stub_agent')
1862
+class StubbedSSHAgentSocket:
1863
+    """A stubbed SSH agent presenting an [`_types.SSHAgentSocket`][]."""
1864
+
1865
+    _SOCKET_IS_CLOSED = 'Socket is closed.'
1866
+    _NO_FLAG_SUPPORT = 'This stubbed SSH agent socket does not support flags.'
1867
+    _PROTOCOL_VIOLATION = 'SSH agent protocol violation.'
1868
+    _INVALID_REQUEST = 'Invalid request.'
1869
+    _UNSUPPORTED_REQUEST = 'Unsupported request.'
1870
+
1871
+    HEADER_SIZE = 4
1872
+    CODE_SIZE = 1
1873
+
1874
+    KNOWN_EXTENSIONS = frozenset({
1875
+        'query',
1876
+        'list-extended@putty.projects.tartarus.org',
1877
+    })
1878
+    """Known and implemented protocol extensions."""
1879
+
1880
+    def __init__(self, *extensions: str) -> None:
1881
+        """Initialize the agent."""
1882
+        self.send_to_client = bytearray()
1883
+        """
1884
+        The buffered response to the client, read piecemeal by [`recv`][].
1885
+        """
1886
+        self.receive_from_client = bytearray()
1887
+        """The last request issued by the client."""
1888
+        self.closed = False
1889
+        """True if the connection is closed, false otherwise."""
1890
+        self.enabled_extensions = frozenset(extensions) & self.KNOWN_EXTENSIONS
1891
+        """
1892
+        Extensions actually enabled in this particular stubbed SSH agent.
1893
+        """
1894
+        self.try_rfc6979 = False
1895
+        """
1896
+        Attempt to issue DSA and ECDSA signatures according to RFC 6979?
1897
+        """
1898
+        self.try_pageant_068_080 = False
1899
+        """
1900
+        Attempt to issue DSA and ECDSA signatures as per Pageant 0.68–0.80?
1901
+        """  # noqa: RUF001
1902
+
1903
+    def __enter__(self) -> Self:
1904
+        """Return self."""
1905
+        return self
1906
+
1907
+    def __exit__(self, *args: object) -> None:
1908
+        """Mark the agent's socket as closed."""
1909
+        self.closed = True
1910
+
1911
+    def sendall(self, data: Buffer, flags: int = 0, /) -> None:
1912
+        """Send data to the SSH agent.
1913
+
1914
+        The signature, and behavior, is identical to
1915
+        [`socket.socket.sendall`][].  Upon successful sending, this
1916
+        agent will parse the request, call the appropriate handler, and
1917
+        buffer the result such that it can be read via [`recv`][], in
1918
+        accordance with the SSH agent protocol.
1919
+
1920
+        Args:
1921
+            data: Binary data to send to the agent.
1922
+            flags: Reserved.  Must be 0.
1923
+
1924
+        Returns:
1925
+            Nothing.  The result should be requested via [`recv`][], and
1926
+            interpreted in accordance with the SSH agent protocol.
1927
+
1928
+        Raises:
1929
+            AssertionError:
1930
+                The flags argument, if specified, must be 0.
1931
+            ValueError:
1932
+                The agent's socket is already closed.  No further
1933
+                requests can be sent.
1934
+
1935
+        """
1936
+        assert not flags, self._NO_FLAG_SUPPORT
1937
+        if self.closed:
1938
+            raise ValueError(self._SOCKET_IS_CLOSED)
1939
+        self.receive_from_client.extend(memoryview(data))
1940
+        try:
1941
+            self.parse_client_request_and_dispatch()
1942
+        except ValueError:
1943
+            payload = int.to_bytes(_types.SSH_AGENT.FAILURE.value, 1, 'big')
1944
+            self.send_to_client.extend(int.to_bytes(len(payload), 4, 'big'))
1945
+            self.send_to_client.extend(payload)
1946
+        finally:
1947
+            self.receive_from_client.clear()
1948
+
1949
+    def recv(self, count: int, flags: int = 0, /) -> bytes:
1950
+        """Read data from the SSH agent.
1951
+
1952
+        As per the SSH agent protocol, data is only available to be read
1953
+        immediately after a request via [`sendall`][].  Calls to
1954
+        [`recv`][] at other points in time that attempt to read data
1955
+        violate the protocol, and will fail.  Notwithstanding the last
1956
+        sentence, at any point in time, though pointless, it is
1957
+        additionally permissible to read 0 bytes from the agent, or any
1958
+        number of bytes from a closed socket.
1959
+
1960
+        Args:
1961
+            count:
1962
+                Number of bytes to read from the agent.
1963
+            flags:
1964
+                Reserved.  Must be 0.
1965
+
1966
+        Returns:
1967
+            (A chunk of) the SSH agent's response to the most recent
1968
+            request.  If reading 0 bytes, or if reading from a closed
1969
+            socket, the returned chunk is always an empty byte string.
1970
+
1971
+        Raises:
1972
+            AssertionError:
1973
+                The flags argument, if specified, must be 0.
1974
+
1975
+                Alternatively, `recv` was called when there was no
1976
+                response to be obtained, in violation of the SSH agent
1977
+                protocol.
1978
+
1979
+        """
1980
+        assert not flags, self._NO_FLAG_SUPPORT
1981
+        assert not count or self.closed or self.send_to_client, (
1982
+            self._PROTOCOL_VIOLATION
1983
+        )
1984
+        ret = bytes(self.send_to_client[:count])
1985
+        del self.send_to_client[:count]
1986
+        return ret
1987
+
1988
+    def parse_client_request_and_dispatch(self) -> None:
1989
+        """Parse the client request and call the matching handler.
1990
+
1991
+        This agent supports the
1992
+        [`SSH_AGENTC_REQUEST_IDENTITIES`][_types.SSH_AGENTC.REQUEST_IDENTITIES],
1993
+        [`SSH_AGENTC_SIGN_REQUEST`][_types.SSH_AGENTC.SIGN_REQUEST] and
1994
+        the [`SSH_AGENTC_EXTENSION`][_types.SSH_AGENTC.EXTENSION]
1995
+        request types.
1996
+
1997
+        """
1998
+
1999
+        if len(self.receive_from_client) < self.HEADER_SIZE + self.CODE_SIZE:
2000
+            raise ValueError(self._INVALID_REQUEST)
2001
+        target_header = ssh_agent.SSHAgentClient.uint32(
2002
+            len(self.receive_from_client) - self.HEADER_SIZE
2003
+        )
2004
+        if target_header != self.receive_from_client[: self.HEADER_SIZE]:
2005
+            raise ValueError(self._INVALID_REQUEST)
2006
+        code = _types.SSH_AGENTC(
2007
+            int.from_bytes(
2008
+                self.receive_from_client[
2009
+                    self.HEADER_SIZE : self.HEADER_SIZE + self.CODE_SIZE
2010
+                ],
2011
+                'big',
2012
+            )
2013
+        )
2014
+
2015
+        def is_enabled_extension(extension: str) -> bool:
2016
+            if (
2017
+                extension not in self.enabled_extensions
2018
+                or code != _types.SSH_AGENTC.EXTENSION
2019
+            ):
2020
+                return False
2021
+            string = ssh_agent.SSHAgentClient.string
2022
+            extension_marker = b'\x1b' + string(extension.encode('ascii'))
2023
+            return self.receive_from_client.startswith(extension_marker, 4)
2024
+
2025
+        result: Buffer | Iterator[int]
2026
+        if code == _types.SSH_AGENTC.REQUEST_IDENTITIES:
2027
+            result = self.request_identities(list_extended=False)
2028
+        elif code == _types.SSH_AGENTC.SIGN_REQUEST:
2029
+            result = self.sign()
2030
+        elif is_enabled_extension('query'):
2031
+            result = self.query_extensions()
2032
+        elif is_enabled_extension('list-extended@putty.projects.tartarus.org'):
2033
+            result = self.request_identities(list_extended=True)
2034
+        else:
2035
+            raise ValueError(self._UNSUPPORTED_REQUEST)
2036
+        self.send_to_client.extend(
2037
+            ssh_agent.SSHAgentClient.string(bytes(result))
2038
+        )
2039
+
2040
+    def query_extensions(self) -> Iterator[int]:
2041
+        """Answer an `SSH_AGENTC_EXTENSION` request.
2042
+
2043
+        Yields:
2044
+            The bytes payload of the response, without the protocol
2045
+            framing.  The payload is yielded byte by byte, as an
2046
+            iterable of 8-bit integers.
2047
+
2048
+        """
2049
+        yield _types.SSH_AGENT.EXTENSION_RESPONSE.value
2050
+        yield from ssh_agent.SSHAgentClient.string(b'query')
2051
+        extension_answers = [
2052
+            b'query',
2053
+            b'list-extended@putty.projects.tartarus.org',
2054
+        ]
2055
+        for a in extension_answers:
2056
+            yield from ssh_agent.SSHAgentClient.string(a)
2057
+
2058
+    def request_identities(
2059
+        self, *, list_extended: bool = False
2060
+    ) -> Iterator[int]:
2061
+        """Answer an `SSH_AGENTC_REQUEST_IDENTITIES` request.
2062
+
2063
+        Args:
2064
+            list_extended:
2065
+                If true, answer an `SSH_AGENTC_EXTENSION` request for
2066
+                the `list-extended@putty.projects.tartarus.org`
2067
+                extension. Otherwise, answer an
2068
+                `SSH_AGENTC_REQUEST_IDENTITIES` request.
2069
+
2070
+        Yields:
2071
+            The bytes payload of the response, without the protocol
2072
+            framing.  The payload is yielded byte by byte, as an
2073
+            iterable of 8-bit integers.
2074
+
2075
+        """
2076
+        if list_extended:
2077
+            yield _types.SSH_AGENT.SUCCESS.value
2078
+        else:
2079
+            yield _types.SSH_AGENT.IDENTITIES_ANSWER.value
2080
+        signature_classes = [
2081
+            SSHTestKeyDeterministicSignatureClass.SPEC,
2082
+        ]
2083
+        if (
2084
+            'list-extended@putty.projects.tartarus.org'
2085
+            in self.enabled_extensions
2086
+        ):
2087
+            signature_classes.append(
2088
+                SSHTestKeyDeterministicSignatureClass.RFC_6979
2089
+            )
2090
+        keys = [
2091
+            v
2092
+            for v in ALL_KEYS.values()
2093
+            if any(cls in v.expected_signatures for cls in signature_classes)
2094
+        ]
2095
+        yield from ssh_agent.SSHAgentClient.uint32(len(keys))
2096
+        for key in keys:
2097
+            yield from ssh_agent.SSHAgentClient.string(key.public_key_data)
2098
+            yield from ssh_agent.SSHAgentClient.string(
2099
+                b'test key without passphrase'
2100
+            )
2101
+            if list_extended:
2102
+                yield from ssh_agent.SSHAgentClient.string(
2103
+                    ssh_agent.SSHAgentClient.uint32(0)
2104
+                )
2105
+
2106
+    def sign(self) -> bytes:
2107
+        """Answer an `SSH_AGENTC_SIGN_REQUEST` request.
2108
+
2109
+        Returns:
2110
+            The bytes payload of the response, without the protocol
2111
+            framing.
2112
+
2113
+        """
2114
+        try_rfc6979 = (
2115
+            'list-extended@putty.projects.tartarus.org'
2116
+            in self.enabled_extensions
2117
+        )
2118
+        spec = SSHTestKeyDeterministicSignatureClass.SPEC
2119
+        rfc6979 = SSHTestKeyDeterministicSignatureClass.RFC_6979
2120
+        key_blob, rest = ssh_agent.SSHAgentClient.unstring_prefix(
2121
+            self.receive_from_client[self.HEADER_SIZE + self.CODE_SIZE :]
2122
+        )
2123
+        sign_data, rest = ssh_agent.SSHAgentClient.unstring_prefix(rest)
2124
+        if len(rest) != 4:
2125
+            raise ValueError(self._INVALID_REQUEST)
2126
+        flags = int.from_bytes(rest, 'big')
2127
+        if flags:
2128
+            raise ValueError(self._UNSUPPORTED_REQUEST)
2129
+        if sign_data != vault.Vault.UUID:
2130
+            raise ValueError(self._UNSUPPORTED_REQUEST)
2131
+        for key in ALL_KEYS.values():
2132
+            if key.public_key_data == key_blob:
2133
+                if spec in key.expected_signatures:
2134
+                    return int.to_bytes(
2135
+                        _types.SSH_AGENT.SIGN_RESPONSE.value, 1, 'big'
2136
+                    ) + ssh_agent.SSHAgentClient.string(
2137
+                        key.expected_signatures[spec].signature
2138
+                    )
2139
+                if try_rfc6979 and rfc6979 in key.expected_signatures:
2140
+                    return int.to_bytes(
2141
+                        _types.SSH_AGENT.SIGN_RESPONSE.value, 1, 'big'
2142
+                    ) + ssh_agent.SSHAgentClient.string(
2143
+                        key.expected_signatures[rfc6979].signature
2144
+                    )
2145
+                raise ValueError(self._UNSUPPORTED_REQUEST)
2146
+        raise ValueError(self._UNSUPPORTED_REQUEST)
2147
+
2148
+
2149
+@socketprovider.SocketProvider.register('stub_with_address')
2150
+class StubbedSSHAgentSocketWithAddress(StubbedSSHAgentSocket):
2151
+    """A [`StubbedSSHAgentSocket`][] requiring a specific address."""
2152
+
2153
+    ADDRESS = 'stub-ssh-agent:'
2154
+    """The correct address for connecting to this stubbed agent."""
2155
+
2156
+    def __init__(self, *extensions: str) -> None:
2157
+        """Initialize the agent, based on `SSH_AUTH_SOCK`.
2158
+
2159
+        Socket addresses of the form `stub-ssh-agent:<errno_value>` will
2160
+        raise an [`OSError`][] (or the respective subclass) with the
2161
+        specified [`errno`][] value.  For example,
2162
+        `stub-ssh-agent:EPERM` will raise a [`PermissionError`][].
2163
+
2164
+        Raises:
2165
+            KeyError:
2166
+                The `SSH_AUTH_SOCK` environment variable is not set.
2167
+            OSError:
2168
+                The address in `SSH_AUTH_SOCK` is unsuited.
2169
+
2170
+        """
2171
+        super().__init__(*extensions)
2172
+        try:
2173
+            orig_address = os.environ['SSH_AUTH_SOCK']
2174
+        except KeyError as exc:
2175
+            msg = 'SSH_AUTH_SOCK environment variable'
2176
+            raise KeyError(msg) from exc
2177
+        address = orig_address
2178
+        if not address.startswith(self.ADDRESS):
2179
+            address = self.ADDRESS + 'ENOENT'
2180
+        errcode = address.removeprefix(self.ADDRESS)
2181
+        if errcode and not (
2182
+            errcode.startswith('E') and hasattr(errno, errcode)
2183
+        ):
2184
+            errcode = 'EINVAL'
2185
+        if errcode:
2186
+            errno_val = getattr(errno, errcode)
2187
+            raise OSError(errno_val, os.strerror(errno_val), orig_address)
2188
+
2189
+
2190
+@socketprovider.SocketProvider.register(
2191
+    'stub_with_address_and_deterministic_dsa'
2192
+)
2193
+class StubbedSSHAgentSocketWithAddressAndDeterministicDSA(
2194
+    StubbedSSHAgentSocketWithAddress
2195
+):
2196
+    """A [`StubbedSSHAgentSocketWithAddress`][] supporting deterministic DSA."""
2197
+
2198
+    def __init__(self) -> None:
2199
+        """Initialize the agent.
2200
+
2201
+        Set the supported extensions, and try issuing RFC 6979 and
2202
+        Pageant 0.68–0.80 DSA/ECDSA signatures, if possible.  See the
2203
+        [superclass constructor][StubbedSSHAgentSocketWithAddress] for
2204
+        other details.
2205
+
2206
+        Raises:
2207
+            KeyError: See superclass.
2208
+            OSError: See superclass.
2209
+
2210
+        """  # noqa: RUF002
2211
+        super().__init__('query', 'list-extended@putty.projects.tartarus.org')
2212
+        self.try_rfc6979 = True
2213
+        self.try_pageant_068_080 = True
2214
+
2215
+
1854 2216
 def list_keys(self: Any = None) -> list[_types.SSHKeyCommentPair]:
1855 2217
     """Return a list of all SSH test keys, as key/comment pairs.
1856 2218
 
... ...
@@ -8,7 +8,10 @@ from __future__ import annotations
8 8
 
9 9
 import base64
10 10
 import contextlib
11
+import errno
11 12
 import io
13
+import os
14
+import pathlib
12 15
 import re
13 16
 import socket
14 17
 import types
... ...
@@ -23,6 +26,7 @@ from hypothesis import strategies
23 26
 import tests
24 27
 from derivepassphrase import _types, ssh_agent, vault
25 28
 from derivepassphrase._internals import cli_helpers
29
+from derivepassphrase.ssh_agent import socketprovider
26 30
 
27 31
 if TYPE_CHECKING:
28 32
     from collections.abc import Iterable
... ...
@@ -31,6 +35,31 @@ if TYPE_CHECKING:
31 35
 
32 36
 
33 37
 class Parametrize(types.SimpleNamespace):
38
+    STUBBED_AGENT_ADDRESSES = pytest.mark.parametrize(
39
+        ['address', 'exception', 'match'],
40
+        [
41
+            pytest.param(None, KeyError, 'SSH_AUTH_SOCK', id='unset'),
42
+            pytest.param('stub-ssh-agent:', None, '', id='standard'),
43
+            pytest.param(
44
+                str(pathlib.Path('~').expanduser()),
45
+                FileNotFoundError,
46
+                os.strerror(errno.ENOENT),
47
+                id='invalid-url',
48
+            ),
49
+            pytest.param(
50
+                'stub-ssh-agent:EPROTONOSUPPORT',
51
+                OSError,
52
+                os.strerror(errno.EPROTONOSUPPORT),
53
+                id='protocol-not-supported',
54
+            ),
55
+            pytest.param(
56
+                'stub-ssh-agent:ABCDEFGHIJKLMNOPQRSTUVWXYZ',
57
+                OSError,
58
+                os.strerror(errno.EINVAL),
59
+                id='invalid-error-code',
60
+            ),
61
+        ],
62
+    )
34 63
     SSH_STRING_EXCEPTIONS = pytest.mark.parametrize(
35 64
         ['input', 'exc_type', 'exc_pattern'],
36 65
         [
... ...
@@ -236,6 +265,76 @@ class Parametrize(types.SimpleNamespace):
236 265
             ),
237 266
         ],
238 267
     )
268
+    INVALID_SSH_AGENT_MESSAGES = pytest.mark.parametrize(
269
+        'message',
270
+        [
271
+            pytest.param(b'\x00\x00\x00\x00', id='empty-message'),
272
+            pytest.param(b'\x00\x00\x00\x0f\x0d', id='truncated-message'),
273
+            pytest.param(
274
+                b'\x00\x00\x00\x06\x1b\x00\x00\x00\x01\xff',
275
+                id='invalid-extension-name',
276
+            ),
277
+            pytest.param(
278
+                b'\x00\x00\x00\x11\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
279
+                id='sign-with-trailing-data',
280
+            ),
281
+        ],
282
+    )
283
+    UNSUPPORTED_SSH_AGENT_MESSAGES = pytest.mark.parametrize(
284
+        'message',
285
+        [
286
+            pytest.param(
287
+                ssh_agent.SSHAgentClient.string(
288
+                    b''.join([
289
+                        b'\x0d',
290
+                        ssh_agent.SSHAgentClient.string(
291
+                            tests.ALL_KEYS['rsa'].public_key_data
292
+                        ),
293
+                        ssh_agent.SSHAgentClient.string(vault.Vault.UUID),
294
+                        b'\x00\x00\x00\x02',
295
+                    ])
296
+                ),
297
+                id='sign-with-flags',
298
+            ),
299
+            pytest.param(
300
+                ssh_agent.SSHAgentClient.string(
301
+                    b''.join([
302
+                        b'\x0d',
303
+                        ssh_agent.SSHAgentClient.string(
304
+                            tests.ALL_KEYS['ed25519'].public_key_data
305
+                        ),
306
+                        b'\x00\x00\x00\x08\x00\x01\x02\x03\x04\x05\x06\x07',
307
+                        b'\x00\x00\x00\x00',
308
+                    ])
309
+                ),
310
+                id='sign-with-nonstandard-passphrase',
311
+            ),
312
+            pytest.param(
313
+                ssh_agent.SSHAgentClient.string(
314
+                    b''.join([
315
+                        b'\x0d',
316
+                        ssh_agent.SSHAgentClient.string(
317
+                            tests.ALL_KEYS['dsa1024'].public_key_data
318
+                        ),
319
+                        ssh_agent.SSHAgentClient.string(vault.Vault.UUID),
320
+                        b'\x00\x00\x00\x00',
321
+                    ])
322
+                ),
323
+                id='sign-key-no-expected-signature',
324
+            ),
325
+            pytest.param(
326
+                ssh_agent.SSHAgentClient.string(
327
+                    b''.join([
328
+                        b'\x0d',
329
+                        b'\x00\x00\x00\x00',
330
+                        ssh_agent.SSHAgentClient.string(vault.Vault.UUID),
331
+                        b'\x00\x00\x00\x00',
332
+                    ])
333
+                ),
334
+                id='sign-key-unregistered-test-key',
335
+            ),
336
+        ],
337
+    )
239 338
     PUBLIC_KEY_DATA = pytest.mark.parametrize(
240 339
         'public_key_struct',
241 340
         list(tests.SUPPORTED_KEYS.values()),
... ...
@@ -317,6 +416,225 @@ class Parametrize(types.SimpleNamespace):
317 416
     )
318 417
 
319 418
 
419
+class TestTestingMachineryStubbedSSHAgentSocket:
420
+    """Test the stubbed SSH agent socket for the `ssh_agent` module tests."""
421
+
422
+    def test_100a_query_extensions_base(self) -> None:
423
+        """The base agent implements no extensions."""
424
+        with contextlib.ExitStack() as stack:
425
+            monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
426
+            monkeypatch.setenv(
427
+                'SSH_AUTH_SOCK', tests.StubbedSSHAgentSocketWithAddress.ADDRESS
428
+            )
429
+            agent = stack.enter_context(
430
+                tests.StubbedSSHAgentSocketWithAddress()
431
+            )
432
+            assert 'query' not in agent.enabled_extensions
433
+            query_request = (
434
+                # SSH string header
435
+                b'\x00\x00\x00\x0a'
436
+                # request code: SSH_AGENTC_EXTENSION
437
+                b'\x1b'
438
+                # payload: SSH string "query"
439
+                b'\x00\x00\x00\x05query'
440
+            )
441
+            query_response = (
442
+                # SSH string header
443
+                b'\x00\x00\x00\x01'
444
+                # response code: SSH_AGENT_FAILURE
445
+                b'\x05'
446
+            )
447
+            agent.sendall(query_request)
448
+            assert agent.recv(1000) == query_response
449
+
450
+    def test_100b_query_extensions_extended(self) -> None:
451
+        """The extended agent implements a known list of extensions."""
452
+        with contextlib.ExitStack() as stack:
453
+            monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
454
+            monkeypatch.setenv(
455
+                'SSH_AUTH_SOCK', tests.StubbedSSHAgentSocketWithAddress.ADDRESS
456
+            )
457
+            agent = stack.enter_context(
458
+                tests.StubbedSSHAgentSocketWithAddressAndDeterministicDSA()
459
+            )
460
+            assert 'query' in agent.enabled_extensions
461
+            query_request = (
462
+                # SSH string header
463
+                b'\x00\x00\x00\x0a'
464
+                # request code: SSH_AGENTC_EXTENSION
465
+                b'\x1b'
466
+                # payload: SSH string "query"
467
+                b'\x00\x00\x00\x05query'
468
+            )
469
+            query_response = (
470
+                # SSH string header
471
+                b'\x00\x00\x00\x40'
472
+                # response code: SSH_AGENT_EXTENSION_RESPONSE
473
+                b'\x1d'
474
+                # extension response: extension type ("query")
475
+                b'\x00\x00\x00\x05query'
476
+                # supported extension #1: query
477
+                b'\x00\x00\x00\x05query'
478
+                # supported extension #2:
479
+                # list-extended@putty.projects.tartarus.org
480
+                b'\x00\x00\x00\x29list-extended@putty.projects.tartarus.org'
481
+            )
482
+            agent.sendall(query_request)
483
+            assert agent.recv(1000) == query_response
484
+
485
+    def test_101_request_identities(self) -> None:
486
+        """The agent implements a known list of identities."""
487
+        unstring_prefix = ssh_agent.SSHAgentClient.unstring_prefix
488
+        with tests.StubbedSSHAgentSocket() as agent:
489
+            query_request = (
490
+                # SSH string header
491
+                b'\x00\x00\x00\x01'
492
+                # request code: SSH_AGENTC_REQUEST_IDENTITIES
493
+                b'\x0b'
494
+            )
495
+            agent.sendall(query_request)
496
+            message_length = int.from_bytes(agent.recv(4), 'big')
497
+            orig_message: bytes | bytearray = bytearray(
498
+                agent.recv(message_length)
499
+            )
500
+            assert (
501
+                _types.SSH_AGENT(orig_message[0])
502
+                == _types.SSH_AGENT.IDENTITIES_ANSWER
503
+            )
504
+            identity_count = int.from_bytes(orig_message[1:5], 'big')
505
+            message = bytes(orig_message[5:])
506
+            for _ in range(identity_count):
507
+                key, message = unstring_prefix(message)
508
+                _comment, message = unstring_prefix(message)
509
+                assert key
510
+                assert key in {
511
+                    k.public_key_data for k in tests.ALL_KEYS.values()
512
+                }
513
+            assert not message
514
+
515
+    @Parametrize.SUPPORTED_SSH_TEST_KEYS
516
+    def test_102_sign(
517
+        self,
518
+        ssh_test_key_type: str,
519
+        ssh_test_key: tests.SSHTestKey,
520
+    ) -> None:
521
+        """The agent signs known key/message pairs."""
522
+        del ssh_test_key_type
523
+        spec = tests.SSHTestKeyDeterministicSignatureClass.SPEC
524
+        assert ssh_test_key.expected_signatures[spec].signature is not None
525
+        string = ssh_agent.SSHAgentClient.string
526
+        query_request = string(
527
+            # request code: SSH_AGENTC_SIGN_REQUEST
528
+            b'\x0d'
529
+            # key: SSH string of the public key
530
+            + string(ssh_test_key.public_key_data)
531
+            # payload: SSH string of the vault UUID
532
+            + string(vault.Vault.UUID)
533
+            # signing flags (uint32, empty)
534
+            + b'\x00\x00\x00\x00'
535
+        )
536
+        query_response = string(
537
+            # response code: SSH_AGENT_SIGN_RESPONSE
538
+            b'\x0e'
539
+            # expected payload: the binary signature as recorded in the test key data structure
540
+            + string(ssh_test_key.expected_signatures[spec].signature)
541
+        )
542
+        with tests.StubbedSSHAgentSocket() as agent:
543
+            agent.sendall(query_request)
544
+            assert agent.recv(1000) == query_response
545
+
546
+    def test_120_close_multiple(self) -> None:
547
+        """The agent can be closed repeatedly."""
548
+        with tests.StubbedSSHAgentSocket() as agent:
549
+            pass
550
+        with tests.StubbedSSHAgentSocket() as agent:
551
+            pass
552
+        del agent
553
+
554
+    def test_121_closed_agents_cannot_be_interacted_with(self) -> None:
555
+        """The agent can be closed repeatedly."""
556
+        with tests.StubbedSSHAgentSocket() as agent:
557
+            pass
558
+        query_request = (
559
+            # SSH string header
560
+            b'\x00\x00\x00\x0a'
561
+            # request code: SSH_AGENTC_EXTENSION
562
+            b'\x1b'
563
+            # payload: SSH string "query"
564
+            b'\x00\x00\x00\x05query'
565
+        )
566
+        query_response = b''
567
+        with pytest.raises(
568
+            ValueError,
569
+            match=re.escape(tests.StubbedSSHAgentSocket._SOCKET_IS_CLOSED),
570
+        ):
571
+            agent.sendall(query_request)
572
+        assert agent.recv(100) == query_response
573
+
574
+    def test_122_no_recv_without_sendall(self) -> None:
575
+        """The agent requires a message before sending a response."""
576
+        with tests.StubbedSSHAgentSocket() as agent:  # noqa: SIM117
577
+            with pytest.raises(
578
+                AssertionError,
579
+                match=re.escape(
580
+                    tests.StubbedSSHAgentSocket._PROTOCOL_VIOLATION
581
+                ),
582
+            ):
583
+                agent.recv(100)
584
+
585
+    @Parametrize.INVALID_SSH_AGENT_MESSAGES
586
+    def test_123_invalid_ssh_agent_messages(
587
+        self,
588
+        message: Buffer,
589
+    ) -> None:
590
+        """The agent responds with errors on invalid messages."""
591
+        query_response = (
592
+            # SSH string header
593
+            b'\x00\x00\x00\x01'
594
+            # response code: SSH_AGENT_FAILURE
595
+            b'\x05'
596
+        )
597
+        with tests.StubbedSSHAgentSocket() as agent:
598
+            agent.sendall(message)
599
+            assert agent.recv(100) == query_response
600
+
601
+    @Parametrize.UNSUPPORTED_SSH_AGENT_MESSAGES
602
+    def test_124_unsupported_ssh_agent_messages(
603
+        self,
604
+        message: Buffer,
605
+    ) -> None:
606
+        """The agent responds with errors on unsupported messages."""
607
+        query_response = (
608
+            # SSH string header
609
+            b'\x00\x00\x00\x01'
610
+            # response code: SSH_AGENT_FAILURE
611
+            b'\x05'
612
+        )
613
+        with tests.StubbedSSHAgentSocket() as agent:
614
+            agent.sendall(message)
615
+            assert agent.recv(100) == query_response
616
+
617
+    @Parametrize.STUBBED_AGENT_ADDRESSES
618
+    def test_125_addresses(
619
+        self,
620
+        address: str | None,
621
+        exception: type[Exception] | None,
622
+        match: str,
623
+    ) -> None:
624
+        """The agent accepts addresses."""
625
+        with contextlib.ExitStack() as stack:
626
+            monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
627
+            if address:
628
+                monkeypatch.setenv('SSH_AUTH_SOCK', address)
629
+            else:
630
+                monkeypatch.delenv('SSH_AUTH_SOCK', raising=False)
631
+            if exception:
632
+                stack.enter_context(
633
+                    pytest.raises(exception, match=re.escape(match))
634
+                )
635
+            tests.StubbedSSHAgentSocketWithAddress()
636
+
637
+
320 638
 class TestStaticFunctionality:
321 639
     """Test the static functionality of the `ssh_agent` module."""
322 640
 
323 641