Marco Ricci commited on 2024-09-21 11:45:20
              Zeige 2 geänderte Dateien mit 187 Einfügungen und 29 Löschungen.
            
This shifts the remaining error checking into the `SSHAgentClient.request` method, most of the time. On the other hand, this makes mocking that method somewhat more involved.
| ... | ... | 
                      @@ -10,14 +10,14 @@ import collections  | 
                  
| 10 | 10 | 
                        import errno  | 
                    
| 11 | 11 | 
                        import os  | 
                    
| 12 | 12 | 
                        import socket  | 
                    
| 13 | 
                        -from typing import TYPE_CHECKING  | 
                    |
| 13 | 
                        +from typing import TYPE_CHECKING, overload  | 
                    |
| 14 | 14 | 
                         | 
                    
| 15 | 15 | 
                        from typing_extensions import Self  | 
                    
| 16 | 16 | 
                         | 
                    
| 17 | 17 | 
                        from derivepassphrase import _types  | 
                    
| 18 | 18 | 
                         | 
                    
| 19 | 19 | 
                        if TYPE_CHECKING:  | 
                    
| 20 | 
                        - from collections.abc import Sequence  | 
                    |
| 20 | 
                        + from collections.abc import Iterable, Sequence  | 
                    |
| 21 | 21 | 
                        from types import TracebackType  | 
                    
| 22 | 22 | 
                         | 
                    
| 23 | 23 | 
                         __all__ = ('SSHAgentClient',)
                       | 
                    
| ... | ... | 
                      @@ -268,9 +268,48 @@ class SSHAgentClient:  | 
                  
| 268 | 268 | 
                        bytestring[m + HEAD_LEN :],  | 
                    
| 269 | 269 | 
                        )  | 
                    
| 270 | 270 | 
                         | 
                    
| 271 | 
                        + @overload  | 
                    |
| 272 | 
                        + def request( # pragma: no cover  | 
                    |
| 273 | 
                        + self,  | 
                    |
| 274 | 
                        + code: int | _types.SSH_AGENTC,  | 
                    |
| 275 | 
                        + payload: bytes | bytearray,  | 
                    |
| 276 | 
                        + /,  | 
                    |
| 277 | 
                        + *,  | 
                    |
| 278 | 
                        + response_code: None = None,  | 
                    |
| 279 | 
                        + ) -> tuple[int, bytes | bytearray]: ...  | 
                    |
| 280 | 
                        +  | 
                    |
| 281 | 
                        + @overload  | 
                    |
| 282 | 
                        + def request( # pragma: no cover  | 
                    |
| 283 | 
                        + self,  | 
                    |
| 284 | 
                        + code: int | _types.SSH_AGENTC,  | 
                    |
| 285 | 
                        + payload: bytes | bytearray,  | 
                    |
| 286 | 
                        + /,  | 
                    |
| 287 | 
                        + *,  | 
                    |
| 288 | 
                        +        response_code: Iterable[_types.SSH_AGENT | int] = frozenset({
                       | 
                    |
| 289 | 
                        + _types.SSH_AGENT.SUCCESS  | 
                    |
| 290 | 
                        + }),  | 
                    |
| 291 | 
                        + ) -> tuple[int, bytes | bytearray]: ...  | 
                    |
| 292 | 
                        +  | 
                    |
| 293 | 
                        + @overload  | 
                    |
| 294 | 
                        + def request( # pragma: no cover  | 
                    |
| 295 | 
                        + self,  | 
                    |
| 296 | 
                        + code: int | _types.SSH_AGENTC,  | 
                    |
| 297 | 
                        + payload: bytes | bytearray,  | 
                    |
| 298 | 
                        + /,  | 
                    |
| 299 | 
                        + *,  | 
                    |
| 300 | 
                        + response_code: _types.SSH_AGENT | int = _types.SSH_AGENT.SUCCESS,  | 
                    |
| 301 | 
                        + ) -> bytes | bytearray: ...  | 
                    |
| 302 | 
                        +  | 
                    |
| 271 | 303 | 
                        def request(  | 
                    
| 272 | 
                        - self, code: int, payload: bytes | bytearray, /  | 
                    |
| 273 | 
                        - ) -> tuple[int, bytes | bytearray]:  | 
                    |
