Fix numerous argument type or range errors
Marco Ricci

Marco Ricci commited on 2024-05-21 00:40:13
Zeige 3 geänderte Dateien mit 71 Einfügungen und 23 Löschungen.

... ...
@@ -122,7 +122,7 @@ class Vault:
122 122
         subtract_or_require(space, self._CHARSETS['space'])
123 123
         subtract_or_require(dash, self._CHARSETS['dash'])
124 124
         subtract_or_require(symbol, self._CHARSETS['symbol'])
125
-        if len(self._required) < self._length:
125
+        if len(self._required) > self._length:
126 126
             raise ValueError('requested passphrase length too short')
127 127
         if not self._allowed:
128 128
             raise ValueError('no allowed characters left')
... ...
@@ -175,9 +175,10 @@ class Vault:
175 175
                                    salt=message, iterations=8, dklen=length)
176 176
 
177 177
     def generate(
178
-        self, service_name: str, /, *, phrase: bytes | bytearray = b'',
179
-    ) -> bytes | bytearray:
180
-        """Generate a service passphrase.
178
+        self, service_name: str | bytes | bytearray, /, *,
179
+        phrase: bytes | bytearray = b'',
180
+    ) -> bytes:
181
+        r"""Generate a service passphrase.
181 182
 
182 183
         Args:
183 184
             service_name:
... ...
@@ -193,6 +194,13 @@ class Vault:
193 194
         # exactly the right length.
194 195
         safety_factor = 2
195 196
         hash_length = int(math.ceil(safety_factor * entropy_bound / 8))
197
+        # Ensure the phrase is a bytes object.  Needed later for safe
198
+        # concatenation.
199
+        if isinstance(service_name, str):
200
+            service_name = service_name.encode('utf-8')
201
+        elif not isinstance(service_name, bytes):
202
+            service_name = bytes(service_name)
203
+        assert_type(service_name, bytes)
196 204
         if not phrase:
197 205
             phrase = self._phrase
198 206
         # Repeat the passphrase generation with ever-increasing hash
... ...
@@ -203,8 +211,7 @@ class Vault:
203 211
             try:
204 212
                 required = self._required[:]
205 213
                 seq = sequin.Sequin(self.create_hash(
206
-                    key=phrase,
207
-                    message=(service_name.encode('utf-8') + self._UUID),
214
+                    key=phrase, message=(service_name + self._UUID),
208 215
                     length=hash_length))
209 216
                 result = bytearray()
210 217
                 while len(result) < self._length:
... ...
@@ -229,12 +236,12 @@ class Vault:
229 236
                         assert previous is not None  # for the type checker
230 237
                         charset = self._subtract(bytes([previous]), charset)
231 238
                     # End checking for repeated characters.
232
-                    index = seq.generate(len(charset))
233
-                    result.extend(charset[index:index+1])
239
+                    pos = seq.generate(len(charset))
240
+                    result.extend(charset[pos:pos+1])
234 241
             except sequin.SequinExhaustedException:
235 242
                 hash_length *= 2
236 243
             else:
237
-                return result
244
+                return bytes(result)
238 245
 
239 246
     @classmethod
