Support looking up a socket provider, even if merely registered
Marco Ricci

Marco Ricci commited on 2025-08-02 14:22:51
Zeige 3 geänderte Dateien mit 59 Einfügungen und 48 Löschungen.


Support looking up a socket provider registry entry without bailing if
the entry is merely registered, but not implemented.  Call this
operation "lookup", as opposed to "resolve".  Use this operation
directly whereever it makes sense (which currently is only in testing
code and in the implementation of "resolve").
... ...
@@ -186,11 +186,13 @@ class SocketProvider:
186 186
             """
187 187
             for alias in [name, *aliases]:
188 188
                 try:
189
-                    existing = cls.resolve(alias)
189
+                    existing = cls.lookup(alias)
190 190
                 except (NoSuchProviderError, NotImplementedError):
191 191
                     cls.registry[alias] = f if alias == name else name
192 192
                 else:
193
-                    if existing != f:
193
+                    if existing is None:
194
+                        cls.registry[alias] = f if alias == name else name
195
+                    elif existing != f:
194 196
                         msg = (
195 197
                             f'The SSH agent socket provider {alias!r} '
196 198
                             f'is already registered.'
... ...
@@ -200,10 +202,36 @@ class SocketProvider:
200 202
 
201 203
         return decorator
202 204
 
205
+    @classmethod
206
+    def lookup(
207
+        cls, provider: _types.SSHAgentSocketProvider | str | None, /
208
+    ) -> _types.SSHAgentSocketProvider | None:
209
+        """Look up a socket provider entry.
210
+
211
+        Args:
212
+            provider: The provider to look up.
213
+
214
+        Returns:
215
+            The callable indicated by this provider, if it is
216
+            implemented, or `None`, if it is merely registered.
217
+
218
+        Raises:
219
+            NoSuchProviderError:
220
+                The provider is not registered.
221
+
222
+        """
223
+        ret = provider
224
+        while isinstance(ret, str):
225
+            try:
226
+                ret = cls.registry[ret]
227
+            except KeyError as exc:
228
+                raise NoSuchProviderError(ret) from exc
229
+        return ret
230
+
203 231
     @classmethod
204 232
     def resolve(
205
-        cls, provider: Callable[[], _types.SSHAgentSocket] | str | None, /
206
-    ) -> Callable[[], _types.SSHAgentSocket]:
233
+        cls, provider: _types.SSHAgentSocketProvider | str | None, /
234
+    ) -> _types.SSHAgentSocketProvider:
207 235
         """Resolve a socket provider to a proper callable.
208 236
 
209 237
         Args:
... ...
@@ -220,12 +248,7 @@ class SocketProvider:
220 248
                 applicable to this `derivepassphrase` installation.
221 249
 