| 304 | 
                        + self,  | 
                    |
| 305 | 
                        + code: int | _types.SSH_AGENTC,  | 
                    |
| 306 | 
                        + payload: bytes | bytearray,  | 
                    |
| 307 | 
                        + /,  | 
                    |
| 308 | 
                        + *,  | 
                    |
| 309 | 
                        + response_code: (  | 
                    |
| 310 | 
                        + Iterable[_types.SSH_AGENT | int] | _types.SSH_AGENT | int | None  | 
                    |
| 311 | 
                        + ) = None,  | 
                    |
| 312 | 
                        + ) -> tuple[int, bytes | bytearray] | bytes | bytearray:  | 
                    |
| 274 | 313 | 
                        """Issue a generic request to the SSH agent.  | 
                    
| 275 | 314 | 
                         | 
                    
| 276 | 315 | 
                        Args:  | 
                    
| ... | ... | 
                      @@ -283,6 +322,10 @@ class SSHAgentClient:  | 
                  
| 283 | 322 | 
                        the request. Request-specific. `request` will add any  | 
                    
| 284 | 323 | 
                        necessary wire framing around the request code and the  | 
                    
| 285 | 324 | 
                        payload.  | 
                    
| 325 | 
                        + response_code:  | 
                    |
| 326 | 
                        + An optional response code, or a set of response codes,  | 
                    |
| 327 | 
                        + that we expect. If given, and the actual response code  | 
                    |
| 328 | 
                        + does not match, raise an error.  | 
                    |
| 286 | 329 | 
                         | 
                    
| 287 | 330 | 
                        Returns:  | 
                    
| 288 | 331 | 
                        A 2-tuple consisting of the response code and the payload,  | 
                    
| ... | ... | 
                      @@ -291,9 +334,24 @@ class SSHAgentClient:  | 
                  
| 291 | 334 | 
                        Raises:  | 
                    
| 292 | 335 | 
                        EOFError:  | 
                    
| 293 | 336 | 
                        The response from the SSH agent is truncated or missing.  | 
                    
| 337 | 
                        + OSError:  | 
                    |
| 338 | 
                        + There was a communication error with the SSH agent.  | 
                    |
| 339 | 
                        + SSHAgentFailedError:  | 
                    |
| 340 | 
                        + We expected specific response codes, but did not receive  | 
                    |
| 341 | 
                        + any of them.  | 
                    |
| 294 | 342 | 
                         | 
                    
| 295 | 343 | 
                        """  | 
                    
| 296 | 
                        - request_message = bytearray([code])  | 
                    |
| 344 | 
                        + if isinstance( # pragma: no branch  | 
                    |
| 345 | 
                        + response_code, int | _types.SSH_AGENT  | 
                    |
| 346 | 
                        + ):  | 
                    |
| 347 | 
                        +            response_code = frozenset({response_code})
                       | 
                    |
| 348 | 
                        + if response_code is not None: # pragma: no branch  | 
                    |
| 349 | 
                        +            response_code = frozenset({
                       | 
                    |
| 350 | 
                        + c if isinstance(c, int) else c.value for c in response_code  | 
                    |
| 351 | 
                        + })  | 
                    |
| 352 | 
                        + request_message = bytearray([  | 
                    |
| 353 | 
                        + code if isinstance(code, int) else code.value  | 
                    |
| 354 | 
                        + ])  | 
                    |
| 297 | 355 | 
                        request_message.extend(payload)  | 
                    
| 298 | 356 | 
                        self._connection.sendall(self.string(request_message))  | 
                    
| 299 | 357 | 
                        chunk = self._connection.recv(HEAD_LEN)  | 
                    
| ... | ... | 
                      @@ -305,7 +363,11 @@ class SSHAgentClient:  | 
                  
| 305 | 363 | 
                        if len(response) < response_length:  | 
                    
| 306 | 364 | 
                        msg = 'truncated response from SSH agent'  | 
                    
| 307 | 365 | 
                        raise EOFError(msg)  | 
                    
| 366 | 
                        + if not response_code: # pragma: no cover  | 
                    |
| 308 | 367 | 
                        return response[0], response[1:]  | 
                    
