Let `Vault` accept Buffer-type values wherever it accepts bytes
Marco Ricci

Marco Ricci commited on 2025-01-29 15:07:21
Zeige 2 geänderte Dateien mit 88 Einfügungen und 27 Löschungen.


Any method that accepts `bytes` and `bytearray` now also accepts
arbitrary Buffer-type classes such as `memoryview` and `array.array`.
The return values (mostly `bytes`, sometimes `bytearray`) remain
unchanged.
... ...
@@ -22,6 +22,8 @@ if TYPE_CHECKING:
22 22
     import socket
23 23
     from collections.abc import Callable
24 24
 
25
+    from typing_extensions import Buffer
26
+
25 27
 __author__ = 'Marco Ricci <software@the13thletter.info>'
26 28
 
27 29
 
... ...
@@ -104,7 +106,7 @@ class Vault:
104 106
     def __init__(  # noqa: PLR0913
105 107
         self,
106 108
         *,
107
-        phrase: bytes | bytearray | str = b'',
109
+        phrase: Buffer | str = b'',
108 110
         length: int = 20,
109 111
         repeat: int = 0,
110 112
         lower: int | None = None,
... ...
@@ -257,7 +259,7 @@ class Vault:
257 259
         return math.ceil(safety_factor * entropy_bound / 8)
258 260
 
259 261
     @staticmethod
260
-    def _get_binary_string(s: bytes | bytearray | str, /) -> bytes:
262
+    def _get_binary_string(s: Buffer | str, /) -> bytes:
261 263
         """Convert the input string to a read-only, binary string.
262 264
 
263 265
         If it is a text string, return the string's UTF-8
... ...
@@ -277,8 +279,8 @@ class Vault:
277 279
     @classmethod
278 280
     def create_hash(
279 281
         cls,
280
-        phrase: bytes | bytearray | str,
281
-        service: bytes | bytearray | str,
282
+        phrase: Buffer | str,
283
+        service: Buffer | str,
282 284
         *,
283 285
         length: int = 32,
284 286
     ) -> bytes:
... ...
@@ -332,7 +334,7 @@ class Vault:
332 334
 
333 335
         """
334 336
         phrase = cls._get_binary_string(phrase)
335
-        assert not isinstance(phrase, str)
337
+        assert isinstance(phrase, bytes)
336 338
         salt = cls._get_binary_string(service) + cls._UUID
337 339
         return hashlib.pbkdf2_hmac(
338 340
             hash_name='sha1',
... ...
@@ -344,10 +346,10 @@ class Vault:
344 346
 
345 347
     def generate(
346 348
         self,
347
-        service_name: bytes | bytearray | str,
349
+        service_name: Buffer | str,
348 350
         /,
349 351
         *,
350
-        phrase: bytes | bytearray | str = b'',
352
+        phrase: Buffer | str = b'',
351 353
     ) -> bytes:
352 354
         r"""Generate a service passphrase.
353 355
 
... ...
@@ -452,7 +454,7 @@ class Vault:
452 454
 
453 455
     @staticmethod
454 456
     def is_suitable_ssh_key(
455
-        key: bytes | bytearray,
457
+        key: Buffer,
456 458
         /,
457 459
         *,
458 460
         client: ssh_agent.SSHAgentClient | None = None,
... ...
@@ -477,6 +479,7 @@ class Vault:
477 479
             restricted to the indicated SSH agent).
478 480
 
479 481
         """
482
+        key = bytes(key)
480 483
         TestFunc: TypeAlias = 'Callable[[bytes | bytearray], bool]'
481 484
         deterministic_signature_types: dict[str, TestFunc]
482 485
         deterministic_signature_types = {
... ...
@@ -515,7 +518,7 @@ class Vault:
515 518
     @classmethod
516 519
     def phrase_from_key(
517 520
         cls,
518
-        key: bytes | bytearray,
521
+        key: Buffer,
519 522
         /,
520 523
         *,
521 524
         conn: ssh_agent.SSHAgentClient | socket.socket | None = None,
... ...
@@ -593,8 +596,8 @@ class Vault:
593 596
     @classmethod
594 597
     def phrases_are_interchangable(
595 598
         cls,
596
-        phrase1: bytes | bytearray,
597
-        phrase2: bytes | bytearray,
599
+        phrase1: Buffer,
600
+        phrase2: Buffer,
598 601
         /,
599 602
     ) -> bool:
600 603
         """Return true if the passphrases are interchangable to Vault.
... ...
@@ -640,7 +643,7 @@ class Vault:
640 643
     @classmethod
641 644
     def _phrase_to_hmac_key(
642 645
         cls,
643
-        phrase: bytes | bytearray | str,
646
+        phrase: Buffer | str,
644 647
         /,
645 648
     ) -> bytes:
646 649
         r"""Return the HMAC key belonging to a passphrase.
... ...
@@ -669,8 +672,8 @@ class Vault:
669 672
 
670 673
     @staticmethod
671 674
     def _subtract(
672
-        charset: bytes | bytearray,
673
-        allowed: bytes | bytearray,
675
+        charset: Buffer,
676
+        allowed: Buffer,
674 677
     ) -> bytearray:
675 678
         """Remove the characters in charset from allowed.
676 679
 
... ...
@@ -696,6 +699,8 @@ class Vault:
696 699
             allowed if isinstance(allowed, bytearray) else bytearray(allowed)
697 700
         )
698 701
         assert_type(allowed, bytearray)
702
+        charset = memoryview(charset).toreadonly().cast('c')
703
+        assert_type(charset, 'memoryview[bytes]')
699 704
         msg_dup_characters = 'duplicate characters in set'
700 705
         if len(frozenset(allowed)) != len(allowed):
701 706
             raise ValueError(msg_dup_characters)
... ...
@@ -6,6 +6,7 @@
6 6
 
7 7
 from __future__ import annotations
8 8
 
9
+import array
9 10
 import hashlib
10 11
 import math
11 12
 from typing import TYPE_CHECKING
... ...
@@ -19,15 +20,17 @@ import tests
19 20
 from derivepassphrase import vault
20 21
 
21 22
 if TYPE_CHECKING:
22
-    from collections.abc import Iterator
23
+    from collections.abc import Callable, Iterator
24
+
25
+    from typing_extensions import Buffer
23 26
 
24 27
 BLOCK_SIZE = hashlib.sha1().block_size
25 28
 DIGEST_SIZE = hashlib.sha1().digest_size
26 29
 
27 30
 
28 31
 def phrases_are_interchangable(
29
-    phrase1: bytes | bytearray | str,
30
-    phrase2: bytes | bytearray | str,
32
+    phrase1: Buffer | str,
33
+    phrase2: Buffer | str,
31 34
     /,
32 35
 ) -> bool:
33 36
     """Work-alike of [`vault.Vault.phrases_are_interchangable`][].
... ...
@@ -358,6 +361,12 @@ class TestVault:
358 361
             b'google'
359 362
         ) == vault.Vault(phrase=self.phrase).generate(bytearray(b'google'))
360 363
 
364
+    def test_202c_reproducibility_and_buffer_like_service_name(self) -> None:
365
+        """Deriving a passphrase works equally for memory views."""
366
+        assert vault.Vault(phrase=self.phrase).generate(
367
+            b'google'
368
+        ) == vault.Vault(phrase=self.phrase).generate(memoryview(b'google'))
369
+
361 370
     @hypothesis.given(
362 371
         phrase=strategies.text(
363 372
             strategies.characters(min_codepoint=32, max_codepoint=126),
... ...
@@ -370,18 +379,65 @@ class TestVault:
370 379
             max_size=32,
371 380
         ),
372 381
     )
373
-    def test_202c_reproducibility_and_binary_service_name(
382
+    def test_203a_reproducibility_and_binary_phrases(
374 383
         self,
375 384
         phrase: str,
376 385
         service: str,
377 386
     ) -> None:
378
-        """Deriving a passphrase works equally for byte arrays/strings."""
379
-        assert vault.Vault(phrase=phrase).generate(service) == vault.Vault(
380
-            phrase=phrase
381
-        ).generate(service.encode('utf-8'))
382
-        assert vault.Vault(phrase=phrase).generate(service) == vault.Vault(
383
-            phrase=phrase
384
-        ).generate(bytearray(service.encode('utf-8')))
387
+        """Binary and text master passphrases generate the same passphrases."""
388
+        buffer_types: dict[str, Callable[..., Buffer]] = {
389
+            'bytes': bytes,
390
+            'bytearray': bytearray,
391
+            'memoryview': memoryview,
392
+            'array.array': lambda data: array.array('B', data),
393
+        }
394
+        for type_name, buffer_type in buffer_types.items():
395
+            str_phrase = phrase
396
+            bytes_phrase = phrase.encode('utf-8')
397
+            assert vault.Vault(phrase=str_phrase).generate(
398
+                service
399
+            ) == vault.Vault(phrase=buffer_type(bytes_phrase)).generate(
400
+                service
401
+            ), (
402
+                f'{str_phrase!r} and {type_name}({bytes_phrase!r}) '
403
+                'master passphrases generate different passphrases'
404
+            )
405
+
406
+    @hypothesis.given(
407
+        phrase=strategies.text(
408
+            strategies.characters(min_codepoint=32, max_codepoint=126),
409
+            min_size=1,
410
+            max_size=32,
411
+        ),
412
+        service=strategies.text(
413
+            strategies.characters(min_codepoint=32, max_codepoint=126),
414
+            min_size=1,
415
+            max_size=32,
416
+        ),
417
+    )
418
+    def test_203b_reproducibility_and_binary_service_name(
419
+        self,
420
+        phrase: str,
421
+        service: str,
422
+    ) -> None:
423
+        """Binary and text service names generate the same passphrases."""
424
+        buffer_types: dict[str, Callable[..., Buffer]] = {
425
+            'bytes': bytes,
426
+            'bytearray': bytearray,
427
+            'memoryview': memoryview,
428
+            'array.array': lambda data: array.array('B', data),
429
+        }
430
+        for type_name, buffer_type in buffer_types.items():
431
+            str_service = service
432
+            bytes_service = service.encode('utf-8')
433
+            assert vault.Vault(phrase=phrase).generate(
434
+                str_service
435
+            ) == vault.Vault(phrase=phrase).generate(
436
+                buffer_type(bytes_service)
437
+            ), (
438
+                f'{str_service!r} and {type_name}({bytes_service!r}) '
439
+                'service name generate different passphrases'
440
+            )
385 441
 
386 442
     @hypothesis.given(
387 443
         phrase=strategies.text(
... ...
@@ -396,7 +452,7 @@ class TestVault:
396 452
             unique=True,
397 453
         ),
398 454
     )
399
-    def test_203a_service_name_dependence(
455
+    def test_204a_service_name_dependence(
400 456
         self,
401 457
         phrase: str,
402 458
         services: list[bytes],
... ...
@@ -421,7 +477,7 @@ class TestVault:
421 477
             unique=True,
422 478
         ),
423 479
     )
424
-    def test_203b_service_name_dependence_with_config(
480
+    def test_204b_service_name_dependence_with_config(
425 481
         self,
426 482
         phrase: str,
427 483
         config: dict[str, int],
428 484