Marco Ricci commited on 2024-05-21 00:40:13
Zeige 3 geänderte Dateien mit 71 Einfügungen und 23 Löschungen.
... | ... |
@@ -122,7 +122,7 @@ class Vault: |
122 | 122 |
subtract_or_require(space, self._CHARSETS['space']) |
123 | 123 |
subtract_or_require(dash, self._CHARSETS['dash']) |
124 | 124 |
subtract_or_require(symbol, self._CHARSETS['symbol']) |
125 |
- if len(self._required) < self._length: |
|
125 |
+ if len(self._required) > self._length: |
|
126 | 126 |
raise ValueError('requested passphrase length too short') |
127 | 127 |
if not self._allowed: |
128 | 128 |
raise ValueError('no allowed characters left') |
... | ... |
@@ -175,9 +175,10 @@ class Vault: |
175 | 175 |
salt=message, iterations=8, dklen=length) |
176 | 176 |
|
177 | 177 |
def generate( |
178 |
- self, service_name: str, /, *, phrase: bytes | bytearray = b'', |
|
179 |
- ) -> bytes | bytearray: |
|
180 |
- """Generate a service passphrase. |
|
178 |
+ self, service_name: str | bytes | bytearray, /, *, |
|
179 |
+ phrase: bytes | bytearray = b'', |
|
180 |
+ ) -> bytes: |
|
181 |
+ r"""Generate a service passphrase. |
|
181 | 182 |
|
182 | 183 |
Args: |
183 | 184 |
service_name: |
... | ... |
@@ -193,6 +194,13 @@ class Vault: |
193 | 194 |
# exactly the right length. |
194 | 195 |
safety_factor = 2 |
195 | 196 |
hash_length = int(math.ceil(safety_factor * entropy_bound / 8)) |
197 |
+ # Ensure the phrase is a bytes object. Needed later for safe |
|
198 |
+ # concatenation. |
|
199 |
+ if isinstance(service_name, str): |
|
200 |
+ service_name = service_name.encode('utf-8') |
|
201 |
+ elif not isinstance(service_name, bytes): |
|
202 |
+ service_name = bytes(service_name) |
|
203 |
+ assert_type(service_name, bytes) |
|
196 | 204 |
if not phrase: |
197 | 205 |
phrase = self._phrase |
198 | 206 |
# Repeat the passphrase generation with ever-increasing hash |
... | ... |
@@ -203,8 +211,7 @@ class Vault: |
203 | 211 |
try: |
204 | 212 |
required = self._required[:] |
205 | 213 |
seq = sequin.Sequin(self.create_hash( |
206 |
- key=phrase, |
|
207 |
- message=(service_name.encode('utf-8') + self._UUID), |
|
214 |
+ key=phrase, message=(service_name + self._UUID), |
|
208 | 215 |
length=hash_length)) |
209 | 216 |
result = bytearray() |
210 | 217 |
while len(result) < self._length: |
... | ... |
@@ -229,12 +236,12 @@ class Vault: |
229 | 236 |
assert previous is not None # for the type checker |
230 | 237 |
charset = self._subtract(bytes([previous]), charset) |
231 | 238 |
# End checking for repeated characters. |
232 |
- index = seq.generate(len(charset)) |
|
233 |
- result.extend(charset[index:index+1]) |
|
239 |
+ pos = seq.generate(len(charset)) |
|
240 |
+ result.extend(charset[pos:pos+1]) |
|
234 | 241 |
except sequin.SequinExhaustedException: |
235 | 242 |
hash_length *= 2 |
236 | 243 |
else: |
237 |
- return result |
|
244 |
+ return bytes(result) |
|
238 | 245 |
|
239 | 246 |
@classmethod |
240 | 247 |
def phrase_from_signature( |
... | ... |
@@ -298,7 +305,7 @@ class Vault: |
298 | 305 |
for c in charset: |
299 | 306 |
try: |
300 | 307 |
pos = allowed.index(c) |
301 |
- except LookupError: |
|
308 |
+ except ValueError: |
|
302 | 309 |
pass |
303 | 310 |
else: |
304 | 311 |
allowed[pos:pos+1] = [] |
... | ... |
@@ -24,7 +24,7 @@ from __future__ import annotations |
24 | 24 |
import collections |
25 | 25 |
import math |
26 | 26 |
|
27 |
-from collections.abc import Sequence, MutableSequence |
|
27 |
+from collections.abc import Iterator, MutableSequence, Sequence |
|
28 | 28 |
from typing import assert_type, Literal, TypeAlias |
29 | 29 |
|
30 | 30 |
__all__ = ('Sequin', 'SequinExhaustedException') |
... | ... |
@@ -54,6 +54,7 @@ class Sequin: |
54 | 54 |
def __init__( |
55 | 55 |
self, |
56 | 56 |
sequence: str | bytes | bytearray | Sequence[int], |
57 |
+ /, *, is_bitstring: bool = False |
|
57 | 58 |
): |
58 | 59 |
"""Initialize the Sequin. |
59 | 60 |
|
... | ... |
@@ -65,6 +66,15 @@ class Sequin: |
65 | 66 |
(Conversion will fail if the text string contains |
66 | 67 |
non-ISO-8859-1 characters.) The numbers are then |
67 | 68 |
converted to bits. |
69 |
+ is_bitstring: |
|
70 |
+ If true, treat the input as a bitstring. By default, |
|
71 |
+ the input is treated as a string of 8-bit integers, from |
|
72 |
+ which the individual bits must still be extracted. |
|
73 |
+ |
|
74 |
+ Raises: |
|
75 |
+ ValueError: |
|
76 |
+ The sequence contains values outside the permissible |
|
77 |
+ range. |
|
68 | 78 |
|
69 | 79 |
""" |
70 | 80 |
def uint8_to_bits(value): |
... | ... |
@@ -72,13 +82,23 @@ class Sequin: |
72 | 82 |
for i in (0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01): |
73 | 83 |
yield 1 if value | i == value else 0 |
74 | 84 |
if isinstance(sequence, str): |
85 |
+ try: |
|
75 | 86 |
sequence = tuple(sequence.encode('iso-8859-1')) |
87 |
+ except UnicodeError as e: |
|
88 |
+ raise ValueError('sequence item out of range') from e |
|
76 | 89 |
else: |
77 | 90 |
sequence = tuple(sequence) |
78 | 91 |
assert_type(sequence, tuple[int, ...]) |
79 |
- self.bases: dict[int, MutableSequence[int]] = {} |
|
80 |
- gen = (bit for num in sequence for bit in uint8_to_bits(num)) |
|
81 |
- self.bases[2] = collections.deque(gen) |
|
92 |
+ self.bases: dict[int, collections.deque[int]] = {} |
|
93 |
+ def gen() -> Iterator[int]: |
|
94 |
+ for num in sequence: |
|
95 |
+ if num not in range(2 if is_bitstring else 256): |
|
96 |
+ raise ValueError('sequence item out of range') |
|
97 |
+ if is_bitstring: |
|
98 |
+ yield num |
|
99 |
+ else: |
|
100 |
+ yield from uint8_to_bits(num) |
|
101 |
+ self.bases[2] = collections.deque(gen()) |
|
82 | 102 |
|
83 | 103 |
def _all_or_nothing_shift( |
84 | 104 |
self, count: int, /, *, base: int = 2 |
... | ... |
@@ -99,7 +119,8 @@ class Sequin: |
99 | 119 |
sequences. |
100 | 120 |
|
101 | 121 |
Examples: |
102 |
- >>> seq = Sequin([1, 0, 1, 0, 0, 1, 0, 0, 0, 1]) |
|
122 |
+ >>> seq = Sequin([1, 0, 1, 0, 0, 1, 0, 0, 0, 1], |
|
123 |
+ ... is_bitstring=True) |
|
103 | 124 |
>>> seq.bases |
104 | 125 |
{2: deque([1, 0, 1, 0, 0, 1, 0, 0, 0, 1])} |
105 | 126 |
>>> seq._all_or_nothing_shift(3) |
... | ... |
@@ -122,15 +143,17 @@ class Sequin: |
122 | 143 |
seq = self.bases[base] |
123 | 144 |
except KeyError: |
124 | 145 |
return () |
125 |
- else: |
|
126 |
- chunk = tuple(seq[:count]) |
|
127 |
- if len(chunk) == count: |
|
128 |
- del seq[:count] |
|
146 |
+ stash: collections.deque[int] = collections.deque() |
|
147 |
+ try: |
|
148 |
+ for i in range(count): |
|
149 |
+ stash.append(seq.popleft()) |
|
150 |
+ except IndexError: |
|
151 |
+ seq.extendleft(reversed(stash)) |
|
152 |
+ return () |
|
129 | 153 |
# Clean up queues. |
130 | 154 |
if not seq: |
131 | 155 |
del self.bases[base] |
132 |
- return chunk |
|
133 |
- return () |
|
156 |
+ return tuple(stash) |
|
134 | 157 |
|
135 | 158 |
@staticmethod |
136 | 159 |
def _big_endian_number( |
... | ... |
@@ -192,16 +215,20 @@ class Sequin: |
192 | 215 |
Args: |
193 | 216 |
n: |
194 | 217 |
Generate numbers in the range 0, ..., `n` - 1. |
195 |
- (Inclusive.) |
|
218 |
+ (Inclusive.) Must be larger than 0. |
|
196 | 219 |
|
197 | 220 |
Returns: |
198 | 221 |
A pseudorandom number in the range 0, ..., `n` - 1. |
199 | 222 |
|
200 | 223 |
Raises: |
224 |
+ ValueError: |
|
225 |
+ The range is empty. |
|
201 | 226 |
SequinExhaustedException: |
202 | 227 |
The sequin is exhausted. |
203 | 228 |
|
204 | 229 |
""" |
230 |
+ if 2 not in self.bases: |
|
231 |
+ raise SequinExhaustedException('Sequin is exhausted') |
|
205 | 232 |
value = self._generate_inner(n, base=2) |
206 | 233 |
if value == n: |
207 | 234 |
raise SequinExhaustedException('Sequin is exhausted') |
... | ... |
@@ -225,7 +252,7 @@ class Sequin: |
225 | 252 |
Args: |
226 | 253 |
n: |
227 | 254 |
Generate numbers in the range 0, ..., `n` - 1. |
228 |
- (Inclusive.) |
|
255 |
+ (Inclusive.) Must be larger than 0. |
|
229 | 256 |
base: |
230 | 257 |
Use the base `base` sequence as a source for |
231 | 258 |
pseudorandom numbers. |
... | ... |
@@ -234,7 +261,15 @@ class Sequin: |
234 | 261 |
A pseudorandom number in the range 0, ..., `n` - 1 if |
235 | 262 |
possible, or `n` if the stream is exhausted. |
236 | 263 |
|
264 |
+ Raises: |
|
265 |
+ ValueError: |
|
266 |
+ The range is empty. |
|
267 |
+ |
|
237 | 268 |
""" |
269 |
+ if n < 1: |
|
270 |
+ raise ValueError('invalid target range') |
|
271 |
+ if base < 2: |
|
272 |
+ raise ValueError(f'invalid base: {base!r}') |
|
238 | 273 |
# p = base ** k, where k is the smallest integer such that |
239 | 274 |
# p >= n. We determine p and k inductively. |
240 | 275 |
p = 1 |
... | ... |
@@ -251,7 +286,10 @@ class Sequin: |
251 | 286 |
while v > n - 1: |
252 | 287 |
list_slice = self._all_or_nothing_shift(k, base=base) |
253 | 288 |
if not list_slice: |
289 |
+ if n != 1: |
|
254 | 290 |
return n |
291 |
+ else: |
|
292 |
+ v = 0 |
|
255 | 293 |
v = self._big_endian_number(list_slice, base=base) |
256 | 294 |
if v > n - 1: |
257 | 295 |
# If r is 0, then p == n, so v < n, or rather |
... | ... |
@@ -106,10 +106,13 @@ class SSHAgentClient: |
106 | 106 |
@classmethod |
107 | 107 |
def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray: |
108 | 108 |
"""Format the payload as an SSH string, as per the agent protocol.""" |
109 |
+ try: |
|
109 | 110 |
ret = bytearray() |
110 | 111 |
ret.extend(cls.uint32(len(payload))) |
111 | 112 |
ret.extend(payload) |
112 | 113 |
return ret |
114 |
+ except Exception as e: |
|
115 |
+ raise TypeError('invalid payload type') from e |
|
113 | 116 |
|
114 | 117 |
@classmethod |
115 | 118 |
def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray: |
116 | 119 |