| 368 | 
                        + if response[0] not in response_code:  | 
                    |
| 369 | 
                        + raise SSHAgentFailedError(response[0], response[1:])  | 
                    |
| 370 | 
                        + return response[1:]  | 
                    |
| 309 | 371 | 
                         | 
                    
| 310 | 372 | 
                        def list_keys(self) -> Sequence[_types.KeyCommentPair]:  | 
                    
| 311 | 373 | 
                        """Request a list of keys known to the SSH agent.  | 
                    
| ... | ... | 
                      @@ -316,17 +378,19 @@ class SSHAgentClient:  | 
                  
| 316 | 378 | 
                        Raises:  | 
                    
| 317 | 379 | 
                        EOFError:  | 
                    
| 318 | 380 | 
                        The response from the SSH agent is truncated or missing.  | 
                    
| 381 | 
                        + OSError:  | 
                    |
| 382 | 
                        + There was a communication error with the SSH agent.  | 
                    |
| 319 | 383 | 
                        TrailingDataError:  | 
                    
| 320 | 384 | 
                        The response from the SSH agent is too long.  | 
                    
| 321 | 385 | 
                        SSHAgentFailedError:  | 
                    
| 322 | 386 | 
                        The agent failed to complete the request.  | 
                    
| 323 | 387 | 
                         | 
                    
| 324 | 388 | 
                        """  | 
                    
