Lock derivepassphrase internals against concurrent updating
Marco Ricci

Marco Ricci commited on 2025-02-28 20:08:07
Zeige 2 geänderte Dateien mit 80 Einfügungen und 29 Löschungen.


Specifically, protect the attributes of Sequin and
VaultNativeConfigParser objects (and derived objects) from concurrent
modification via threading locks, and document that they are
thread-safe.
... ...
@@ -32,6 +32,7 @@ import json
32 32
 import logging
33 33
 import os
34 34
 import pathlib
35
+import threading
35 36
 import warnings
36 37
 from typing import TYPE_CHECKING
37 38
 
... ...
@@ -176,6 +177,7 @@ class VaultNativeConfigParser(abc.ABC):
176 177
         if not password:
177 178
             msg = 'Password must not be empty'
178 179
             raise ValueError(msg)
180
+        self._consistency_lock = threading.RLock()
179 181
         self._contents = bytes(contents)
180 182
         self._iv_size = 0
181 183
         self._mac_size = 0
... ...
@@ -204,6 +206,7 @@ class VaultNativeConfigParser(abc.ABC):
204 206
                 unexpected extra contents, or invalid padding.)
205 207
 
206 208
         """
209
+        with self._consistency_lock:
207 210
             if self._data is self._sentinel:
208 211
                 self._parse_contents()
209 212
                 self._derive_keys()
... ...
@@ -283,26 +286,35 @@ class VaultNativeConfigParser(abc.ABC):
283 286
             ),
284 287
         )
285 288
 
286
-        if len(self._contents) < self._iv_size + 16 + self._mac_size:
287
-            msg = 'Invalid vault configuration file: file is truncated'
288
-            raise ValueError(msg)
289
-
290 289
         def cut(buffer: bytes, cutpoint: int) -> tuple[bytes, bytes]:
291 290
             return buffer[:cutpoint], buffer[cutpoint:]
292 291
 
293
-        cutpos1 = len(self._contents) - self._mac_size
294
-        cutpos2 = self._iv_size
292
+        with self._consistency_lock:
293
+            contents = self._contents
294
+            iv_size = self._iv_size
295
+            mac_size = self._mac_size
296
+
297
+            if len(contents) < iv_size + 16 + mac_size:
298
+                msg = 'Invalid vault configuration file: file is truncated'
299
+                raise ValueError(msg)
295 300
 
296
-        self._message, self._message_tag = cut(self._contents, cutpos1)
297
-        self._iv, self._payload = cut(self._message, cutpos2)
301
+            cutpos1 = len(contents) - mac_size
302
+            cutpos2 = iv_size
303
+            message, message_tag = cut(contents, cutpos1)
304
+            iv, payload = cut(message, cutpos2)
305
+
306
+            self._message = message
307
+            self._message_tag = message_tag
308
+            self._iv = iv
309
+            self._payload = payload
298 310
 
299 311
         logger.debug(
300 312
             _msg.TranslatedString(
301 313
                 _msg.DebugMsgTemplate.VAULT_NATIVE_PARSE_BUFFER,
302
-                contents=_h(self._contents),
303
-                iv=_h(self._iv),
304
-                payload=_h(self._payload),
305
-                mac=_h(self._message_tag),
314
+                contents=_h(contents),
315
+                iv=_h(iv),
316
+                payload=_h(payload),
317
+                mac=_h(message_tag),
306 318
             ),
307 319
         )
308 320
 
... ...
@@ -318,6 +330,7 @@ class VaultNativeConfigParser(abc.ABC):
318 330
                 _msg.InfoMsgTemplate.VAULT_NATIVE_DERIVING_KEYS,
319 331
             ),
320 332
         )
333
+        with self._consistency_lock:
321 334
             self._generate_keys()
322 335
             assert len(self._encryption_key) == self._encryption_key_size, (
323 336
                 'Derived encryption key is invalid'
... ...
@@ -356,18 +369,20 @@ class VaultNativeConfigParser(abc.ABC):
356 369
                 _msg.InfoMsgTemplate.VAULT_NATIVE_CHECKING_MAC,
357 370
             ),
358 371
         )
372
+        with self._consistency_lock:
359 373
             mac = hmac.HMAC(self._signing_key, hashes.SHA256())
360 374
             mac_input = self._hmac_input()
375
+            mac_expected = self._message_tag
361 376
         logger.debug(
362 377
             _msg.TranslatedString(
363 378
                 _msg.DebugMsgTemplate.VAULT_NATIVE_CHECKING_MAC_DETAILS,
364 379
                 mac_input=_h(mac_input),
365
-                mac=_h(self._message_tag),
380
+                mac=_h(mac_expected),
366 381
             ),
367 382
         )
368 383
         mac.update(mac_input)
369 384
         try:
370
-            mac.verify(self._message_tag)
385
+            mac.verify(mac_expected)
371 386
         except crypt_exceptions.InvalidSignature:
372 387
             msg = 'File does not contain a valid signature'
373 388
             raise ValueError(msg) from None
... ...
@@ -399,9 +414,12 @@ class VaultNativeConfigParser(abc.ABC):
399 414
                 _msg.InfoMsgTemplate.VAULT_NATIVE_DECRYPTING_CONTENTS,
400 415
             ),
401 416
         )
417
+        with self._consistency_lock:
418
+            payload = self._payload
419
+            iv_size = self._iv_size
402 420
         decryptor = self._make_decryptor()
403 421
         padded_plaintext = bytearray()
404
-        padded_plaintext.extend(decryptor.update(self._payload))
422
+        padded_plaintext.extend(decryptor.update(payload))
405 423
         padded_plaintext.extend(decryptor.finalize())
406 424
         logger.debug(
407 425
             _msg.TranslatedString(
... ...
@@ -409,7 +427,7 @@ class VaultNativeConfigParser(abc.ABC):
409 427
                 contents=_h(padded_plaintext),
410 428
             ),
411 429
         )
412
-        unpadder = padding.PKCS7(self._iv_size * 8).unpadder()
430
+        unpadder = padding.PKCS7(iv_size * 8).unpadder()
413 431
         plaintext = bytearray()
414 432
         plaintext.extend(unpadder.update(padded_plaintext))
415 433
         plaintext.extend(unpadder.finalize())
... ...
@@ -480,8 +498,13 @@ class VaultNativeV03ConfigParser(VaultNativeConfigParser):
480 498
             moderately determined attackers!
481 499
 
482 500
         """
