Fix character set subtraction logic
Marco Ricci

Marco Ricci commited on 2024-06-08 19:06:56
Zeige 2 geänderte Dateien mit 22 Einfügungen und 5 Löschungen.


Use a static method, and treat both the original character set and the
subtracted character set as sets (i.e. no repetitions allowed).
... ...
@@ -305,27 +305,35 @@ class Vault:
305 305
             ret = client.sign(key, cls._UUID)
306 306
         return ret
307 307
 
308
+    @staticmethod
308 309
     def _subtract(
309
-        self, charset: bytes | bytearray, allowed: bytes | bytearray,
310
+        charset: bytes | bytearray, allowed: bytes | bytearray,
310 311
     ) -> bytearray:
311 312
         """Remove the characters in charset from allowed.
312 313
 
313 314
         This preserves the relative order of characters in `allowed`.
314 315
 
315 316
         Args:
316
-            charset: Characters to remove.
317
-            allowed: Character set to remove the other characters from.
317
+            charset:
318
+                Characters to remove.  Must not contain duplicate
319
+                characters.
320
+            allowed:
321
+                Character set to remove the other characters from.  Must
322
+                not contain duplicate characters.
318 323
 
319 324
         Returns:
320
-            The pruned character set.
325
+            The pruned "allowed" character set.
321 326
 
322 327
         Raises:
323
-            ValueError: `charset` contained duplicate characters.
328
+            ValueError:
329
+                `allowed` or `charset` contained duplicate characters.
324 330
 
325 331
         """
326 332
         allowed = (allowed if isinstance(allowed, bytearray)
327 333
                    else bytearray(allowed))
328 334
         assert_type(allowed, bytearray)
335
+        if len(frozenset(allowed)) != len(allowed):
336
+            raise ValueError('duplicate characters in set')
329 337
         if len(frozenset(charset)) != len(charset):
330 338
             raise ValueError('duplicate characters in set')
331 339
         for c in charset:
... ...
@@ -113,3 +113,12 @@ def test_220_very_limited_character_set():
113 113
     generated = Vault(phrase=b'', length=24, lower=0, upper=0,
114 114
                       space=0, symbol=0).generate('testing')
115 115
     assert b'763252593304946694588866' == generated
116
+
117
+def test_300_character_set_subtraction():
118
+    assert Vault._subtract(b'be', b'abcdef') == bytearray(b'acdf')
119
+
120
+def test_301_character_set_subtraction_duplicate():
121
+    with pytest.raises(ValueError, match='duplicate characters'):
122
+        Vault._subtract(b'abcdef', b'aabbccddeeff')
123
+    with pytest.raises(ValueError, match='duplicate characters'):
124
+        Vault._subtract(b'aabbccddeeff', b'abcdef')
116 125