Marco Ricci commited on 2025-08-02 14:22:51
Zeige 3 geänderte Dateien mit 59 Einfügungen und 48 Löschungen.
Support looking up a socket provider registry entry without bailing if the entry is merely registered, but not implemented. Call this operation "lookup", as opposed to "resolve". Use this operation directly whereever it makes sense (which currently is only in testing code and in the implementation of "resolve").
... | ... |
@@ -186,11 +186,13 @@ class SocketProvider: |
186 | 186 |
""" |
187 | 187 |
for alias in [name, *aliases]: |
188 | 188 |
try: |
189 |
- existing = cls.resolve(alias) |
|
189 |
+ existing = cls.lookup(alias) |
|
190 | 190 |
except (NoSuchProviderError, NotImplementedError): |
191 | 191 |
cls.registry[alias] = f if alias == name else name |
192 | 192 |
else: |
193 |
- if existing != f: |
|
193 |
+ if existing is None: |
|
194 |
+ cls.registry[alias] = f if alias == name else name |
|
195 |
+ elif existing != f: |
|
194 | 196 |
msg = ( |
195 | 197 |
f'The SSH agent socket provider {alias!r} ' |
196 | 198 |
f'is already registered.' |
... | ... |
@@ -200,10 +202,36 @@ class SocketProvider: |
200 | 202 |
|
201 | 203 |
return decorator |
202 | 204 |
|
205 |
+ @classmethod |
|
206 |
+ def lookup( |
|
207 |
+ cls, provider: _types.SSHAgentSocketProvider | str | None, / |
|
208 |
+ ) -> _types.SSHAgentSocketProvider | None: |
|
209 |
+ """Look up a socket provider entry. |
|
210 |
+ |
|
211 |
+ Args: |
|
212 |
+ provider: The provider to look up. |
|
213 |
+ |
|
214 |
+ Returns: |
|
215 |
+ The callable indicated by this provider, if it is |
|
216 |
+ implemented, or `None`, if it is merely registered. |
|
217 |
+ |
|
218 |
+ Raises: |
|
219 |
+ NoSuchProviderError: |
|
220 |
+ The provider is not registered. |
|
221 |
+ |
|
222 |
+ """ |
|
223 |
+ ret = provider |
|
224 |
+ while isinstance(ret, str): |
|
225 |
+ try: |
|
226 |
+ ret = cls.registry[ret] |
|
227 |
+ except KeyError as exc: |
|
228 |
+ raise NoSuchProviderError(ret) from exc |
|
229 |
+ return ret |
|
230 |
+ |
|
203 | 231 |
@classmethod |
204 | 232 |
def resolve( |
205 |
- cls, provider: Callable[[], _types.SSHAgentSocket] | str | None, / |
|
206 |
- ) -> Callable[[], _types.SSHAgentSocket]: |
|
233 |
+ cls, provider: _types.SSHAgentSocketProvider | str | None, / |
|
234 |
+ ) -> _types.SSHAgentSocketProvider: |
|
207 | 235 |
"""Resolve a socket provider to a proper callable. |
208 | 236 |
|
209 | 237 |
Args: |
... | ... |
@@ -220,12 +248,7 @@ class SocketProvider: |
220 | 248 |
applicable to this `derivepassphrase` installation. |
221 | 249 |
|
222 | 250 |
""" |
223 |
- ret = provider |
|
224 |
- while isinstance(ret, str): |
|
225 |
- try: |
|
226 |
- ret = cls.registry[ret] |
|
227 |
- except KeyError as exc: |
|
228 |
- raise NoSuchProviderError(ret) from exc |
|
251 |
+ ret = cls.lookup(provider) |
|
229 | 252 |
if ret is None: |
230 | 253 |
msg = ( |
231 | 254 |
f'The {ret!r} socket provider is not functional on or ' |
... | ... |
@@ -764,11 +764,9 @@ class SystemSupportAction(str, enum.Enum): |
764 | 764 |
| None |
765 | 765 |
) |
766 | 766 |
try: |
767 |
- intended = socketprovider.SocketProvider.resolve(provider) |
|
767 |
+ intended = socketprovider.SocketProvider.lookup(provider) |
|
768 | 768 |
except socketprovider.NoSuchProviderError as exc: |
769 | 769 |
intended = exc |
770 |
- except NotImplementedError: |
|
771 |
- intended = None |
|
772 | 770 |
actual: ( |
773 | 771 |
_types.SSHAgentSocketProvider |
774 | 772 |
| socketprovider.NoSuchProviderError |
... | ... |
@@ -776,10 +774,10 @@ class SystemSupportAction(str, enum.Enum): |
776 | 774 |
) |
777 | 775 |
for name in ssh_agent.SSHAgentClient.SOCKET_PROVIDERS: |
778 | 776 |
try: |
779 |
- actual = socketprovider.SocketProvider.resolve(name) |
|
777 |
+ actual = socketprovider.SocketProvider.lookup(name) |
|
780 | 778 |
except socketprovider.NoSuchProviderError as exc: |
781 | 779 |
actual = exc |
782 |
- except NotImplementedError: |
|
780 |
+ if actual is None: |
|
783 | 781 |
continue |
784 | 782 |
break |
785 | 783 |
else: |
... | ... |
@@ -927,9 +927,12 @@ class TestStaticFunctionality: |
927 | 927 |
"""Resolving entries in the socket provider registry works.""" |
928 | 928 |
registry = socketprovider.SocketProvider.registry |
929 | 929 |
resolve = socketprovider.SocketProvider.resolve |
930 |
+ lookup = socketprovider.SocketProvider.lookup |
|
930 | 931 |
with pytest.MonkeyPatch.context() as monkeypatch: |
931 | 932 |
monkeypatch.setitem(registry, 'stub_agent', None) |
933 |
+ assert callable(lookup('native')) |
|
932 | 934 |
assert callable(resolve('native')) |
935 |
+ assert lookup('stub_agent') is None |
|
933 | 936 |
with pytest.raises(NotImplementedError): |
934 | 937 |
resolve('stub_agent') |
935 | 938 |
|
... | ... |
@@ -942,6 +945,7 @@ class TestStaticFunctionality: |
942 | 945 |
"""Resolving a chain of providers works.""" |
943 | 946 |
registry = socketprovider.SocketProvider.registry |
944 | 947 |
resolve = socketprovider.SocketProvider.resolve |
948 |
+ lookup = socketprovider.SocketProvider.lookup |
|
945 | 949 |
try: |
946 | 950 |
implementation = resolve('native') |
947 | 951 |
except NotImplementedError: # pragma: no cover |
... | ... |
@@ -959,11 +963,15 @@ class TestStaticFunctionality: |
959 | 963 |
for link in chain: |
960 | 964 |
monkeypatch.setitem(registry, link, target) |
961 | 965 |
target = link |
966 |
+ for link in chain: |
|
967 |
+ assert lookup(link) == ( |
|
968 |
+ implementation if terminal != 'unimplemented' else None |
|
969 |
+ ) |
|
962 | 970 |
if terminal == 'unimplemented': |
963 | 971 |
with pytest.raises(NotImplementedError): |
964 |
- resolve(chain[-1]) |
|
972 |
+ resolve(link) |
|
965 | 973 |
else: |
966 |
- assert resolve(chain[-1]) == implementation |
|
974 |
+ assert resolve(link) == implementation |
|
967 | 975 |
|
968 | 976 |
@hypothesis.given( |
969 | 977 |
terminal=strategies.sampled_from([ |
... | ... |
@@ -996,6 +1004,7 @@ class TestStaticFunctionality: |
996 | 1004 |
"""Resolving a chain of providers works.""" |
997 | 1005 |
registry = socketprovider.SocketProvider.registry |
998 | 1006 |
resolve = socketprovider.SocketProvider.resolve |
1007 |
+ lookup = socketprovider.SocketProvider.lookup |
|
999 | 1008 |
try: |
1000 | 1009 |
implementation = resolve('native') |
1001 | 1010 |
except NotImplementedError: # pragma: no cover |
... | ... |
@@ -1014,11 +1023,15 @@ class TestStaticFunctionality: |
1014 | 1023 |
for link in chain: |
1015 | 1024 |
monkeypatch.setitem(registry, link, target) |
1016 | 1025 |
target = link |
1026 |
+ for link in chain: |
|
1027 |
+ assert lookup(link) == ( |
|
1028 |
+ implementation if terminal != 'unimplemented' else None |
|
1029 |
+ ) |
|
1017 | 1030 |
if terminal == 'unimplemented': |
1018 | 1031 |
with pytest.raises(NotImplementedError): |
1019 |
- resolve(chain[-1]) |
|
1032 |
+ resolve(link) |
|
1020 | 1033 |
else: |
1021 |
- assert resolve(chain[-1]) == implementation |
|
1034 |
+ assert resolve(link) == implementation |
|
1022 | 1035 |
|
1023 | 1036 |
@Parametrize.GOOD_ENTRY_POINTS |
1024 | 1037 |
def test_230_find_all_socket_providers( |
... | ... |
@@ -1664,32 +1677,6 @@ class TestAgentInteraction: |
1664 | 1677 |
client.query_extensions() |
1665 | 1678 |
|
1666 | 1679 |
|
1667 |
-def safe_resolve(name: str) -> _types.SSHAgentSocketProvider | None: |
|
1668 |
- """Safely resolve an SSH agent socket provider name. |
|
1669 |
- |
|
1670 |
- If the provider is merely reserved, return `None` instead of raising |
|
1671 |
- an exception. Otherwise behave like |
|
1672 |
- [`socketprovider.SocketProvider.resolve`][] does. |
|
1673 |
- |
|
1674 |
- Args: |
|
1675 |
- name: The name of the provider to resolve. |
|
1676 |
- |
|
1677 |
- Returns: |
|
1678 |
- The SSH agent socket provider if registered, `None` if merely |
|
1679 |
- reserved, and an error if not found. |
|
1680 |
- |
|
1681 |
- Raises: |
|
1682 |
- socketprovider.NoSuchProviderError: |
|
1683 |
- No such provider was found in the registry. This includes |
|
1684 |
- entries that are merely reserved. |
|
1685 |
- |
|
1686 |
- """ |
|
1687 |
- try: |
|
1688 |
- return socketprovider.SocketProvider.resolve(name) |
|
1689 |
- except NotImplementedError: # pragma: no cover |
|
1690 |
- return None |
|
1691 |
- |
|
1692 |
- |
|
1693 | 1680 |
@strategies.composite |
1694 | 1681 |
def draw_alias_chain( |
1695 | 1682 |
draw: strategies.DrawFn, |
... | ... |
@@ -1841,7 +1828,9 @@ class SSHAgentSocketProviderRegistryStateMachine( |
1841 | 1828 |
) |
1842 | 1829 |
def get_registry_keys(self) -> stateful.MultipleResults[str]: |
1843 | 1830 |
"""Read the standard keys from the registry.""" |
1844 |
- self.model.update({k: safe_resolve(k) for k in self.registry}) |
|
1831 |
+ self.model.update({ |
|
1832 |
+ k: socketprovider.SocketProvider.lookup(k) for k in self.registry |
|
1833 |
+ }) |
|
1845 | 1834 |
return stateful.multiple(*self.registry.keys()) |
1846 | 1835 |
|
1847 | 1836 |
@stateful.rule( |
... | ... |
@@ -1855,12 +1844,13 @@ class SSHAgentSocketProviderRegistryStateMachine( |
1855 | 1844 |
|
1856 | 1845 |
@stateful.invariant() |
1857 | 1846 |
def check_consistency(self) -> None: |
1847 |
+ lookup = socketprovider.SocketProvider.lookup |
|
1858 | 1848 |
assert self.registry.keys() == self.model.keys() |
1859 | 1849 |
for k in self.model: |
1860 |
- resolved = safe_resolve(k) |
|
1850 |
+ resolved = lookup(k) |
|
1861 | 1851 |
modelled = self.model[k] |
1862 | 1852 |
step1 = self.registry[k] |
1863 |
- manually = safe_resolve(step1) if isinstance(step1, str) else step1 |
|
1853 |
+ manually = lookup(step1) if isinstance(step1, str) else step1 |
|
1864 | 1854 |
assert resolved == modelled |
1865 | 1855 |
assert resolved == manually |
1866 | 1856 |
|
1867 | 1857 |