Marco Ricci commited on 2025-08-17 18:00:57
Zeige 3 geänderte Dateien mit 263 Einfügungen und 226 Löschungen.
(This is part 5 of a series of refactorings for the test suite.) In the basic tests, for the stubbed SSH agent socket and the agent error response tests, factor out the common environment setup for the respective group (in particular, the multiple copies of the fake `request` function from the SSH agent client). Combine the SSH agent data signing tests for suitable and unsuitable keys into a single, more heavily parametrized test. (The combined test is branchier, but more straightforward for discerning the differences between the two types of keys.) Also fix some miscellaneous errors: the IDs in the `RESOLVE_CHAINS` parametrization, and the order of contexts in the `test_request_error_responses` method. Further concerning the basic tests, add an explicit test for the `list-extended@putty.projects.tartarus.org` extension (which the extended stubbed SSH agent implements to signal support for deterministic DSA signatures, in a lackluster attempt to masquerade as Pageant). Also address a TODO in the test for SSH key selection concerning degenerate key lists. In the test data, add an extra comment on how suitable and unsuitable SSH test keys are determined, and add a function signature for the SSH agent client `request` function. In the heavy-duty tests, merely add missing `draw` labels.
... | ... |
@@ -21,7 +21,7 @@ from __future__ import annotations |
21 | 21 |
|
22 | 22 |
import base64 |
23 | 23 |
import enum |
24 |
-from typing import TYPE_CHECKING |
|
24 |
+from typing import TYPE_CHECKING, Protocol |
|
25 | 25 |
|
26 | 26 |
from typing_extensions import NamedTuple |
27 | 27 |
|
... | ... |
@@ -31,7 +31,7 @@ from derivepassphrase.ssh_agent import socketprovider |
31 | 31 |
__all__ = () |
32 | 32 |
|
33 | 33 |
if TYPE_CHECKING: |
34 |
- from collections.abc import Mapping |
|
34 |
+ from collections.abc import Iterable, Mapping |
|
35 | 35 |
|
36 | 36 |
from typing_extensions import Any |
37 | 37 |
|
... | ... |
@@ -244,6 +244,26 @@ class RunningSSHAgentInfo(NamedTuple): |
244 | 244 |
return self.socket |
245 | 245 |
|
246 | 246 |
|
247 |
+# `derivepassphrase` internal functions |
|
248 |
+# ------------------------------------- |
|
249 |
+ |
|
250 |
+ |
|
251 |
+class RequestFunc(Protocol): |
|
252 |
+ """The call signature of [`ssh_agent.SSHAgentClient.request`][].""" |
|
253 |
+ |
|
254 |
+ def __call__( |
|
255 |
+ self, |
|
256 |
+ request_code: int | _types.SSH_AGENTC, |
|
257 |
+ payload: bytes | bytearray, |
|
258 |
+ /, |
|
259 |
+ *, |
|
260 |
+ response_code: Iterable[int | _types.SSH_AGENT] |
|
261 |
+ | int |
|
262 |
+ | _types.SSH_AGENT |
|
263 |
+ | None = None, |
|
264 |
+ ) -> tuple[int, bytes | bytearray] | bytes | bytearray: ... |
|
265 |
+ |
|
266 |
+ |
|
247 | 267 |
# Vault configurations |
248 | 268 |
# ==================== |
249 | 269 |
|
... | ... |
@@ -1039,11 +1059,23 @@ Rlc3Qga2V5IHdpdGhvdXQgcGFzc3BocmFzZQ== |
1039 | 1059 |
SUPPORTED_KEYS: Mapping[str, SSHTestKey] = { |
1040 | 1060 |
k: v for k, v in ALL_KEYS.items() if v.is_suitable() |
1041 | 1061 |
} |
1042 |
-"""The subset of SSH test keys suitable for use with vault.""" |
|
1062 |
+"""The subset of SSH test keys suitable for use with vault. |
|
1063 |
+ |
|
1064 |
+Suitability is tested -- via [SSHTestKey.is_suitable][], then |
|
1065 |
+[vault.Vault.is_suitable_ssh_key][] -- via an internal whitelist, not |
|
1066 |
+via the presence or absence of a specific expected signature class. |
|
1067 |
+ |
|
1068 |
+""" |
|
1043 | 1069 |
UNSUITABLE_KEYS: Mapping[str, SSHTestKey] = { |
1044 | 1070 |
k: v for k, v in ALL_KEYS.items() if not v.is_suitable() |
1045 | 1071 |
} |
1046 |
-"""The subset of SSH test keys not suitable for use with vault.""" |
|
1072 |
+"""The subset of SSH test keys not suitable for use with vault. |
|
1073 |
+ |
|
1074 |
+Suitability is tested -- via [SSHTestKey.is_suitable][], then |
|
1075 |
+[vault.Vault.is_suitable_ssh_key][] -- via an internal whitelist, not |
|
1076 |
+via the presence or absence of a specific expected signature class. |
|
1077 |
+ |
|
1078 |
+""" |
|
1047 | 1079 |
|
1048 | 1080 |
|
1049 | 1081 |
# Vault test configurations |
... | ... |
@@ -17,13 +17,14 @@ import re |
17 | 17 |
import socket |
18 | 18 |
import sys |
19 | 19 |
import types |
20 |
-from typing import TYPE_CHECKING |
|
20 |
+from typing import TYPE_CHECKING, NamedTuple |
|
21 | 21 |
|
22 | 22 |
import click |
23 | 23 |
import click.testing |
24 | 24 |
import hypothesis |
25 | 25 |
import pytest |
26 | 26 |
from hypothesis import strategies |
27 |
+from typing_extensions import TypeAlias |
|
27 | 28 |
|
28 | 29 |
from derivepassphrase import _types, ssh_agent, vault |
29 | 30 |
from derivepassphrase._internals import cli_helpers |
... | ... |
@@ -33,7 +34,7 @@ from tests.data import callables |
33 | 34 |
from tests.machinery import pytest as pytest_machinery |
34 | 35 |
|
35 | 36 |
if TYPE_CHECKING: |
36 |
- from collections.abc import Iterable |
|
37 |
+ from collections.abc import Iterable, Iterator, Mapping |
|
37 | 38 |
|
38 | 39 |
from typing_extensions import Any, Buffer, Literal |
39 | 40 |
|
... | ... |
@@ -489,17 +490,17 @@ class Parametrize(types.SimpleNamespace): |
489 | 490 |
list(data.SUPPORTED_KEYS.items()), |
490 | 491 |
ids=data.SUPPORTED_KEYS.keys(), |
491 | 492 |
) |
492 |
- UNSUITABLE_SSH_TEST_KEYS = pytest.mark.parametrize( |
|
493 |
+ ALL_SSH_TEST_KEYS = pytest.mark.parametrize( |
|
493 | 494 |
["ssh_test_key_type", "ssh_test_key"], |
494 |
- list(data.UNSUITABLE_KEYS.items()), |
|
495 |
- ids=data.UNSUITABLE_KEYS.keys(), |
|
495 |
+ list(data.ALL_KEYS.items()), |
|
496 |
+ ids=data.ALL_KEYS.keys(), |
|
496 | 497 |
) |
497 | 498 |
RESOLVE_CHAINS = pytest.mark.parametrize( |
498 | 499 |
["terminal", "chain"], |
499 | 500 |
[ |
500 | 501 |
pytest.param("callable", ["a"], id="callable-1"), |
501 | 502 |
pytest.param("callable", ["a", "b", "c", "d"], id="callable-4"), |
502 |
- pytest.param("alias", ["e"], id="alias-5"), |
|
503 |
+ pytest.param("alias", ["e"], id="alias-1"), |
|
503 | 504 |
pytest.param("alias", ["e", "f", "g", "h", "i"], id="alias-5"), |
504 | 505 |
pytest.param("unimplemented", ["j"], id="unimplemented-1"), |
505 | 506 |
pytest.param("unimplemented", ["j", "k"], id="unimplemented-2"), |
... | ... |
@@ -510,17 +511,24 @@ class Parametrize(types.SimpleNamespace): |
510 | 511 |
class TestStubbedSSHAgentSocketRequests: |
511 | 512 |
"""Test the stubbed SSH agent socket: normal requests.""" |
512 | 513 |
|
513 |
- def test_query_extensions_base(self) -> None: |
|
514 |
- """The base agent implements no extensions.""" |
|
514 |
+ @contextlib.contextmanager |
|
515 |
+ def _get_addressed_agent( |
|
516 |
+ self, *, extended_agent: bool = False |
|
517 |
+ ) -> Iterator[machinery.StubbedSSHAgentSocketWithAddress]: |
|
518 |
+ agent_class: type[machinery.StubbedSSHAgentSocketWithAddress] = ( |
|
519 |
+ machinery.StubbedSSHAgentSocketWithAddressAndDeterministicDSA |
|
520 |
+ if extended_agent |
|
521 |
+ else machinery.StubbedSSHAgentSocketWithAddress |
|
522 |
+ ) |
|
515 | 523 |
with contextlib.ExitStack() as stack: |
516 | 524 |
monkeypatch = stack.enter_context(pytest.MonkeyPatch.context()) |
517 |
- monkeypatch.setenv( |
|
518 |
- "SSH_AUTH_SOCK", |
|
519 |
- machinery.StubbedSSHAgentSocketWithAddress.ADDRESS, |
|
520 |
- ) |
|
521 |
- agent = stack.enter_context( |
|
522 |
- machinery.StubbedSSHAgentSocketWithAddress() |
|
523 |
- ) |
|
525 |
+ monkeypatch.setenv("SSH_AUTH_SOCK", agent_class.ADDRESS) |
|
526 |
+ agent = stack.enter_context(agent_class()) |
|
527 |
+ yield agent |
|
528 |
+ |
|
529 |
+ def test_query_extensions_base(self) -> None: |
|
530 |
+ """The base agent implements no extensions.""" |
|
531 |
+ with self._get_addressed_agent(extended_agent=False) as agent: |
|
524 | 532 |
assert "query" not in agent.enabled_extensions |
525 | 533 |
query_request = ( |
526 | 534 |
# SSH string header |
... | ... |
@@ -541,15 +549,7 @@ class TestStubbedSSHAgentSocketRequests: |
541 | 549 |
|
542 | 550 |
def test_query_extensions_extended(self) -> None: |
543 | 551 |
"""The extended agent implements a known list of extensions.""" |
544 |
- with contextlib.ExitStack() as stack: |
|
545 |
- monkeypatch = stack.enter_context(pytest.MonkeyPatch.context()) |
|
546 |
- monkeypatch.setenv( |
|
547 |
- "SSH_AUTH_SOCK", |
|
548 |
- machinery.StubbedSSHAgentSocketWithAddress.ADDRESS, |
|
549 |
- ) |
|
550 |
- agent = stack.enter_context( |
|
551 |
- machinery.StubbedSSHAgentSocketWithAddressAndDeterministicDSA() |
|
552 |
- ) |
|
552 |
+ with self._get_addressed_agent(extended_agent=True) as agent: |
|
553 | 553 |
assert "query" in agent.enabled_extensions |
554 | 554 |
query_request = ( |
555 | 555 |
# SSH string header |
... | ... |
@@ -605,6 +605,41 @@ class TestStubbedSSHAgentSocketRequests: |
605 | 605 |
} |
606 | 606 |
assert not message |
607 | 607 |
|
608 |
+ def test_request_identities_extended(self) -> None: |
|
609 |
+ """The extended agent implements PuTTY's `list-extended` extension.""" |
|
610 |
+ unstring_prefix = ssh_agent.SSHAgentClient.unstring_prefix |
|
611 |
+ with self._get_addressed_agent(extended_agent=True) as agent: |
|
612 |
+ extension_request = ( |
|
613 |
+ # SSH string header |
|
614 |
+ b"\x00\x00\x00\x2e" |
|
615 |
+ # request code: SSH_AGENTC_REQUEST_IDENTITIES |
|
616 |
+ b"\x1b" |
|
617 |
+ # extension type: list-extended@putty.projects.tartarus.org |
|
618 |
+ b"\x00\x00\x00\x29list-extended@putty.projects.tartarus.org" |
|
619 |
+ # (no payload) |
|
620 |
+ ) |
|
621 |
+ agent.sendall(extension_request) |
|
622 |
+ message_length = int.from_bytes(agent.recv(4), "big") |
|
623 |
+ orig_message: bytes | bytearray = bytearray( |
|
624 |
+ agent.recv(message_length) |
|
625 |
+ ) |
|
626 |
+ assert ( |
|
627 |
+ _types.SSH_AGENT(orig_message[0]) |
|
628 |
+ == _types.SSH_AGENT.SUCCESS |
|
629 |
+ ) |
|
630 |
+ identity_count = int.from_bytes(orig_message[1:5], "big") |
|
631 |
+ message = bytes(orig_message[5:]) |
|
632 |
+ for _ in range(identity_count): |
|
633 |
+ key, message = unstring_prefix(message) |
|
634 |
+ _comment, message = unstring_prefix(message) |
|
635 |
+ flags, message = unstring_prefix(message) |
|
636 |
+ assert flags == b"\x00\x00\x00\x00" |
|
637 |
+ assert key |
|
638 |
+ assert key in { |
|
639 |
+ k.public_key_data for k in data.ALL_KEYS.values() |
|
640 |
+ } |
|
641 |
+ assert not message |
|
642 |
+ |
|
608 | 643 |
@Parametrize.SUPPORTED_SSH_TEST_KEYS |
609 | 644 |
def test_sign( |
610 | 645 |
self, |
... | ... |
@@ -1242,14 +1277,14 @@ class TestSSHAgentSocketProviderRegistry: |
1242 | 1277 |
class TestAgentSigning: |
1243 | 1278 |
"""Test actually talking to the SSH agent: signing data.""" |
1244 | 1279 |
|
1245 |
- @Parametrize.SUPPORTED_SSH_TEST_KEYS |
|
1280 |
+ @Parametrize.ALL_SSH_TEST_KEYS |
|
1246 | 1281 |
def test_sign_data_via_agent( |
1247 | 1282 |
self, |
1248 | 1283 |
ssh_agent_client_with_test_keys_loaded: ssh_agent.SSHAgentClient, |
1249 | 1284 |
ssh_test_key_type: str, |
1250 | 1285 |
ssh_test_key: data.SSHTestKey, |
1251 | 1286 |
) -> None: |
1252 |
- """Signing data with specific SSH keys works. |
|
1287 |
+ """Signing data with specific SSH keys works iff the key is suitable. |
|
1253 | 1288 |
|
1254 | 1289 |
Single tests may abort early (skip) if the indicated key is not |
1255 | 1290 |
loaded in the agent. Presumably this means the key type is |
... | ... |
@@ -1259,6 +1294,9 @@ class TestAgentSigning: |
1259 | 1294 |
client = ssh_agent_client_with_test_keys_loaded |
1260 | 1295 |
key_comment_pairs = {bytes(k): bytes(c) for k, c in client.list_keys()} |
1261 | 1296 |
public_key_data = ssh_test_key.public_key_data |
1297 |
+ if public_key_data not in key_comment_pairs: # pragma: no cover |
|
1298 |
+ pytest.skip(f"prerequisite {ssh_test_key_type} SSH key not loaded") |
|
1299 |
+ if ssh_test_key.is_suitable(): |
|
1262 | 1300 |
assert ( |
1263 | 1301 |
data.SSHTestKeyDeterministicSignatureClass.SPEC |
1264 | 1302 |
in ssh_test_key.expected_signatures |
... | ... |
@@ -1268,8 +1306,6 @@ class TestAgentSigning: |
1268 | 1306 |
] |
1269 | 1307 |
expected_signature = sig.signature |
1270 | 1308 |
derived_passphrase = sig.derived_passphrase |
1271 |
- if public_key_data not in key_comment_pairs: # pragma: no cover |
|
1272 |
- pytest.skip(f"prerequisite {ssh_test_key_type} SSH key not loaded") |
|
1273 | 1309 |
signature = bytes( |
1274 | 1310 |
client.sign(payload=vault.Vault.UUID, key=public_key_data) |
1275 | 1311 |
) |
... | ... |
@@ -1286,35 +1322,28 @@ class TestAgentSigning: |
1286 | 1322 |
vault.Vault.phrase_from_key(public_key_data, conn=client) |
1287 | 1323 |
== derived_passphrase |
1288 | 1324 |
), f"SSH signature mismatch ({ssh_test_key_type})" |
1289 |
- |
|
1290 |
- @Parametrize.UNSUITABLE_SSH_TEST_KEYS |
|
1291 |
- def test_sign_data_via_agent_unsupported( |
|
1292 |
- self, |
|
1293 |
- ssh_agent_client_with_test_keys_loaded: ssh_agent.SSHAgentClient, |
|
1294 |
- ssh_test_key_type: str, |
|
1295 |
- ssh_test_key: data.SSHTestKey, |
|
1296 |
- ) -> None: |
|
1297 |
- """Using an unsuitable key with [`vault.Vault`][] fails. |
|
1298 |
- |
|
1299 |
- Single tests may abort early (skip) if the indicated key is not |
|
1300 |
- loaded in the agent. Presumably this means the key type is |
|
1301 |
- unsupported. Single tests may also abort early if the agent |
|
1302 |
- ensures that the generally unsuitable key is actually suitable |
|
1303 |
- under this agent. |
|
1304 |
- |
|
1305 |
- """ |
|
1306 |
- client = ssh_agent_client_with_test_keys_loaded |
|
1307 |
- key_comment_pairs = {bytes(k): bytes(c) for k, c in client.list_keys()} |
|
1308 |
- public_key_data = ssh_test_key.public_key_data |
|
1309 |
- if public_key_data not in key_comment_pairs: # pragma: no cover |
|
1310 |
- pytest.skip(f"prerequisite {ssh_test_key_type} SSH key not loaded") |
|
1325 |
+ else: |
|
1326 |
+ assert ( |
|
1327 |
+ data.SSHTestKeyDeterministicSignatureClass.SPEC |
|
1328 |
+ not in ssh_test_key.expected_signatures |
|
1329 |
+ ) |
|
1311 | 1330 |
assert not vault.Vault.is_suitable_ssh_key( |
1312 | 1331 |
public_key_data, client=None |
1313 | 1332 |
), f"Expected {ssh_test_key_type} key to be unsuitable in general" |
1314 |
- if vault.Vault.is_suitable_ssh_key(public_key_data, client=client): |
|
1315 |
- pytest.skip( |
|
1316 |
- f"agent automatically ensures {ssh_test_key_type} key is suitable" |
|
1333 |
+ if vault.Vault.is_suitable_ssh_key( |
|
1334 |
+ public_key_data, client=client |
|
1335 |
+ ): # pragma: no cover [external] |
|
1336 |
+ potential_signatures = { |
|
1337 |
+ sig_class: sig.signature |
|
1338 |
+ for sig_class, sig in ssh_test_key.expected_signatures.items() |
|
1339 |
+ if sig_class |
|
1340 |
+ != data.SSHTestKeyDeterministicSignatureClass.SPEC |
|
1341 |
+ } |
|
1342 |
+ signature = bytes( |
|
1343 |
+ client.sign(payload=vault.Vault.UUID, key=public_key_data) |
|
1317 | 1344 |
) |
1345 |
+ assert signature in potential_signatures.values() |
|
1346 |
+ else: # pragma: no cover [external] |
|
1318 | 1347 |
with pytest.raises(ValueError, match="unsuitable SSH key"): |
1319 | 1348 |
vault.Vault.phrase_from_key(public_key_data, conn=client) |
1320 | 1349 |
|
... | ... |
@@ -1338,50 +1367,30 @@ class TestSuitableKeys: |
1338 | 1367 |
""" |
1339 | 1368 |
client = ssh_agent_client_with_test_keys_loaded |
1340 | 1369 |
|
1341 |
- def key_is_suitable(key: bytes) -> bool: |
|
1342 |
- """Stub out [`vault.Vault.key_is_suitable`][].""" |
|
1343 |
- always = {v.public_key_data for v in data.SUPPORTED_KEYS.values()} |
|
1344 |
- dsa = { |
|
1345 |
- v.public_key_data |
|
1346 |
- for k, v in data.UNSUITABLE_KEYS.items() |
|
1347 |
- if k.startswith(("dsa", "ecdsa")) |
|
1348 |
- } |
|
1349 |
- return key in always or ( |
|
1350 |
- client.has_deterministic_dsa_signatures() and key in dsa |
|
1351 |
- ) |
|
1352 |
- |
|
1353 |
- # TODO(the-13th-letter): Handle the unlikely(?) case that only |
|
1354 |
- # one test key is loaded, but `single` is False. Rename the |
|
1355 |
- # `index` variable to `input`, store the `input` in there, and |
|
1356 |
- # make the definition of `text` in the else block dependent on |
|
1357 |
- # `n` being singular or non-singular. |
|
1358 |
- if single: |
|
1359 | 1370 |
monkeypatch.setattr( |
1360 | 1371 |
ssh_agent.SSHAgentClient, |
1361 | 1372 |
"list_keys", |
1362 |
- callables.list_keys_singleton, |
|
1373 |
+ callables.list_keys_singleton if single else callables.list_keys, |
|
1374 |
+ ) |
|
1375 |
+ raw_keys_list = ( |
|
1376 |
+ callables.list_keys_singleton() |
|
1377 |
+ if single |
|
1378 |
+ else callables.list_keys() |
|
1363 | 1379 |
) |
1364 | 1380 |
keys = [ |
1365 | 1381 |
pair.key |
1366 |
- for pair in callables.list_keys_singleton() |
|
1367 |
- if key_is_suitable(pair.key) |
|
1382 |
+ for pair in raw_keys_list |
|
1383 |
+ if vault.Vault.is_suitable_ssh_key(pair.key, client=client) |
|
1368 | 1384 |
] |
1369 |
- index = "1" |
|
1385 |
+ n = len(keys) |
|
1386 |
+ if single or n == 1: |
|
1387 |
+ input_text = "yes\n" |
|
1370 | 1388 |
text = "Use this key? yes\n" |
1371 | 1389 |
else: |
1372 |
- monkeypatch.setattr( |
|
1373 |
- ssh_agent.SSHAgentClient, |
|
1374 |
- "list_keys", |
|
1375 |
- callables.list_keys, |
|
1390 |
+ input_text = f"{1 + keys.index(key)}\n" |
|
1391 |
+ text = ( |
|
1392 |
+ f"Your selection? (1-{n}, leave empty to abort): {input_text}" |
|
1376 | 1393 |
) |
1377 |
- keys = [ |
|
1378 |
- pair.key |
|
1379 |
- for pair in callables.list_keys() |
|
1380 |
- if key_is_suitable(pair.key) |
|
1381 |
- ] |
|
1382 |
- index = str(1 + keys.index(key)) |
|
1383 |
- n = len(keys) |
|
1384 |
- text = f"Your selection? (1-{n}, leave empty to abort): {index}\n" |
|
1385 | 1394 |
b64_key = base64.standard_b64encode(key).decode("ASCII") |
1386 | 1395 |
|
1387 | 1396 |
@click.command() |
... | ... |
@@ -1390,14 +1399,9 @@ class TestSuitableKeys: |
1390 | 1399 |
key = cli_helpers.select_ssh_key(client) |
1391 | 1400 |
click.echo(base64.standard_b64encode(key).decode("ASCII")) |
1392 | 1401 |
|
1393 |
- # TODO(the-13th-letter): (Continued from above.) Update input |
|
1394 |
- # data to use `index`/`input` directly and unconditionally. |
|
1395 | 1402 |
runner = machinery.CliRunner(mix_stderr=True) |
1396 | 1403 |
result = runner.invoke( |
1397 |
- driver, |
|
1398 |
- [], |
|
1399 |
- input=("yes\n" if single else f"{index}\n"), |
|
1400 |
- catch_exceptions=True, |
|
1404 |
+ driver, [], input=input_text, catch_exceptions=True |
|
1401 | 1405 |
) |
1402 | 1406 |
for snippet in ("Suitable SSH keys:\n", text, f"\n{b64_key}\n"): |
1403 | 1407 |
assert result.clean_exit(output=snippet), "expected clean exit" |
... | ... |
@@ -1465,6 +1469,104 @@ class TestOtherConstructorFeatures: |
1465 | 1469 |
class TestAgentErrorResponses: |
1466 | 1470 |
"""Test actually talking to the SSH agent: errors from the SSH agent.""" |
1467 | 1471 |
|
1472 |
+ class Request(NamedTuple): |
|
1473 |
+ """A request for the SSH agent protocol. |
|
1474 |
+ |
|
1475 |
+ Attributes: |
|
1476 |
+ code: |
|
1477 |
+ The expected request code. |
|
1478 |
+ payload: |
|
1479 |
+ The expected payload. Optional. |
|
1480 |
+ |
|
1481 |
+ """ |
|
1482 |
+ |
|
1483 |
+ code: _types.SSH_AGENTC |
|
1484 |
+ """""" |
|
1485 |
+ payload: bytes | None |
|
1486 |
+ """""" |
|
1487 |
+ |
|
1488 |
+ class Response(NamedTuple): |
|
1489 |
+ """A response for the SSH agent protocol. |
|
1490 |
+ |
|
1491 |
+ Attributes: |
|
1492 |
+ code: |
|
1493 |
+ The desired response code. |
|
1494 |
+ payload: |
|
1495 |
+ The desired response payload. |
|
1496 |
+ |
|
1497 |
+ """ |
|
1498 |
+ |
|
1499 |
+ code: _types.SSH_AGENT |
|
1500 |
+ """""" |
|
1501 |
+ payload: bytes |
|
1502 |
+ """""" |
|
1503 |
+ |
|
1504 |
+ def _make_request_stub( |
|
1505 |
+ self, |
|
1506 |
+ request_response_map: Mapping[Request, Response], |
|
1507 |
+ /, |
|
1508 |
+ ) -> data.RequestFunc: |
|
1509 |
+ """Return a stubbed SSH agent client `request` function. |
|
1510 |
+ |
|
1511 |
+ Args: |
|
1512 |
+ request_response_map: |
|
1513 |
+ A map of request/response pairs. |
|
1514 |
+ |
|
1515 |
+ """ |
|
1516 |
+ |
|
1517 |
+ def request( |
|
1518 |
+ request_code: int | _types.SSH_AGENTC, |
|
1519 |
+ payload: bytes | bytearray, |
|
1520 |
+ /, |
|
1521 |
+ *, |
|
1522 |
+ response_code: Iterable[int | _types.SSH_AGENT] |
|
1523 |
+ | int |
|
1524 |
+ | _types.SSH_AGENT |
|
1525 |
+ | None = None, |
|
1526 |
+ ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
1527 |
+ request_code = _types.SSH_AGENTC(request_code) |
|
1528 |
+ response_code = ( |
|
1529 |
+ frozenset({_types.SSH_AGENT.SUCCESS}) |
|
1530 |
+ if response_code is None |
|
1531 |
+ else frozenset({_types.SSH_AGENT(response_code)}) |
|
1532 |
+ if isinstance(response_code, (int, _types.SSH_AGENT)) |
|
1533 |
+ else frozenset(map(_types.SSH_AGENT, response_code)) |
|
1534 |
+ ) |
|
1535 |
+ response_ = request_response_map.get( |
|
1536 |
+ self.Request(request_code, bytes(payload)) |
|
1537 |
+ ) or request_response_map.get(self.Request(request_code, None)) |
|
1538 |
+ if response_ is None: # pragma: no cover [failsafe] |
|
1539 |
+ raise ssh_agent.SSHAgentFailedError( |
|
1540 |
+ _types.SSH_AGENT.FAILURE.value, |
|
1541 |
+ "No prepared response for the request " |
|
1542 |
+ f"({request_code!r}, {payload!r})".encode(), |
|
1543 |
+ ) |
|
1544 |
+ code, response = response_ |
|
1545 |
+ if code not in response_code: |
|
1546 |
+ raise ssh_agent.SSHAgentFailedError(code.value, response) |
|
1547 |
+ return response |
|
1548 |
+ |
|
1549 |
+ return request |
|
1550 |
+ |
|
1551 |
+ @contextlib.contextmanager |
|
1552 |
+ def _setup_agent_and_request_handler( |
|
1553 |
+ self, request_response_map: Mapping[Request, Response], / |
|
1554 |
+ ) -> Iterator[tuple[pytest.MonkeyPatch, ssh_agent.SSHAgentClient]]: |
|
1555 |
+ # TODO(the-13th-letter): Rewrite using parenthesized |
|
1556 |
+ # with-statements. |
|
1557 |
+ # https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9 |
|
1558 |
+ with contextlib.ExitStack() as stack: |
|
1559 |
+ monkeypatch = stack.enter_context(pytest.MonkeyPatch.context()) |
|
1560 |
+ client = stack.enter_context( |
|
1561 |
+ ssh_agent.SSHAgentClient.ensure_agent_subcontext() |
|
1562 |
+ ) |
|
1563 |
+ monkeypatch.setattr( |
|
1564 |
+ client, |
|
1565 |
+ "request", |
|
1566 |
+ self._make_request_stub(request_response_map), |
|
1567 |
+ ) |
|
1568 |
+ yield monkeypatch, client |
|
1569 |
+ |
|
1468 | 1570 |
@Parametrize.TRUNCATED_AGENT_RESPONSES |
1469 | 1571 |
def test_truncated_server_response( |
1470 | 1572 |
self, |
... | ... |
@@ -1509,43 +1611,12 @@ class TestAgentErrorResponses: |
1509 | 1611 |
|
1510 | 1612 |
""" |
1511 | 1613 |
del running_ssh_agent |
1512 |
- |
|
1513 |
- passed_response_code = response_code |
|
1514 |
- |
|
1515 |
- # TODO(the-13th-letter): Extract this mock function into a common |
|
1516 |
- # top-level "request" mock function. |
|
1517 |
- def request( |
|
1518 |
- request_code: int | _types.SSH_AGENTC, |
|
1519 |
- payload: bytes | bytearray, |
|
1520 |
- /, |
|
1521 |
- *, |
|
1522 |
- response_code: Iterable[int | _types.SSH_AGENT] |
|
1523 |
- | int |
|
1524 |
- | _types.SSH_AGENT |
|
1525 |
- | None = None, |
|
1526 |
- ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
1527 |
- del request_code |
|
1528 |
- del payload |
|
1529 |
- if isinstance( # pragma: no branch |
|
1530 |
- response_code, (int, _types.SSH_AGENT) |
|
1531 |
- ): |
|
1532 |
- response_code = frozenset({response_code}) |
|
1533 |
- if response_code is not None: # pragma: no branch |
|
1534 |
- response_code = frozenset({ |
|
1535 |
- c if isinstance(c, int) else c.value for c in response_code |
|
1536 |
- }) |
|
1537 |
- |
|
1538 |
- if not response_code: # pragma: no cover |
|
1539 |
- return (passed_response_code.value, response) |
|
1540 |
- if passed_response_code.value not in response_code: |
|
1541 |
- raise ssh_agent.SSHAgentFailedError( |
|
1542 |
- passed_response_code.value, response |
|
1543 |
- ) |
|
1544 |
- return response |
|
1545 |
- |
|
1546 |
- with pytest.MonkeyPatch.context() as monkeypatch: |
|
1547 |
- client = ssh_agent.SSHAgentClient() |
|
1548 |
- monkeypatch.setattr(client, "request", request) |
|
1614 |
+ with self._setup_agent_and_request_handler({ |
|
1615 |
+ self.Request( |
|
1616 |
+ _types.SSH_AGENTC.REQUEST_IDENTITIES, None |
|
1617 |
+ ): self.Response(response_code, response), |
|
1618 |
+ }) as contexts: |
|
1619 |
+ _, client = contexts |
|
1549 | 1620 |
with pytest.raises(exc_type, match=exc_pattern): |
1550 | 1621 |
client.list_keys() |
1551 | 1622 |
|
... | ... |
@@ -1570,50 +1641,23 @@ class TestAgentErrorResponses: |
1570 | 1641 |
|
1571 | 1642 |
""" |
1572 | 1643 |
del running_ssh_agent |
1573 |
- passed_response_code = response_code |
|
1574 |
- |
|
1575 |
- # TODO(the-13th-letter): Extract this mock function into a common |
|
1576 |
- # top-level "request" mock function. |
|
1577 |
- def request( |
|
1578 |
- request_code: int | _types.SSH_AGENTC, |
|
1579 |
- payload: bytes | bytearray, |
|
1580 |
- /, |
|
1581 |
- *, |
|
1582 |
- response_code: Iterable[int | _types.SSH_AGENT] |
|
1583 |
- | int |
|
1584 |
- | _types.SSH_AGENT |
|
1585 |
- | None = None, |
|
1586 |
- ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
1587 |
- del request_code |
|
1588 |
- del payload |
|
1589 |
- if isinstance( # pragma: no branch |
|
1590 |
- response_code, (int, _types.SSH_AGENT) |
|
1591 |
- ): |
|
1592 |
- response_code = frozenset({response_code}) |
|
1593 |
- if response_code is not None: # pragma: no branch |
|
1594 |
- response_code = frozenset({ |
|
1595 |
- c if isinstance(c, int) else c.value for c in response_code |
|
1596 |
- }) |
|
1597 |
- |
|
1598 |
- if not response_code: # pragma: no cover |
|
1599 |
- return (passed_response_code.value, response) |
|
1600 |
- if ( |
|
1601 |
- passed_response_code.value not in response_code |
|
1602 |
- ): # pragma: no branch |
|
1603 |
- raise ssh_agent.SSHAgentFailedError( |
|
1604 |
- passed_response_code.value, response |
|
1605 |
- ) |
|
1606 |
- return response # pragma: no cover |
|
1607 |
- |
|
1608 |
- with pytest.MonkeyPatch.context() as monkeypatch: |
|
1609 |
- client = ssh_agent.SSHAgentClient() |
|
1610 |
- monkeypatch.setattr(client, "request", request) |
|
1611 |
- Pair = _types.SSHKeyCommentPair # noqa: N806 |
|
1644 |
+ Pair: TypeAlias = _types.SSHKeyCommentPair |
|
1612 | 1645 |
com = b"no comment" |
1613 | 1646 |
loaded_keys = [ |
1614 | 1647 |
Pair(v.public_key_data, com).toreadonly() |
1615 | 1648 |
for v in data.SUPPORTED_KEYS.values() |
1616 | 1649 |
] |
1650 |
+ sign_payload = ( |
|
1651 |
+ ssh_agent.SSHAgentClient.string(key) |
|
1652 |
+ + ssh_agent.SSHAgentClient.string(b"abc") |
|
1653 |
+ + ssh_agent.SSHAgentClient.uint32(0) |
|
1654 |
+ ) |
|
1655 |
+ with self._setup_agent_and_request_handler({ |
|
1656 |
+ self.Request( |
|
1657 |
+ _types.SSH_AGENTC.SIGN_REQUEST, sign_payload |
|
1658 |
+ ): self.Response(response_code, response), |
|
1659 |
+ }) as contexts: |
|
1660 |
+ monkeypatch, client = contexts |
|
1617 | 1661 |
monkeypatch.setattr(client, "list_keys", lambda: loaded_keys) |
1618 | 1662 |
with pytest.raises(exc_type, match=exc_pattern): |
1619 | 1663 |
client.sign(key, b"abc", check_if_key_loaded=check) |
... | ... |
@@ -1642,68 +1686,27 @@ class TestAgentErrorResponses: |
1642 | 1686 |
# with-statements. |
1643 | 1687 |
# https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9 |
1644 | 1688 |
with contextlib.ExitStack() as stack: |
1645 |
- stack.enter_context(pytest.raises(exc_type, match=exc_pattern)) |
|
1646 | 1689 |
client = stack.enter_context(ssh_agent.SSHAgentClient()) |
1690 |
+ stack.enter_context(pytest.raises(exc_type, match=exc_pattern)) |
|
1647 | 1691 |
client.request(request_code, b"", response_code=response_code) |
1648 | 1692 |
|
1649 | 1693 |
@Parametrize.QUERY_EXTENSIONS_MALFORMED_RESPONSES |
1650 | 1694 |
def test_query_extensions_malformed_responses( |
1651 | 1695 |
self, |
1652 |
- monkeypatch: pytest.MonkeyPatch, |
|
1653 | 1696 |
running_ssh_agent: data.RunningSSHAgentInfo, |
1654 | 1697 |
response_data: bytes, |
1655 | 1698 |
) -> None: |
1656 | 1699 |
"""Fail on malformed responses while querying extensions.""" |
1657 | 1700 |
del running_ssh_agent |
1658 |
- |
|
1659 |
- # TODO(the-13th-letter): Extract this mock function into a common |
|
1660 |
- # top-level "request" mock function after removing the |
|
1661 |
- # payload-specific parts. |
|
1662 |
- def request( |
|
1663 |
- code: int | _types.SSH_AGENTC, |
|
1664 |
- payload: Buffer, |
|
1665 |
- /, |
|
1666 |
- *, |
|
1667 |
- response_code: ( |
|
1668 |
- Iterable[_types.SSH_AGENT | int] |
|
1669 |
- | _types.SSH_AGENT |
|
1670 |
- | int |
|
1671 |
- | None |
|
1672 |
- ) = None, |
|
1673 |
- ) -> tuple[int, bytes] | bytes: |
|
1674 |
- request_codes = { |
|
1701 |
+ with self._setup_agent_and_request_handler({ |
|
1702 |
+ self.Request( |
|
1675 | 1703 |
_types.SSH_AGENTC.EXTENSION, |
1676 |
- _types.SSH_AGENTC.EXTENSION.value, |
|
1677 |
- } |
|
1678 |
- assert code in request_codes |
|
1679 |
- response_codes = { |
|
1680 |
- _types.SSH_AGENT.EXTENSION_RESPONSE, |
|
1681 |
- _types.SSH_AGENT.EXTENSION_RESPONSE.value, |
|
1682 |
- _types.SSH_AGENT.SUCCESS, |
|
1683 |
- _types.SSH_AGENT.SUCCESS.value, |
|
1684 |
- } |
|
1685 |
- assert payload == b"\x00\x00\x00\x05query" |
|
1686 |
- if response_code is None: # pragma: no cover |
|
1687 |
- return ( |
|
1688 |
- _types.SSH_AGENT.EXTENSION_RESPONSE.value, |
|
1689 |
- response_data, |
|
1690 |
- ) |
|
1691 |
- if isinstance( # pragma: no cover |
|
1692 |
- response_code, (_types.SSH_AGENT, int) |
|
1693 |
- ): |
|
1694 |
- assert response_code in response_codes |
|
1695 |
- return response_data |
|
1696 |
- for single_code in response_code: # pragma: no cover |
|
1697 |
- assert single_code in response_codes |
|
1698 |
- return response_data # pragma: no cover |
|
1699 |
- |
|
1700 |
- # TODO(the-13th-letter): Rewrite using parenthesized |
|
1701 |
- # with-statements. |
|
1702 |
- # https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9 |
|
1703 |
- with contextlib.ExitStack() as stack: |
|
1704 |
- monkeypatch2 = stack.enter_context(monkeypatch.context()) |
|
1705 |
- client = stack.enter_context(ssh_agent.SSHAgentClient()) |
|
1706 |
- monkeypatch2.setattr(client, "request", request) |
|
1704 |
+ ssh_agent.SSHAgentClient.string(b"query"), |
|
1705 |
+ ): self.Response( |
|
1706 |
+ _types.SSH_AGENT.EXTENSION_RESPONSE, response_data |
|
1707 |
+ ), |
|
1708 |
+ }) as contexts: |
|
1709 |
+ _, client = contexts |
|
1707 | 1710 |
with pytest.raises( |
1708 | 1711 |
RuntimeError, |
1709 | 1712 |
match=r"Malformed response|does not match request", |
... | ... |
@@ -96,13 +96,15 @@ def draw_alias_chain( |
96 | 96 |
|
97 | 97 |
err_msg_chain_size = "Chain sizes must always be 1 or larger." |
98 | 98 |
|
99 |
- size = draw(chain_size) |
|
99 |
+ size = draw(chain_size, label="chain_size") |
|
100 | 100 |
if size < 1: # pragma: no cover |
101 | 101 |
raise ValueError(err_msg_chain_size) |
102 | 102 |
names: list[str] = [] |
103 | 103 |
base: str | None = None |
104 | 104 |
if existing: |
105 |
- names.append(draw(known_keys_strategy.filter(not_an_alias))) |
|
105 |
+ names.append( |
|
106 |
+ draw(known_keys_strategy.filter(not_an_alias), label="base") |
|
107 |
+ ) |
|
106 | 108 |
base = names[0] |
107 | 109 |
size -= 1 |
108 | 110 |
new_key_strategy = new_keys_strategy.filter( |
... | ... |
@@ -124,7 +126,7 @@ def draw_alias_chain( |
124 | 126 |
max_size=size, |
125 | 127 |
unique=True, |
126 | 128 |
) |
127 |
- names.extend(draw(list_strategy)) |
|
129 |
+ names.extend(draw(list_strategy, label="others")) |
|
128 | 130 |
return tuple(names) |
129 | 131 |
|
130 | 132 |
|
131 | 133 |