Turn the built in SSH agent socket provider names into an enum
Marco Ricci

Marco Ricci commited on 2026-01-24 22:59:24
Zeige 6 geänderte Dateien mit 246 Einfügungen und 74 Löschungen.


This eliminates typos once and for all.  It also makes it really easy to
distinguish third-party socket providers from first-party ones, as the
enum cannot be amended later.  Finally, it centralizes the knowledge for
testing whether the socket provider is functional directly to the enum,
similar to the other version info "feature items".

(In fact, the "master SSH key" vault feature (`Features.SSH_KEY`) can
now delegate the feature support check to the socket provider name
enum.)
... ...
@@ -864,30 +864,21 @@ class Feature(str, enum.Enum):
864 864
         reporting whether this can principally work, or not.
865 865
 
866 866
         """
867
-        import sys  # noqa: PLC0415
868
-
869
-        import exceptiongroup  # noqa: PLC0415
870
-
867
+        # Because BuiltinSSHAgentSocketProvider.test tests the same
868
+        # thing, just specific to a certain socket provider, we delegate
869
+        # to that test... not only out of laziness, but also out of
870
+        # consistency, avoiding two reimplementations of the same logic.
871 871
         from derivepassphrase import ssh_agent  # noqa: PLC0415
872 872
 
873
-        if sys.version_info < (3, 11):
874
-            from exceptiongroup import BaseExceptionGroup  # noqa: PLC0415
875
-
876
-        ret = True
877
-
878
-        def handle_notimplementederror(
879
-            _exc: BaseExceptionGroup,
880
-        ) -> None:  # pragma: no cover [unused]
881
-            nonlocal ret
882
-            ret = False
883
-
884
-        with exceptiongroup.catch({  # noqa: SIM117
885
-            NotImplementedError: handle_notimplementederror,
886
-            Exception: lambda _exc: None,
887
-        }):
888
-            with ssh_agent.SSHAgentClient.ensure_agent_subcontext():
889
-                pass
890
-        return ret
873
+        for provider_name in ssh_agent.SSHAgentClient.SOCKET_PROVIDERS:
874
+            try:
875
+                provider = BuiltinSSHAgentSocketProvider(provider_name)
876
+            except ValueError:
877
+                continue
878
+            else:
879
+                if provider.test():
880
+                    return True
881
+        return False
891 882
 
892 883
     def test(self, *_args: Any, **_kwargs: Any) -> bool:  # noqa: ANN401
893 884
         """Return true if this feature is enabled."""
... ...
@@ -1018,6 +1009,123 @@ class Subcommand(str, enum.Enum):
1018 1009
     __format__ = str.__format__  # type: ignore[assignment]
1019 1010
 
1020 1011
 
1012
+class BuiltinSSHAgentSocketProvider(str, enum.Enum):
1013
+    """SSH agent socket providers built into `derivepassphrase`.
1014
+
1015
+    Attributes:
1016
+        SSH_AUTH_SOCK_ON_POSIX:
1017
+            A socket provider (on POSIX) that queries the
1018
+            `SSH_AUTH_SOCK` environment variable.
1019
+        SSH_AUTH_SOCK_ON_WINDOWS:
1020
+            A socket provider (on The Annoying OS, a.k.a. Microsoft
1021
+            Windows) that queries the `SSH_AUTH_SOCK` environment
1022
+            variable.
1023
+        PAGEANT_ON_WINDOWS:
1024
+            A socket provider (on The Annoying OS, a.k.a. Microsoft
1025
+            Windows) that connects to Pageant's standard socket.  The
1026
+            socket address is computed by the socket provider.
1027
+        OPENSSH_ON_WINDOWS:
1028
+            A socket provider (on The Annoying OS, a.k.a. Microsoft
1029
+            Windows) that connects to OpenSSH on Windows's standard
1030
+            socket.  The socket address is hardcoded by the socket
1031
+            provider.
1032
+        STUB_AGENT:
1033
+            A basic fake agent's socket provider that only reacts to
1034
+            known test keys.  Used by the test suite.
1035
+        STUB_AGENT_WITH_ADDRESS:
1036
+            A more orchestratable fake agent's socket provider, compared
1037
+            to [`STUB_AGENT`][], that only reacts to known test keys.
1038
+            Used by the test suite.
1039
+        STUB_AGENT_WITH_ADDRESS_AND_DETERMINISTIC_DSA:
1040
+            An elaborate fake agent's socket provider that only reacts
1041
+            to known test keys.  Used by the test suite.
1042
+        SSH_AUTH_SOCK:
1043
+            A registry alias for [`SSH_AUTH_SOCK_ON_POSIX`][].
1044
+        UNIX_DOMAIN:
1045
+            A registry alias for [`SSH_AUTH_SOCK_ON_POSIX`][].
1046
+        POSIX:
1047
+            A registry alias for [`UNIX_DOMAIN`][].
1048
+        WINDOWS_NAMED_PIPE:
1049
+            A registry alias for [`SSH_AUTH_SOCK_ON_WINDOWS`][].
1050
+        WINDOWS:
1051
+            A registry alias for [`WINDOWS_NAMED_PIPE`][].
1052
+        NATIVE:
1053
+            A registry alias for [`WINDOWS`][] if on The Annoying OS, or
1054
+            for [`POSIX`][] otherwise.
1055
+
1056
+    """
1057
+
1058
+    SSH_AUTH_SOCK_ON_POSIX = "ssh_auth_sock_on_posix"
1059
+    """"""
1060
+    SSH_AUTH_SOCK_ON_WINDOWS = "ssh_auth_sock_on_windows"
1061
+    """"""
1062
+    PAGEANT_ON_WINDOWS = "pageant_on_windows"
1063
+    """"""
1064
+    OPENSSH_ON_WINDOWS = "openssh_on_windows"
1065
+    """"""
1066
+
1067
+    STUB_AGENT = "stub_agent"
1068
+    """"""
1069
+    STUB_AGENT_WITH_ADDRESS = "stub_agent_with_address"
1070
+    """"""
1071
+    STUB_AGENT_WITH_ADDRESS_AND_DETERMINISTIC_DSA = (
1072
+        "stub_agent_with_address_and_deterministic_dsa"
1073
+    )
1074
+    """"""
1075
+
1076
+    SSH_AUTH_SOCK = "ssh_auth_sock"
1077
+    """"""
1078
+    UNIX_DOMAIN = "unix_domain"
1079
+    """"""
1080
+    POSIX = "posix"
1081
+    """"""
1082
+    WINDOWS_NAMED_PIPE = "windows_named_pipe"
1083
+    """"""
1084
+    WINDOWS = "windows"
1085
+    """"""
1086
+    NATIVE = "native"
1087
+    """"""
1088
+
1089
+    def test(
1090
+        self,
1091
+        *_args: Any,  # noqa: ANN401
1092
+        **_kwargs: Any,  # noqa: ANN401
1093
+    ) -> bool:  # pragma: no cover [external]
1094
+        """Return true if this SSH agent socket provider is available.
1095
+
1096
+        This works by actually attempting to connet to an agent via the
1097
+        socket provider.
1098
+
1099
+        """
1100
+        import sys  # noqa: PLC0415
1101
+
1102
+        import exceptiongroup  # noqa: PLC0415
1103
+
1104
+        from derivepassphrase import ssh_agent  # noqa: PLC0415
1105
+
1106
+        if sys.version_info < (3, 11):
1107
+            from exceptiongroup import BaseExceptionGroup  # noqa: PLC0415
1108
+
1109
+        ret = True
1110
+
1111
+        def handle_notimplementederror(
1112
+            _exc: BaseExceptionGroup,
1113
+        ) -> None:  # pragma: no cover [unused]
1114
+            nonlocal ret
1115
+            ret = False
1116
+
1117
+        with exceptiongroup.catch({  # noqa: SIM117
1118
+            NotImplementedError: handle_notimplementederror,
1119
+            Exception: lambda _exc: None,
1120
+        }):
1121
+            with ssh_agent.SSHAgentClient.ensure_agent_subcontext(conn=[self]):
1122
+                pass
1123
+        return ret
1124
+
1125
+    __str__ = str.__str__
1126
+    __format__ = str.__format__  # type: ignore[assignment]
1127
+
1128
+
1021 1129
 @runtime_checkable
1022 1130
 class SSHAgentSocket(Protocol):
1023 1131
     """An abstract networking socket connected to an SSH agent.
