3d09a07f18a46b597808e3f2030a566d4ef7d1e2
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

1) # SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info>
2) #
3) # SPDX-License-Identifier: MIT
4) 
5) """A bare-bones SSH agent client supporting signing and key listing."""
6) 
7) from __future__ import annotations
8) 
9) import collections
10) import enum
11) import errno
12) import os
13) import pathlib
14) import socket
15) 
16) from collections.abc import Sequence, MutableSequence
17) from typing import Any, NamedTuple, Self, TypeAlias
18) from ssh_agent_client.types import KeyCommentPair, SSH_AGENT, SSH_AGENTC
19) 
20) __all__ = ('SSHAgentClient',)
Marco Ricci Remove __about__.py files,...

Marco Ricci authored 5 months ago

21) __author__ = 'Marco Ricci <m@the13thletter.info>'
22) __version__ = "0.1.0"
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

23) 
24) _socket = socket
25) 
26) class SSHAgentClient:
27)     """A bare-bones SSH agent client supporting signing and key listing.
28) 
29)     The main use case is requesting the agent sign some data, after
30)     checking that the necessary key is already loaded.
31) 
32)     The main fleshed out methods are `list_keys` and `sign`, which
33)     implement the `REQUEST_IDENTITIES` and `SIGN_REQUEST` requests.  If
34)     you *really* wanted to, there is enough infrastructure in place to
35)     issue other requests as defined in the protocol---it's merely the
36)     wrapper functions and the protocol numbers table that are missing.
37) 
38)     """
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

39)     _connection: socket.socket
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

40)     def __init__(
41)         self, /, *, socket: socket.socket | None = None, timeout: int = 125
42)     ) -> None:
43)         """Initialize the client.
44) 
45)         Args:
46)             socket:
47)                 An optional socket, connected to the SSH agent.  If not
48)                 given, we query the `SSH_AUTH_SOCK` environment
49)                 variable to auto-discover the correct socket address.
50)             timeout:
51)                 A connection timeout for the SSH agent.  Only used if
52)                 the socket is not yet connected.  The default value
53)                 gives ample time for agent connections forwarded via
54)                 SSH on high-latency networks (e.g. Tor).
55) 
56)         """
57)         if socket is not None:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

58)             self._connection = socket
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

59)         else:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

60)             self._connection = _socket.socket(family=_socket.AF_UNIX)
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

61)         try:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

62)             # Test whether the socket is connected.
63)             self._connection.getpeername()
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

64)         except OSError as e:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

65)             # This condition is hard to test purposefully, so exclude
66)             # from coverage.
67)             if e.errno != errno.ENOTCONN:  # pragma: no cover
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

68)                 raise
69)             try:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

70)                 ssh_auth_sock = os.environ['SSH_AUTH_SOCK']
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

71)             except KeyError as e:
72)                 raise RuntimeError(
73)                     "Can't find running ssh-agent: missing SSH_AUTH_SOCK"
74)                 ) from e
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

75)             self._connection.settimeout(timeout)
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

76)             try:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

77)                 self._connection.connect(ssh_auth_sock)
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

78)             except FileNotFoundError as e:
79)                 raise RuntimeError(
80)                     "Can't find running ssh-agent: unusable SSH_AUTH_SOCK"
81)                 ) from e
82) 
83)     def __enter__(self) -> Self:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

84)         """Close socket connection upon context manager completion."""
85)         self._connection.__enter__()
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

86)         return self
87) 
88)     def __exit__(
89)         self, exc_type: Any, exc_val: Any, exc_tb: Any
90)     ) -> bool:
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

91)         """Close socket connection upon context manager completion."""
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

92)         return bool(
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

93)             self._connection.__exit__(
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

94)                 exc_type, exc_val, exc_tb)  # type: ignore[func-returns-value]
95)         )
96) 
97)     @staticmethod
Marco Ricci Add unit tests, both new an...

Marco Ricci authored 5 months ago

98)     def uint32(num: int, /) -> bytes:
99)         r"""Format the number as a `uint32`, as per the agent protocol.
100) 
101)         Args:
102)             num: A number.
103) 
104)         Returns:
105)             The number in SSH agent wire protocol format, i.e. as
106)             a 32-bit big endian number.
107) 
108)         Raises:
109)             OverflowError:
110)                 As per [`int.to_bytes`][].
111) 
112)         Examples:
113)             >>> SSHAgentClient.uint32(16777216)
114)             b'\x01\x00\x00\x00'
115) 
116)         """
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

117)         return int.to_bytes(num, 4, 'big', signed=False)
118) 
119)     @classmethod
120)     def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray:
Marco Ricci Add unit tests, both new an...

Marco Ricci authored 5 months ago