| 325 | 
                        - response_code, response = self.request(  | 
                    |
| 326 | 
                        - _types.SSH_AGENTC.REQUEST_IDENTITIES.value, b''  | 
                    |
| 389 | 
                        + response = self.request(  | 
                    |
| 390 | 
                        + _types.SSH_AGENTC.REQUEST_IDENTITIES.value,  | 
                    |
| 391 | 
                        + b'',  | 
                    |
| 392 | 
                        + response_code=_types.SSH_AGENT.IDENTITIES_ANSWER,  | 
                    |
| 327 | 393 | 
                        )  | 
                    
| 328 | 
                        - if response_code != _types.SSH_AGENT.IDENTITIES_ANSWER.value:  | 
                    |
| 329 | 
                        - raise SSHAgentFailedError(response_code, response)  | 
                    |
| 330 | 394 | 
                        response_stream = collections.deque(response)  | 
                    
| 331 | 395 | 
                         | 
                    
| 332 | 396 | 
                        def shift(num: int) -> bytes:  | 
                    
| ... | ... | 
                      @@ -390,6 +454,8 @@ class SSHAgentClient:  | 
                  
| 390 | 454 | 
                        Raises:  | 
                    
| 391 | 455 | 
                        EOFError:  | 
                    
| 392 | 456 | 
                        The response from the SSH agent is truncated or missing.  | 
                    
| 457 | 
                        + OSError:  | 
                    |
| 458 | 
                        + There was a communication error with the SSH agent.  | 
                    |
| 393 | 459 | 
                        TrailingDataError:  | 
                    
| 394 | 460 | 
                        The response from the SSH agent is too long.  | 
                    
| 395 | 461 | 
                        SSHAgentFailedError:  | 
                    
| ... | ... | 
                      @@ -407,9 +473,10 @@ class SSHAgentClient:  | 
                  
| 407 | 473 | 
                        request_data = bytearray(self.string(key))  | 
                    
| 408 | 474 | 
                        request_data.extend(self.string(payload))  | 
                    
| 409 | 475 | 
                        request_data.extend(self.uint32(flags))  | 
                    
| 410 | 
                        - response_code, response = self.request(  | 
                    |
| 411 | 
                        - _types.SSH_AGENTC.SIGN_REQUEST.value, request_data  | 
                    |
| 476 | 
                        + return self.unstring(  | 
                    |
| 477 | 
                        + self.request(  | 
                    |
| 478 | 
                        + _types.SSH_AGENTC.SIGN_REQUEST.value,  | 
                    |
| 479 | 
                        + request_data,  | 
                    |
| 480 | 
                        + response_code=_types.SSH_AGENT.SIGN_RESPONSE,  | 
                    |
| 481 | 
                        + )  | 
                    |
| 412 | 482 | 
                        )  | 
                    
| 413 | 
                        - if response_code != _types.SSH_AGENT.SIGN_RESPONSE.value:  | 
                    |
| 414 | 
                        - raise SSHAgentFailedError(response_code, response)  | 
                    |
| 415 | 
                        - return self.unstring(response)  | 
                    
| ... | ... | 
                      @@ -22,7 +22,7 @@ import tests  | 
                  
| 22 | 22 | 
                        from derivepassphrase import _types, cli, ssh_agent, vault  | 
                    
| 23 | 23 | 
                         | 
                    
| 24 | 24 | 
                        if TYPE_CHECKING:  | 
                    
| 25 | 
                        - from collections.abc import Iterator  | 
                    |
| 25 | 
                        + from collections.abc import Iterable, Iterator  | 
                    |
| 26 | 26 | 
                         | 
                    
| 27 | 27 | 
                         | 
                    
| 28 | 28 | 
                        class TestStaticFunctionality:  | 
                    
| ... | ... | 
                      @@ -387,30 +387,66 @@ class TestAgentInteraction:  | 
                  
| 387 | 387 | 
                        exc_type: type[Exception],  | 
                    
| 388 | 388 | 
                        exc_pattern: str,  | 
                    
| 389 | 389 | 
                        ) -> None:  | 
                    
| 390 | 
                        - client = ssh_agent.SSHAgentClient()  | 
                    |
| 391 | 
                        - monkeypatch.setattr(  | 
                    |
| 392 | 
                        - client,  | 
                    |
| 393 | 
                        - 'request',  | 
                    |
| 394 | 
                        - lambda *a, **kw: (response_code.value, response), # noqa: ARG005  | 
                    |
| 390 | 
                        + passed_response_code = response_code  | 
                    |
| 391 | 
                        +  | 
                    |
| 392 | 
                        + def request(  | 
                    |
| 393 | 
                        + request_code: int | _types.SSH_AGENTC,  | 
                    |
| 394 | 
                        + payload: bytes | bytearray,  | 
                    |
| 395 | 
                        + /,  | 
                    |
| 396 | 
                        + *,  | 
                    |
| 397 | 
                        + response_code: Iterable[int | _types.SSH_AGENT]  | 
                    |
| 398 | 
                        + | int  | 
                    |
| 399 | 
                        + | _types.SSH_AGENT  | 
                    |
| 400 | 
                        + | None = None,  | 
                    |
| 401 | 
                        + ) -> tuple[int, bytes | bytearray] | bytes | bytearray:  | 
                    |
| 402 | 
                        + del request_code  | 
                    |
| 403 | 
                        + del payload  | 
                    |
| 404 | 
                        + if isinstance( # pragma: no branch  | 
                    |
| 405 | 
                        + response_code, int | _types.SSH_AGENT  | 
                    |
| 406 | 
                        + ):  | 
                    |
| 407 | 
                        +                response_code = frozenset({response_code})
                       | 
                    |
| 408 | 
                        + if response_code is not None: # pragma: no branch  | 
                    |
| 409 | 
                        +                response_code = frozenset({
                       | 
                    |
| 410 | 
                        + c if isinstance(c, int) else c.value for c in response_code  | 
                    |
| 411 | 
                        + })  | 
                    |
| 412 | 
                        +  | 
                    |
| 413 | 
                        + if not response_code: # pragma: no cover  | 
                    |
| 414 | 
                        + return (passed_response_code.value, response)  | 
                    |
| 415 | 
                        + if passed_response_code.value not in response_code:  | 
                    |
| 416 | 
                        + raise ssh_agent.SSHAgentFailedError(  | 
                    |
| 417 | 
                        + passed_response_code.value, response  | 
                    |
| 395 | 418 | 
                        )  | 
                    
| 419 | 
                        + return response  | 
                    |
| 420 | 
                        +  | 
                    |
| 421 | 
                        + client = ssh_agent.SSHAgentClient()  | 
                    |
| 422 | 
                        + monkeypatch.setattr(client, 'request', request)  | 
                    |
| 396 | 423 | 
                        with pytest.raises(exc_type, match=exc_pattern):  | 
                    
| 397 | 424 | 
                        client.list_keys()  | 
                    
| 398 | 425 | 
                         | 
                    
| 399 | 426 | 
                        @tests.skip_if_no_agent  | 
                    
| 400 | 427 | 
                        @pytest.mark.parametrize(  | 
                    
| 401 | 
                        - ['key', 'check', 'response', 'exc_type', 'exc_pattern'],  | 
                    |
| 428 | 
                        + [  | 
                    |
| 429 | 
                        + 'key',  | 
                    |
| 430 | 
                        + 'check',  | 
                    |
| 431 | 
                        + 'response_code',  | 
                    |
| 432 | 
                        + 'response',  | 
                    |
| 433 | 
                        + 'exc_type',  | 
                    |
| 434 | 
                        + 'exc_pattern',  | 
                    |
| 435 | 
                        + ],  | 
                    |
| 402 | 436 | 
                        [  | 
                    
| 403 | 437 | 
                        (  | 
                    
| 404 | 438 | 
                        b'invalid-key',  | 
                    
| 405 | 439 | 
                        True,  | 
                    
| 406 | 
                        - (_types.SSH_AGENT.FAILURE, b''),  | 
                    |
| 440 | 
                        + _types.SSH_AGENT.FAILURE,  | 
                    |
| 441 | 
                        + b'',  | 
                    |
| 407 | 442 | 
                        KeyError,  | 
                    
| 408 | 443 | 
                        'target SSH key not loaded into agent',  | 
                    
| 409 | 444 | 
                        ),  | 
                    
| 410 | 445 | 
                        (  | 
                    
| 411 | 446 | 
                        tests.SUPPORTED_KEYS['ed25519']['public_key_data'],  | 
                    
| 412 | 447 | 
                        True,  | 
                    
| 413 | 
                        - (_types.SSH_AGENT.FAILURE, b''),  | 
                    |
| 448 | 
                        + _types.SSH_AGENT.FAILURE,  | 
                    |
| 449 | 
                        + b'',  | 
                    |
| 414 | 450 | 
                        ssh_agent.SSHAgentFailedError,  | 
                    
| 415 | 451 | 
                        'failed to complete the request',  | 
                    
| 416 | 452 | 
                        ),  | 
                    
| ... | ... | 
                      @@ -421,16 +457,46 @@ class TestAgentInteraction:  | 
                  
| 421 | 457 | 
                        monkeypatch: pytest.MonkeyPatch,  | 
                    
| 422 | 458 | 
                        key: bytes | bytearray,  | 
                    
| 423 | 459 | 
                        check: bool,  | 
                    
| 424 | 
                        - response: tuple[_types.SSH_AGENT, bytes | bytearray],  | 
                    |
| 460 | 
                        + response_code: _types.SSH_AGENT,  | 
                    |
| 461 | 
                        + response: bytes | bytearray,  | 
                    |
| 425 | 462 | 
                        exc_type: type[Exception],  | 
                    
| 426 | 463 | 
                        exc_pattern: str,  | 
                    
| 427 | 464 | 
                        ) -> None:  | 
                    
| 428 | 
                        - client = ssh_agent.SSHAgentClient()  | 
                    |
| 429 | 
                        - monkeypatch.setattr(  | 
                    |
| 430 | 
                        - client,  | 
                    |
| 431 | 
                        - 'request',  | 
                    |
| 432 | 
                        - lambda a, b: (response[0].value, response[1]), # noqa: ARG005  | 
                    |
| 465 | 
                        + passed_response_code = response_code  | 
                    |
| 466 | 
                        +  | 
                    |
| 467 | 
                        + def request(  | 
                    |
| 468 | 
                        + request_code: int | _types.SSH_AGENTC,  | 
                    |
| 469 | 
                        + payload: bytes | bytearray,  | 
                    |
| 470 | 
                        + /,  | 
                    |
| 471 | 
                        + *,  | 
                    |
| 472 | 
                        + response_code: Iterable[int | _types.SSH_AGENT]  | 
                    |
| 473 | 
                        + | int  | 
                    |
| 474 | 
                        + | _types.SSH_AGENT  | 
                    |
| 475 | 
                        + | None = None,  | 
                    |
| 476 | 
                        + ) -> tuple[int, bytes | bytearray] | bytes | bytearray:  | 
                    |
| 477 | 
                        + del request_code  | 
                    |
| 478 | 
                        + del payload  | 
                    |
| 479 | 
                        + if isinstance( # pragma: no branch  | 
                    |
| 480 | 
                        + response_code, int | _types.SSH_AGENT  | 
                    |
| 481 | 
                        + ):  | 
                    |
| 482 | 
                        +                response_code = frozenset({response_code})
                       | 
                    |
| 483 | 
                        + if response_code is not None: # pragma: no branch  | 
                    |
| 484 | 
                        +                response_code = frozenset({
                       | 
                    |
| 485 | 
                        + c if isinstance(c, int) else c.value for c in response_code  | 
                    |
| 486 | 
                        + })  | 
                    |
| 487 | 
                        +  | 
                    |
| 488 | 
                        + if not response_code: # pragma: no cover  | 
                    |
| 489 | 
                        + return (passed_response_code.value, response)  | 
                    |
| 490 | 
                        + if (  | 
                    |
| 491 | 
                        + passed_response_code.value not in response_code  | 
                    |
| 492 | 
                        + ): # pragma: no branch  | 
                    |
| 493 | 
                        + raise ssh_agent.SSHAgentFailedError(  | 
                    |
| 494 | 
                        + passed_response_code.value, response  | 
                    |
| 433 | 495 | 
                        )  | 
                    
| 496 | 
                        + return response # pragma: no cover  | 
                    |
| 497 | 
                        +  | 
                    |
| 498 | 
                        + client = ssh_agent.SSHAgentClient()  | 
                    |
| 499 | 
                        + monkeypatch.setattr(client, 'request', request)  | 
                    |
| 434 | 500 | 
                        KeyCommentPair = _types.KeyCommentPair # noqa: N806  | 
                    
| 435 | 501 | 
                        loaded_keys = [  | 
                    
| 436 | 502 | 
                        KeyCommentPair(v['public_key_data'], b'no comment')  | 
                    
| ... | ... | 
                      @@ -439,3 +505,28 @@ class TestAgentInteraction:  | 
                  
| 439 | 505 | 
                        monkeypatch.setattr(client, 'list_keys', lambda: loaded_keys)  | 
                    
| 440 | 506 | 
                        with pytest.raises(exc_type, match=exc_pattern):  | 
                    
| 441 | 507 | 
                        client.sign(key, b'abc', check_if_key_loaded=check)  | 
                    
| 508 | 
                        +  | 
                    |
| 509 | 
                        + @tests.skip_if_no_agent  | 
                    |
| 510 | 
                        + @pytest.mark.parametrize(  | 
                    |
| 511 | 
                        + ['request_code', 'response_code', 'exc_type', 'exc_pattern'],  | 
                    |
| 512 | 
                        + [  | 
                    |
| 513 | 
                        + (  | 
                    |
| 514 | 
                        + _types.SSH_AGENTC.REQUEST_IDENTITIES,  | 
                    |
| 515 | 
                        + _types.SSH_AGENT.SUCCESS,  | 
                    |
| 516 | 
                        + ssh_agent.SSHAgentFailedError,  | 
                    |
| 517 | 
                        +                f'[Code {_types.SSH_AGENT.IDENTITIES_ANSWER.value}]',
                       | 
                    |
| 518 | 
                        + ),  | 
                    |
| 519 | 
                        + ],  | 
                    |
| 520 | 
                        + )  | 
                    |
| 521 | 
                        + def test_340_request_error_responses(  | 
                    |
| 522 | 
                        + self,  | 
                    |
| 523 | 
                        + request_code: _types.SSH_AGENTC,  | 
                    |
| 524 | 
                        + response_code: _types.SSH_AGENT,  | 
                    |
| 525 | 
                        + exc_type: type[Exception],  | 
                    |
| 526 | 
                        + exc_pattern: str,  | 
                    |
| 527 | 
                        + ) -> None:  | 
                    |
| 528 | 
                        + with (  | 
                    |
| 529 | 
                        + pytest.raises(exc_type, match=exc_pattern),  | 
                    |
| 530 | 
                        + ssh_agent.SSHAgentClient() as client,  | 
                    |
| 531 | 
                        + ):  | 
                    |
| 532 | 
                        + client.request(request_code, b'', response_code=response_code)  | 
                    |
| 442 | 533 |