Support passing expected SSH agent response codes
Marco Ricci

Marco Ricci commited on 2024-09-21 11:45:20
Zeige 2 geänderte Dateien mit 187 Einfügungen und 29 Löschungen.


This shifts the remaining error checking into the
`SSHAgentClient.request` method, most of the time.  On the other hand,
this makes mocking that method somewhat more involved.
... ...
@@ -10,14 +10,14 @@ import collections
10 10
 import errno
11 11
 import os
12 12
 import socket
13
-from typing import TYPE_CHECKING
13
+from typing import TYPE_CHECKING, overload
14 14
 
15 15
 from typing_extensions import Self
16 16
 
17 17
 from derivepassphrase import _types
18 18
 
19 19
 if TYPE_CHECKING:
20
-    from collections.abc import Sequence
20
+    from collections.abc import Iterable, Sequence
21 21
     from types import TracebackType
22 22
 
23 23
 __all__ = ('SSHAgentClient',)
... ...
@@ -268,9 +268,48 @@ class SSHAgentClient:
268 268
             bytestring[m + HEAD_LEN :],
269 269
         )
270 270
 
271
+    @overload
272
+    def request(  # pragma: no cover
273
+        self,
274
+        code: int | _types.SSH_AGENTC,
275
+        payload: bytes | bytearray,
276
+        /,
277
+        *,
278
+        response_code: None = None,
279
+    ) -> tuple[int, bytes | bytearray]: ...
280
+
281
+    @overload
282
+    def request(  # pragma: no cover
283
+        self,
284
+        code: int | _types.SSH_AGENTC,
285
+        payload: bytes | bytearray,
286
+        /,
287
+        *,
288
+        response_code: Iterable[_types.SSH_AGENT | int] = frozenset({
289
+            _types.SSH_AGENT.SUCCESS
290
+        }),
291
+    ) -> tuple[int, bytes | bytearray]: ...
292
+
293
+    @overload
294
+    def request(  # pragma: no cover
295
+        self,
296
+        code: int | _types.SSH_AGENTC,
297
+        payload: bytes | bytearray,
298
+        /,
299
+        *,
300
+        response_code: _types.SSH_AGENT | int = _types.SSH_AGENT.SUCCESS,
301
+    ) -> bytes | bytearray: ...
302
+
271 303
     def request(
272
-        self, code: int, payload: bytes | bytearray, /
273
-    ) -> tuple[int, bytes | bytearray]:
304
+        self,
305
+        code: int | _types.SSH_AGENTC,
306
+        payload: bytes | bytearray,
307
+        /,
308
+        *,
309
+        response_code: (
310
+            Iterable[_types.SSH_AGENT | int] | _types.SSH_AGENT | int | None
311
+        ) = None,
312
+    ) -> tuple[int, bytes | bytearray] | bytes | bytearray:
274 313
         """Issue a generic request to the SSH agent.
275 314
 
276 315
         Args:
... ...
@@ -283,6 +322,10 @@ class SSHAgentClient:
283 322
                 the request.  Request-specific.  `request` will add any
284 323
                 necessary wire framing around the request code and the
285 324
                 payload.
325
+            response_code:
326
+                An optional response code, or a set of response codes,
327
+                that we expect.  If given, and the actual response code
328
+                does not match, raise an error.
286 329
 
287 330
         Returns:
288 331
             A 2-tuple consisting of the response code and the payload,
... ...
@@ -291,9 +334,24 @@ class SSHAgentClient:
291 334
         Raises:
292 335
             EOFError:
293 336
                 The response from the SSH agent is truncated or missing.
337
+            OSError:
338
+                There was a communication error with the SSH agent.
339
+            SSHAgentFailedError:
340
+                We expected specific response codes, but did not receive
341
+                any of them.
294 342
 
295 343
         """
296
-        request_message = bytearray([code])
344
+        if isinstance(  # pragma: no branch
345
+            response_code, int | _types.SSH_AGENT
346
+        ):
347
+            response_code = frozenset({response_code})
348
+        if response_code is not None:  # pragma: no branch
349
+            response_code = frozenset({
350
+                c if isinstance(c, int) else c.value for c in response_code
351
+            })
352
+        request_message = bytearray([
353
+            code if isinstance(code, int) else code.value
354
+        ])
297 355
         request_message.extend(payload)
298 356
         self._connection.sendall(self.string(request_message))
299 357
         chunk = self._connection.recv(HEAD_LEN)
... ...
@@ -305,7 +363,11 @@ class SSHAgentClient:
305 363
         if len(response) < response_length:
306 364
             msg = 'truncated response from SSH agent'
307 365
             raise EOFError(msg)
366
+        if not response_code:  # pragma: no cover
308 367
             return response[0], response[1:]
