Expose some functionality from `Vault` as interal methods
Marco Ricci

Marco Ricci commited on 2024-06-22 21:19:30
Zeige 2 geänderte Dateien mit 161 Einfügungen und 24 Löschungen.


Expose some functionality from `derivepassphrase.Vault` as interal
methods, to facilitate testing and to avoid reimplementing the same
functionality again in the command-line interface.  This includes hash
length estimation and SSH key suitability checking.
... ...
@@ -132,19 +132,70 @@ class Vault:
132 132
         for _ in range(len(self._required), self._length):
133 133
             self._required.append(bytes(self._allowed))
134 134
 
135
-    def _entropy_upper_bound(self) -> int:
135
+    def _entropy(self) -> float:
136 136
         """Estimate the passphrase entropy, given the current settings.
137 137
 
138 138
         The entropy is the base 2 logarithm of the amount of
139
-        possibilities.  We operate directly on the logarithms, and round
140
-        each summand up, overestimating the true entropy.
139
+        possibilities.  We operate directly on the logarithms, and use
140
+        sorting and [`math.fsum`][] to keep high accuracy.
141
+
142
+        Note:
143
+            We actually overestimate the entropy here because of poor
144
+            handling of character repetitions.  In the extreme, assuming
145
+            that only one character were allowed, then because there is
146
+            only one possible string of each given length, the entropy
147
+            of that string `s` is always be zero.  However, we calculate
148
+            the entropy as `math.log2(math.factorial(len(s)))`, i.e. we
149
+            assume the characters at the respective string position are
150
+            distinguishable from each other.
151
+
152
+        Returns:
153
+            A valid (and somewhat close) upper bound to the entropy.
141 154
 
142 155
         """
143 156
         factors: list[int] = []
157
+        if not self._required or any(not x for x in self._required):
158
+            return float('-inf')
144 159
         for i, charset in enumerate(self._required):
145 160
             factors.append(i + 1)
146 161
             factors.append(len(charset))
147
-        return sum(int(math.ceil(math.log2(f))) for f in factors)
162
+        factors.sort()
163
+        return math.fsum(math.log2(f) for f in factors)
164
+
165
+    def _estimate_sufficient_hash_length(
166
+        self, safety_factor: float = 2.0,
167
+    ) -> int:
168
+        """Estimate the sufficient hash length, given the current settings.
169
+
170
+        Using the entropy (via `_entropy`) and a safety factor, give an
171
+        initial estimate of the length to use for `create_hash` such
172
+        that using a `Sequin` with this hash will not exhaust it during
173
+        passphrase generation.
174
+
175
+        Args:
176
+            safety_factor: The safety factor.  Must be at least 1.
177
+
178
+        Returns:
179
+            The estimated sufficient hash length.
180
+
181
+        Warning:
182
+            This is a heuristic, not an exact computation; it may
183
+            underestimate the true necessary hash length.  It is
184
+            intended as a starting point for searching for a sufficient
185
+            hash length, usually by doubling the hash length each time
186
+            it does not yet prove so.
187
+
188
+        """
189
+        try:
190
+            safety_factor = float(safety_factor)
191
+        except TypeError as e:
192
+            raise TypeError(f'invalid safety factor: not a float: '
193
+                            f'{safety_factor!r}') from e
194
+        if not math.isfinite(safety_factor) or safety_factor < 1.0:
195
+            raise ValueError(f'invalid safety factor {safety_factor!r}')
196
+        # Ensure the bound is strictly positive.
197
+        entropy_bound = max(1, self._entropy())
198
+        return int(math.ceil(safety_factor * entropy_bound / 8))
148 199
 
149 200
     @classmethod
