Support one-off SSH agent client child contexts
Marco Ricci

Marco Ricci commited on 2024-11-13 20:54:26
Zeige 3 geänderte Dateien mit 67 Einfügungen und 51 Löschungen.


Centralize functionality for constructing one-off SSH agent clients in
child contexts.
... ...
@@ -8,7 +8,6 @@ from __future__ import annotations
8 8
 
9 9
 import base64
10 10
 import collections
11
-import contextlib
12 11
 import copy
13 12
 import enum
14 13
 import importlib
... ...
@@ -16,7 +15,6 @@ import inspect
16 15
 import json
17 16
 import logging
18 17
 import os
19
-import socket
20 18
 import unicodedata
21 19
 from typing import (
22 20
     TYPE_CHECKING,
... ...
@@ -37,6 +35,7 @@ from derivepassphrase import _types, exporter, ssh_agent, vault
37 35
 
38 36
 if TYPE_CHECKING:
39 37
     import pathlib
38
+    import socket
40 39
     import types
41 40
     from collections.abc import (
42 41
         Iterator,
... ...
@@ -484,21 +483,8 @@ def _get_suitable_ssh_keys(
484 483
 
485 484
     Args:
486 485
         conn:
487
-            An optional connection hint to the SSH agent; specifically,
488
-            an SSH agent client, or a socket connected to an SSH agent.
489
-
490
-            If an existing SSH agent client, then this client will be
491
-            queried for the SSH keys, and otherwise left intact.
492
-
493
-            If a socket, then a one-shot client will be constructed
494
-            based on the socket to query the agent, and deconstructed
495
-            afterwards.
496
-
497
-            If neither are given, then the agent's socket location is
498
-            looked up in the `SSH_AUTH_SOCK` environment variable, and
499
-            used to construct/deconstruct a one-shot client, as in the
500
-            previous case.  This requires the [`socket.AF_UNIX`][]
501
-            symbol to exist.
486
+            An optional connection hint to the SSH agent.  See
487
+            [`ssh_agent.SSHAgentClient.ensure_agent_subcontext`][].
502 488
 
503 489
     Yields:
504 490
         Every SSH key from the SSH agent that is suitable for passphrase
... ...
@@ -524,20 +510,7 @@ def _get_suitable_ssh_keys(
524 510
             The agent failed to supply a list of loaded keys.
525 511
 
526 512
     """
527
-    client: ssh_agent.SSHAgentClient
528
-    client_context: contextlib.AbstractContextManager[Any]
529
-    # Use match/case here once Python 3.9 becomes unsupported.
530
-    if isinstance(conn, ssh_agent.SSHAgentClient):
531
-        client = conn
532
-        client_context = contextlib.nullcontext()
533
-    elif isinstance(conn, socket.socket) or conn is None:
534
-        client = ssh_agent.SSHAgentClient(socket=conn)
535
-        client_context = client
536
-    else:  # pragma: no cover
537
-        assert_never(conn)
538
-        msg = f'invalid connection hint: {conn!r}'
539
-        raise TypeError(msg)  # noqa: DOC501
540
-    with client_context:
513
+    with ssh_agent.SSHAgentClient.ensure_agent_subcontext(conn) as client:
541 514
         try:
542 515
             all_key_comment_pairs = list(client.list_keys())
543 516
         except EOFError as e:  # pragma: no cover
... ...
@@ -633,21 +606,8 @@ def _select_ssh_key(
633 606
 
634 607
     Args:
635 608
         conn:
636
-            An optional connection hint to the SSH agent; specifically,
637
-            an SSH agent client, or a socket connected to an SSH agent.
638
-
639
-            If an existing SSH agent client, then this client will be
640
-            queried for the SSH keys, and otherwise left intact.
641
-
642
-            If a socket, then a one-shot client will be constructed
643
-            based on the socket to query the agent, and deconstructed
644
-            afterwards.
645
-
646
-            If neither are given, then the agent's socket location is
647
-            looked up in the `SSH_AUTH_SOCK` environment variable, and
648
-            used to construct/deconstruct a one-shot client, as in the
649
-            previous case.  This requires the [`socket.AF_UNIX`][]
650
-            symbol to exist.
609
+            An optional connection hint to the SSH agent.  See
610
+            [`ssh_agent.SSHAgentClient.ensure_agent_subcontext`][].
651 611
 
652 612
     Returns:
653 613
         The selected SSH key.
... ...
@@ -7,16 +7,17 @@
7 7
 from __future__ import annotations
8 8
 
9 9
 import collections
10
+import contextlib
10 11
 import os
11 12
 import socket
12 13
 from typing import TYPE_CHECKING, overload
13 14
 
14
-from typing_extensions import Self
15
+from typing_extensions import Self, assert_never
15 16
 
16 17
 from derivepassphrase import _types
17 18
 
18 19
 if TYPE_CHECKING:
19
-    from collections.abc import Iterable, Sequence
20
+    from collections.abc import Iterable, Iterator, Sequence
20 21
     from types import TracebackType
21 22
 
22 23
     from typing_extensions import Buffer
... ...
@@ -282,6 +283,61 @@ class SSHAgentClient:
282 283
             bytes(bytestring[m + HEAD_LEN :]),
283 284
         )
284 285
 
286
+    @classmethod
287
+    @contextlib.contextmanager
288
+    def ensure_agent_subcontext(
289
+        cls,
290
+        conn: SSHAgentClient | socket.socket | None = None,
291
+    ) -> Iterator[SSHAgentClient]:
292
+        """Return an SSH agent client subcontext.
293
+
294
+        If necessary, construct an SSH agent client first using the
295
+        connection hint.
296
+
297
+        Args:
298
+            conn:
299
+                If an existing SSH agent client, then enter a context
300
+                within this client's scope.  After exiting the context,
301
+                the client persists, including its socket.
302
+
303
+                If a socket, then construct a client using this socket,
304
+                then enter a context within this client's scope.  After
305
+                exiting the context, the client is destroyed and the
306
+                socket is closed.
307
+
308
+                If `None`, construct a client using agent
309
+                auto-discovery, then enter a context within this
310
+                client's scope.  After exiting the context, both the
311
+                client and its socket are destroyed.
312
+
313
+        Yields:
314
+            When entering this context, return the SSH agent client.
315
+
316
+        Raises:
317
+            KeyError:
318
+                `conn` was `None`, and the `SSH_AUTH_SOCK` environment
319
+                variable was not found.
320
+            NotImplementedError:
321
+                `conn` was `None`, and this Python does not support
322
+                [`socket.AF_UNIX`][], so the SSH agent client cannot be
323
+                automatically set up.
324
+            OSError:
325
+                `conn` was a socket or `None`, and there was an error
326
+                setting up a socket connection to the agent.
327
+
328
+        """
329
+        # Use match/case here once Python 3.9 becomes unsupported.
330
+        if isinstance(conn, SSHAgentClient):
331
+            with contextlib.nullcontext():
332
+                yield conn
333
+        elif isinstance(conn, socket.socket) or conn is None:
334
+            with SSHAgentClient(socket=conn) as client:
335
+                yield client
336
+        else:  # pragma: no cover
337
+            assert_never(conn)
338
+            msg = f'invalid connection hint: {conn!r}'
339
+            raise TypeError(msg)  # noqa: DOC501
340
+
285 341
     @overload
286 342
     def request(  # pragma: no cover
287 343
         self,
... ...
@@ -522,10 +522,10 @@ class Vault:
522 522
                 'signature not deterministic'
523 523
             )
524 524
             raise ValueError(msg)
525
-        with ssh_agent.SSHAgentClient() as client:
525
+        with ssh_agent.SSHAgentClient.ensure_agent_subcontext() as client:
526 526
             raw_sig = client.sign(key, cls._UUID)
527
-        _keytype, trailer = client.unstring_prefix(raw_sig)
528
-        signature_blob = client.unstring(trailer)
527
+        _keytype, trailer = ssh_agent.SSHAgentClient.unstring_prefix(raw_sig)
528
+        signature_blob = ssh_agent.SSHAgentClient.unstring(trailer)
529 529
         return bytes(base64.standard_b64encode(signature_blob))
530 530
 
531 531
     @staticmethod
532 532