Add function for SSH framed byte strings with trailing data
Marco Ricci

Marco Ricci commited on 2024-06-22 21:19:30
Zeige 2 geänderte Dateien mit 58 Einfügungen und 4 Löschungen.


The new function `ssh_agent_client.unstring_prefix` does not abort if
there is trailing data after the SSH framed byte string.
... ...
@@ -167,6 +167,45 @@ class SSHAgentClient:
167 167
             raise ValueError('malformed SSH byte string')
168 168
         return bytestring[4:]
169 169
 
170
+    @classmethod
171
+    def unstring_prefix(
172
+        cls, bytestring: bytes | bytearray, /
173
+    ) -> tuple[bytes | bytearray, bytes | bytearray]:
174
+        r"""Unpack an SSH string at the beginning of the byte string.
175
+
176
+        Args:
177
+            bytestring:
178
+                A (general) byte string, beginning with a framed/SSH
179
+                byte string.
180
+
181
+        Returns:
182
+            A 2-tuple `(a, b)`, where `a` is the unframed byte
183
+            string/payload at the beginning of input byte string, and
184
+            `b` is the remainder of the input byte string.
185
+
186
+        Raises:
187
+            ValueError:
188
+                The byte string does not begin with an SSH string.
189
+
190
+        Examples:
191
+            >>> a, b = SSHAgentClient.unstring_prefix(
192
+            ...     b'\x00\x00\x00\x07ssh-rsa____trailing data')
193
+            >>> (bytes(a), bytes(b))
194
+            (b'ssh-rsa', b'____trailing data')
195
+            >>> a, b = SSHAgentClient.unstring_prefix(
196
+            ...     SSHAgentClient.string(b'ssh-ed25519'))
197
+            >>> (bytes(a), bytes(b))
198
+            (b'ssh-ed25519', b'')
199
+
200
+        """
201
+        n = len(bytestring)
202
+        if n < 4:
203
+            raise ValueError('malformed SSH byte string')
204
+        m = int.from_bytes(bytestring[:4], 'big', signed=False)
205
+        if m + 4 > n:
206
+            raise ValueError('malformed SSH byte string')
207
+        return (bytestring[4:m + 4], bytestring[m + 4:])
208
+
170 209
     def request(
171 210
         self, code: int, payload: bytes | bytearray, /
172 211
     ) -> tuple[int, bytes | bytearray]:
... ...
@@ -331,20 +331,35 @@ def test_client_string(input, expected):
331 331
 ])
332 332
 def test_client_unstring(input, expected):
333 333
     unstring = ssh_agent_client.SSHAgentClient.unstring
334
+    unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
334 335
     assert bytes(unstring(input)) == expected
336
+    assert tuple(bytes(x) for x in unstring_prefix(input)) == (expected, b'')
335 337
 
336
-@pytest.mark.parametrize(['input', 'exc_type', 'exc_pattern'], [
337
-    (b'ssh', ValueError, 'malformed SSH byte string'),
338
-    (b'\x00\x00\x00\x08ssh-rsa', ValueError, 'malformed SSH byte string'),
338
+@pytest.mark.parametrize(
339
+    ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [
340
+        (b'ssh', ValueError, 'malformed SSH byte string', False, None),
341
+        (
342
+            b'\x00\x00\x00\x08ssh-rsa',
343
+            ValueError, 'malformed SSH byte string',
344
+            False, None,
345
+        ),
339 346
         (
340 347
             b'\x00\x00\x00\x04XXX trailing text',
341 348
             ValueError, 'malformed SSH byte string',
349
+            True, (b'XXX ', b'trailing text'),
342 350
         ),
343 351
 ])
344
-def test_client_unstring_exceptions(input, exc_type, exc_pattern):
352
+def test_client_unstring_exceptions(input, exc_type, exc_pattern,
353
+                                    has_trailer, parts):
345 354
     unstring = ssh_agent_client.SSHAgentClient.unstring
355
+    unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
346 356
     with pytest.raises(exc_type, match=exc_pattern):
347 357
         unstring(input)
358
+    if has_trailer:
359
+        assert tuple(bytes(x) for x in unstring_prefix(input)) == parts
360
+    else:
361
+        with pytest.raises(exc_type, match=exc_pattern):
362
+            unstring_prefix(input)
348 363
 
349 364
 def test_key_decoding():
350 365
     public_key = SUPPORTED['ed25519']['public_key']
351 366