368
+        if response[0] not in response_code:
369
+            raise SSHAgentFailedError(response[0], response[1:])
370
+        return response[1:]
309 371
 
310 372
     def list_keys(self) -> Sequence[_types.KeyCommentPair]:
311 373
         """Request a list of keys known to the SSH agent.
... ...
@@ -316,17 +378,19 @@ class SSHAgentClient:
316 378
         Raises:
317 379
             EOFError:
318 380
                 The response from the SSH agent is truncated or missing.
381
+            OSError:
382
+                There was a communication error with the SSH agent.
319 383
             TrailingDataError:
320 384
                 The response from the SSH agent is too long.
321 385
             SSHAgentFailedError:
322 386
                 The agent failed to complete the request.
323 387
 
324 388
         """
325
-        response_code, response = self.request(
326
-            _types.SSH_AGENTC.REQUEST_IDENTITIES.value, b''
389
+        response = self.request(
390
+            _types.SSH_AGENTC.REQUEST_IDENTITIES.value,
391
+            b'',
392
+            response_code=_types.SSH_AGENT.IDENTITIES_ANSWER,
327 393
         )
328
-        if response_code != _types.SSH_AGENT.IDENTITIES_ANSWER.value:
329
-            raise SSHAgentFailedError(response_code, response)
330 394
         response_stream = collections.deque(response)
331 395
 
332 396
         def shift(num: int) -> bytes:
... ...
@@ -390,6 +454,8 @@ class SSHAgentClient:
390 454
         Raises:
391 455
             EOFError:
392 456
                 The response from the SSH agent is truncated or missing.
457
+            OSError:
458
+                There was a communication error with the SSH agent.
393 459
             TrailingDataError:
394 460
                 The response from the SSH agent is too long.
395 461
             SSHAgentFailedError:
... ...
@@ -407,9 +473,10 @@ class SSHAgentClient:
407 473
         request_data = bytearray(self.string(key))
408 474
         request_data.extend(self.string(payload))
409 475
         request_data.extend(self.uint32(flags))
410
-        response_code, response = self.request(
411
-            _types.SSH_AGENTC.SIGN_REQUEST.value, request_data
476
+        return self.unstring(
477
+            self.request(
478
+                _types.SSH_AGENTC.SIGN_REQUEST.value,
479
+                request_data,
480
+                response_code=_types.SSH_AGENT.SIGN_RESPONSE,
481
+            )
412 482
         )
413
-        if response_code != _types.SSH_AGENT.SIGN_RESPONSE.value:
414
-            raise SSHAgentFailedError(response_code, response)
415
-        return self.unstring(response)
... ...
@@ -22,7 +22,7 @@ import tests
22 22
 from derivepassphrase import _types, cli, ssh_agent, vault
23 23
 
24 24
 if TYPE_CHECKING:
25
-    from collections.abc import Iterator
25
+    from collections.abc import Iterable, Iterator
26 26
 
27 27
 
28 28
 class TestStaticFunctionality:
... ...
@@ -387,30 +387,66 @@ class TestAgentInteraction:
387 387
         exc_type: type[Exception],
388 388
         exc_pattern: str,
389 389
     ) -> None:
390
-        client = ssh_agent.SSHAgentClient()
391
-        monkeypatch.setattr(
392
-            client,
393
-            'request',
394
-            lambda *a, **kw: (response_code.value, response),  # noqa: ARG005
390
+        passed_response_code = response_code
391
+
392
+        def request(
393
+            request_code: int | _types.SSH_AGENTC,
394
+            payload: bytes | bytearray,
395
+            /,
396
+            *,
397
+            response_code: Iterable[int | _types.SSH_AGENT]
398
+            | int
399
+            | _types.SSH_AGENT
400
+            | None = None,
401
+        ) -> tuple[int, bytes | bytearray] | bytes | bytearray:
402
+            del request_code
403
+            del payload
404
+            if isinstance(  # pragma: no branch
405
+                response_code, int | _types.SSH_AGENT
406
+            ):
407
+                response_code = frozenset({response_code})
408
+            if response_code is not None:  # pragma: no branch
409
+                response_code = frozenset({
410
+                    c if isinstance(c, int) else c.value for c in response_code
411
+                })
412
+
413
+            if not response_code:  # pragma: no cover
414
+                return (passed_response_code.value, response)
415
+            if passed_response_code.value not in response_code:
416
+                raise ssh_agent.SSHAgentFailedError(
417
+                    passed_response_code.value, response
395 418
                 )
419
+            return response
420
+
421
+        client = ssh_agent.SSHAgentClient()
422
+        monkeypatch.setattr(client, 'request', request)
396 423
         with pytest.raises(exc_type, match=exc_pattern):
397 424
             client.list_keys()
398 425
 
399 426
     @tests.skip_if_no_agent
400 427
     @pytest.mark.parametrize(
401
-        ['key', 'check', 'response', 'exc_type', 'exc_pattern'],
428
+        [
429
+            'key',
430
+            'check',
431
+            'response_code',
432
+            'response',
433
+            'exc_type',
434
+            'exc_pattern',
435
+        ],
402 436
         [
403 437
             (
404 438
                 b'invalid-key',
405 439
                 True,
406
-                (_types.SSH_AGENT.FAILURE, b''),
440
+                _types.SSH_AGENT.FAILURE,
441
+                b'',
407 442
                 KeyError,
408 443
                 'target SSH key not loaded into agent',
409 444
             ),
410 445
             (
411 446
                 tests.SUPPORTED_KEYS['ed25519']['public_key_data'],
412 447
                 True,
413
-                (_types.SSH_AGENT.FAILURE, b''),
448
+                _types.SSH_AGENT.FAILURE,
449
+                b'',
414 450
                 ssh_agent.SSHAgentFailedError,
415 451
                 'failed to complete the request',
416 452
             ),
... ...
@@ -421,16 +457,46 @@ class TestAgentInteraction:
421 457
         monkeypatch: pytest.MonkeyPatch,
422 458
         key: bytes | bytearray,
423 459
         check: bool,
424
-        response: tuple[_types.SSH_AGENT, bytes | bytearray],
460
+        response_code: _types.SSH_AGENT,
461
+        response: bytes | bytearray,
425 462
         exc_type: type[Exception],
426 463
         exc_pattern: str,
427 464
     ) -> None:
428
-        client = ssh_agent.SSHAgentClient()
429
-        monkeypatch.setattr(
430
-            client,
431
-            'request',
432
-            lambda a, b: (response[0].value, response[1]),  # noqa: ARG005
465
+        passed_response_code = response_code
466
+
467
+        def request(
468
+            request_code: int | _types.SSH_AGENTC,
469
+            payload: bytes | bytearray,
470
+            /,
471
+            *,
472
+            response_code: Iterable[int | _types.SSH_AGENT]
473
+            | int
474
+            | _types.SSH_AGENT
475
+            | None = None,
476
+        ) -> tuple[int, bytes | bytearray] | bytes | bytearray:
477
+            del request_code
478
+            del payload
479
+            if isinstance(  # pragma: no branch
480
+                response_code, int | _types.SSH_AGENT
481
+            ):
482
+                response_code = frozenset({response_code})
483
+            if response_code is not None:  # pragma: no branch
484
+                response_code = frozenset({
485
+                    c if isinstance(c, int) else c.value for c in response_code
486
+                })
487
+
488
+            if not response_code:  # pragma: no cover
489
+                return (passed_response_code.value, response)
490
+            if (
491
+                passed_response_code.value not in response_code
492
+            ):  # pragma: no branch
493
+                raise ssh_agent.SSHAgentFailedError(
494
+                    passed_response_code.value, response
433 495
                 )
496
+            return response  # pragma: no cover
497
+
498
+        client = ssh_agent.SSHAgentClient()
499
+        monkeypatch.setattr(client, 'request', request)
434 500
         KeyCommentPair = _types.KeyCommentPair  # noqa: N806
435 501
         loaded_keys = [
436 502
             KeyCommentPair(v['public_key_data'], b'no comment')
... ...
@@ -439,3 +505,28 @@ class TestAgentInteraction:
439 505
         monkeypatch.setattr(client, 'list_keys', lambda: loaded_keys)
440 506
         with pytest.raises(exc_type, match=exc_pattern):
441 507
             client.sign(key, b'abc', check_if_key_loaded=check)
508
+
509
+    @tests.skip_if_no_agent
510
+    @pytest.mark.parametrize(
511
+        ['request_code', 'response_code', 'exc_type', 'exc_pattern'],
512
+        [
513
+            (
514
+                _types.SSH_AGENTC.REQUEST_IDENTITIES,
515
+                _types.SSH_AGENT.SUCCESS,
516
+                ssh_agent.SSHAgentFailedError,
517
+                f'[Code {_types.SSH_AGENT.IDENTITIES_ANSWER.value}]',
518
+            ),
519
+        ],
520
+    )
521
+    def test_340_request_error_responses(
522
+        self,
523
+        request_code: _types.SSH_AGENTC,
524
+        response_code: _types.SSH_AGENT,
525
+        exc_type: type[Exception],
526
+        exc_pattern: str,
527
+    ) -> None:
528
+        with (
529
+            pytest.raises(exc_type, match=exc_pattern),
530
+            ssh_agent.SSHAgentClient() as client,
531
+        ):
532
+            client.request(request_code, b'', response_code=response_code)
442 533