150 201
     def create_hash(
... ...
@@ -225,12 +276,8 @@ class Vault:
225 276
             b': 4TVH#5:aZl8LueOT\\{'
226 277
 
227 278
         """
228
-        entropy_bound = self._entropy_upper_bound()
229
-        # Use a safety factor, because a sequin will potentially throw
230
-        # bits away and we cannot rely on having generated a hash of
231
-        # exactly the right length.
232
-        safety_factor = 2
233
-        hash_length = int(math.ceil(safety_factor * entropy_bound / 8))
279
+        hash_length = self._estimate_sufficient_hash_length()
280
+        assert hash_length >= 1
234 281
         # Ensure the phrase is a bytes object.  Needed later for safe
235 282
         # concatenation.
236 283
         if isinstance(service_name, str):
... ...
@@ -267,11 +314,36 @@ class Vault:
267 314
                                                      charset)
268 315
                     pos = seq.generate(len(charset))
269 316
                     result.extend(charset[pos:pos+1])
270
-            except sequin.SequinExhaustedException:  # pragma: no cover
317
+            except sequin.SequinExhaustedException:
271 318
                 hash_length *= 2
272 319
             else:
273 320
                 return bytes(result)
274 321
 
322
+    @staticmethod
323
+    def _is_suitable_ssh_key(key: bytes | bytearray, /) -> bool:
324
+        """Check whether the key is suitable for passphrase derivation.
325
+
326
+        Currently, this only checks whether signatures with this key
327
+        type are deterministic.
328
+
329
+        Args:
330
+            key: SSH public key to check.
331
+
332
+        Returns:
333
+            True if and only if the key is suitable for use in deriving
334
+            a passphrase deterministically.
335
+
336
+        """
337
+        deterministic_signature_types = {
338
+            'ssh-ed25519':
339
+                lambda k: k.startswith(b'\x00\x00\x00\x0bssh-ed25519'),
340
+            'ssh-ed448':
341
+                lambda k: k.startswith(b'\x00\x00\x00\x09ssh-ed448'),
342
+            'ssh-rsa':
343
+                lambda k: k.startswith(b'\x00\x00\x00\x07ssh-rsa'),
344
+        }
345
+        return any(v(key) for v in deterministic_signature_types.values())
346
+
275 347
     @classmethod
276 348
     def phrase_from_signature(
277 349
         cls, key: bytes | bytearray, /
... ...
@@ -314,15 +386,7 @@ class Vault:
314 386
             True
315 387
 
316 388
         """
317
-        deterministic_signature_types = {
318
-            'ssh-ed25519':
319
-                lambda k: k.startswith(b'\x00\x00\x00\x0bssh-ed25519'),
320
-            'ssh-ed448':
321
-                lambda k: k.startswith(b'\x00\x00\x00\x09ssh-ed448'),
322
-            'ssh-rsa':
323
-                lambda k: k.startswith(b'\x00\x00\x00\x07ssh-rsa'),
324
-        }
325
-        if not any(v(key) for v in deterministic_signature_types.values()):
389
+        if not cls._is_suitable_ssh_key(key):
326 390
             raise ValueError(
327 391
                 'unsuitable SSH key: bad key, or signature not deterministic')
328 392
         with ssh_agent_client.SSHAgentClient() as client:
... ...
@@ -4,17 +4,22 @@
4 4
 
5 5
 """Test passphrase generation via derivepassphrase.Vault."""
6 6
 
7
-import pytest
7
+from __future__ import annotations
8
+
9
+import math
8 10
 
9 11
 import derivepassphrase
10 12
 import sequin
13
+import pytest
11 14
 
12 15
 Vault = derivepassphrase.Vault
13 16
 phrase = b'She cells C shells bye the sea shoars'
17
+google_phrase = rb': 4TVH#5:aZl8LueOT\{'
18
+twitter_phrase = rb"[ (HN_N:lI&<ro=)3'g9"
14 19
 
15
-@pytest.mark.parametrize('service,expected', [
16
-    (b'google', rb': 4TVH#5:aZl8LueOT\{'),
17
-    ('twitter', rb"[ (HN_N:lI&<ro=)3'g9"),
20
+@pytest.mark.parametrize(['service', 'expected'], [
21
+    (b'google', google_phrase),
22
+    ('twitter', twitter_phrase),
18 23
 ])
19 24
 def test_200_basic_configuration(service, expected):
20 25
     assert Vault(phrase=phrase).generate(service) == expected
... ...
@@ -122,3 +127,71 @@ def test_301_character_set_subtraction_duplicate():
122 127
         Vault._subtract(b'abcdef', b'aabbccddeeff')
123 128
     with pytest.raises(ValueError, match='duplicate characters'):
124 129
         Vault._subtract(b'aabbccddeeff', b'abcdef')
130
+
131
+@pytest.mark.parametrize(['length', 'settings', 'entropy'], [
132
+    (20, {}, math.log2(math.factorial(20)) + 20 * math.log2(94)),
133
+    (
134
+        20,
135
+        {'upper': 0, 'number': 0, 'space': 0, 'symbol': 0},
136
+        math.log2(math.factorial(20)) + 20 * math.log2(26)
137
+    ),
138
+    (0, {}, float('-inf')),
139
+    (0, {'lower': 0, 'number': 0, 'space': 0, 'symbol': 0}, float('-inf')),
140
+    (1, {}, math.log2(94)),
141
+    (1, {'upper': 0, 'lower': 0, 'number': 0, 'symbol': 0}, 0.0),
142
+])
143
+def test_400_entropy(
144
+    length: int, settings: dict[str, int], entropy: int
145
+) -> None:
146
+    v = Vault(length=length, **settings)
147
+    assert math.isclose(v._entropy(), entropy)
148
+    assert v._estimate_sufficient_hash_length() > 0
149
+    if math.isfinite(entropy) and entropy:
150
+        assert v._estimate_sufficient_hash_length(1.0) == math.ceil(entropy / 8)
151
+    assert v._estimate_sufficient_hash_length(8.0) >= entropy
152
+
153
+def test_401_hash_length_estimation(
154
+) -> None:
155
+    v = Vault(phrase=phrase)
156
+    with pytest.raises(ValueError,
157
+                       match='invalid safety factor'):
158
+        assert v._estimate_sufficient_hash_length(-1.0)
159
+    with pytest.raises(TypeError,
160
+                       match='invalid safety factor: not a float'):
161
+        assert v._estimate_sufficient_hash_length(None)  # type: ignore
162
+    v2 = Vault(phrase=phrase, lower=0, upper=0, number=0, symbol=0,
163
+               space=1, length=1)
164
+    assert v2._entropy() == 0.0
165
+    assert v2._estimate_sufficient_hash_length() > 0
166
+
167
+@pytest.mark.parametrize(['service', 'expected'], [
168
+    (b'google', google_phrase),
169
+    ('twitter', twitter_phrase),
170
+])
171
+def test_402_hash_length_expansion(
172
+    monkeypatch: Any, service: str | bytes, expected: bytes
173
+) -> None:
174
+    v = Vault(phrase=phrase)
175
+    monkeypatch.setattr(v,
176
+                        '_estimate_sufficient_hash_length',
177
+                        lambda *args, **kwargs: 1)
178
+    assert v._estimate_sufficient_hash_length
179
+    assert v.generate(service) == expected
180
+
181
+@pytest.mark.parametrize(['s', 'raises'], [
182
+    ('ñ', True), ('Düsseldorf', True),
183
+    ('liberté, egalité, fraternité', True), ('ASCII', False),
184
+    ('Düsseldorf'.encode('UTF-8'), False),
185
+    (bytearray([2, 3, 5, 7, 11, 13]), False),
186
+])
187
+def test_403_binary_strings(s: str | bytes | bytearray, raises: bool) -> None:
188
+    binstr = derivepassphrase.Vault._get_binary_string
189
+    if raises:
190
+        with pytest.raises(derivepassphrase.AmbiguousByteRepresentationError):
191
+            binstr(s)
192
+    elif isinstance(s, str):
193
+        assert binstr(s) == s.encode('UTF-8')
194
+        assert binstr(binstr(s)) == s.encode('UTF-8')
195
+    else:
196
+        assert binstr(s) == bytes(s)
197
+        assert binstr(binstr(s)) == bytes(s)
125 198