483
-        self._encryption_key = self._pbkdf2(self._password, self.KEY_SIZE, 100)
484
-        self._signing_key = self._pbkdf2(self._password, self.KEY_SIZE, 200)
501
+        with self._consistency_lock:
502
+            self._encryption_key = self._pbkdf2(
503
+                self._password, self.KEY_SIZE, 100
504
+            )
505
+            self._signing_key = self._pbkdf2(
506
+                self._password, self.KEY_SIZE, 200
507
+            )
485 508
             self._encryption_key_size = self._signing_key_size = self.KEY_SIZE
486 509
 
487 510
     def _hmac_input(self) -> bytes:
... ...
@@ -500,8 +523,11 @@ class VaultNativeV03ConfigParser(VaultNativeConfigParser):
500 523
         (MAC-verified) message payload.
501 524
 
502 525
         """
526
+        with self._consistency_lock:
527
+            encryption_key = self._encryption_key
528
+            iv = self._iv
503 529
         return ciphers.Cipher(
504
-            algorithms.AES256(self._encryption_key), modes.CBC(self._iv)
530
+            algorithms.AES256(encryption_key), modes.CBC(iv)
505 531
         ).decryptor()
506 532
 
507 533
 
... ...
@@ -545,14 +571,17 @@ class VaultNativeV02ConfigParser(VaultNativeConfigParser):
545 571
                 properly.
546 572
 
547 573
         """
574
+        with self._consistency_lock:
548 575
             super()._parse_contents()
549
-        self._payload = base64.standard_b64decode(self._payload)
550
-        self._message_tag = bytes.fromhex(self._message_tag.decode('ASCII'))
576
+            payload = self._payload = base64.standard_b64decode(self._payload)
577
+            message_tag = self._message_tag = bytes.fromhex(
578
+                self._message_tag.decode('ASCII')
579
+            )
551 580
         logger.debug(
552 581
             _msg.TranslatedString(
553 582
                 _msg.DebugMsgTemplate.VAULT_NATIVE_V02_PAYLOAD_MAC_POSTPROCESSING,
554
-                payload=_h(self._payload),
555
-                mac=_h(self._message_tag),
583
+                payload=_h(payload),
584
+                mac=_h(message_tag),
556 585
             ),
557 586
         )
558 587
 
... ...
@@ -580,6 +609,7 @@ class VaultNativeV02ConfigParser(VaultNativeConfigParser):
580 609
             access by even moderately determined attackers!
581 610
 