... ...
@@ -24,6 +24,8 @@ from ctypes.wintypes import (  # type: ignore[attr-defined]
24 24
 )
25 25
 from typing import TYPE_CHECKING, cast
26 26
 
27
+from derivepassphrase import _types
28
+
27 29
 if TYPE_CHECKING:
28 30
     from collections.abc import Callable
29 31
     from typing import ClassVar
... ...
@@ -38,8 +40,6 @@ if TYPE_CHECKING:
38 40
         TypeVar,
39 41
     )
40 42
 
41
-    from derivepassphrase import _types
42
-
43 43
     SSHAgentSocketProviderT = TypeVar(
44 44
         "SSHAgentSocketProviderT", bound=_types.SSHAgentSocketProvider
45 45
     )
... ...
@@ -982,20 +982,55 @@ class SocketProvider:
982 982
         cls.registry.update(entries.maps[0])
983 983
 
984 984
 
985
-SocketProvider.registry.update({
986
-    "ssh_auth_sock_on_posix": SocketProvider.unix_domain_ssh_auth_sock,
987
-    "pageant_on_windows": SocketProvider.windows_named_pipe_for_pageant,  # noqa: E501
988
-    "openssh_on_windows": SocketProvider.windows_named_pipe_for_openssh,  # noqa: E501
989
-    "ssh_auth_sock_on_windows": SocketProvider.windows_named_pipe_ssh_auth_sock,  # noqa: E501
985
+SocketProvider.registry.update([
986
+    (
987
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX,
988
+        SocketProvider.unix_domain_ssh_auth_sock,
989
+    ),
990
+    (
991
+        _types.BuiltinSSHAgentSocketProvider.PAGEANT_ON_WINDOWS,
992
+        SocketProvider.windows_named_pipe_for_pageant,
993
+    ),
994
+    (
995
+        _types.BuiltinSSHAgentSocketProvider.OPENSSH_ON_WINDOWS,
996
+        SocketProvider.windows_named_pipe_for_openssh,
997
+    ),
998
+    (
999
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS,
1000
+        SocketProvider.windows_named_pipe_ssh_auth_sock,
1001
+    ),
990 1002
     # known instances
991
-    "stub_agent": None,
992
-    "stub_agent_with_address": None,
993
-    "stub_agent_with_address_and_deterministic_dsa": None,
1003
+    (_types.BuiltinSSHAgentSocketProvider.STUB_AGENT, None),
1004
+    (_types.BuiltinSSHAgentSocketProvider.STUB_AGENT_WITH_ADDRESS, None),
1005
+    (
1006
+        _types.BuiltinSSHAgentSocketProvider.STUB_AGENT_WITH_ADDRESS_AND_DETERMINISTIC_DSA,
1007
+        None,
1008
+    ),
994 1009
     # aliases
995
-    "ssh_auth_sock": "ssh_auth_sock_on_posix",
996
-    "unix_domain": "ssh_auth_sock_on_posix",
997
-    "posix": "unix_domain",
998
-    "windows_named_pipe": "ssh_auth_sock_on_windows",
999
-    "windows": "windows_named_pipe",
1000
-    "native": "windows" if os.name == "nt" else "posix",
1001
-})
1010
+    (
1011
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK,
1012
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX,
1013
+    ),
1014
+    (
1015
+        _types.BuiltinSSHAgentSocketProvider.UNIX_DOMAIN,
1016
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX,
1017
+    ),
1018
+    (
1019
+        _types.BuiltinSSHAgentSocketProvider.POSIX,
1020
+        _types.BuiltinSSHAgentSocketProvider.UNIX_DOMAIN,
1021
+    ),
1022
+    (
1023
+        _types.BuiltinSSHAgentSocketProvider.WINDOWS_NAMED_PIPE,
1024
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS,
1025
+    ),
1026
+    (
1027
+        _types.BuiltinSSHAgentSocketProvider.WINDOWS,
1028
+        _types.BuiltinSSHAgentSocketProvider.WINDOWS_NAMED_PIPE,
1029
+    ),
1030
+    (
1031
+        _types.BuiltinSSHAgentSocketProvider.NATIVE,
1032
+        _types.BuiltinSSHAgentSocketProvider.WINDOWS
1033
+        if os.name == "nt"
1034
+        else _types.BuiltinSSHAgentSocketProvider.POSIX,
1035
+    ),
1036
+])
... ...
@@ -354,8 +354,10 @@ class VaultTestConfig(NamedTuple):
354 354
 