240 247
     def phrase_from_signature(
... ...
@@ -298,7 +305,7 @@ class Vault:
298 305
         for c in charset:
299 306
             try:
300 307
                 pos = allowed.index(c)
301
-            except LookupError:
308
+            except ValueError:
302 309
                 pass
303 310
             else:
304 311
                 allowed[pos:pos+1] = []
... ...
@@ -24,7 +24,7 @@ from __future__ import annotations
24 24
 import collections
25 25
 import math
26 26
 
27
-from collections.abc import Sequence, MutableSequence
27
+from collections.abc import Iterator, MutableSequence, Sequence
28 28
 from typing import assert_type, Literal, TypeAlias
29 29
 
30 30
 __all__ = ('Sequin', 'SequinExhaustedException')
... ...
@@ -54,6 +54,7 @@ class Sequin:
54 54
     def __init__(
55 55
         self,
56 56
         sequence: str | bytes | bytearray | Sequence[int],
57
+        /, *, is_bitstring: bool = False
57 58
     ):
58 59
         """Initialize the Sequin.
59 60
 
... ...
@@ -65,6 +66,15 @@ class Sequin:
65 66
                 (Conversion will fail if the text string contains
66 67
                 non-ISO-8859-1 characters.)  The numbers are then
67 68
                 converted to bits.
69
+            is_bitstring:
70
+                If true, treat the input as a bitstring.  By default,
71
+                the input is treated as a string of 8-bit integers, from
72
+                which the individual bits must still be extracted.
73
+
74
+        Raises:
75
+            ValueError:
76
+                The sequence contains values outside the permissible
77
+                range.
68 78
 
69 79
         """
70 80
         def uint8_to_bits(value):
... ...
@@ -72,13 +82,23 @@ class Sequin:
72 82
             for i in (0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01):
73 83
                 yield 1 if value | i == value else 0
74 84
         if isinstance(sequence, str):
85
+            try:
75 86
                 sequence = tuple(sequence.encode('iso-8859-1'))
87
+            except UnicodeError as e:
88
+                raise ValueError('sequence item out of range') from e
76 89
         else:
77 90
             sequence = tuple(sequence)
78 91
         assert_type(sequence, tuple[int, ...])
79
-        self.bases: dict[int, MutableSequence[int]] = {}
80
-        gen = (bit for num in sequence for bit in uint8_to_bits(num))
81
-        self.bases[2] = collections.deque(gen)
92
+        self.bases: dict[int, collections.deque[int]] = {}
93
+        def gen() -> Iterator[int]:
94
+            for num in sequence:
95
+                if num not in range(2 if is_bitstring else 256):
96
+                    raise ValueError('sequence item out of range')
97
+                if is_bitstring:
98
+                    yield num
99
+                else:
100
+                    yield from uint8_to_bits(num)
101
+        self.bases[2] = collections.deque(gen())
82 102
 
83 103
     def _all_or_nothing_shift(
84 104
         self, count: int, /, *, base: int = 2
... ...
@@ -99,7 +119,8 @@ class Sequin:
99 119
             sequences.
100 120
 
101 121
         Examples:
102
-            >>> seq = Sequin([1, 0, 1, 0, 0, 1, 0, 0, 0, 1])
122
+            >>> seq = Sequin([1, 0, 1, 0, 0, 1, 0, 0, 0, 1],
123
+            ...              is_bitstring=True)
103 124
             >>> seq.bases
104 125
             {2: deque([1, 0, 1, 0, 0, 1, 0, 0, 0, 1])}
105 126
             >>> seq._all_or_nothing_shift(3)
... ...
@@ -122,15 +143,17 @@ class Sequin:
122 143
             seq = self.bases[base]
123 144
         except KeyError:
124 145
             return ()
125
-        else:
126
-            chunk = tuple(seq[:count])
127
-            if len(chunk) == count:
128
-                del seq[:count]
146
+        stash: collections.deque[int] = collections.deque()
147
+        try:
148
+            for i in range(count):
149
+                stash.append(seq.popleft())
150
+        except IndexError:
151
+            seq.extendleft(reversed(stash))
152
+            return ()
129 153
         # Clean up queues.
130 154
         if not seq:
131 155
             del self.bases[base]
132
-                return chunk
133
-            return ()
156
+        return tuple(stash)
134 157
 
135 158
     @staticmethod
136 159
     def _big_endian_number(
... ...
@@ -192,16 +215,20 @@ class Sequin:
192 215
         Args:
193 216
             n:
194 217
                 Generate numbers in the range 0, ..., `n` - 1.
195
-                (Inclusive.)
218
+                (Inclusive.)  Must be larger than 0.
196 219
 
197 220
         Returns:
198 221
             A pseudorandom number in the range 0, ..., `n` - 1.
199 222
 
200 223
         Raises:
224
+            ValueError:
225
+                The range is empty.
201 226
             SequinExhaustedException:
202 227
                 The sequin is exhausted.
203 228
 
204 229
         """
230
+        if 2 not in self.bases:
231
+            raise SequinExhaustedException('Sequin is exhausted')
205 232
         value = self._generate_inner(n, base=2)
206 233
         if value == n:
207 234
             raise SequinExhaustedException('Sequin is exhausted')
... ...
@@ -225,7 +252,7 @@ class Sequin:
225 252
         Args:
226 253
             n:
227 254
                 Generate numbers in the range 0, ..., `n` - 1.
228
-                (Inclusive.)
255
+                (Inclusive.)  Must be larger than 0.
229 256
             base:
230 257
                 Use the base `base` sequence as a source for
231 258
                 pseudorandom numbers.
... ...
@@ -234,7 +261,15 @@ class Sequin:
234 261
             A pseudorandom number in the range 0, ..., `n` - 1 if
235 262
             possible, or `n` if the stream is exhausted.
236 263
 
264
+        Raises:
265
+            ValueError:
266
+                The range is empty.
267
+
237 268
         """
269
+        if n < 1:
270
+            raise ValueError('invalid target range')
271
+        if base < 2:
272
+            raise ValueError(f'invalid base: {base!r}')
238 273
         # p = base ** k, where k is the smallest integer such that
239 274
         # p >= n.  We determine p and k inductively.
240 275
         p = 1
... ...
@@ -251,7 +286,10 @@ class Sequin:
251 286
         while v > n - 1:
252 287
             list_slice = self._all_or_nothing_shift(k, base=base)
253 288
             if not list_slice:
289
+                if n != 1:
254 290
                     return n
291
+                else:
292
+                    v = 0
255 293
             v = self._big_endian_number(list_slice, base=base)
256 294
             if v > n - 1:
257 295
                 # If r is 0, then p == n, so v < n, or rather
... ...
@@ -106,10 +106,13 @@ class SSHAgentClient:
106 106
     @classmethod
107 107
     def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray:
108 108
         """Format the payload as an SSH string, as per the agent protocol."""
109
+        try:
109 110
             ret = bytearray()
110 111
             ret.extend(cls.uint32(len(payload)))
111 112
             ret.extend(payload)
112 113
             return ret
114
+        except Exception as e:
115
+            raise TypeError('invalid payload type') from e
113 116
 
114 117
     @classmethod
115 118
     def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray:
116 119