Decouple deterministic signatures from general SSH agent detection
Marco Ricci

Marco Ricci commited on 2024-11-26 14:12:53
Zeige 1 geänderte Dateien mit 60 Einfügungen und 4 Löschungen.


Instead of tying deterministic signatures directly to the detection of
Pageant specifically, add a general mechanism for attempting to infer
the connected SSH agent from its reported list of extensions.  This
moves the question of *how* we detect certain SSH agents out of the
deterministic signature checking function.

Alas, OpenSSH does not support the extension query message we issue,
despite them supporting the extension system in general *and* stewarding
the SSH agent protocol specification which defines this message
normatively.  So our implementation must tolerate a moderate level of
spec violation.
... ...
@@ -18,6 +18,7 @@ from derivepassphrase import _types
18 18
 
19 19
 if TYPE_CHECKING:
20 20
     from collections.abc import Iterable, Iterator, Sequence
21
+    from collections.abc import Set as AbstractSet
21 22
     from types import TracebackType
22 23
 
23 24
     from typing_extensions import Buffer
... ...
@@ -338,6 +339,18 @@ class SSHAgentClient:
338 339
             msg = f'invalid connection hint: {conn!r}'
339 340
             raise TypeError(msg)  # noqa: DOC501
340 341
 
342
+    def _agent_is_pageant(self) -> bool:
343
+        """Return True if we are connected to Pageant.
344
+
345
+        Warning:
346
+            This is a heuristic, not a verified query or computation.
347
+
348
+        """
349
+        return (
350
+            b'list-extended@putty.projects.tartarus.org'
351
+            in self.query_extensions()
352
+        )
353
+
341 354
     def has_deterministic_signatures(self) -> bool:
342 355
         """Check whether the agent returns deterministic signatures.
343 356
 
... ...
@@ -351,11 +364,12 @@ class SSHAgentClient:
351 364
             | Pageant (PuTTY) | `list-extended@putty.projects.tartarus.org` extension request |
352 365
 
353 366
         """  # noqa: E501
354
-        returncode, _payload = self.request(
355
-            _types.SSH_AGENTC.EXTENSION,
356
-            self.string(b'list-extended@putty.projects.tartarus.org'),
367
+        known_good_agents = {
368
+            'Pageant': self._agent_is_pageant,
369
+        }
370
+        return any(  # pragma: no branch
371
+            v() for v in known_good_agents.values()
357 372
         )
358
-        return returncode == _types.SSH_AGENT.SUCCESS.value
359 373
 
360 374
     @overload
361 375
     def request(  # pragma: no cover
... ...
@@ -578,3 +592,45 @@ class SSHAgentClient:
578 592
                 )
579 593
             )
580 594
         )
595
+
596
+    def query_extensions(self) -> AbstractSet[bytes]:
597
+        """Request a list of extensions supported by the SSH agent.
598
+
599
+        Args:
600
+            raise_if_no_extension_support:
601
+                If true, and if the agent does not support querying
602
+                extensions, then raise an error.  If false, silently
603
+                return an empty result.
604
+
605
+        Returns:
606
+            A read-only sequence of extension names.
607
+
608
+        Raises:
609
+            EOFError:
610
+                The response from the SSH agent is truncated or missing.
611
+            OSError:
612
+                There was a communication error with the SSH agent.
613
+            SSHAgentFailedError:
614
+                The agent failed to complete the request.
615
+
616
+        """
617
+        try:
618
+            response_data = self.request(
619
+                _types.SSH_AGENTC.EXTENSION,
620
+                self.string(b'query'),
621
+                response_code={
622
+                    _types.SSH_AGENT.EXTENSION_RESPONSE,
623
+                    _types.SSH_AGENT.SUCCESS,
624
+                },
625
+            )
626
+        except SSHAgentFailedError:
627
+            # Cannot query extension support.  Assume no extensions.
628
+            # This isn't necessarily true, e.g. for OpenSSH's ssh-agent.
629
+            return frozenset()
630
+        extensions: set[bytes] = set()
631
+        _query, response_data = self.unstring_prefix(response_data)
632
+        assert bytes(_query) == b'query'
633
+        while response_data:
634
+            extension, response_data = self.unstring_prefix(response_data)
635
+            extensions.add(bytes(extension))
636
+        return frozenset(extensions)
581 637