Marco Ricci commited on 2024-09-30 14:35:47
Zeige 1 geänderte Dateien mit 51 Einfügungen und 44 Löschungen.
As of Python 3.12, any custom Python class can declare support for the buffer protocol. So instead of special-casing `bytes` and `bytearray`, and ignoring all other types, support arbitrary classes with buffer protocol support. Furthermore, explicitly return bytes objects (i.e., read-only copies) of all involved byte strings, because the buffer protocol ensures that copies are relatively cheap.
| ... | ... |
@@ -20,6 +20,8 @@ if TYPE_CHECKING: |
| 20 | 20 |
from collections.abc import Iterable, Sequence |
| 21 | 21 |
from types import TracebackType |
| 22 | 22 |
|
| 23 |
+ from typing_extensions import Buffer |
|
| 24 |
+ |
|
| 23 | 25 |
__all__ = ('SSHAgentClient',)
|
| 24 | 26 |
__author__ = 'Marco Ricci <software@the13thletter.info>' |
| 25 | 27 |
|
| ... | ... |
@@ -171,70 +173,69 @@ class SSHAgentClient: |
| 171 | 173 |
return int.to_bytes(num, 4, 'big', signed=False) |
| 172 | 174 |
|
| 173 | 175 |
@classmethod |
| 174 |
- def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray: |
|
| 176 |
+ def string(cls, payload: Buffer, /) -> bytes: |
|
| 175 | 177 |
r"""Format the payload as an SSH string, as per the agent protocol. |
| 176 | 178 |
|
| 177 | 179 |
Args: |
| 178 |
- payload: A byte string. |
|
| 180 |
+ payload: A bytes-like object. |
|
| 179 | 181 |
|
| 180 | 182 |
Returns: |
| 181 |
- The payload, framed in the SSH agent wire protocol format. |
|
| 183 |
+ The payload, framed in the SSH agent wire protocol format, |
|
| 184 |
+ as a bytes object. |
|
| 182 | 185 |
|
| 183 | 186 |
Examples: |
| 184 |
- >>> bytes(SSHAgentClient.string(b'ssh-rsa')) |
|
| 187 |
+ >>> SSHAgentClient.string(b'ssh-rsa') |
|
| 185 | 188 |
b'\x00\x00\x00\x07ssh-rsa' |
| 186 | 189 |
|
| 187 | 190 |
""" |
| 188 | 191 |
try: |
| 192 |
+ payload = memoryview(payload) |
|
| 193 |
+ except TypeError as e: |
|
| 194 |
+ msg = 'invalid payload type' |
|
| 195 |
+ raise TypeError(msg) from e # noqa: DOC501 |
|
| 189 | 196 |
ret = bytearray() |
| 190 | 197 |
ret.extend(cls.uint32(len(payload))) |
| 191 | 198 |
ret.extend(payload) |
| 192 |
- except Exception as e: |
|
| 193 |
- msg = 'invalid payload type' |
|
| 194 |
- raise TypeError(msg) from e # noqa: DOC501 |
|
| 195 |
- return ret |
|
| 199 |
+ return bytes(ret) |
|
| 196 | 200 |
|
| 197 | 201 |
@classmethod |
| 198 |
- def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray: |
|
| 202 |
+ def unstring(cls, bytestring: Buffer, /) -> bytes: |
|
| 199 | 203 |
r"""Unpack an SSH string. |
| 200 | 204 |
|
| 201 | 205 |
Args: |
| 202 |
- bytestring: A framed byte string. |
|
| 206 |
+ bytestring: A framed bytes-like object. |
|
| 203 | 207 |
|
| 204 | 208 |
Returns: |
| 205 |
- The unframed byte string, i.e., the payload. |
|
| 209 |
+ The payload, as a bytes object. |
|
| 206 | 210 |
|
| 207 | 211 |
Raises: |
| 208 | 212 |
ValueError: |
| 209 | 213 |
The byte string is not an SSH string. |
| 210 | 214 |
|
| 211 | 215 |
Examples: |
| 212 |
- >>> bytes(SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa')) |
|
| 216 |
+ >>> SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa') |
|
| 213 | 217 |
b'ssh-rsa' |
| 214 |
- >>> bytes( |
|
| 215 |
- ... SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519')) |
|
| 216 |
- ... ) |
|
| 218 |
+ >>> SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519')) |
|
| 217 | 219 |
b'ssh-ed25519' |
| 218 | 220 |
|
| 219 |
- """ # noqa: E501 |
|
| 221 |
+ """ |
|
| 222 |
+ bytestring = memoryview(bytestring) |
|
| 220 | 223 |
n = len(bytestring) |
| 221 | 224 |
msg = 'malformed SSH byte string' |
| 222 | 225 |
if n < HEAD_LEN or n != HEAD_LEN + int.from_bytes( |
| 223 | 226 |
bytestring[:HEAD_LEN], 'big', signed=False |
| 224 | 227 |
): |
| 225 | 228 |
raise ValueError(msg) |
| 226 |
- return bytestring[HEAD_LEN:] |
|
| 229 |
+ return bytes(bytestring[HEAD_LEN:]) |
|
| 227 | 230 |
|
| 228 | 231 |
@classmethod |
| 229 |
- def unstring_prefix( |
|
| 230 |
- cls, bytestring: bytes | bytearray, / |
|
| 231 |
- ) -> tuple[bytes | bytearray, bytes | bytearray]: |
|
| 232 |
+ def unstring_prefix(cls, bytestring: Buffer, /) -> tuple[bytes, bytes]: |
|
| 232 | 233 |
r"""Unpack an SSH string at the beginning of the byte string. |
| 233 | 234 |
|
| 234 | 235 |
Args: |
| 235 | 236 |
bytestring: |
| 236 |
- A (general) byte string, beginning with a framed/SSH |
|
| 237 |
- byte string. |
|
| 237 |
+ A bytes-like object, beginning with a framed/SSH byte |
|
| 238 |
+ string. |
|
| 238 | 239 |
|
| 239 | 240 |
Returns: |
| 240 | 241 |
A 2-tuple `(a, b)`, where `a` is the unframed byte |
| ... | ... |
@@ -246,18 +247,17 @@ class SSHAgentClient: |
| 246 | 247 |
The byte string does not begin with an SSH string. |
| 247 | 248 |
|
| 248 | 249 |
Examples: |
| 249 |
- >>> a, b = SSHAgentClient.unstring_prefix( |
|
| 250 |
+ >>> SSHAgentClient.unstring_prefix( |
|
| 250 | 251 |
... b'\x00\x00\x00\x07ssh-rsa____trailing data' |
| 251 | 252 |
... ) |
| 252 |
- >>> (bytes(a), bytes(b)) |
|
| 253 | 253 |
(b'ssh-rsa', b'____trailing data') |
| 254 |
- >>> a, b = SSHAgentClient.unstring_prefix( |
|
| 254 |
+ >>> SSHAgentClient.unstring_prefix( |
|
| 255 | 255 |
... SSHAgentClient.string(b'ssh-ed25519') |
| 256 | 256 |
... ) |
| 257 |
- >>> (bytes(a), bytes(b)) |
|
| 258 | 257 |
(b'ssh-ed25519', b'') |
| 259 | 258 |
|
| 260 | 259 |
""" |
| 260 |
+ bytestring = memoryview(bytestring).toreadonly() |
|
| 261 | 261 |
n = len(bytestring) |
| 262 | 262 |
msg = 'malformed SSH byte string' |
| 263 | 263 |
if n < HEAD_LEN: |
| ... | ... |
@@ -266,52 +266,52 @@ class SSHAgentClient: |
| 266 | 266 |
if m + HEAD_LEN > n: |
| 267 | 267 |
raise ValueError(msg) |
| 268 | 268 |
return ( |
| 269 |
- bytestring[HEAD_LEN : m + HEAD_LEN], |
|
| 270 |
- bytestring[m + HEAD_LEN :], |
|
| 269 |
+ bytes(bytestring[HEAD_LEN : m + HEAD_LEN]), |
|
| 270 |
+ bytes(bytestring[m + HEAD_LEN :]), |
|
| 271 | 271 |
) |
| 272 | 272 |
|
| 273 | 273 |
@overload |
| 274 | 274 |
def request( # pragma: no cover |
| 275 | 275 |
self, |
| 276 | 276 |
code: int | _types.SSH_AGENTC, |
| 277 |
- payload: bytes | bytearray, |
|
| 277 |
+ payload: Buffer, |
|
| 278 | 278 |
/, |
| 279 | 279 |
*, |
| 280 | 280 |
response_code: None = None, |
| 281 |
- ) -> tuple[int, bytes | bytearray]: ... |
|
| 281 |
+ ) -> tuple[int, bytes]: ... |
|
| 282 | 282 |
|
| 283 | 283 |
@overload |
| 284 | 284 |
def request( # pragma: no cover |
| 285 | 285 |
self, |
| 286 | 286 |
code: int | _types.SSH_AGENTC, |
| 287 |
- payload: bytes | bytearray, |
|
| 287 |
+ payload: Buffer, |
|
| 288 | 288 |
/, |
| 289 | 289 |
*, |
| 290 | 290 |
response_code: Iterable[_types.SSH_AGENT | int] = frozenset({
|
| 291 | 291 |
_types.SSH_AGENT.SUCCESS |
| 292 | 292 |
}), |
| 293 |
- ) -> tuple[int, bytes | bytearray]: ... |
|
| 293 |
+ ) -> tuple[int, bytes]: ... |
|
| 294 | 294 |
|
| 295 | 295 |
@overload |
| 296 | 296 |
def request( # pragma: no cover |
| 297 | 297 |
self, |
| 298 | 298 |
code: int | _types.SSH_AGENTC, |
| 299 |
- payload: bytes | bytearray, |
|
| 299 |
+ payload: Buffer, |
|
| 300 | 300 |
/, |
| 301 | 301 |
*, |
| 302 | 302 |
response_code: _types.SSH_AGENT | int = _types.SSH_AGENT.SUCCESS, |
| 303 |
- ) -> bytes | bytearray: ... |
|
| 303 |
+ ) -> bytes: ... |
|
| 304 | 304 |
|
| 305 | 305 |
def request( |
| 306 | 306 |
self, |
| 307 | 307 |
code: int | _types.SSH_AGENTC, |
| 308 |
- payload: bytes | bytearray, |
|
| 308 |
+ payload: Buffer, |
|
| 309 | 309 |
/, |
| 310 | 310 |
*, |
| 311 | 311 |
response_code: ( |
| 312 | 312 |
Iterable[_types.SSH_AGENT | int] | _types.SSH_AGENT | int | None |
| 313 | 313 |
) = None, |
| 314 |
- ) -> tuple[int, bytes | bytearray] | bytes | bytearray: |
|
| 314 |
+ ) -> tuple[int, bytes] | bytes: |
|
| 315 | 315 |
"""Issue a generic request to the SSH agent. |
| 316 | 316 |
|
| 317 | 317 |
Args: |
| ... | ... |
@@ -320,10 +320,12 @@ class SSHAgentClient: |
| 320 | 320 |
protocol numbers to use here (and which protocol numbers |
| 321 | 321 |
to expect in a response). |
| 322 | 322 |
payload: |
| 323 |
- A byte string containing the payload, or "contents", of |
|
| 324 |
- the request. Request-specific. `request` will add any |
|
| 325 |
- necessary wire framing around the request code and the |
|
| 326 |
- payload. |
|
| 323 |
+ A bytes-like object containing the payload, or |
|
| 324 |
+ "contents", of the request. Request-specific. |
|
| 325 |
+ |
|
| 326 |
+ It is our responsibility to add any necessary wire |
|
| 327 |
+ framing around the request code and the payload, |
|
| 328 |
+ not the caller's. |
|
| 327 | 329 |
response_code: |
| 328 | 330 |
An optional response code, or a set of response codes, |
| 329 | 331 |
that we expect. If given, and the actual response code |
| ... | ... |
@@ -351,6 +353,7 @@ class SSHAgentClient: |
| 351 | 353 |
response_code = frozenset({
|
| 352 | 354 |
c if isinstance(c, int) else c.value for c in response_code |
| 353 | 355 |
}) |
| 356 |
+ payload = memoryview(payload) |
|
| 354 | 357 |
request_message = bytearray([ |
| 355 | 358 |
code if isinstance(code, int) else code.value |
| 356 | 359 |
]) |
| ... | ... |
@@ -424,12 +427,12 @@ class SSHAgentClient: |
| 424 | 427 |
def sign( |
| 425 | 428 |
self, |
| 426 | 429 |
/, |
| 427 |
- key: bytes | bytearray, |
|
| 428 |
- payload: bytes | bytearray, |
|
| 430 |
+ key: Buffer, |
|
| 431 |
+ payload: Buffer, |
|
| 429 | 432 |
*, |
| 430 | 433 |
flags: int = 0, |
| 431 | 434 |
check_if_key_loaded: bool = False, |
| 432 |
- ) -> bytes | bytearray: |
|
| 435 |
+ ) -> bytes: |
|
| 433 | 436 |
"""Request the SSH agent sign the payload with the key. |
| 434 | 437 |
|
| 435 | 438 |
Args: |
| ... | ... |
@@ -467,6 +470,8 @@ class SSHAgentClient: |
| 467 | 470 |
loaded into the agent. |
| 468 | 471 |
|
| 469 | 472 |
""" |
| 473 |
+ key = memoryview(key) |
|
| 474 |
+ payload = memoryview(payload) |
|
| 470 | 475 |
if check_if_key_loaded: |
| 471 | 476 |
loaded_keys = frozenset({pair.key for pair in self.list_keys()})
|
| 472 | 477 |
if bytes(key) not in loaded_keys: |
| ... | ... |
@@ -475,10 +480,12 @@ class SSHAgentClient: |
| 475 | 480 |
request_data = bytearray(self.string(key)) |
| 476 | 481 |
request_data.extend(self.string(payload)) |
| 477 | 482 |
request_data.extend(self.uint32(flags)) |
| 478 |
- return self.unstring( |
|
| 483 |
+ return bytes( |
|
| 484 |
+ self.unstring( |
|
| 479 | 485 |
self.request( |
| 480 | 486 |
_types.SSH_AGENTC.SIGN_REQUEST.value, |
| 481 | 487 |
request_data, |
| 482 | 488 |
response_code=_types.SSH_AGENT.SIGN_RESPONSE, |
| 483 | 489 |
) |
| 484 | 490 |
) |
| 491 |
+ ) |
|
| 485 | 492 |