355 355
 
356 356
 ssh_auth_sock_on_posix_entry = _types.SSHAgentSocketProviderEntry(
357
-    socketprovider.SocketProvider.resolve("ssh_auth_sock_on_posix"),
358
-    "ssh_auth_sock_on_posix",
357
+    socketprovider.SocketProvider.resolve(
358
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX
359
+    ),
360
+    _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX,
359 361
     (),
360 362
 )
361 363
 """
... ...
@@ -364,8 +366,10 @@ domain socket handler on POSIX systems.
364 366
 """
365 367
 
366 368
 ssh_auth_sock_on_windows_entry = _types.SSHAgentSocketProviderEntry(
367
-    socketprovider.SocketProvider.resolve("ssh_auth_sock_on_windows"),
368
-    "ssh_auth_sock_on_windows",
369
+    socketprovider.SocketProvider.resolve(
370
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS
371
+    ),
372
+    _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS,
369 373
     (),
370 374
 )
371 375
 """
... ...
@@ -384,7 +388,11 @@ is not a callable.
384 388
 """
385 389
 
386 390
 faulty_entry_name_exists = _types.SSHAgentSocketProviderEntry(
387
-    socketprovider.SocketProvider.resolve("windows"), "posix", ()
391
+    socketprovider.SocketProvider.resolve(
392
+        _types.BuiltinSSHAgentSocketProvider.WINDOWS
393
+    ),
394
+    _types.BuiltinSSHAgentSocketProvider.POSIX,
395
+    (),
388 396
 )
