Refactor the SSH agent tests
Marco Ricci

Marco Ricci commited on 2025-08-17 18:00:57
Zeige 3 geänderte Dateien mit 263 Einfügungen und 226 Löschungen.


(This is part 5 of a series of refactorings for the test suite.)

In the basic tests, for the stubbed SSH agent socket and the agent error
response tests, factor out the common environment setup for the
respective group (in particular, the multiple copies of the fake
`request` function from the SSH agent client).  Combine the SSH agent
data signing tests for suitable and unsuitable keys into a single, more
heavily parametrized test.  (The combined test is branchier, but more
straightforward for discerning the differences between the two types of
keys.)  Also fix some miscellaneous errors: the IDs in the
`RESOLVE_CHAINS` parametrization, and the order of contexts in the
`test_request_error_responses` method.

Further concerning the basic tests, add an explicit test for the
`list-extended@putty.projects.tartarus.org` extension (which the
extended stubbed SSH agent implements to signal support for
deterministic DSA signatures, in a lackluster attempt to masquerade as
Pageant).  Also address a TODO in the test for SSH key selection
concerning degenerate key lists.

In the test data, add an extra comment on how suitable and unsuitable
SSH test keys are determined, and add a function signature for the
SSH agent client `request` function.

In the heavy-duty tests, merely add missing `draw` labels.
... ...
@@ -21,7 +21,7 @@ from __future__ import annotations
21 21
 
22 22
 import base64
23 23
 import enum
24
-from typing import TYPE_CHECKING
24
+from typing import TYPE_CHECKING, Protocol
25 25
 
26 26
 from typing_extensions import NamedTuple
27 27
 
... ...
@@ -31,7 +31,7 @@ from derivepassphrase.ssh_agent import socketprovider
31 31
 __all__ = ()
32 32
 
33 33
 if TYPE_CHECKING:
34
-    from collections.abc import Mapping
34
+    from collections.abc import Iterable, Mapping
35 35
 
36 36
     from typing_extensions import Any
37 37
 
... ...
@@ -244,6 +244,26 @@ class RunningSSHAgentInfo(NamedTuple):
244 244
         return self.socket
245 245
 
246 246
 
247
+# `derivepassphrase` internal functions
248
+# -------------------------------------
249
+
250
+
251
+class RequestFunc(Protocol):
252
+    """The call signature of [`ssh_agent.SSHAgentClient.request`][]."""
253
+
254
+    def __call__(
255
+        self,
256
+        request_code: int | _types.SSH_AGENTC,
257
+        payload: bytes | bytearray,
258
+        /,
259
+        *,
260
+        response_code: Iterable[int | _types.SSH_AGENT]
261
+        | int
262
+        | _types.SSH_AGENT
263
+        | None = None,
264
+    ) -> tuple[int, bytes | bytearray] | bytes | bytearray: ...
265
+
266
+
247 267
 # Vault configurations
248 268
 # ====================
249 269
 
... ...
@@ -1039,11 +1059,23 @@ Rlc3Qga2V5IHdpdGhvdXQgcGFzc3BocmFzZQ==
1039 1059
 SUPPORTED_KEYS: Mapping[str, SSHTestKey] = {
1040 1060
     k: v for k, v in ALL_KEYS.items() if v.is_suitable()
1041 1061
 }
1042
-"""The subset of SSH test keys suitable for use with vault."""
1062
+"""The subset of SSH test keys suitable for use with vault.
1063
+
1064
+Suitability is tested -- via [SSHTestKey.is_suitable][], then
1065
+[vault.Vault.is_suitable_ssh_key][] -- via an internal whitelist, not
1066
+via the presence or absence of a specific expected signature class.
1067
+
1068
+"""
1043 1069
 UNSUITABLE_KEYS: Mapping[str, SSHTestKey] = {
1044 1070
     k: v for k, v in ALL_KEYS.items() if not v.is_suitable()
1045 1071
 }
1046
-"""The subset of SSH test keys not suitable for use with vault."""
1072
+"""The subset of SSH test keys not suitable for use with vault.
1073
+
1074
+Suitability is tested -- via [SSHTestKey.is_suitable][], then
1075
+[vault.Vault.is_suitable_ssh_key][] -- via an internal whitelist, not
1076
+via the presence or absence of a specific expected signature class.
1077
+
1078
+"""
1047 1079
 
1048 1080
 