121)         r"""Format the payload as an SSH string, as per the agent protocol.
122) 
123)         Args:
124)             payload: A byte string.
125) 
126)         Returns:
127)             The payload, framed in the SSH agent wire protocol format.
128) 
129)         Examples:
130)             >>> bytes(SSHAgentClient.string(b'ssh-rsa'))
131)             b'\x00\x00\x00\x07ssh-rsa'
132) 
133)         """
Marco Ricci Fix numerous argument type...

Marco Ricci authored 5 months ago

134)         try:
135)             ret = bytearray()
136)             ret.extend(cls.uint32(len(payload)))
137)             ret.extend(payload)
138)             return ret
139)         except Exception as e:
140)             raise TypeError('invalid payload type') from e
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

141) 
142)     @classmethod
143)     def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray:
Marco Ricci Add unit tests, both new an...

Marco Ricci authored 5 months ago

144)         r"""Unpack an SSH string.
145) 
146)         Args:
147)             bytestring: A framed byte string.
148) 
149)         Returns:
150)             The unframed byte string, i.e., the payload.
151) 
152)         Raises:
153)             ValueError:
Marco Ricci Add function for SSH framed...

Marco Ricci authored 4 months ago

154)                 The byte string is not an SSH string.
Marco Ricci Add unit tests, both new an...

Marco Ricci authored 5 months ago

155) 
156)         Examples:
157)             >>> bytes(SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa'))
158)             b'ssh-rsa'
159)             >>> bytes(SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519')))
160)             b'ssh-ed25519'
161) 
162)         """
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

163)         n = len(bytestring)
164)         if n < 4:
165)             raise ValueError('malformed SSH byte string')
166)         elif n != 4 + int.from_bytes(bytestring[:4], 'big', signed=False):
167)             raise ValueError('malformed SSH byte string')
168)         return bytestring[4:]
169) 
Marco Ricci Add function for SSH framed...

Marco Ricci authored 4 months ago

170)     @classmethod
171)     def unstring_prefix(
172)         cls, bytestring: bytes | bytearray, /
173)     ) -> tuple[bytes | bytearray, bytes | bytearray]:
174)         r"""Unpack an SSH string at the beginning of the byte string.
175) 
176)         Args:
177)             bytestring:
178)                 A (general) byte string, beginning with a framed/SSH
179)                 byte string.
180) 
181)         Returns:
182)             A 2-tuple `(a, b)`, where `a` is the unframed byte
183)             string/payload at the beginning of input byte string, and
184)             `b` is the remainder of the input byte string.
185) 
186)         Raises:
187)             ValueError:
188)                 The byte string does not begin with an SSH string.
189) 
190)         Examples:
191)             >>> a, b = SSHAgentClient.unstring_prefix(
192)             ...     b'\x00\x00\x00\x07ssh-rsa____trailing data')
193)             >>> (bytes(a), bytes(b))
194)             (b'ssh-rsa', b'____trailing data')
195)             >>> a, b = SSHAgentClient.unstring_prefix(
196)             ...     SSHAgentClient.string(b'ssh-ed25519'))
197)             >>> (bytes(a), bytes(b))
198)             (b'ssh-ed25519', b'')
199) 
200)         """
201)         n = len(bytestring)
202)         if n < 4:
203)             raise ValueError('malformed SSH byte string')
204)         m = int.from_bytes(bytestring[:4], 'big', signed=False)
205)         if m + 4 > n:
206)             raise ValueError('malformed SSH byte string')
207)         return (bytestring[4:m + 4], bytestring[m + 4:])
208) 
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

209)     def request(
210)         self, code: int, payload: bytes | bytearray, /
211)     ) -> tuple[int, bytes | bytearray]:
212)         """Issue a generic request to the SSH agent.
213) 
214)         Args:
215)             code:
216)                 The request code.  See the SSH agent protocol for
217)                 protocol numbers to use here (and which protocol numbers
218)                 to expect in a response).
219)             payload:
220)                 A byte string containing the payload, or "contents", of
221)                 the request.  Request-specific.  `request` will add any
222)                 necessary wire framing around the request code and the
223)                 payload.
224) 
225)         Returns:
226)             A 2-tuple consisting of the response code and the payload,
227)             with all wire framing removed.
228) 
229)         Raises:
230)             EOFError:
231)                 The response from the SSH agent is truncated or missing.
232) 
233)         """
234)         request_message = bytearray([code])
235)         request_message.extend(payload)
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

236)         self._connection.sendall(self.string(request_message))
237)         chunk = self._connection.recv(4)
Marco Ricci Add prototype implementation

Marco Ricci authored 5 months ago

238)         if len(chunk) < 4:
239)             raise EOFError('cannot read response length')
240)         response_length = int.from_bytes(chunk, 'big', signed=False)
Marco Ricci Remove public attributes of...

Marco Ricci authored 4 months ago

241)         response = self._connection.recv(response_length)