389 397
 """
390 398
 A faulty [`_types.SSHAgentSocketProviderEntry`][]: the indicated handler
... ...
@@ -392,9 +400,14 @@ is already registered with a different callable.
392 400
 """
393 401
 
394 402
 faulty_entry_alias_exists = _types.SSHAgentSocketProviderEntry(
395
-    socketprovider.SocketProvider.resolve("ssh_auth_sock_on_posix"),
396
-    "ssh_auth_sock_on_posix",
397
-    ("unix_domain", "windows_named_pipes"),
403
+    socketprovider.SocketProvider.resolve(
404
+        _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX
405
+    ),
406
+    _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX,
407
+    (
408
+        _types.BuiltinSSHAgentSocketProvider.UNIX_DOMAIN,
409
+        _types.BuiltinSSHAgentSocketProvider.WINDOWS_NAMED_PIPE,
410
+    ),
398 411
 )
399 412
 """
400 413
 A faulty [`_types.SSHAgentSocketProviderEntry`][]: the alias is already
... ...
@@ -334,7 +334,7 @@ class SystemSupportAction(str, enum.Enum):
334 334
             )
335 335
         elif self in {self.UNSET_NATIVE, self.UNSET_NATIVE_AND_ENSURE_USE}:
336 336
             self.check_or_ensure_use(
337
-                "native",
337
+                _types.BuiltinSSHAgentSocketProvider.NATIVE,
338 338
                 monkeypatch=monkeypatch,
339 339
                 ensure_use=(self == self.UNSET_NATIVE_AND_ENSURE_USE),
340 340
             )
... ...
@@ -343,14 +343,14 @@ class SystemSupportAction(str, enum.Enum):
343 343
             monkeypatch.delattr(ctypes, "windll", raising=False)
344 344
         elif self in {self.UNSET_AF_UNIX, self.UNSET_AF_UNIX_AND_ENSURE_USE}:
345 345
             self.check_or_ensure_use(
346
-                "posix",
346
+                _types.BuiltinSSHAgentSocketProvider.POSIX,
347 347
                 monkeypatch=monkeypatch,
348 348
                 ensure_use=(self == self.UNSET_AF_UNIX_AND_ENSURE_USE),
349 349
             )
350 350
             monkeypatch.delattr(socket, "AF_UNIX", raising=False)
351 351
         elif self in {self.UNSET_WINDLL, self.UNSET_WINDLL_AND_ENSURE_USE}:
352 352
             self.check_or_ensure_use(
353
-                "windows",
353
+                _types.BuiltinSSHAgentSocketProvider.WINDOWS,
354 354
                 monkeypatch=monkeypatch,
355 355
                 ensure_use=(self == self.UNSET_WINDLL_AND_ENSURE_USE),
356 356
             )
... ...
@@ -110,7 +110,12 @@ class Parametrize(types.SimpleNamespace):
110 110
         ],
111 111
     )
112 112
     EXISTING_REGISTRY_ENTRIES = pytest.mark.parametrize(
113
-        "existing", ["ssh_auth_sock_on_posix", "ssh_auth_sock_on_windows"]
113
+        "existing",
114
+        [
115
+            _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX,
116
+            _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS,
117
+        ],
118
+        ids=str,
114 119
     )
115 120
     SSH_STRING_EXCEPTIONS = pytest.mark.parametrize(
116 121
         ["input", "exc_type", "exc_pattern"],
... ...
@@ -809,7 +814,10 @@ class TestSSHAgentSocketProviderRegistry:
809 814
         # entry ultimately resolves to a different provider.  They needn't
810 815
         # technically be leaves of the socket provider forest, but "leaves"
811 816
         # emphasizes the quality we are looking for.
812
-        leaves = ["posix", "windows"]
817
+        leaves = [
818
+            _types.BuiltinSSHAgentSocketProvider.POSIX,
819
+            _types.BuiltinSSHAgentSocketProvider.WINDOWS,
820
+        ]
813 821
         assert all([isinstance(registry.get(leaf), str) for leaf in leaves]), (
814 822
             "Registry is shaped incompatibly; cannot determine base entries"
815 823
         )
... ...
@@ -887,14 +895,14 @@ class TestSSHAgentSocketProviderRegistry:
887 895
 
888 896
         provider = socketprovider.SocketProvider.resolve(existing)
889 897
         new_registry = {
890
-            "ssh_auth_sock_on_posix": socketprovider.SocketProvider.resolve(
891
-                "ssh_auth_sock_on_posix"
898
+            _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX: socketprovider.SocketProvider.resolve(
899
+                _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX
892 900
             ),
893
-            "ssh_auth_sock_on_windows": socketprovider.SocketProvider.resolve(
894
-                "ssh_auth_sock_on_windows"
901
+            _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS: socketprovider.SocketProvider.resolve(
902
+                _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS
895 903
             ),
896
-            "unix_domain": "ssh_auth_sock_on_posix",
897
-            "windows_named_pipe": "ssh_auth_sock_on_windows",
904
+            _types.BuiltinSSHAgentSocketProvider.UNIX_DOMAIN: _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX,
905
+            _types.BuiltinSSHAgentSocketProvider.WINDOWS_NAMED_PIPE: _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS,
898 906
         }
899 907
         names = [
900 908
             k
... ...
@@ -7,17 +7,14 @@
7 7
 from __future__ import annotations
8 8
 
9 9
 import contextlib
10
-from typing import TYPE_CHECKING
11 10
 
12 11
 import pytest
13 12
 from hypothesis import stateful, strategies
14 13
 
14
+from derivepassphrase import _types
15 15
 from derivepassphrase.ssh_agent import socketprovider
16 16
 from tests.machinery import pytest as pytest_machinery
17 17
 
18
-if TYPE_CHECKING:
19
-    from derivepassphrase import _types
20
-
21 18
 # All tests in this module are heavy-duty tests.
22 19
 pytestmark = [pytest_machinery.heavy_duty]
23 20
 
... ...
@@ -151,19 +148,30 @@ class SSHAgentSocketProviderRegistryStateMachine(
151 148
         self.registry: dict[
152 149
             str, _types.SSHAgentSocketProvider | str | None
153 150
         ] = {
154
-            "ssh_auth_sock_on_posix": self.orig_registry[
155
-                "ssh_auth_sock_on_posix"
156
-            ],
157
-            "ssh_auth_sock_on_windows": self.orig_registry[
158
-                "ssh_auth_sock_on_windows"
159
-            ],
160
-            "native": (
161
-                "windows_named_pipe"
162
-                if self.orig_registry["native"] == "windows"
163
-                else "unix_domain"
151
+            _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX: (
152
+                self.orig_registry[
153
+                    _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX
154
+                ]
155
+            ),
156
+            _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS: (
157
+                self.orig_registry[
158
+                    _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS
159
+                ]
160
+            ),
161
+            _types.BuiltinSSHAgentSocketProvider.NATIVE: (
162
+                _types.BuiltinSSHAgentSocketProvider.WINDOWS_NAMED_PIPE
163
+                if self.orig_registry[
164
+                    _types.BuiltinSSHAgentSocketProvider.NATIVE
165
+                ]
166
+                == _types.BuiltinSSHAgentSocketProvider.WINDOWS
167
+                else _types.BuiltinSSHAgentSocketProvider.UNIX_DOMAIN
168
+            ),
169
+            _types.BuiltinSSHAgentSocketProvider.UNIX_DOMAIN: (
170
+                _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_POSIX
171
+            ),
172
+            _types.BuiltinSSHAgentSocketProvider.WINDOWS_NAMED_PIPE: (
173
+                _types.BuiltinSSHAgentSocketProvider.SSH_AUTH_SOCK_ON_WINDOWS
164 174
             ),
165
-            "unix_domain": "ssh_auth_sock_on_posix",
166
-            "windows_named_pipe": "windows_named_pipe",
167 175
         }
168 176
         self.monkeypatch.setattr(
169 177
             socketprovider.SocketProvider, "registry", self.registry
170 178