222 250
         """
223
-        ret = provider
224
-        while isinstance(ret, str):
225
-            try:
226
-                ret = cls.registry[ret]
227
-            except KeyError as exc:
228
-                raise NoSuchProviderError(ret) from exc
251
+        ret = cls.lookup(provider)
229 252
         if ret is None:
230 253
             msg = (
231 254
                 f'The {ret!r} socket provider is not functional on or '
... ...
@@ -764,11 +764,9 @@ class SystemSupportAction(str, enum.Enum):
764 764
                 | None
765 765
             )
766 766
             try:
767
-                intended = socketprovider.SocketProvider.resolve(provider)
767
+                intended = socketprovider.SocketProvider.lookup(provider)
768 768
             except socketprovider.NoSuchProviderError as exc:
769 769
                 intended = exc
770
-            except NotImplementedError:
771
-                intended = None
772 770
             actual: (
773 771
                 _types.SSHAgentSocketProvider
774 772
                 | socketprovider.NoSuchProviderError
... ...
@@ -776,10 +774,10 @@ class SystemSupportAction(str, enum.Enum):
776 774
             )
777 775
             for name in ssh_agent.SSHAgentClient.SOCKET_PROVIDERS:
778 776
                 try:
779
-                    actual = socketprovider.SocketProvider.resolve(name)
777
+                    actual = socketprovider.SocketProvider.lookup(name)
780 778
                 except socketprovider.NoSuchProviderError as exc:
781 779
                     actual = exc
782
-                except NotImplementedError:
780
+                if actual is None:
783 781
                     continue
784 782
                 break
785 783
             else:
... ...
@@ -927,9 +927,12 @@ class TestStaticFunctionality:
927 927
         """Resolving entries in the socket provider registry works."""
928 928
         registry = socketprovider.SocketProvider.registry
929 929
         resolve = socketprovider.SocketProvider.resolve
930
+        lookup = socketprovider.SocketProvider.lookup
930 931
         with pytest.MonkeyPatch.context() as monkeypatch:
931 932
             monkeypatch.setitem(registry, 'stub_agent', None)
933
+            assert callable(lookup('native'))
932 934
             assert callable(resolve('native'))
935
+            assert lookup('stub_agent') is None
933 936
             with pytest.raises(NotImplementedError):
934 937
                 resolve('stub_agent')
935 938
 
... ...
@@ -942,6 +945,7 @@ class TestStaticFunctionality:
942 945
         """Resolving a chain of providers works."""
943 946
         registry = socketprovider.SocketProvider.registry
944 947
         resolve = socketprovider.SocketProvider.resolve
948
+        lookup = socketprovider.SocketProvider.lookup
945 949
         try:
946 950
             implementation = resolve('native')
947 951
         except NotImplementedError:  # pragma: no cover
... ...
@@ -959,11 +963,15 @@ class TestStaticFunctionality:
959 963
             for link in chain:
960 964
                 monkeypatch.setitem(registry, link, target)
961 965
                 target = link
966
+            for link in chain:
967
+                assert lookup(link) == (
968
+                    implementation if terminal != 'unimplemented' else None
969
+                )
962 970
                 if terminal == 'unimplemented':
963 971
                     with pytest.raises(NotImplementedError):
964
-                    resolve(chain[-1])
972
+                        resolve(link)
965 973
                 else:
966
-                assert resolve(chain[-1]) == implementation
974
+                    assert resolve(link) == implementation
967 975
 
968 976
     @hypothesis.given(
969 977
         terminal=strategies.sampled_from([
... ...
@@ -996,6 +1004,7 @@ class TestStaticFunctionality:
996 1004
         """Resolving a chain of providers works."""
997 1005
         registry = socketprovider.SocketProvider.registry
998 1006
         resolve = socketprovider.SocketProvider.resolve
1007
+        lookup = socketprovider.SocketProvider.lookup
999 1008
         try:
1000 1009
             implementation = resolve('native')
1001 1010
         except NotImplementedError:  # pragma: no cover
... ...
@@ -1014,11 +1023,15 @@ class TestStaticFunctionality:
1014 1023
             for link in chain:
1015 1024
                 monkeypatch.setitem(registry, link, target)
1016 1025
                 target = link
1026
+            for link in chain:
1027
+                assert lookup(link) == (
1028
+                    implementation if terminal != 'unimplemented' else None
1029
+                )
1017 1030
                 if terminal == 'unimplemented':
1018 1031
                     with pytest.raises(NotImplementedError):
1019
-                    resolve(chain[-1])
1032
+                        resolve(link)
1020 1033
                 else:
1021
-                assert resolve(chain[-1]) == implementation
1034
+                    assert resolve(link) == implementation
1022 1035
 
1023 1036
     @Parametrize.GOOD_ENTRY_POINTS
1024 1037
     def test_230_find_all_socket_providers(
... ...
@@ -1664,32 +1677,6 @@ class TestAgentInteraction:
1664 1677
                 client.query_extensions()
1665 1678
 
1666 1679
 
1667
-def safe_resolve(name: str) -> _types.SSHAgentSocketProvider | None:
1668
-    """Safely resolve an SSH agent socket provider name.
1669
-
1670
-    If the provider is merely reserved, return `None` instead of raising
1671
-    an exception.  Otherwise behave like
1672
-    [`socketprovider.SocketProvider.resolve`][] does.
1673
-
1674
-    Args:
1675
-        name: The name of the provider to resolve.
1676
-
1677
-    Returns:
1678
-        The SSH agent socket provider if registered, `None` if merely
1679
-        reserved, and an error if not found.
1680
-
1681
-    Raises:
1682
-        socketprovider.NoSuchProviderError:
1683
-            No such provider was found in the registry.  This includes
1684
-            entries that are merely reserved.
1685
-
1686
-    """
1687
-    try:
1688
-        return socketprovider.SocketProvider.resolve(name)
1689
-    except NotImplementedError:  # pragma: no cover
1690
-        return None
1691
-
1692
-
1693 1680
 @strategies.composite
1694 1681
 def draw_alias_chain(
1695 1682
     draw: strategies.DrawFn,
... ...
@@ -1841,7 +1828,9 @@ class SSHAgentSocketProviderRegistryStateMachine(
1841 1828
     )
1842 1829
     def get_registry_keys(self) -> stateful.MultipleResults[str]:
1843 1830
         """Read the standard keys from the registry."""
1844
-        self.model.update({k: safe_resolve(k) for k in self.registry})
1831
+        self.model.update({
1832
+            k: socketprovider.SocketProvider.lookup(k) for k in self.registry
1833
+        })
1845 1834
         return stateful.multiple(*self.registry.keys())
1846 1835
 
1847 1836
     @stateful.rule(
... ...
@@ -1855,12 +1844,13 @@ class SSHAgentSocketProviderRegistryStateMachine(
1855 1844
 
1856 1845
     @stateful.invariant()
1857 1846
     def check_consistency(self) -> None:
1847
+        lookup = socketprovider.SocketProvider.lookup
1858 1848
         assert self.registry.keys() == self.model.keys()
1859 1849
         for k in self.model:
1860
-            resolved = safe_resolve(k)
1850
+            resolved = lookup(k)
1861 1851
             modelled = self.model[k]
1862 1852
             step1 = self.registry[k]
1863
-            manually = safe_resolve(step1) if isinstance(step1, str) else step1
1853
+            manually = lookup(step1) if isinstance(step1, str) else step1
1864 1854
             assert resolved == modelled
1865 1855
             assert resolved == manually
1866 1856
 
1867 1857