Marco Ricci commited on 2024-09-21 11:45:20
Zeige 2 geänderte Dateien mit 187 Einfügungen und 29 Löschungen.
This shifts the remaining error checking into the `SSHAgentClient.request` method, most of the time. On the other hand, this makes mocking that method somewhat more involved.
| ... | ... |
@@ -10,14 +10,14 @@ import collections |
| 10 | 10 |
import errno |
| 11 | 11 |
import os |
| 12 | 12 |
import socket |
| 13 |
-from typing import TYPE_CHECKING |
|
| 13 |
+from typing import TYPE_CHECKING, overload |
|
| 14 | 14 |
|
| 15 | 15 |
from typing_extensions import Self |
| 16 | 16 |
|
| 17 | 17 |
from derivepassphrase import _types |
| 18 | 18 |
|
| 19 | 19 |
if TYPE_CHECKING: |
| 20 |
- from collections.abc import Sequence |
|
| 20 |
+ from collections.abc import Iterable, Sequence |
|
| 21 | 21 |
from types import TracebackType |
| 22 | 22 |
|
| 23 | 23 |
__all__ = ('SSHAgentClient',)
|
| ... | ... |
@@ -268,9 +268,48 @@ class SSHAgentClient: |
| 268 | 268 |
bytestring[m + HEAD_LEN :], |
| 269 | 269 |
) |
| 270 | 270 |
|
| 271 |
+ @overload |
|
| 272 |
+ def request( # pragma: no cover |
|
| 273 |
+ self, |
|
| 274 |
+ code: int | _types.SSH_AGENTC, |
|
| 275 |
+ payload: bytes | bytearray, |
|
| 276 |
+ /, |
|
| 277 |
+ *, |
|
| 278 |
+ response_code: None = None, |
|
| 279 |
+ ) -> tuple[int, bytes | bytearray]: ... |
|
| 280 |
+ |
|
| 281 |
+ @overload |
|
| 282 |
+ def request( # pragma: no cover |
|
| 283 |
+ self, |
|
| 284 |
+ code: int | _types.SSH_AGENTC, |
|
| 285 |
+ payload: bytes | bytearray, |
|
| 286 |
+ /, |
|
| 287 |
+ *, |
|
| 288 |
+ response_code: Iterable[_types.SSH_AGENT | int] = frozenset({
|
|
| 289 |
+ _types.SSH_AGENT.SUCCESS |
|
| 290 |
+ }), |
|
| 291 |
+ ) -> tuple[int, bytes | bytearray]: ... |
|
| 292 |
+ |
|
| 293 |
+ @overload |
|
| 294 |
+ def request( # pragma: no cover |
|
| 295 |
+ self, |
|
| 296 |
+ code: int | _types.SSH_AGENTC, |
|
| 297 |
+ payload: bytes | bytearray, |
|
| 298 |
+ /, |
|
| 299 |
+ *, |
|
| 300 |
+ response_code: _types.SSH_AGENT | int = _types.SSH_AGENT.SUCCESS, |
|
| 301 |
+ ) -> bytes | bytearray: ... |
|
| 302 |
+ |
|
| 271 | 303 |
def request( |
| 272 |
- self, code: int, payload: bytes | bytearray, / |
|
| 273 |
- ) -> tuple[int, bytes | bytearray]: |
|
| 304 |
+ self, |
|
| 305 |
+ code: int | _types.SSH_AGENTC, |
|
| 306 |
+ payload: bytes | bytearray, |
|
| 307 |
+ /, |
|
| 308 |
+ *, |
|
| 309 |
+ response_code: ( |
|
| 310 |
+ Iterable[_types.SSH_AGENT | int] | _types.SSH_AGENT | int | None |
|
| 311 |
+ ) = None, |
|
| 312 |
+ ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
| 274 | 313 |
"""Issue a generic request to the SSH agent. |
| 275 | 314 |
|
| 276 | 315 |
Args: |
| ... | ... |
@@ -283,6 +322,10 @@ class SSHAgentClient: |
| 283 | 322 |
the request. Request-specific. `request` will add any |
| 284 | 323 |
necessary wire framing around the request code and the |
| 285 | 324 |
payload. |
| 325 |
+ response_code: |
|
| 326 |
+ An optional response code, or a set of response codes, |
|
| 327 |
+ that we expect. If given, and the actual response code |
|
| 328 |
+ does not match, raise an error. |
|
| 286 | 329 |
|
| 287 | 330 |
Returns: |
| 288 | 331 |
A 2-tuple consisting of the response code and the payload, |
| ... | ... |
@@ -291,9 +334,24 @@ class SSHAgentClient: |
| 291 | 334 |
Raises: |
| 292 | 335 |
EOFError: |
| 293 | 336 |
The response from the SSH agent is truncated or missing. |
| 337 |
+ OSError: |
|
| 338 |
+ There was a communication error with the SSH agent. |
|
| 339 |
+ SSHAgentFailedError: |
|
| 340 |
+ We expected specific response codes, but did not receive |
|
| 341 |
+ any of them. |
|
| 294 | 342 |
|
| 295 | 343 |
""" |
| 296 |
- request_message = bytearray([code]) |
|
| 344 |
+ if isinstance( # pragma: no branch |
|
| 345 |
+ response_code, int | _types.SSH_AGENT |
|
| 346 |
+ ): |
|
| 347 |
+ response_code = frozenset({response_code})
|
|
| 348 |
+ if response_code is not None: # pragma: no branch |
|
| 349 |
+ response_code = frozenset({
|
|
| 350 |
+ c if isinstance(c, int) else c.value for c in response_code |
|
| 351 |
+ }) |
|
| 352 |
+ request_message = bytearray([ |
|
| 353 |
+ code if isinstance(code, int) else code.value |
|
| 354 |
+ ]) |
|
| 297 | 355 |
request_message.extend(payload) |
| 298 | 356 |
self._connection.sendall(self.string(request_message)) |
| 299 | 357 |
chunk = self._connection.recv(HEAD_LEN) |
| ... | ... |
@@ -305,7 +363,11 @@ class SSHAgentClient: |
| 305 | 363 |
if len(response) < response_length: |
| 306 | 364 |
msg = 'truncated response from SSH agent' |
| 307 | 365 |
raise EOFError(msg) |
| 366 |
+ if not response_code: # pragma: no cover |
|
| 308 | 367 |
return response[0], response[1:] |
| 368 |
+ if response[0] not in response_code: |
|
| 369 |
+ raise SSHAgentFailedError(response[0], response[1:]) |
|
| 370 |
+ return response[1:] |
|
| 309 | 371 |
|
| 310 | 372 |
def list_keys(self) -> Sequence[_types.KeyCommentPair]: |
| 311 | 373 |
"""Request a list of keys known to the SSH agent. |
| ... | ... |
@@ -316,17 +378,19 @@ class SSHAgentClient: |
| 316 | 378 |
Raises: |
| 317 | 379 |
EOFError: |
| 318 | 380 |
The response from the SSH agent is truncated or missing. |
| 381 |
+ OSError: |
|
| 382 |
+ There was a communication error with the SSH agent. |
|
| 319 | 383 |
TrailingDataError: |
| 320 | 384 |
The response from the SSH agent is too long. |
| 321 | 385 |
SSHAgentFailedError: |
| 322 | 386 |
The agent failed to complete the request. |
| 323 | 387 |
|
| 324 | 388 |
""" |
| 325 |
- response_code, response = self.request( |
|
| 326 |
- _types.SSH_AGENTC.REQUEST_IDENTITIES.value, b'' |
|
| 389 |
+ response = self.request( |
|
| 390 |
+ _types.SSH_AGENTC.REQUEST_IDENTITIES.value, |
|
| 391 |
+ b'', |
|
| 392 |
+ response_code=_types.SSH_AGENT.IDENTITIES_ANSWER, |
|
| 327 | 393 |
) |
| 328 |
- if response_code != _types.SSH_AGENT.IDENTITIES_ANSWER.value: |
|
| 329 |
- raise SSHAgentFailedError(response_code, response) |
|
| 330 | 394 |
response_stream = collections.deque(response) |
| 331 | 395 |
|
| 332 | 396 |
def shift(num: int) -> bytes: |
| ... | ... |
@@ -390,6 +454,8 @@ class SSHAgentClient: |
| 390 | 454 |
Raises: |
| 391 | 455 |
EOFError: |
| 392 | 456 |
The response from the SSH agent is truncated or missing. |
| 457 |
+ OSError: |
|
| 458 |
+ There was a communication error with the SSH agent. |
|
| 393 | 459 |
TrailingDataError: |
| 394 | 460 |
The response from the SSH agent is too long. |
| 395 | 461 |
SSHAgentFailedError: |
| ... | ... |
@@ -407,9 +473,10 @@ class SSHAgentClient: |
| 407 | 473 |
request_data = bytearray(self.string(key)) |
| 408 | 474 |
request_data.extend(self.string(payload)) |
| 409 | 475 |
request_data.extend(self.uint32(flags)) |
| 410 |
- response_code, response = self.request( |
|
| 411 |
- _types.SSH_AGENTC.SIGN_REQUEST.value, request_data |
|
| 476 |
+ return self.unstring( |
|
| 477 |
+ self.request( |
|
| 478 |
+ _types.SSH_AGENTC.SIGN_REQUEST.value, |
|
| 479 |
+ request_data, |
|
| 480 |
+ response_code=_types.SSH_AGENT.SIGN_RESPONSE, |
|
| 481 |
+ ) |
|
| 412 | 482 |
) |
| 413 |
- if response_code != _types.SSH_AGENT.SIGN_RESPONSE.value: |
|
| 414 |
- raise SSHAgentFailedError(response_code, response) |
|
| 415 |
- return self.unstring(response) |
| ... | ... |
@@ -22,7 +22,7 @@ import tests |
| 22 | 22 |
from derivepassphrase import _types, cli, ssh_agent, vault |
| 23 | 23 |
|
| 24 | 24 |
if TYPE_CHECKING: |
| 25 |
- from collections.abc import Iterator |
|
| 25 |
+ from collections.abc import Iterable, Iterator |
|
| 26 | 26 |
|
| 27 | 27 |
|
| 28 | 28 |
class TestStaticFunctionality: |
| ... | ... |
@@ -387,30 +387,66 @@ class TestAgentInteraction: |
| 387 | 387 |
exc_type: type[Exception], |
| 388 | 388 |
exc_pattern: str, |
| 389 | 389 |
) -> None: |
| 390 |
- client = ssh_agent.SSHAgentClient() |
|
| 391 |
- monkeypatch.setattr( |
|
| 392 |
- client, |
|
| 393 |
- 'request', |
|
| 394 |
- lambda *a, **kw: (response_code.value, response), # noqa: ARG005 |
|
| 390 |
+ passed_response_code = response_code |
|
| 391 |
+ |
|
| 392 |
+ def request( |
|
| 393 |
+ request_code: int | _types.SSH_AGENTC, |
|
| 394 |
+ payload: bytes | bytearray, |
|
| 395 |
+ /, |
|
| 396 |
+ *, |
|
| 397 |
+ response_code: Iterable[int | _types.SSH_AGENT] |
|
| 398 |
+ | int |
|
| 399 |
+ | _types.SSH_AGENT |
|
| 400 |
+ | None = None, |
|
| 401 |
+ ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
| 402 |
+ del request_code |
|
| 403 |
+ del payload |
|
| 404 |
+ if isinstance( # pragma: no branch |
|
| 405 |
+ response_code, int | _types.SSH_AGENT |
|
| 406 |
+ ): |
|
| 407 |
+ response_code = frozenset({response_code})
|
|
| 408 |
+ if response_code is not None: # pragma: no branch |
|
| 409 |
+ response_code = frozenset({
|
|
| 410 |
+ c if isinstance(c, int) else c.value for c in response_code |
|
| 411 |
+ }) |
|
| 412 |
+ |
|
| 413 |
+ if not response_code: # pragma: no cover |
|
| 414 |
+ return (passed_response_code.value, response) |
|
| 415 |
+ if passed_response_code.value not in response_code: |
|
| 416 |
+ raise ssh_agent.SSHAgentFailedError( |
|
| 417 |
+ passed_response_code.value, response |
|
| 395 | 418 |
) |
| 419 |
+ return response |
|
| 420 |
+ |
|
| 421 |
+ client = ssh_agent.SSHAgentClient() |
|
| 422 |
+ monkeypatch.setattr(client, 'request', request) |
|
| 396 | 423 |
with pytest.raises(exc_type, match=exc_pattern): |
| 397 | 424 |
client.list_keys() |
| 398 | 425 |
|
| 399 | 426 |
@tests.skip_if_no_agent |
| 400 | 427 |
@pytest.mark.parametrize( |
| 401 |
- ['key', 'check', 'response', 'exc_type', 'exc_pattern'], |
|
| 428 |
+ [ |
|
| 429 |
+ 'key', |
|
| 430 |
+ 'check', |
|
| 431 |
+ 'response_code', |
|
| 432 |
+ 'response', |
|
| 433 |
+ 'exc_type', |
|
| 434 |
+ 'exc_pattern', |
|
| 435 |
+ ], |
|
| 402 | 436 |
[ |
| 403 | 437 |
( |
| 404 | 438 |
b'invalid-key', |
| 405 | 439 |
True, |
| 406 |
- (_types.SSH_AGENT.FAILURE, b''), |
|
| 440 |
+ _types.SSH_AGENT.FAILURE, |
|
| 441 |
+ b'', |
|
| 407 | 442 |
KeyError, |
| 408 | 443 |
'target SSH key not loaded into agent', |
| 409 | 444 |
), |
| 410 | 445 |
( |
| 411 | 446 |
tests.SUPPORTED_KEYS['ed25519']['public_key_data'], |
| 412 | 447 |
True, |
| 413 |
- (_types.SSH_AGENT.FAILURE, b''), |
|
| 448 |
+ _types.SSH_AGENT.FAILURE, |
|
| 449 |
+ b'', |
|
| 414 | 450 |
ssh_agent.SSHAgentFailedError, |
| 415 | 451 |
'failed to complete the request', |
| 416 | 452 |
), |
| ... | ... |
@@ -421,16 +457,46 @@ class TestAgentInteraction: |
| 421 | 457 |
monkeypatch: pytest.MonkeyPatch, |
| 422 | 458 |
key: bytes | bytearray, |
| 423 | 459 |
check: bool, |
| 424 |
- response: tuple[_types.SSH_AGENT, bytes | bytearray], |
|
| 460 |
+ response_code: _types.SSH_AGENT, |
|
| 461 |
+ response: bytes | bytearray, |
|
| 425 | 462 |
exc_type: type[Exception], |
| 426 | 463 |
exc_pattern: str, |
| 427 | 464 |
) -> None: |
| 428 |
- client = ssh_agent.SSHAgentClient() |
|
| 429 |
- monkeypatch.setattr( |
|
| 430 |
- client, |
|
| 431 |
- 'request', |
|
| 432 |
- lambda a, b: (response[0].value, response[1]), # noqa: ARG005 |
|
| 465 |
+ passed_response_code = response_code |
|
| 466 |
+ |
|
| 467 |
+ def request( |
|
| 468 |
+ request_code: int | _types.SSH_AGENTC, |
|
| 469 |
+ payload: bytes | bytearray, |
|
| 470 |
+ /, |
|
| 471 |
+ *, |
|
| 472 |
+ response_code: Iterable[int | _types.SSH_AGENT] |
|
| 473 |
+ | int |
|
| 474 |
+ | _types.SSH_AGENT |
|
| 475 |
+ | None = None, |
|
| 476 |
+ ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
| 477 |
+ del request_code |
|
| 478 |
+ del payload |
|
| 479 |
+ if isinstance( # pragma: no branch |
|
| 480 |
+ response_code, int | _types.SSH_AGENT |
|
| 481 |
+ ): |
|
| 482 |
+ response_code = frozenset({response_code})
|
|
| 483 |
+ if response_code is not None: # pragma: no branch |
|
| 484 |
+ response_code = frozenset({
|
|
| 485 |
+ c if isinstance(c, int) else c.value for c in response_code |
|
| 486 |
+ }) |
|
| 487 |
+ |
|
| 488 |
+ if not response_code: # pragma: no cover |
|
| 489 |
+ return (passed_response_code.value, response) |
|
| 490 |
+ if ( |
|
| 491 |
+ passed_response_code.value not in response_code |
|
| 492 |
+ ): # pragma: no branch |
|
| 493 |
+ raise ssh_agent.SSHAgentFailedError( |
|
| 494 |
+ passed_response_code.value, response |
|
| 433 | 495 |
) |
| 496 |
+ return response # pragma: no cover |
|
| 497 |
+ |
|
| 498 |
+ client = ssh_agent.SSHAgentClient() |
|
| 499 |
+ monkeypatch.setattr(client, 'request', request) |
|
| 434 | 500 |
KeyCommentPair = _types.KeyCommentPair # noqa: N806 |
| 435 | 501 |
loaded_keys = [ |
| 436 | 502 |
KeyCommentPair(v['public_key_data'], b'no comment') |
| ... | ... |
@@ -439,3 +505,28 @@ class TestAgentInteraction: |
| 439 | 505 |
monkeypatch.setattr(client, 'list_keys', lambda: loaded_keys) |
| 440 | 506 |
with pytest.raises(exc_type, match=exc_pattern): |
| 441 | 507 |
client.sign(key, b'abc', check_if_key_loaded=check) |
| 508 |
+ |
|
| 509 |
+ @tests.skip_if_no_agent |
|
| 510 |
+ @pytest.mark.parametrize( |
|
| 511 |
+ ['request_code', 'response_code', 'exc_type', 'exc_pattern'], |
|
| 512 |
+ [ |
|
| 513 |
+ ( |
|
| 514 |
+ _types.SSH_AGENTC.REQUEST_IDENTITIES, |
|
| 515 |
+ _types.SSH_AGENT.SUCCESS, |
|
| 516 |
+ ssh_agent.SSHAgentFailedError, |
|
| 517 |
+ f'[Code {_types.SSH_AGENT.IDENTITIES_ANSWER.value}]',
|
|
| 518 |
+ ), |
|
| 519 |
+ ], |
|
| 520 |
+ ) |
|
| 521 |
+ def test_340_request_error_responses( |
|
| 522 |
+ self, |
|
| 523 |
+ request_code: _types.SSH_AGENTC, |
|
| 524 |
+ response_code: _types.SSH_AGENT, |
|
| 525 |
+ exc_type: type[Exception], |
|
| 526 |
+ exc_pattern: str, |
|
| 527 |
+ ) -> None: |
|
| 528 |
+ with ( |
|
| 529 |
+ pytest.raises(exc_type, match=exc_pattern), |
|
| 530 |
+ ssh_agent.SSHAgentClient() as client, |
|
| 531 |
+ ): |
|
| 532 |
+ client.request(request_code, b'', response_code=response_code) |
|
| 442 | 533 |