Add proper support for Buffer types in the SSH agent client
Marco Ricci

Marco Ricci commited on 2024-09-30 14:35:47
Zeige 1 geänderte Dateien mit 51 Einfügungen und 44 Löschungen.


As of Python 3.12, any custom Python class can declare support for the
buffer protocol.  So instead of special-casing `bytes` and `bytearray`,
and ignoring all other types, support arbitrary classes with buffer
protocol support.  Furthermore, explicitly return bytes objects (i.e.,
read-only copies) of all involved byte strings, because the buffer
protocol ensures that copies are relatively cheap.
... ...
@@ -20,6 +20,8 @@ if TYPE_CHECKING:
20 20
     from collections.abc import Iterable, Sequence
21 21
     from types import TracebackType
22 22
 
23
+    from typing_extensions import Buffer
24
+
23 25
 __all__ = ('SSHAgentClient',)
24 26
 __author__ = 'Marco Ricci <software@the13thletter.info>'
25 27
 
... ...
@@ -171,70 +173,69 @@ class SSHAgentClient:
171 173
         return int.to_bytes(num, 4, 'big', signed=False)
172 174
 
173 175
     @classmethod
174
-    def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray:
176
+    def string(cls, payload: Buffer, /) -> bytes:
175 177
         r"""Format the payload as an SSH string, as per the agent protocol.
176 178
 
177 179
         Args:
178
-            payload: A byte string.
180
+            payload: A bytes-like object.
179 181
 
180 182
         Returns:
181
-            The payload, framed in the SSH agent wire protocol format.
183
+            The payload, framed in the SSH agent wire protocol format,
184
+            as a bytes object.
182 185
 
183 186
         Examples:
184
-            >>> bytes(SSHAgentClient.string(b'ssh-rsa'))
187
+            >>> SSHAgentClient.string(b'ssh-rsa')
185 188
             b'\x00\x00\x00\x07ssh-rsa'
186 189
 
187 190
         """
188 191
         try:
192
+            payload = memoryview(payload)
193
+        except TypeError as e:
194
+            msg = 'invalid payload type'
195
+            raise TypeError(msg) from e  # noqa: DOC501
189 196
         ret = bytearray()
190 197
         ret.extend(cls.uint32(len(payload)))
191 198
         ret.extend(payload)
192
-        except Exception as e:
193
-            msg = 'invalid payload type'
194
-            raise TypeError(msg) from e  # noqa: DOC501
195
-        return ret
199
+        return bytes(ret)
196 200
 
197 201
     @classmethod
198
-    def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray:
202
+    def unstring(cls, bytestring: Buffer, /) -> bytes:
199 203
         r"""Unpack an SSH string.
200 204
 
201 205
         Args:
202
-            bytestring: A framed byte string.
206
+            bytestring: A framed bytes-like object.
203 207
 
204 208
         Returns:
205
-            The unframed byte string, i.e., the payload.
209
+            The payload, as a bytes object.
206 210
 
207 211
         Raises:
208 212
             ValueError:
209 213
                 The byte string is not an SSH string.
210 214
 
211 215
         Examples:
212
-            >>> bytes(SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa'))
216
+            >>> SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa')
213 217
             b'ssh-rsa'
214
-            >>> bytes(
215
-            ...     SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519'))
216
-            ... )
218
+            >>> SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519'))
217 219
             b'ssh-ed25519'
218 220
 
219
-        """  # noqa: E501
221
+        """
222
+        bytestring = memoryview(bytestring)
220 223
         n = len(bytestring)
221 224
         msg = 'malformed SSH byte string'
222 225
         if n < HEAD_LEN or n != HEAD_LEN + int.from_bytes(
223 226
             bytestring[:HEAD_LEN], 'big', signed=False
224 227
         ):
225 228
             raise ValueError(msg)
226
-        return bytestring[HEAD_LEN:]
229
+        return bytes(bytestring[HEAD_LEN:])
227 230
 
228 231
     @classmethod
229
-    def unstring_prefix(
230
-        cls, bytestring: bytes | bytearray, /
231
-    ) -> tuple[bytes | bytearray, bytes | bytearray]:
232
+    def unstring_prefix(cls, bytestring: Buffer, /) -> tuple[bytes, bytes]:
232 233
         r"""Unpack an SSH string at the beginning of the byte string.
233 234
 
234 235
         Args:
235 236
             bytestring:
236
-                A (general) byte string, beginning with a framed/SSH
237
-                byte string.
237
+                A bytes-like object, beginning with a framed/SSH byte
238
+                string.
238 239
 
239 240
         Returns:
240 241
             A 2-tuple `(a, b)`, where `a` is the unframed byte
... ...
@@ -246,18 +247,17 @@ class SSHAgentClient:
246 247
                 The byte string does not begin with an SSH string.
247 248
 
248 249
         Examples:
249
-            >>> a, b = SSHAgentClient.unstring_prefix(
250
+            >>> SSHAgentClient.unstring_prefix(
250 251
             ...     b'\x00\x00\x00\x07ssh-rsa____trailing data'
251 252
             ... )
252
-            >>> (bytes(a), bytes(b))
253 253
             (b'ssh-rsa', b'____trailing data')
254
-            >>> a, b = SSHAgentClient.unstring_prefix(
254
+            >>> SSHAgentClient.unstring_prefix(
255 255
             ...     SSHAgentClient.string(b'ssh-ed25519')
256 256
             ... )
257
-            >>> (bytes(a), bytes(b))
258 257
             (b'ssh-ed25519', b'')
259 258
 
260 259
         """
260
+        bytestring = memoryview(bytestring).toreadonly()
261 261
         n = len(bytestring)
262 262
         msg = 'malformed SSH byte string'
263 263
         if n < HEAD_LEN:
... ...
@@ -266,52 +266,52 @@ class SSHAgentClient:
266 266
         if m + HEAD_LEN > n:
267 267
             raise ValueError(msg)
268 268
         return (
269
-            bytestring[HEAD_LEN : m + HEAD_LEN],
270
-            bytestring[m + HEAD_LEN :],
269
+            bytes(bytestring[HEAD_LEN : m + HEAD_LEN]),
270
+            bytes(bytestring[m + HEAD_LEN :]),
271 271
         )
272 272
 
273 273
     @overload
274 274
     def request(  # pragma: no cover
275 275
         self,
276 276
         code: int | _types.SSH_AGENTC,
277
-        payload: bytes | bytearray,
277
+        payload: Buffer,
278 278
         /,
279 279
         *,
280 280
         response_code: None = None,
281
-    ) -> tuple[int, bytes | bytearray]: ...
281
+    ) -> tuple[int, bytes]: ...
282 282
 
283 283
     @overload
284 284
     def request(  # pragma: no cover
285 285
         self,
286 286
         code: int | _types.SSH_AGENTC,
287
-        payload: bytes | bytearray,
287
+        payload: Buffer,
288 288
         /,
289 289
         *,
290 290
         response_code: Iterable[_types.SSH_AGENT | int] = frozenset({
291 291
             _types.SSH_AGENT.SUCCESS
292 292
         }),
293
-    ) -> tuple[int, bytes | bytearray]: ...
293
+    ) -> tuple[int, bytes]: ...
294 294
 
295 295
     @overload
296 296
     def request(  # pragma: no cover
297 297
         self,
298 298
         code: int | _types.SSH_AGENTC,
299
-        payload: bytes | bytearray,
299
+        payload: Buffer,
300 300
         /,
301 301
         *,
302 302
         response_code: _types.SSH_AGENT | int = _types.SSH_AGENT.SUCCESS,
303
-    ) -> bytes | bytearray: ...
303
+    ) -> bytes: ...
304 304
 
305 305
     def request(
306 306
         self,
307 307
         code: int | _types.SSH_AGENTC,
308
-        payload: bytes | bytearray,
308
+        payload: Buffer,
309 309
         /,
310 310
         *,
311 311
         response_code: (
312 312
             Iterable[_types.SSH_AGENT | int] | _types.SSH_AGENT | int | None
313 313
         ) = None,
314
-    ) -> tuple[int, bytes | bytearray] | bytes | bytearray:
314
+    ) -> tuple[int, bytes] | bytes:
315 315
         """Issue a generic request to the SSH agent.
316 316
 
317 317
         Args:
... ...
@@ -320,10 +320,12 @@ class SSHAgentClient:
320 320
                 protocol numbers to use here (and which protocol numbers
321 321
                 to expect in a response).
322 322
             payload:
323
-                A byte string containing the payload, or "contents", of
324
-                the request.  Request-specific.  `request` will add any
325
-                necessary wire framing around the request code and the
326
-                payload.
323
+                A bytes-like object containing the payload, or
324
+                "contents", of the request.  Request-specific.
325
+
326
+                It is our responsibility to add any necessary wire
327
+                framing around the request code and the payload,
328
+                not the caller's.
327 329
             response_code:
328 330
                 An optional response code, or a set of response codes,
329 331
                 that we expect.  If given, and the actual response code
... ...
@@ -351,6 +353,7 @@ class SSHAgentClient:
351 353
             response_code = frozenset({
352 354
                 c if isinstance(c, int) else c.value for c in response_code
353 355
             })
356
+        payload = memoryview(payload)
354 357
         request_message = bytearray([
355 358
             code if isinstance(code, int) else code.value
356 359
         ])
... ...
@@ -424,12 +427,12 @@ class SSHAgentClient:
424 427
     def sign(
425 428
         self,
426 429
         /,
427
-        key: bytes | bytearray,
428
-        payload: bytes | bytearray,
430
+        key: Buffer,
431
+        payload: Buffer,
429 432
         *,
430 433
         flags: int = 0,
431 434
         check_if_key_loaded: bool = False,
432
-    ) -> bytes | bytearray:
435
+    ) -> bytes:
433 436
         """Request the SSH agent sign the payload with the key.
434 437
 
435 438
         Args:
... ...
@@ -467,6 +470,8 @@ class SSHAgentClient:
467 470
                 loaded into the agent.
468 471
 
469 472
         """
473
+        key = memoryview(key)
474
+        payload = memoryview(payload)
470 475
         if check_if_key_loaded:
471 476
             loaded_keys = frozenset({pair.key for pair in self.list_keys()})
472 477
             if bytes(key) not in loaded_keys:
... ...
@@ -475,10 +480,12 @@ class SSHAgentClient:
475 480
         request_data = bytearray(self.string(key))
476 481
         request_data.extend(self.string(payload))
477 482
         request_data.extend(self.uint32(flags))
478
-        return self.unstring(
483
+        return bytes(
484
+            self.unstring(
479 485
                 self.request(
480 486
                     _types.SSH_AGENTC.SIGN_REQUEST.value,
481 487
                     request_data,
482 488
                     response_code=_types.SSH_AGENT.SIGN_RESPONSE,
483 489
                 )
484 490
             )
491
+        )
485 492