Rename vault v0.2/v0.3 classes, and fix API weirdnesses and test coverage
Marco Ricci

Marco Ricci commited on 2024-08-31 21:26:49
Zeige 3 geänderte Dateien mit 128 Einfügungen und 49 Löschungen.


The `Reader` classes are now named `VaultNativeConfigParser`, because
`Reader` is a very non-descript name.  The interface is now somewhat more
pythonic, and uses fewer internal branches.  Further tests of the
internals have been added, so the module has 100% test coverage.
... ...
@@ -47,7 +47,7 @@ def _load_data(
47 47
                 raise ModuleNotFoundError
48 48
             with open(path, 'rb') as infile:
49 49
                 contents = base64.standard_b64decode(infile.read())
50
-            return module.V02Reader(contents, key).run()
50
+            return module.VaultNativeV02ConfigParser(contents, key)()
51 51
         case 'v0.3':
52 52
             module = importlib.import_module(
53 53
                 'derivepassphrase.exporter.vault_v03_and_below'
... ...
@@ -56,7 +56,7 @@ def _load_data(
56 56
                 raise ModuleNotFoundError
57 57
             with open(path, 'rb') as infile:
58 58
                 contents = base64.standard_b64decode(infile.read())
59
-            return module.V03Reader(contents, key).run()
59
+            return module.VaultNativeV03ConfigParser(contents, key)()
60 60
         case 'storeroom':
61 61
             module = importlib.import_module(
62 62
                 'derivepassphrase.exporter.storeroom'
... ...
@@ -7,7 +7,14 @@ import base64
7 7
 import json
8 8
 import logging
9 9
 import warnings
10
-from typing import TYPE_CHECKING, Any
10
+from typing import TYPE_CHECKING
11
+
12
+from derivepassphrase import exporter, vault
13
+
14
+if TYPE_CHECKING:
15
+    from typing import Any
16
+
17
+    from typing_extensions import Buffer
11 18
 
12 19
 if TYPE_CHECKING:
13 20
     from cryptography import exceptions as crypt_exceptions
... ...
@@ -46,8 +53,6 @@ else:
46 53
     else:
47 54
         STUBBED = False
48 55
 
49
-from derivepassphrase import exporter, vault
50
-
51 56
 logger = logging.getLogger(__name__)
52 57
 
53 58
 
... ...
@@ -55,32 +60,37 @@ def _h(bs: bytes | bytearray) -> str:
55 60
     return 'bytes.fromhex({!r})'.format(bs.hex(' '))
56 61
 
57 62
 
58
-class Reader(abc.ABC):
59
-    def __init__(
60
-        self, contents: bytes | bytearray, password: str | bytes | bytearray
61
-    ) -> None:
63
+class VaultNativeConfigParser(abc.ABC):
64
+    def __init__(self, contents: Buffer, password: str | Buffer) -> None:
62 65
         if not password:
63
-            msg = 'No password given; check VAULT_KEY environment variable'
66
+            msg = 'Password must not be empty'
64 67
             raise ValueError(msg)
65
-        self.contents = contents
66
-        self.password = password
68
+        self._contents = bytes(contents)
67 69
         self.iv_size = 0
68 70
         self.mac_size = 0
69 71
         self.encryption_key = b''
70 72
         self.encryption_key_size = 0
71 73
         self.signing_key = b''
72 74
         self.signing_key_size = 0
73
-
74
-    def run(self) -> Any:
75
+        self.message = b''
76
+        self.message_tag = b''
77
+        self.iv = b''
78
+        self.payload = b''
79
+        self._password = password
80
+        self._sentinel: object = object()
81
+        self._data: Any = self._sentinel
82
+
83
+    def __call__(self) -> Any:
84
+        if self._data is self._sentinel:
75 85
             self._parse_contents()
76 86
             self._derive_keys()
77 87
             self._check_signature()
78
-        self._decrypt_payload()
88
+            self._data = self._decrypt_payload()
79 89
         return self._data
80 90
 
81 91
     @staticmethod
82 92
     def pbkdf2(
83
-        password: str | bytes | bytearray, key_size: int, iterations: int
93
+        password: str | Buffer, key_size: int, iterations: int
84 94
     ) -> bytes:
85 95
         if isinstance(password, str):
86 96
             password = password.encode('utf-8')
... ...
@@ -89,7 +99,7 @@ class Reader(abc.ABC):
89 99
             length=key_size // 2,
90 100
             salt=vault.Vault._UUID,  # noqa: SLF001
91 101
             iterations=iterations,
92
-        ).derive(password)
102
+        ).derive(bytes(password))
93 103
         logger.debug(
94 104
             'binary = pbkdf2(%s, %s, %s, %s, %s) = %s -> %s',
95 105
             repr(password),
... ...
@@ -105,21 +115,22 @@ class Reader(abc.ABC):
105 115
     def _parse_contents(self) -> None:
106 116
         logger.info('Parsing IV, payload and signature from the file contents')
107 117
 
108
-        if len(self.contents) < self.iv_size + 16 + self.mac_size:
109
-            msg = 'File contents are too small to parse'
118
+        if len(self._contents) < self.iv_size + 16 + self.mac_size:
119
+            msg = 'Invalid vault configuration file: file is truncated'
110 120
             raise ValueError(msg)
111 121
 
112
-        cutpos1 = self.iv_size
113
-        cutpos2 = len(self.contents) - self.mac_size
122
+        def cut(buffer: bytes, cutpoint: int) -> tuple[bytes, bytes]:
123
+            return buffer[:cutpoint], buffer[cutpoint:]
124
+
125
+        cutpos1 = len(self._contents) - self.mac_size
126
+        cutpos2 = self.iv_size
114 127
 
115
-        self.message = self.contents[:cutpos2]
116
-        self.message_tag = self.contents[cutpos2:]
117
-        self.iv = self.message[:cutpos1]
118
-        self.payload = self.message[cutpos1:]
128
+        self.message, self.message_tag = cut(self._contents, cutpos1)
129
+        self.iv, self.payload = cut(self.message, cutpos2)
119 130
 
120 131
         logger.debug(
121 132
             'buffer %s = [[%s, %s], %s]',
122
-            _h(self.contents),
133
+            _h(self._contents),
123 134
             _h(self.iv),
124 135
             _h(self.payload),
125 136
             _h(self.message_tag),
... ...
@@ -130,10 +141,10 @@ class Reader(abc.ABC):
130 141
         self._generate_keys()
131 142
         assert (
132 143
             len(self.encryption_key) == self.encryption_key_size
133
-        ), 'Derived encryption key is not valid'
144
+        ), 'Derived encryption key is invalid'
134 145
         assert (
135 146
             len(self.signing_key) == self.signing_key_size
136
-        ), 'Derived signing key is not valid'
147
+        ), 'Derived signing key is invalid'
137 148
 
138 149
     @abc.abstractmethod
139 150
     def _generate_keys(self) -> None:
... ...
@@ -152,14 +163,14 @@ class Reader(abc.ABC):
152 163
         try:
153 164
             mac.verify(self.message_tag)
154 165
         except crypt_exceptions.InvalidSignature:
155
-            msg = 'File does not contain a valid HMAC-SHA256 signature'
166
+            msg = 'File does not contain a valid signature'
156 167
             raise ValueError(msg) from None
157 168
 
158 169
     @abc.abstractmethod
159 170
     def _hmac_input(self) -> bytes:
160 171
         raise AssertionError
161 172
 
162
-    def _decrypt_payload(self) -> None:
173
+    def _decrypt_payload(self) -> Any:
163 174
         decryptor = self._make_decryptor()
164 175
         padded_plaintext = bytearray()
165 176
         padded_plaintext.extend(decryptor.update(self.payload))
... ...
@@ -170,14 +181,14 @@ class Reader(abc.ABC):
170 181
         plaintext.extend(unpadder.update(padded_plaintext))
171 182
         plaintext.extend(unpadder.finalize())
172 183
         logger.debug('plaintext = %s', _h(plaintext))
173
-        self._data = json.loads(plaintext)
184
+        return json.loads(plaintext)
174 185
 
175 186
     @abc.abstractmethod
176 187
     def _make_decryptor(self) -> ciphers.CipherContext:
177 188
         raise AssertionError
178 189
 
179 190
 
180
-class V03Reader(Reader):
191
+class VaultNativeV03ConfigParser(VaultNativeConfigParser):
181 192
     KEY_SIZE = 32
182 193
 
183 194
     def __init__(self, *args: Any, **kwargs: Any) -> None:
... ...
@@ -185,13 +196,15 @@ class V03Reader(Reader):
185 196
         self.iv_size = 16
186 197
         self.mac_size = 32
187 198
 
188
-    def run(self) -> Any:
199
+    def __call__(self) -> Any:
200
+        if self._data is self._sentinel:
189 201
             logger.info('Attempting to parse as v0.3 configuration')
190
-        return super().run()
202
+            return super().__call__()
203
+        return self._data
191 204
 
192 205
     def _generate_keys(self) -> None:
193
-        self.encryption_key = self.pbkdf2(self.password, self.KEY_SIZE, 100)
194
-        self.signing_key = self.pbkdf2(self.password, self.KEY_SIZE, 200)
206
+        self.encryption_key = self.pbkdf2(self._password, self.KEY_SIZE, 100)
207
+        self.signing_key = self.pbkdf2(self._password, self.KEY_SIZE, 200)
195 208
         self.encryption_key_size = self.signing_key_size = self.KEY_SIZE
196 209
 
197 210
     def _hmac_input(self) -> bytes:
... ...
@@ -203,15 +216,17 @@ class V03Reader(Reader):
203 216
         ).decryptor()
204 217
 
205 218
 
206
-class V02Reader(Reader):
219
+class VaultNativeV02ConfigParser(VaultNativeConfigParser):
207 220
     def __init__(self, *args: Any, **kwargs: Any) -> None:
208 221
         super().__init__(*args, **kwargs)
209 222
         self.iv_size = 16
210 223
         self.mac_size = 64
211 224
 
212
-    def run(self) -> Any:
225
+    def __call__(self) -> Any:
226
+        if self._data is self._sentinel:
213 227
             logger.info('Attempting to parse as v0.2 configuration')
214
-        return super().run()
228
+            return super().__call__()
229
+        return self._data
215 230
 
216 231
     def _parse_contents(self) -> None:
217 232
         super()._parse_contents()
... ...
@@ -220,8 +235,8 @@ class V02Reader(Reader):
220 235
         self.message_tag = bytes.fromhex(self.message_tag.decode('ASCII'))
221 236
 
222 237
     def _generate_keys(self) -> None:
223
-        self.encryption_key = self.pbkdf2(self.password, 8, 16)
224
-        self.signing_key = self.pbkdf2(self.password, 16, 16)
238
+        self.encryption_key = self.pbkdf2(self._password, 8, 16)
239
+        self.signing_key = self.pbkdf2(self._password, 16, 16)
225 240
         self.encryption_key_size = 8
226 241
         self.signing_key_size = 16
227 242
 
... ...
@@ -229,13 +244,12 @@ class V02Reader(Reader):
229 244
         return base64.standard_b64encode(self.message)
230 245
 
231 246
     def _make_decryptor(self) -> ciphers.CipherContext:
232
-        def evp_bytestokey_md5_one_iteration(
233
-            data: bytes, salt: bytes | None, key_size: int, iv_size: int
247
+        def evp_bytestokey_md5_one_iteration_no_salt(
248
+            data: bytes, key_size: int, iv_size: int
234 249
         ) -> tuple[bytes, bytes]:
235 250
             total_size = key_size + iv_size
236 251
             buffer = bytearray()
237 252
             last_block = b''
238
-            if salt is None:
239 253
             salt = b''
240 254
             logging.debug(
241 255
                 (
... ...
@@ -271,8 +285,8 @@ class V02Reader(Reader):
271 285
             return bytes(buffer[:key_size]), bytes(buffer[key_size:total_size])
272 286
 
273 287
         data = base64.standard_b64encode(self.iv + self.encryption_key)
274
-        encryption_key, iv = evp_bytestokey_md5_one_iteration(
275
-            data, salt=None, key_size=32, iv_size=16
288
+        encryption_key, iv = evp_bytestokey_md5_one_iteration_no_salt(
289
+            data, key_size=32, iv_size=16
276 290
         )
277 291
         return ciphers.Cipher(
278 292
             algorithms.AES256(encryption_key), modes.CBC(iv)
... ...
@@ -287,7 +301,7 @@ if __name__ == '__main__':
287 301
         contents = base64.standard_b64decode(infile.read())
288 302
     password = exporter.get_vault_key()
289 303
     try:
290
-        config = V03Reader(contents, password).run()
304
+        config = VaultNativeV03ConfigParser(contents, password)()
291 305
     except ValueError:
292
-        config = V02Reader(contents, password).run()
306
+        config = VaultNativeV02ConfigParser(contents, password)()
293 307
     print(json.dumps(config, indent=2, sort_keys=True))  # noqa: T201
... ...
@@ -4,6 +4,7 @@
4 4
 
5 5
 from __future__ import annotations
6 6
 
7
+import base64
7 8
 import json
8 9
 from typing import TYPE_CHECKING
9 10
 
... ...
@@ -11,11 +12,12 @@ import click.testing
11 12
 import pytest
12 13
 
13 14
 import tests
14
-from derivepassphrase.exporter import cli, storeroom
15
+from derivepassphrase.exporter import cli, storeroom, vault_v03_and_below
15 16
 
16 17
 cryptography = pytest.importorskip('cryptography', minversion='38.0')
17 18
 
18 19
 if TYPE_CHECKING:
20
+    from collections.abc import Callable
19 21
     from typing import Any
20 22
 
21 23
 
... ...
@@ -303,3 +305,66 @@ class TestStoreroom:
303 305
             pytest.raises(RuntimeError, match='Object key mismatch'),
304 306
         ):
305 307
                 storeroom.export_storeroom_data()
308
+
309
+
310
+class TestVaultNativeConfig:
311
+    @pytest.mark.parametrize(
312
+        ['iterations', 'result'],
313
+        [
314
+            (100, b'6ede361e81e9c061efcdd68aeb768b80'),
315
+            (200, b'bcc7d01e075b9ffb69e702bf701187c1'),
316
+        ],
317
+    )
318
+    def test_200_pbkdf2_manually(self, iterations: int, result: bytes) -> None:
319
+        assert vault_v03_and_below.VaultNativeConfigParser.pbkdf2(tests.VAULT_MASTER_KEY.encode('utf-8'), 32, iterations) == result
320
+
321
+    @pytest.mark.parametrize(
322
+        ['parser_class', 'config', 'result'],
323
+        [
324
+            pytest.param(
325
+                vault_v03_and_below.VaultNativeV02ConfigParser,
326
+                tests.VAULT_V02_CONFIG,
327
+                tests.VAULT_V02_CONFIG_DATA,
328
+                id='0.2',
329
+            ),
330
+            pytest.param(
331
+                vault_v03_and_below.VaultNativeV03ConfigParser,
332
+                tests.VAULT_V03_CONFIG,
333
+                tests.VAULT_V03_CONFIG_DATA,
334
+                id='0.3',
335
+            ),
336
+        ],
337
+    )
338
+    def test_300_result_caching(
339
+        self,
340
+        monkeypatch: pytest.MonkeyPatch,
341
+        parser_class: type[vault_v03_and_below.VaultNativeConfigParser],
342
+        config: str,
343
+        result: dict[str, Any],
344
+    ) -> None:
345
+
346
+        def null_func(name: str) -> Callable[..., None]:
347
+            def func(*_args: Any, **_kwargs: Any) -> None:  # pragma: no cover
348
+                msg = f'disallowed and stubbed out function {name} called'
349
+                raise AssertionError(msg)
350
+            return func
351
+
352
+        runner = click.testing.CliRunner(mix_stderr=False)
353
+        with tests.isolated_vault_exporter_config(
354
+            monkeypatch=monkeypatch,
355
+            runner=runner,
356
+            vault_config=config,
357
+        ):
358
+            parser = parser_class(base64.b64decode(config), tests.VAULT_MASTER_KEY)
359
+            assert parser() == result
360
+            # Now stub out all functions used to calculate the above result.
361
+            monkeypatch.setattr(parser, '_parse_contents', null_func('_parse_contents'))
362
+            monkeypatch.setattr(parser, '_derive_keys', null_func('_derive_keys'))
363
+            monkeypatch.setattr(parser, '_check_signature', null_func('_check_signature'))
364
+            monkeypatch.setattr(parser, '_decrypt_payload', null_func('_decrypt_payload'))
365
+            assert parser() == result
366
+            assert vault_v03_and_below.VaultNativeConfigParser.__call__(parser) == result
367
+
368
+    def test_400_no_password(self) -> None:
369
+        with pytest.raises(ValueError, match='Password must not be empty'):
370
+            vault_v03_and_below.VaultNativeV03ConfigParser(b'', b'')
306 371