Marco Ricci commited on 2024-06-22 21:19:30
Zeige 2 geänderte Dateien mit 58 Einfügungen und 4 Löschungen.
The new function `ssh_agent_client.unstring_prefix` does not abort if there is trailing data after the SSH framed byte string.
| ... | ... |
@@ -167,6 +167,45 @@ class SSHAgentClient: |
| 167 | 167 |
raise ValueError('malformed SSH byte string')
|
| 168 | 168 |
return bytestring[4:] |
| 169 | 169 |
|
| 170 |
+ @classmethod |
|
| 171 |
+ def unstring_prefix( |
|
| 172 |
+ cls, bytestring: bytes | bytearray, / |
|
| 173 |
+ ) -> tuple[bytes | bytearray, bytes | bytearray]: |
|
| 174 |
+ r"""Unpack an SSH string at the beginning of the byte string. |
|
| 175 |
+ |
|
| 176 |
+ Args: |
|
| 177 |
+ bytestring: |
|
| 178 |
+ A (general) byte string, beginning with a framed/SSH |
|
| 179 |
+ byte string. |
|
| 180 |
+ |
|
| 181 |
+ Returns: |
|
| 182 |
+ A 2-tuple `(a, b)`, where `a` is the unframed byte |
|
| 183 |
+ string/payload at the beginning of input byte string, and |
|
| 184 |
+ `b` is the remainder of the input byte string. |
|
| 185 |
+ |
|
| 186 |
+ Raises: |
|
| 187 |
+ ValueError: |
|
| 188 |
+ The byte string does not begin with an SSH string. |
|
| 189 |
+ |
|
| 190 |
+ Examples: |
|
| 191 |
+ >>> a, b = SSHAgentClient.unstring_prefix( |
|
| 192 |
+ ... b'\x00\x00\x00\x07ssh-rsa____trailing data') |
|
| 193 |
+ >>> (bytes(a), bytes(b)) |
|
| 194 |
+ (b'ssh-rsa', b'____trailing data') |
|
| 195 |
+ >>> a, b = SSHAgentClient.unstring_prefix( |
|
| 196 |
+ ... SSHAgentClient.string(b'ssh-ed25519')) |
|
| 197 |
+ >>> (bytes(a), bytes(b)) |
|
| 198 |
+ (b'ssh-ed25519', b'') |
|
| 199 |
+ |
|
| 200 |
+ """ |
|
| 201 |
+ n = len(bytestring) |
|
| 202 |
+ if n < 4: |
|
| 203 |
+ raise ValueError('malformed SSH byte string')
|
|
| 204 |
+ m = int.from_bytes(bytestring[:4], 'big', signed=False) |
|
| 205 |
+ if m + 4 > n: |
|
| 206 |
+ raise ValueError('malformed SSH byte string')
|
|
| 207 |
+ return (bytestring[4:m + 4], bytestring[m + 4:]) |
|
| 208 |
+ |
|
| 170 | 209 |
def request( |
| 171 | 210 |
self, code: int, payload: bytes | bytearray, / |
| 172 | 211 |
) -> tuple[int, bytes | bytearray]: |
| ... | ... |
@@ -331,20 +331,35 @@ def test_client_string(input, expected): |
| 331 | 331 |
]) |
| 332 | 332 |
def test_client_unstring(input, expected): |
| 333 | 333 |
unstring = ssh_agent_client.SSHAgentClient.unstring |
| 334 |
+ unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix |
|
| 334 | 335 |
assert bytes(unstring(input)) == expected |
| 336 |
+ assert tuple(bytes(x) for x in unstring_prefix(input)) == (expected, b'') |
|
| 335 | 337 |
|
| 336 |
-@pytest.mark.parametrize(['input', 'exc_type', 'exc_pattern'], [ |
|
| 337 |
- (b'ssh', ValueError, 'malformed SSH byte string'), |
|
| 338 |
- (b'\x00\x00\x00\x08ssh-rsa', ValueError, 'malformed SSH byte string'), |
|
| 338 |
+@pytest.mark.parametrize( |
|
| 339 |
+ ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [ |
|
| 340 |
+ (b'ssh', ValueError, 'malformed SSH byte string', False, None), |
|
| 341 |
+ ( |
|
| 342 |
+ b'\x00\x00\x00\x08ssh-rsa', |
|
| 343 |
+ ValueError, 'malformed SSH byte string', |
|
| 344 |
+ False, None, |
|
| 345 |
+ ), |
|
| 339 | 346 |
( |
| 340 | 347 |
b'\x00\x00\x00\x04XXX trailing text', |
| 341 | 348 |
ValueError, 'malformed SSH byte string', |
| 349 |
+ True, (b'XXX ', b'trailing text'), |
|
| 342 | 350 |
), |
| 343 | 351 |
]) |
| 344 |
-def test_client_unstring_exceptions(input, exc_type, exc_pattern): |
|
| 352 |
+def test_client_unstring_exceptions(input, exc_type, exc_pattern, |
|
| 353 |
+ has_trailer, parts): |
|
| 345 | 354 |
unstring = ssh_agent_client.SSHAgentClient.unstring |
| 355 |
+ unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix |
|
| 346 | 356 |
with pytest.raises(exc_type, match=exc_pattern): |
| 347 | 357 |
unstring(input) |
| 358 |
+ if has_trailer: |
|
| 359 |
+ assert tuple(bytes(x) for x in unstring_prefix(input)) == parts |
|
| 360 |
+ else: |
|
| 361 |
+ with pytest.raises(exc_type, match=exc_pattern): |
|
| 362 |
+ unstring_prefix(input) |
|
| 348 | 363 |
|
| 349 | 364 |
def test_key_decoding(): |
| 350 | 365 |
public_key = SUPPORTED['ed25519']['public_key'] |
| 351 | 366 |