Marco Ricci commited on 2024-09-30 14:35:47
Zeige 1 geänderte Dateien mit 51 Einfügungen und 44 Löschungen.
As of Python 3.12, any custom Python class can declare support for the buffer protocol. So instead of special-casing `bytes` and `bytearray`, and ignoring all other types, support arbitrary classes with buffer protocol support. Furthermore, explicitly return bytes objects (i.e., read-only copies) of all involved byte strings, because the buffer protocol ensures that copies are relatively cheap.
... | ... |
@@ -20,6 +20,8 @@ if TYPE_CHECKING: |
20 | 20 |
from collections.abc import Iterable, Sequence |
21 | 21 |
from types import TracebackType |
22 | 22 |
|
23 |
+ from typing_extensions import Buffer |
|
24 |
+ |
|
23 | 25 |
__all__ = ('SSHAgentClient',) |
24 | 26 |
__author__ = 'Marco Ricci <software@the13thletter.info>' |
25 | 27 |
|
... | ... |
@@ -171,70 +173,69 @@ class SSHAgentClient: |
171 | 173 |
return int.to_bytes(num, 4, 'big', signed=False) |
172 | 174 |
|
173 | 175 |
@classmethod |
174 |
- def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray: |
|
176 |
+ def string(cls, payload: Buffer, /) -> bytes: |
|
175 | 177 |
r"""Format the payload as an SSH string, as per the agent protocol. |
176 | 178 |
|
177 | 179 |
Args: |
178 |
- payload: A byte string. |
|
180 |
+ payload: A bytes-like object. |
|
179 | 181 |
|
180 | 182 |
Returns: |
181 |
- The payload, framed in the SSH agent wire protocol format. |
|
183 |
+ The payload, framed in the SSH agent wire protocol format, |
|
184 |
+ as a bytes object. |
|
182 | 185 |
|
183 | 186 |
Examples: |
184 |
- >>> bytes(SSHAgentClient.string(b'ssh-rsa')) |
|
187 |
+ >>> SSHAgentClient.string(b'ssh-rsa') |
|
185 | 188 |
b'\x00\x00\x00\x07ssh-rsa' |
186 | 189 |
|
187 | 190 |
""" |
188 | 191 |
try: |
192 |
+ payload = memoryview(payload) |
|
193 |
+ except TypeError as e: |
|
194 |
+ msg = 'invalid payload type' |
|
195 |
+ raise TypeError(msg) from e # noqa: DOC501 |
|
189 | 196 |
ret = bytearray() |
190 | 197 |
ret.extend(cls.uint32(len(payload))) |
191 | 198 |
ret.extend(payload) |
192 |
- except Exception as e: |
|
193 |
- msg = 'invalid payload type' |
|
194 |
- raise TypeError(msg) from e # noqa: DOC501 |
|
195 |
- return ret |
|
199 |
+ return bytes(ret) |
|
196 | 200 |
|
197 | 201 |
@classmethod |
198 |
- def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray: |
|
202 |
+ def unstring(cls, bytestring: Buffer, /) -> bytes: |
|
199 | 203 |
r"""Unpack an SSH string. |
200 | 204 |
|
201 | 205 |
Args: |
202 |
- bytestring: A framed byte string. |
|
206 |
+ bytestring: A framed bytes-like object. |
|
203 | 207 |
|
204 | 208 |
Returns: |
205 |
- The unframed byte string, i.e., the payload. |
|
209 |
+ The payload, as a bytes object. |
|
206 | 210 |
|
207 | 211 |
Raises: |
208 | 212 |
ValueError: |
209 | 213 |
The byte string is not an SSH string. |
210 | 214 |
|
211 | 215 |
Examples: |
212 |
- >>> bytes(SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa')) |
|
216 |
+ >>> SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa') |
|
213 | 217 |
b'ssh-rsa' |
214 |
- >>> bytes( |
|
215 |
- ... SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519')) |
|
216 |
- ... ) |
|
218 |
+ >>> SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519')) |
|
217 | 219 |
b'ssh-ed25519' |
218 | 220 |
|
219 |
- """ # noqa: E501 |
|
221 |
+ """ |
|
222 |
+ bytestring = memoryview(bytestring) |
|
220 | 223 |
n = len(bytestring) |
221 | 224 |
msg = 'malformed SSH byte string' |
222 | 225 |
if n < HEAD_LEN or n != HEAD_LEN + int.from_bytes( |
223 | 226 |
bytestring[:HEAD_LEN], 'big', signed=False |
224 | 227 |
): |
225 | 228 |
raise ValueError(msg) |
226 |
- return bytestring[HEAD_LEN:] |
|
229 |
+ return bytes(bytestring[HEAD_LEN:]) |
|
227 | 230 |
|
228 | 231 |
@classmethod |
229 |
- def unstring_prefix( |
|
230 |
- cls, bytestring: bytes | bytearray, / |
|
231 |
- ) -> tuple[bytes | bytearray, bytes | bytearray]: |
|
232 |
+ def unstring_prefix(cls, bytestring: Buffer, /) -> tuple[bytes, bytes]: |
|
232 | 233 |
r"""Unpack an SSH string at the beginning of the byte string. |
233 | 234 |
|
234 | 235 |
Args: |
235 | 236 |
bytestring: |
236 |
- A (general) byte string, beginning with a framed/SSH |
|
237 |
- byte string. |
|
237 |
+ A bytes-like object, beginning with a framed/SSH byte |
|
238 |
+ string. |
|
238 | 239 |
|
239 | 240 |
Returns: |
240 | 241 |
A 2-tuple `(a, b)`, where `a` is the unframed byte |
... | ... |
@@ -246,18 +247,17 @@ class SSHAgentClient: |
246 | 247 |
The byte string does not begin with an SSH string. |
247 | 248 |
|
248 | 249 |
Examples: |
249 |
- >>> a, b = SSHAgentClient.unstring_prefix( |
|
250 |
+ >>> SSHAgentClient.unstring_prefix( |
|
250 | 251 |
... b'\x00\x00\x00\x07ssh-rsa____trailing data' |
251 | 252 |
... ) |
252 |
- >>> (bytes(a), bytes(b)) |
|
253 | 253 |
(b'ssh-rsa', b'____trailing data') |
254 |
- >>> a, b = SSHAgentClient.unstring_prefix( |
|
254 |
+ >>> SSHAgentClient.unstring_prefix( |
|
255 | 255 |
... SSHAgentClient.string(b'ssh-ed25519') |
256 | 256 |
... ) |
257 |
- >>> (bytes(a), bytes(b)) |
|
258 | 257 |
(b'ssh-ed25519', b'') |
259 | 258 |
|
260 | 259 |
""" |
260 |
+ bytestring = memoryview(bytestring).toreadonly() |
|
261 | 261 |
n = len(bytestring) |
262 | 262 |
msg = 'malformed SSH byte string' |
263 | 263 |
if n < HEAD_LEN: |
... | ... |
@@ -266,52 +266,52 @@ class SSHAgentClient: |
266 | 266 |
if m + HEAD_LEN > n: |
267 | 267 |
raise ValueError(msg) |
268 | 268 |
return ( |
269 |
- bytestring[HEAD_LEN : m + HEAD_LEN], |
|
270 |
- bytestring[m + HEAD_LEN :], |
|
269 |
+ bytes(bytestring[HEAD_LEN : m + HEAD_LEN]), |
|
270 |
+ bytes(bytestring[m + HEAD_LEN :]), |
|
271 | 271 |
) |
272 | 272 |
|
273 | 273 |
@overload |
274 | 274 |
def request( # pragma: no cover |
275 | 275 |
self, |
276 | 276 |
code: int | _types.SSH_AGENTC, |
277 |
- payload: bytes | bytearray, |
|
277 |
+ payload: Buffer, |
|
278 | 278 |
/, |
279 | 279 |
*, |
280 | 280 |
response_code: None = None, |
281 |
- ) -> tuple[int, bytes | bytearray]: ... |
|
281 |
+ ) -> tuple[int, bytes]: ... |
|
282 | 282 |
|
283 | 283 |
@overload |
284 | 284 |
def request( # pragma: no cover |
285 | 285 |
self, |
286 | 286 |
code: int | _types.SSH_AGENTC, |
287 |
- payload: bytes | bytearray, |
|
287 |
+ payload: Buffer, |
|
288 | 288 |
/, |
289 | 289 |
*, |
290 | 290 |
response_code: Iterable[_types.SSH_AGENT | int] = frozenset({ |
291 | 291 |
_types.SSH_AGENT.SUCCESS |
292 | 292 |
}), |
293 |
- ) -> tuple[int, bytes | bytearray]: ... |
|
293 |
+ ) -> tuple[int, bytes]: ... |
|
294 | 294 |
|
295 | 295 |
@overload |
296 | 296 |
def request( # pragma: no cover |
297 | 297 |
self, |
298 | 298 |
code: int | _types.SSH_AGENTC, |
299 |
- payload: bytes | bytearray, |
|
299 |
+ payload: Buffer, |
|
300 | 300 |
/, |
301 | 301 |
*, |
302 | 302 |
response_code: _types.SSH_AGENT | int = _types.SSH_AGENT.SUCCESS, |
303 |
- ) -> bytes | bytearray: ... |
|
303 |
+ ) -> bytes: ... |
|
304 | 304 |
|
305 | 305 |
def request( |
306 | 306 |
self, |
307 | 307 |
code: int | _types.SSH_AGENTC, |
308 |
- payload: bytes | bytearray, |
|
308 |
+ payload: Buffer, |
|
309 | 309 |
/, |
310 | 310 |
*, |
311 | 311 |
response_code: ( |
312 | 312 |
Iterable[_types.SSH_AGENT | int] | _types.SSH_AGENT | int | None |
313 | 313 |
) = None, |
314 |
- ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
314 |
+ ) -> tuple[int, bytes] | bytes: |
|
315 | 315 |
"""Issue a generic request to the SSH agent. |
316 | 316 |
|
317 | 317 |
Args: |
... | ... |
@@ -320,10 +320,12 @@ class SSHAgentClient: |
320 | 320 |
protocol numbers to use here (and which protocol numbers |
321 | 321 |
to expect in a response). |
322 | 322 |
payload: |
323 |
- A byte string containing the payload, or "contents", of |
|
324 |
- the request. Request-specific. `request` will add any |
|
325 |
- necessary wire framing around the request code and the |
|
326 |
- payload. |
|
323 |
+ A bytes-like object containing the payload, or |
|
324 |
+ "contents", of the request. Request-specific. |
|
325 |
+ |
|
326 |
+ It is our responsibility to add any necessary wire |
|
327 |
+ framing around the request code and the payload, |
|
328 |
+ not the caller's. |
|
327 | 329 |
response_code: |
328 | 330 |
An optional response code, or a set of response codes, |
329 | 331 |
that we expect. If given, and the actual response code |
... | ... |
@@ -351,6 +353,7 @@ class SSHAgentClient: |
351 | 353 |
response_code = frozenset({ |
352 | 354 |
c if isinstance(c, int) else c.value for c in response_code |
353 | 355 |
}) |
356 |
+ payload = memoryview(payload) |
|
354 | 357 |
request_message = bytearray([ |
355 | 358 |
code if isinstance(code, int) else code.value |
356 | 359 |
]) |
... | ... |
@@ -424,12 +427,12 @@ class SSHAgentClient: |
424 | 427 |
def sign( |
425 | 428 |
self, |
426 | 429 |
/, |
427 |
- key: bytes | bytearray, |
|
428 |
- payload: bytes | bytearray, |
|
430 |
+ key: Buffer, |
|
431 |
+ payload: Buffer, |
|
429 | 432 |
*, |
430 | 433 |
flags: int = 0, |
431 | 434 |
check_if_key_loaded: bool = False, |
432 |
- ) -> bytes | bytearray: |
|
435 |
+ ) -> bytes: |
|
433 | 436 |
"""Request the SSH agent sign the payload with the key. |
434 | 437 |
|
435 | 438 |
Args: |
... | ... |
@@ -467,6 +470,8 @@ class SSHAgentClient: |
467 | 470 |
loaded into the agent. |
468 | 471 |
|
469 | 472 |
""" |
473 |
+ key = memoryview(key) |
|
474 |
+ payload = memoryview(payload) |
|
470 | 475 |
if check_if_key_loaded: |
471 | 476 |
loaded_keys = frozenset({pair.key for pair in self.list_keys()}) |
472 | 477 |
if bytes(key) not in loaded_keys: |
... | ... |
@@ -475,10 +480,12 @@ class SSHAgentClient: |
475 | 480 |
request_data = bytearray(self.string(key)) |
476 | 481 |
request_data.extend(self.string(payload)) |
477 | 482 |
request_data.extend(self.uint32(flags)) |
478 |
- return self.unstring( |
|
483 |
+ return bytes( |
|
484 |
+ self.unstring( |
|
479 | 485 |
self.request( |
480 | 486 |
_types.SSH_AGENTC.SIGN_REQUEST.value, |
481 | 487 |
request_data, |
482 | 488 |
response_code=_types.SSH_AGENT.SIGN_RESPONSE, |
483 | 489 |
) |
484 | 490 |
) |
491 |
+ ) |
|
485 | 492 |