1049 1081
 # Vault test configurations
... ...
@@ -17,13 +17,14 @@ import re
17 17
 import socket
18 18
 import sys
19 19
 import types
20
-from typing import TYPE_CHECKING
20
+from typing import TYPE_CHECKING, NamedTuple
21 21
 
22 22
 import click
23 23
 import click.testing
24 24
 import hypothesis
25 25
 import pytest
26 26
 from hypothesis import strategies
27
+from typing_extensions import TypeAlias
27 28
 
28 29
 from derivepassphrase import _types, ssh_agent, vault
29 30
 from derivepassphrase._internals import cli_helpers
... ...
@@ -33,7 +34,7 @@ from tests.data import callables
33 34
 from tests.machinery import pytest as pytest_machinery
34 35
 
35 36
 if TYPE_CHECKING:
36
-    from collections.abc import Iterable
37
+    from collections.abc import Iterable, Iterator, Mapping
37 38
 
38 39
     from typing_extensions import Any, Buffer, Literal
39 40
 
... ...
@@ -489,17 +490,17 @@ class Parametrize(types.SimpleNamespace):
489 490
         list(data.SUPPORTED_KEYS.items()),
490 491
         ids=data.SUPPORTED_KEYS.keys(),
491 492
     )
492
-    UNSUITABLE_SSH_TEST_KEYS = pytest.mark.parametrize(
493
+    ALL_SSH_TEST_KEYS = pytest.mark.parametrize(
493 494
         ["ssh_test_key_type", "ssh_test_key"],
494
-        list(data.UNSUITABLE_KEYS.items()),
495
-        ids=data.UNSUITABLE_KEYS.keys(),
495
+        list(data.ALL_KEYS.items()),
496
+        ids=data.ALL_KEYS.keys(),
496 497
     )
497 498
     RESOLVE_CHAINS = pytest.mark.parametrize(
498 499
         ["terminal", "chain"],
499 500
         [
500 501
             pytest.param("callable", ["a"], id="callable-1"),
501 502
             pytest.param("callable", ["a", "b", "c", "d"], id="callable-4"),
502
-            pytest.param("alias", ["e"], id="alias-5"),
503
+            pytest.param("alias", ["e"], id="alias-1"),
503 504
             pytest.param("alias", ["e", "f", "g", "h", "i"], id="alias-5"),
504 505
             pytest.param("unimplemented", ["j"], id="unimplemented-1"),
505 506
             pytest.param("unimplemented", ["j", "k"], id="unimplemented-2"),
... ...
@@ -510,17 +511,24 @@ class Parametrize(types.SimpleNamespace):
510 511
 class TestStubbedSSHAgentSocketRequests:
511 512
     """Test the stubbed SSH agent socket: normal requests."""
512 513
 
513
-    def test_query_extensions_base(self) -> None:
514
-        """The base agent implements no extensions."""
514
+    @contextlib.contextmanager
515
+    def _get_addressed_agent(
516
+        self, *, extended_agent: bool = False
517
+    ) -> Iterator[machinery.StubbedSSHAgentSocketWithAddress]:
518
+        agent_class: type[machinery.StubbedSSHAgentSocketWithAddress] = (
519
+            machinery.StubbedSSHAgentSocketWithAddressAndDeterministicDSA
520
+            if extended_agent
521
+            else machinery.StubbedSSHAgentSocketWithAddress
522
+        )
515 523
         with contextlib.ExitStack() as stack:
516 524
             monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
517
-            monkeypatch.setenv(
518
-                "SSH_AUTH_SOCK",
519
-                machinery.StubbedSSHAgentSocketWithAddress.ADDRESS,
520
-            )
521
-            agent = stack.enter_context(
522
-                machinery.StubbedSSHAgentSocketWithAddress()
523
-            )
525
+            monkeypatch.setenv("SSH_AUTH_SOCK", agent_class.ADDRESS)
526
+            agent = stack.enter_context(agent_class())
527
+            yield agent
528
+
529
+    def test_query_extensions_base(self) -> None:
530
+        """The base agent implements no extensions."""
531
+        with self._get_addressed_agent(extended_agent=False) as agent:
524 532
             assert "query" not in agent.enabled_extensions
525 533
             query_request = (
526 534
                 # SSH string header
... ...
@@ -541,15 +549,7 @@ class TestStubbedSSHAgentSocketRequests:
541 549
 
542 550
     def test_query_extensions_extended(self) -> None:
543 551
         """The extended agent implements a known list of extensions."""
544
-        with contextlib.ExitStack() as stack:
545
-            monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
546
-            monkeypatch.setenv(
547
-                "SSH_AUTH_SOCK",
548
-                machinery.StubbedSSHAgentSocketWithAddress.ADDRESS,
549
-            )
550
-            agent = stack.enter_context(
551
-                machinery.StubbedSSHAgentSocketWithAddressAndDeterministicDSA()
552
-            )
552
+        with self._get_addressed_agent(extended_agent=True) as agent:
553 553
             assert "query" in agent.enabled_extensions
554 554
             query_request = (
555 555
                 # SSH string header
... ...
@@ -605,6 +605,41 @@ class TestStubbedSSHAgentSocketRequests:
605 605
                 }
606 606
             assert not message
607 607
 
608
+    def test_request_identities_extended(self) -> None:
609
+        """The extended agent implements PuTTY's `list-extended` extension."""
610
+        unstring_prefix = ssh_agent.SSHAgentClient.unstring_prefix
611
+        with self._get_addressed_agent(extended_agent=True) as agent:
612
+            extension_request = (
613
+                # SSH string header
614
+                b"\x00\x00\x00\x2e"
615
+                # request code: SSH_AGENTC_REQUEST_IDENTITIES
616
+                b"\x1b"
617
+                # extension type: list-extended@putty.projects.tartarus.org
618
+                b"\x00\x00\x00\x29list-extended@putty.projects.tartarus.org"
619
+                # (no payload)
620
+            )
621
+            agent.sendall(extension_request)
622
+            message_length = int.from_bytes(agent.recv(4), "big")
623
+            orig_message: bytes | bytearray = bytearray(
624
+                agent.recv(message_length)
625
+            )
626
+            assert (
627
+                _types.SSH_AGENT(orig_message[0])
628
+                == _types.SSH_AGENT.SUCCESS
629
+            )
630
+            identity_count = int.from_bytes(orig_message[1:5], "big")
631
+            message = bytes(orig_message[5:])
632
+            for _ in range(identity_count):
633
+                key, message = unstring_prefix(message)
634
+                _comment, message = unstring_prefix(message)
635
+                flags, message = unstring_prefix(message)
636
+                assert flags == b"\x00\x00\x00\x00"
637
+                assert key
638
+                assert key in {
639
+                    k.public_key_data for k in data.ALL_KEYS.values()
640
+                }
641
+            assert not message
642
+
608 643
     @Parametrize.SUPPORTED_SSH_TEST_KEYS
609 644
     def test_sign(
610 645
         self,
... ...
@@ -1242,14 +1277,14 @@ class TestSSHAgentSocketProviderRegistry:
1242 1277
 class TestAgentSigning:
1243 1278
     """Test actually talking to the SSH agent: signing data."""
1244 1279
 
1245
-    @Parametrize.SUPPORTED_SSH_TEST_KEYS
1280
+    @Parametrize.ALL_SSH_TEST_KEYS
1246 1281
     def test_sign_data_via_agent(
1247 1282
         self,
1248 1283
         ssh_agent_client_with_test_keys_loaded: ssh_agent.SSHAgentClient,
1249 1284
         ssh_test_key_type: str,
1250 1285
         ssh_test_key: data.SSHTestKey,
1251 1286
     ) -> None:
1252
-        """Signing data with specific SSH keys works.
1287
+        """Signing data with specific SSH keys works iff the key is suitable.
1253 1288
 
1254 1289
         Single tests may abort early (skip) if the indicated key is not
1255 1290
         loaded in the agent.  Presumably this means the key type is
... ...
@@ -1259,6 +1294,9 @@ class TestAgentSigning:
1259 1294
         client = ssh_agent_client_with_test_keys_loaded
1260 1295
         key_comment_pairs = {bytes(k): bytes(c) for k, c in client.list_keys()}
1261 1296
         public_key_data = ssh_test_key.public_key_data
1297
+        if public_key_data not in key_comment_pairs:  # pragma: no cover
1298
+            pytest.skip(f"prerequisite {ssh_test_key_type} SSH key not loaded")
1299
+        if ssh_test_key.is_suitable():
1262 1300
             assert (
1263 1301
                 data.SSHTestKeyDeterministicSignatureClass.SPEC
1264 1302
                 in ssh_test_key.expected_signatures
... ...
@@ -1268,8 +1306,6 @@ class TestAgentSigning:
1268 1306
             ]
1269 1307
             expected_signature = sig.signature
1270 1308
             derived_passphrase = sig.derived_passphrase
1271
-        if public_key_data not in key_comment_pairs:  # pragma: no cover
1272
-            pytest.skip(f"prerequisite {ssh_test_key_type} SSH key not loaded")
1273 1309
             signature = bytes(
1274 1310
                 client.sign(payload=vault.Vault.UUID, key=public_key_data)
1275 1311
             )
... ...
@@ -1286,35 +1322,28 @@ class TestAgentSigning:
1286 1322
                 vault.Vault.phrase_from_key(public_key_data, conn=client)
1287 1323
                 == derived_passphrase
1288 1324
             ), f"SSH signature mismatch ({ssh_test_key_type})"
1289
-
1290
-    @Parametrize.UNSUITABLE_SSH_TEST_KEYS
1291
-    def test_sign_data_via_agent_unsupported(
1292
-        self,
1293
-        ssh_agent_client_with_test_keys_loaded: ssh_agent.SSHAgentClient,
1294
-        ssh_test_key_type: str,
1295
-        ssh_test_key: data.SSHTestKey,
1296
-    ) -> None:
1297
-        """Using an unsuitable key with [`vault.Vault`][] fails.
1298
-
1299
-        Single tests may abort early (skip) if the indicated key is not
1300
-        loaded in the agent.  Presumably this means the key type is
1301
-        unsupported.  Single tests may also abort early if the agent
1302
-        ensures that the generally unsuitable key is actually suitable
1303
-        under this agent.
1304
-
1305
-        """
1306
-        client = ssh_agent_client_with_test_keys_loaded
1307
-        key_comment_pairs = {bytes(k): bytes(c) for k, c in client.list_keys()}
1308
-        public_key_data = ssh_test_key.public_key_data
1309
-        if public_key_data not in key_comment_pairs:  # pragma: no cover
1310
-            pytest.skip(f"prerequisite {ssh_test_key_type} SSH key not loaded")
1325
+        else:
1326
+            assert (
1327
+                data.SSHTestKeyDeterministicSignatureClass.SPEC
1328
+                not in ssh_test_key.expected_signatures
1329
+            )
1311 1330
             assert not vault.Vault.is_suitable_ssh_key(
1312 1331
                 public_key_data, client=None
1313 1332
             ), f"Expected {ssh_test_key_type} key to be unsuitable in general"
1314
-        if vault.Vault.is_suitable_ssh_key(public_key_data, client=client):
1315
-            pytest.skip(
1316
-                f"agent automatically ensures {ssh_test_key_type} key is suitable"
1333
+            if vault.Vault.is_suitable_ssh_key(
1334
+                public_key_data, client=client
1335
+            ):  # pragma: no cover [external]
1336
+                potential_signatures = {
1337
+                    sig_class: sig.signature
1338
+                    for sig_class, sig in ssh_test_key.expected_signatures.items()
1339
+                    if sig_class
1340
+                    != data.SSHTestKeyDeterministicSignatureClass.SPEC
1341
+                }
1342
+                signature = bytes(
1343
+                    client.sign(payload=vault.Vault.UUID, key=public_key_data)
1317 1344
                 )
1345
+                assert signature in potential_signatures.values()
1346
+            else:  # pragma: no cover [external]
1318 1347
                 with pytest.raises(ValueError, match="unsuitable SSH key"):
1319 1348
                     vault.Vault.phrase_from_key(public_key_data, conn=client)
1320 1349
 
... ...
@@ -1338,50 +1367,30 @@ class TestSuitableKeys:
1338 1367
         """
1339 1368
         client = ssh_agent_client_with_test_keys_loaded
1340 1369
 
1341
-        def key_is_suitable(key: bytes) -> bool:
1342
-            """Stub out [`vault.Vault.key_is_suitable`][]."""
1343
-            always = {v.public_key_data for v in data.SUPPORTED_KEYS.values()}
1344
-            dsa = {
1345
-                v.public_key_data
1346
-                for k, v in data.UNSUITABLE_KEYS.items()
1347
-                if k.startswith(("dsa", "ecdsa"))
1348
-            }
1349
-            return key in always or (
1350
-                client.has_deterministic_dsa_signatures() and key in dsa
1351
-            )
1352
-
1353
-        # TODO(the-13th-letter): Handle the unlikely(?) case that only
1354
-        # one test key is loaded, but `single` is False.  Rename the
1355
-        # `index` variable to `input`, store the `input` in there, and
1356
-        # make the definition of `text` in the else block dependent on
1357
-        # `n` being singular or non-singular.
1358
-        if single:
1359 1370
         monkeypatch.setattr(
1360 1371
             ssh_agent.SSHAgentClient,
1361 1372
             "list_keys",
1362
-                callables.list_keys_singleton,
1373
+            callables.list_keys_singleton if single else callables.list_keys,
1374
+        )
1375
+        raw_keys_list = (
1376
+            callables.list_keys_singleton()
1377
+            if single
1378
+            else callables.list_keys()
1363 1379
         )
1364 1380
         keys = [
1365 1381
             pair.key
1366
-                for pair in callables.list_keys_singleton()
1367
-                if key_is_suitable(pair.key)
1382
+            for pair in raw_keys_list
1383
+            if vault.Vault.is_suitable_ssh_key(pair.key, client=client)
1368 1384
         ]
1369
-            index = "1"
1385
+        n = len(keys)
1386
+        if single or n == 1:
1387
+            input_text = "yes\n"
1370 1388
             text = "Use this key? yes\n"
1371 1389
         else:
1372
-            monkeypatch.setattr(
1373
-                ssh_agent.SSHAgentClient,
1374
-                "list_keys",
1375
-                callables.list_keys,
1390
+            input_text = f"{1 + keys.index(key)}\n"
1391
+            text = (
1392
+                f"Your selection? (1-{n}, leave empty to abort): {input_text}"
1376 1393
             )
1377
-            keys = [
1378
-                pair.key
1379
-                for pair in callables.list_keys()
1380
-                if key_is_suitable(pair.key)
1381
-            ]
1382
-            index = str(1 + keys.index(key))
1383
-            n = len(keys)
1384
-            text = f"Your selection? (1-{n}, leave empty to abort): {index}\n"
1385 1394
         b64_key = base64.standard_b64encode(key).decode("ASCII")
1386 1395
 
1387 1396
         @click.command()
... ...
@@ -1390,14 +1399,9 @@ class TestSuitableKeys:
1390 1399
             key = cli_helpers.select_ssh_key(client)
1391 1400
             click.echo(base64.standard_b64encode(key).decode("ASCII"))
1392 1401
 
1393
-        # TODO(the-13th-letter): (Continued from above.)  Update input
1394
-        # data to use `index`/`input` directly and unconditionally.
1395 1402
         runner = machinery.CliRunner(mix_stderr=True)
1396 1403
         result = runner.invoke(
1397
-            driver,
1398
-            [],
1399
-            input=("yes\n" if single else f"{index}\n"),
1400
-            catch_exceptions=True,
1404
+            driver, [], input=input_text, catch_exceptions=True
1401 1405
         )
1402 1406
         for snippet in ("Suitable SSH keys:\n", text, f"\n{b64_key}\n"):
1403 1407
             assert result.clean_exit(output=snippet), "expected clean exit"
... ...
@@ -1465,6 +1469,104 @@ class TestOtherConstructorFeatures:
1465 1469
 class TestAgentErrorResponses:
1466 1470
     """Test actually talking to the SSH agent: errors from the SSH agent."""
1467 1471
 
1472
+    class Request(NamedTuple):
1473
+        """A request for the SSH agent protocol.
1474
+
1475
+        Attributes:
1476
+            code:
1477
+                The expected request code.
1478
+            payload:
1479
+                The expected payload.  Optional.
1480
+
1481
+        """
1482
+
1483
+        code: _types.SSH_AGENTC
1484
+        """"""
1485
+        payload: bytes | None
1486
+        """"""
1487
+
1488
+    class Response(NamedTuple):
1489
+        """A response for the SSH agent protocol.
1490
+
1491
+        Attributes:
1492
+            code:
1493
+                The desired response code.
1494
+            payload:
1495
+                The desired response payload.
1496
+
1497
+        """
1498
+
1499
+        code: _types.SSH_AGENT
1500
+        """"""
1501
+        payload: bytes
1502
+        """"""
1503
+
1504
+    def _make_request_stub(
1505
+        self,
1506
+        request_response_map: Mapping[Request, Response],
1507
+        /,
1508
+    ) -> data.RequestFunc:
1509
+        """Return a stubbed SSH agent client `request` function.
1510
+
1511
+        Args:
1512
+            request_response_map:
1513
+                A map of request/response pairs.
1514
+
1515
+        """
1516
+
1517
+        def request(
1518
+            request_code: int | _types.SSH_AGENTC,
1519
+            payload: bytes | bytearray,
1520
+            /,
1521
+            *,
1522
+            response_code: Iterable[int | _types.SSH_AGENT]
1523
+            | int
1524
+            | _types.SSH_AGENT
1525
+            | None = None,
1526
+        ) -> tuple[int, bytes | bytearray] | bytes | bytearray:
1527
+            request_code = _types.SSH_AGENTC(request_code)
1528
+            response_code = (
1529
+                frozenset({_types.SSH_AGENT.SUCCESS})
1530
+                if response_code is None
1531
+                else frozenset({_types.SSH_AGENT(response_code)})
1532
+                if isinstance(response_code, (int, _types.SSH_AGENT))
1533
+                else frozenset(map(_types.SSH_AGENT, response_code))
1534
+            )
1535
+            response_ = request_response_map.get(
1536
+                self.Request(request_code, bytes(payload))
1537
+            ) or request_response_map.get(self.Request(request_code, None))
1538
+            if response_ is None:  # pragma: no cover [failsafe]
1539
+                raise ssh_agent.SSHAgentFailedError(
1540
+                    _types.SSH_AGENT.FAILURE.value,
1541
+                    "No prepared response for the request "
1542
+                    f"({request_code!r}, {payload!r})".encode(),
1543
+                )
1544
+            code, response = response_
1545
+            if code not in response_code:
1546
+                raise ssh_agent.SSHAgentFailedError(code.value, response)
1547
+            return response
1548
+
1549
+        return request
1550
+
1551
+    @contextlib.contextmanager
1552
+    def _setup_agent_and_request_handler(
1553
+        self, request_response_map: Mapping[Request, Response], /
1554
+    ) -> Iterator[tuple[pytest.MonkeyPatch, ssh_agent.SSHAgentClient]]:
1555
+        # TODO(the-13th-letter): Rewrite using parenthesized
1556
+        # with-statements.
1557
+        # https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9
1558
+        with contextlib.ExitStack() as stack:
1559
+            monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
1560
+            client = stack.enter_context(
1561
+                ssh_agent.SSHAgentClient.ensure_agent_subcontext()
1562
+            )
1563
+            monkeypatch.setattr(
1564
+                client,
1565
+                "request",
1566
+                self._make_request_stub(request_response_map),
1567
+            )
1568
+            yield monkeypatch, client
1569
+
1468 1570
     @Parametrize.TRUNCATED_AGENT_RESPONSES
1469 1571
     def test_truncated_server_response(
1470 1572
         self,
... ...
@@ -1509,43 +1611,12 @@ class TestAgentErrorResponses:
1509 1611
 
1510 1612
         """
1511 1613
         del running_ssh_agent
1512
-
1513
-        passed_response_code = response_code
1514
-
1515
-        # TODO(the-13th-letter): Extract this mock function into a common
1516
-        # top-level "request" mock function.
1517
-        def request(
1518
-            request_code: int | _types.SSH_AGENTC,
1519
-            payload: bytes | bytearray,
1520
-            /,
1521
-            *,
1522
-            response_code: Iterable[int | _types.SSH_AGENT]
1523
-            | int
1524
-            | _types.SSH_AGENT
1525
-            | None = None,
1526
-        ) -> tuple[int, bytes | bytearray] | bytes | bytearray:
1527
-            del request_code
1528
-            del payload
1529
-            if isinstance(  # pragma: no branch
1530
-                response_code, (int, _types.SSH_AGENT)
1531
-            ):
1532
-                response_code = frozenset({response_code})
1533
-            if response_code is not None:  # pragma: no branch
1534
-                response_code = frozenset({
1535
-                    c if isinstance(c, int) else c.value for c in response_code
1536
-                })
1537
-
1538
-            if not response_code:  # pragma: no cover
1539
-                return (passed_response_code.value, response)
1540
-            if passed_response_code.value not in response_code:
1541
-                raise ssh_agent.SSHAgentFailedError(
1542
-                    passed_response_code.value, response
1543
-                )
1544
-            return response
1545
-
1546
-        with pytest.MonkeyPatch.context() as monkeypatch:
1547
-            client = ssh_agent.SSHAgentClient()
1548
-            monkeypatch.setattr(client, "request", request)
1614
+        with self._setup_agent_and_request_handler({
1615
+            self.Request(
1616
+                _types.SSH_AGENTC.REQUEST_IDENTITIES, None
1617
+            ): self.Response(response_code, response),
1618
+        }) as contexts:
1619
+            _, client = contexts
1549 1620
             with pytest.raises(exc_type, match=exc_pattern):
1550 1621
                 client.list_keys()
1551 1622
 
... ...
@@ -1570,50 +1641,23 @@ class TestAgentErrorResponses:
1570 1641
 
1571 1642
         """
1572 1643
         del running_ssh_agent
1573
-        passed_response_code = response_code
1574
-
1575
-        # TODO(the-13th-letter): Extract this mock function into a common
1576
-        # top-level "request" mock function.
1577
-        def request(
1578
-            request_code: int | _types.SSH_AGENTC,
1579
-            payload: bytes | bytearray,
1580
-            /,
1581
-            *,
1582
-            response_code: Iterable[int | _types.SSH_AGENT]
1583
-            | int
1584
-            | _types.SSH_AGENT
1585
-            | None = None,
1586
-        ) -> tuple[int, bytes | bytearray] | bytes | bytearray:
1587
-            del request_code
1588
-            del payload
1589
-            if isinstance(  # pragma: no branch
1590
-                response_code, (int, _types.SSH_AGENT)
1591
-            ):
1592
-                response_code = frozenset({response_code})
1593
-            if response_code is not None:  # pragma: no branch
1594
-                response_code = frozenset({
1595
-                    c if isinstance(c, int) else c.value for c in response_code
1596
-                })
1597
-
1598
-            if not response_code:  # pragma: no cover
1599
-                return (passed_response_code.value, response)
1600
-            if (
1601
-                passed_response_code.value not in response_code
1602
-            ):  # pragma: no branch
1603
-                raise ssh_agent.SSHAgentFailedError(
1604
-                    passed_response_code.value, response
1605
-                )
1606
-            return response  # pragma: no cover
1607
-
1608
-        with pytest.MonkeyPatch.context() as monkeypatch:
1609
-            client = ssh_agent.SSHAgentClient()
1610
-            monkeypatch.setattr(client, "request", request)
1611
-            Pair = _types.SSHKeyCommentPair  # noqa: N806
1644
+        Pair: TypeAlias = _types.SSHKeyCommentPair
1612 1645
         com = b"no comment"
1613 1646
         loaded_keys = [
1614 1647
             Pair(v.public_key_data, com).toreadonly()
1615 1648
             for v in data.SUPPORTED_KEYS.values()
1616 1649
         ]
1650
+        sign_payload = (
1651
+            ssh_agent.SSHAgentClient.string(key)
1652
+            + ssh_agent.SSHAgentClient.string(b"abc")
1653
+            + ssh_agent.SSHAgentClient.uint32(0)
1654
+        )
1655
+        with self._setup_agent_and_request_handler({
1656
+            self.Request(
1657
+                _types.SSH_AGENTC.SIGN_REQUEST, sign_payload
1658
+            ): self.Response(response_code, response),
1659
+        }) as contexts:
1660
+            monkeypatch, client = contexts
1617 1661
             monkeypatch.setattr(client, "list_keys", lambda: loaded_keys)
1618 1662
             with pytest.raises(exc_type, match=exc_pattern):
1619 1663
                 client.sign(key, b"abc", check_if_key_loaded=check)
... ...
@@ -1642,68 +1686,27 @@ class TestAgentErrorResponses:
1642 1686
         # with-statements.
1643 1687
         # https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9
1644 1688
         with contextlib.ExitStack() as stack:
1645
-            stack.enter_context(pytest.raises(exc_type, match=exc_pattern))
1646 1689
             client = stack.enter_context(ssh_agent.SSHAgentClient())
1690
+            stack.enter_context(pytest.raises(exc_type, match=exc_pattern))
1647 1691
             client.request(request_code, b"", response_code=response_code)
1648 1692
 
1649 1693
     @Parametrize.QUERY_EXTENSIONS_MALFORMED_RESPONSES
1650 1694
     def test_query_extensions_malformed_responses(
1651 1695
         self,
1652
-        monkeypatch: pytest.MonkeyPatch,
1653 1696
         running_ssh_agent: data.RunningSSHAgentInfo,
1654 1697
         response_data: bytes,
1655 1698
     ) -> None:
1656 1699
         """Fail on malformed responses while querying extensions."""
1657 1700
         del running_ssh_agent
1658
-
1659
-        # TODO(the-13th-letter): Extract this mock function into a common
1660
-        # top-level "request" mock function after removing the
1661
-        # payload-specific parts.
1662
-        def request(
1663
-            code: int | _types.SSH_AGENTC,
1664
-            payload: Buffer,
1665
-            /,
1666
-            *,
1667
-            response_code: (
1668
-                Iterable[_types.SSH_AGENT | int]
1669
-                | _types.SSH_AGENT
1670
-                | int
1671
-                | None
1672
-            ) = None,
1673
-        ) -> tuple[int, bytes] | bytes:
1674
-            request_codes = {
1701
+        with self._setup_agent_and_request_handler({
1702
+            self.Request(
1675 1703
                 _types.SSH_AGENTC.EXTENSION,
1676
-                _types.SSH_AGENTC.EXTENSION.value,
1677
-            }
1678
-            assert code in request_codes
1679
-            response_codes = {
1680
-                _types.SSH_AGENT.EXTENSION_RESPONSE,
1681
-                _types.SSH_AGENT.EXTENSION_RESPONSE.value,
1682
-                _types.SSH_AGENT.SUCCESS,
1683
-                _types.SSH_AGENT.SUCCESS.value,
1684
-            }
1685
-            assert payload == b"\x00\x00\x00\x05query"
1686
-            if response_code is None:  # pragma: no cover
1687
-                return (
1688
-                    _types.SSH_AGENT.EXTENSION_RESPONSE.value,
1689
-                    response_data,
1690
-                )
1691
-            if isinstance(  # pragma: no cover
1692
-                response_code, (_types.SSH_AGENT, int)
1693
-            ):
1694
-                assert response_code in response_codes
1695
-                return response_data
1696
-            for single_code in response_code:  # pragma: no cover
1697
-                assert single_code in response_codes
1698
-            return response_data  # pragma: no cover
1699
-
1700
-        # TODO(the-13th-letter): Rewrite using parenthesized
1701
-        # with-statements.
1702
-        # https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9
1703
-        with contextlib.ExitStack() as stack:
1704
-            monkeypatch2 = stack.enter_context(monkeypatch.context())
1705
-            client = stack.enter_context(ssh_agent.SSHAgentClient())
1706
-            monkeypatch2.setattr(client, "request", request)
1704
+                ssh_agent.SSHAgentClient.string(b"query"),
1705
+            ): self.Response(
1706
+                _types.SSH_AGENT.EXTENSION_RESPONSE, response_data
1707
+            ),
1708
+        }) as contexts:
1709
+            _, client = contexts
1707 1710
             with pytest.raises(
1708 1711
                 RuntimeError,
1709 1712
                 match=r"Malformed response|does not match request",
... ...
@@ -96,13 +96,15 @@ def draw_alias_chain(
96 96
 
97 97
     err_msg_chain_size = "Chain sizes must always be 1 or larger."
98 98
 
99
-    size = draw(chain_size)
99
+    size = draw(chain_size, label="chain_size")
100 100
     if size < 1:  # pragma: no cover
101 101
         raise ValueError(err_msg_chain_size)
102 102
     names: list[str] = []
103 103
     base: str | None = None
104 104
     if existing:
105
-        names.append(draw(known_keys_strategy.filter(not_an_alias)))
105
+        names.append(
106
+            draw(known_keys_strategy.filter(not_an_alias), label="base")
107
+        )
106 108
         base = names[0]
107 109
         size -= 1
108 110
         new_key_strategy = new_keys_strategy.filter(
... ...
@@ -124,7 +126,7 @@ def draw_alias_chain(
124 126
         max_size=size,
125 127
         unique=True,
126 128
     )
127
-    names.extend(draw(list_strategy))
129
+    names.extend(draw(list_strategy, label="others"))
128 130
     return tuple(names)
129 131
 
130 132
 
131 133