582 611
         """
612
+        with self._consistency_lock:
583 613
             self._encryption_key = self._pbkdf2(self._password, 8, 16)
584 614
             self._signing_key = self._pbkdf2(self._password, 16, 16)
585 615
             self._encryption_key_size = 8
... ...
@@ -713,6 +743,7 @@ class VaultNativeV02ConfigParser(VaultNativeConfigParser):
713 743
             determined attackers!
714 744
 
715 745
         """
746
+        with self._consistency_lock:
716 747
             data = base64.standard_b64encode(self._iv + self._encryption_key)
717 748
         encryption_key, iv = self._evp_bytestokey_md5_one_iteration_no_salt(
718 749
             data, key_size=32, iv_size=16
... ...
@@ -23,6 +23,7 @@ The main API is the [`Sequin`][] class, which is thoroughly documented.
23 23
 from __future__ import annotations
24 24
 
25 25
 import collections
26
+import threading
26 27
 from typing import TYPE_CHECKING
27 28
 
28 29
 from typing_extensions import assert_type
... ...
@@ -99,18 +100,19 @@ class Sequin:
99 100
         else:
100 101
             sequence = tuple(sequence)
101 102
         assert_type(sequence, tuple[int, ...])
102
-        self.bases: dict[int, collections.deque[int]] = {}
103
-
104
-        def gen() -> Iterator[int]:
103
+        consistency_lock = threading.RLock()
104
+        bitstream: collections.deque[int] = collections.deque()
105 105
         for num in sequence:
106 106
             if num not in range(2 if is_bitstring else 256):
107 107
                 raise ValueError(msg)
108 108
             if is_bitstring:
109
-                    yield num
109
+                bitstream.append(num)
110 110
             else:
111
-                    yield from uint8_to_bits(num)
111
+                bitstream.extend(uint8_to_bits(num))
112 112
 
113
-        self.bases[2] = collections.deque(gen())
113
+        with consistency_lock:
114
+            self.consistency_lock = consistency_lock
115
+            self.bases = {2: bitstream}
114 116
 
115 117
     def _all_or_nothing_shift(
116 118
         self, count: int, /, *, base: int = 2
... ...
@@ -126,6 +128,9 @@ class Sequin:
126 128
             consume them from the sequence and return them.  Otherwise,
127 129
             consume nothing, and return nothing.
128 130
 
131
+        Info: Thread-safety
132
+            This call is thread-safe.
133
+
129 134
         Notes:
130 135
             We currently remove now-empty sequences from the registry of
131 136
             sequences.
... ...
@@ -150,6 +155,7 @@ class Sequin:
150 155
             False
151 156
 
152 157
         """
158
+        with self.consistency_lock:
153 159
             try:
154 160
                 seq = self.bases[base]
155 161
             except KeyError:
... ...
@@ -181,6 +187,9 @@ class Sequin:
181 187
             ValueError: `base` is an invalid base.
182 188
             ValueError: Not all integers are valid base `base` digits.
183 189
 
190
+        Info: Thread-safety
191
+            This call is thread-safe.
192
+
184 193
         Examples:
185 194
             >>> Sequin._big_endian_number([1, 2, 3, 4, 5, 6, 7, 8], base=10)
186 195
             12345678
... ...
@@ -235,6 +244,9 @@ class Sequin:
235 244
             SequinExhaustedError:
236 245
                 The sequin is exhausted.
237 246
 
247
+        Info: Thread-safety
248
+            This call is thread-safe.
249
+
238 250
         Examples:
239 251
             >>> seq = Sequin(
240 252
             ...     [1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1],
... ...
@@ -268,6 +280,7 @@ class Sequin:
268 280
             SequinExhaustedError: Sequin is exhausted
269 281
 
270 282
         """
283
+        with self.consistency_lock:
271 284
             if 2 not in self.bases:  # noqa: PLR2004
272 285
                 raise SequinExhaustedError
273 286
             value = self._generate_inner(n, base=2)
... ...
@@ -304,6 +317,9 @@ class Sequin:
304 317
             ValueError:
305 318
                 The range is empty.
306 319
 
320
+        Warning: Thread-safety
321
+            This call is **not thread-safe**.
322
+
307 323
         Examples:
308 324
             >>> seq = Sequin(
309 325
             ...     [1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1],
... ...
@@ -379,7 +395,11 @@ class Sequin:
379 395
 
380 396
         Sets up the base `base` sequence if it does not yet exist.
381 397
 
398
+        Info: Thread-safety
399
+            This call is thread-safe.
400
+
382 401
         """
402
+        with self.consistency_lock:
383 403
             if base not in self.bases:
384 404
                 self.bases[base] = collections.deque()
385 405
             self.bases[base].append(value)
386 406