Overhaul the validation function for vault(1) configurations
Marco Ricci

Marco Ricci commited on 2024-09-30 10:25:58
Zeige 3 geänderte Dateien mit 316 Einfügungen und 24 Löschungen.


Rewrite `derivepassphrase._types.is_vault_config` into a proper,
validation function `validate_vault_config` that throws errors, and
optionally disallows extension or unknown settings.  The old
`is_vault_config` function is then implemented in terms of the new
function.

Use this opportunity to change the return annotation to
`typing_extensions.TypeIs`, because that is what was semantically
intended anyway.

Naturally, throwing actual errors instead of returning `False` means
that the error handling gets bulkier and more fine-grained.  Which in
turn means that extra tests are necessary to stay at high test coverage
levels.
... ...
@@ -0,0 +1,13 @@
1
+### Changed
2
+
3
+  - Rewrite functionality for checking for valid vault(1) configurations:
4
+    include an actual validation function which throws errors upon
5
+    encountering format violations, and which allows specifying which types
6
+    of extensions (unknown settings, `derivepassphrase`-only settings) to
7
+    tolerate during validation.
8
+
9
+    This is a **breaking API change** because the function return annotation
10
+    changed, from [`typing.TypeGuard`][] to [`typing_extensions.TypeIs`][].
11
+    These were the originally intended semantics, but when
12
+    `derivepassphrase` was first designed, the Python type system did not
13
+    support this kind of partial type narrowing.
... ...
@@ -7,15 +7,24 @@
7 7
 from __future__ import annotations
8 8
 
9 9
 import enum
10
-from typing import Literal, NamedTuple, TypeGuard
10
+from typing import TYPE_CHECKING
11 11
 
12 12
 from typing_extensions import (
13
-    Any,
13
+    NamedTuple,
14 14
     NotRequired,
15
-    Required,
16 15
     TypedDict,
17 16
 )
18 17
 
18
+if TYPE_CHECKING:
19
+    from collections.abc import Sequence
20
+    from typing import Literal
21
+
22
+    from typing_extensions import (
23
+        Any,
24
+        Required,
25
+        TypeIs,
26
+    )
27
+
19 28
 __all__ = (
20 29
     'SSH_AGENT',
21 30
     'SSH_AGENTC',
... ...
@@ -126,46 +135,182 @@ class VaultConfig(TypedDict, _VaultConfig, total=False):
126 135
     services: Required[dict[str, VaultConfigServicesSettings]]
127 136
 
128 137
 
129
-def is_vault_config(obj: Any) -> TypeGuard[VaultConfig]:  # noqa: ANN401,C901,PLR0911,PLR0912
130
-    """Check if `obj` is a valid vault config, according to typing.
138
+def validate_vault_config(  # noqa: C901,PLR0912,PLR0915
139
+    obj: Any,  # noqa: ANN401
140
+    /,
141
+    *,
142
+    allow_unknown_settings: bool = False,
143
+    allow_derivepassphrase_extensions: bool = False,
144
+) -> None:
145
+    """Check that `obj` is a valid vault config.
131 146
 
132 147
     Args:
133
-        obj: The object to test.
134
-
135
-    Returns:
136
-        True if this is a vault config, false otherwise.
148
+        obj:
149
+            The object to test.
150
+        allow_unknown_settings:
151
+            If false, abort on unknown settings.
152
+        allow_derivepassphrase_extensions:
153
+            If true, allow `derivepassphrase` extensions.
154
+
155
+    Raises:
156
+        TypeError:
157
+            An entry in the vault config, or the vault config itself,
158
+            has the wrong type.
159
+        ValueError:
160
+            An entry in the vault config is not allowed, or has a
161
+            disallowed value.
137 162
 
138 163
     """
164
+
165
+    def as_json_path_string(json_path: Sequence[str], /) -> str:
166
+        return ''.join('.' + repr(x) for x in json_path)
167
+
168
+    err_obj_not_a_dict = 'vault config is not a dict'
169
+    err_non_str_service_name = (
170
+        'vault config contains non-string service name {!r}'
171
+    )
172
+
173
+    def err_not_a_dict(json_path: Sequence[str], /) -> str:
174
+        json_path_str = as_json_path_string(json_path)
175
+        return f'vault config entry {json_path_str} is not a dict'
176
+
177
+    def err_not_a_string(json_path: Sequence[str], /) -> str:
178
+        json_path_str = as_json_path_string(json_path)
179
+        return f'vault config entry {json_path_str} is not a string'
180
+
181
+    def err_not_an_int(json_path: Sequence[str], /) -> str:
182
+        json_path_str = as_json_path_string(json_path)
183
+        return f'vault config entry {json_path_str} is not an integer'
184
+
185
+    err_key_and_phrase = (
186
+        '"key" and "phrase" specified on the same vault config level'
187
+    )
188
+
189
+    def err_derivepassphrase_extension(
190
+        key: str, json_path: Sequence[str], /
191
+    ) -> str:
192
+        json_path_str = as_json_path_string(json_path)
193
+        return (
194
+            f'vault config entry {json_path_str} uses '
195
+            f'`derivepassphrase` extension {key!r}'
196
+        )
197
+
198
+    def err_unknown_setting(key: str, json_path: Sequence[str], /) -> str:
199
+        json_path_str = as_json_path_string(json_path)
200
+        return (
201
+            f'vault config entry {json_path_str} uses '
202
+            f'unknown setting {key!r}'
203
+        )
204
+
205
+    def err_bad_number(
206
+        key: str,
207
+        json_path: Sequence[str],
208
+        /,
209
+        *,
210
+        strictly_positive: bool = False,
211
+    ) -> str:
212
+        json_path_str = as_json_path_string((*json_path, key))
213
+        return f'vault config entry {json_path_str} is ' + (
214
+            'not positive' if strictly_positive else 'negative'
215
+        )
216
+
139 217
     if not isinstance(obj, dict):
140
-        return False
218
+        raise TypeError(err_obj_not_a_dict)
141 219
     if 'global' in obj:
142 220
         o_global = obj['global']
143 221
         if not isinstance(o_global, dict):
144
-            return False
145
-        for key in ('key', 'phrase', 'unicode_normalization_form'):
146
-            if key in o_global and not isinstance(o_global[key], str):
147
-                return False
222
+            raise TypeError(err_not_a_dict(['global']))
223
+        for key, value in o_global.items():
224
+            match key:
225
+                case 'key' | 'phrase':
226
+                    if not isinstance(value, str):
227
+                        raise TypeError(err_not_a_dict(['global', key]))
228
+                case 'unicode_normalization_form':
229
+                    if not isinstance(value, str):
230
+                        raise TypeError(err_not_a_dict(['global', key]))
231
+                    if not allow_derivepassphrase_extensions:
232
+                        raise ValueError(
233
+                            err_derivepassphrase_extension(key, ('global',))
234
+                        )
235
+                case _ if not allow_unknown_settings:
236
+                    raise ValueError(err_unknown_setting(key, ('global',)))
148 237
         if 'key' in o_global and 'phrase' in o_global:
149
-            return False
238
+            raise ValueError(err_key_and_phrase)
150 239
     if not isinstance(obj.get('services'), dict):
151
-        return False
240
+        raise TypeError(err_not_a_dict(['services']))
152 241
     for sv_name, service in obj['services'].items():
153 242
         if not isinstance(sv_name, str):
154
-            return False
243
+            raise TypeError(err_non_str_service_name.format(sv_name))
155 244
         if not isinstance(service, dict):
156
-            return False
245
+            raise TypeError(err_not_a_dict(['services', sv_name]))
157 246
         for key, value in service.items():
158 247
             match key:
159 248
                 case 'notes' | 'phrase' | 'key':
160 249
                     if not isinstance(value, str):
161
-                        return False
250
+                        raise TypeError(
251
+                            err_not_a_string(['services', sv_name, key])
252
+                        )
162 253
                 case 'length':
163
-                    if not isinstance(value, int) or value < 1:
164
-                        return False
165
-                case _:
166
-                    if not isinstance(value, int) or value < 0:
167
-                        return False
254
+                    if not isinstance(value, int):
255
+                        raise TypeError(
256
+                            err_not_an_int(['services', sv_name, key])
257
+                        )
258
+                    if value < 1:
259
+                        raise ValueError(
260
+                            err_bad_number(
261
+                                key,
262
+                                ['services', sv_name],
263
+                                strictly_positive=True,
264
+                            )
265
+                        )
266
+                case (
267
+                    'repeat'
268
+                    | 'lower'
269
+                    | 'upper'
270
+                    | 'number'
271
+                    | 'space'
272
+                    | 'dash'
273
+                    | 'symbol'
274
+                ):
275
+                    if not isinstance(value, int):
276
+                        raise TypeError(
277
+                            err_not_an_int(['services', sv_name, key])
278
+                        )
279
+                    if value < 0:
280
+                        raise ValueError(
281
+                            err_bad_number(
282
+                                key,
283
+                                ['services', sv_name],
284
+                                strictly_positive=False,
285
+                            )
286
+                        )
287
+                case _ if not allow_unknown_settings:
288
+                    raise ValueError(
289
+                        err_unknown_setting(key, ['services', sv_name])
290
+                    )
168 291
         if 'key' in service and 'phrase' in service:
292
+            raise ValueError(err_key_and_phrase)
293
+
294
+
295
+def is_vault_config(obj: Any) -> TypeIs[VaultConfig]:  # noqa: ANN401
296
+    """Check if `obj` is a valid vault config, according to typing.
297
+
298
+    Args:
299
+        obj: The object to test.
300
+
301
+    Returns:
302
+        True if this is a vault config, false otherwise.
303
+
304
+    """
305
+    try:
306
+        validate_vault_config(
307
+            obj,
308
+            allow_unknown_settings=True,
309
+            allow_derivepassphrase_extensions=True,
310
+        )
311
+    except (TypeError, ValueError) as exc:
312
+        if 'vault config ' not in str(exc):  # pragma: no cover
313
+            raise  # noqa: DOC501
169 314
         return False
170 315
     return True
171 316
 
... ...
@@ -44,6 +44,10 @@ from derivepassphrase import _types
44 44
             {'services': {'sv': {'length': -10}}},
45 45
             'bad config value: services.sv.length',
46 46
         ),
47
+        (
48
+            {'services': {'sv': {'lower': '10'}}},
49
+            'bad config value: services.sv.lower',
50
+        ),
47 51
         (
48 52
             {'services': {'sv': {'upper': -10}}},
49 53
             'bad config value: services.sv.upper',
... ...
@@ -86,6 +90,48 @@ from derivepassphrase import _types
86 90
             },
87 91
             '',
88 92
         ),
93
+        (
94
+            {
95
+                'global': {'key': '...', 'unicode_normalization_form': 'NFC'},
96
+                'services': {
97
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
98
+                    'sv2': {'length': 10, 'repeat': 1, 'lower': 1},
99
+                },
100
+            },
101
+            '',
102
+        ),
103
+        (
104
+            {
105
+                'global': {'key': '...', 'unicode_normalization_form': None},
106
+                'services': {},
107
+            },
108
+            'bad config value: global.unicode_normalization_form',
109
+        ),
110
+        (
111
+            {
112
+                'global': {'key': '...', 'unknown_key': None},
113
+                'services': {
114
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
115
+                    'sv2': {'length': 10, 'repeat': 1, 'lower': 1},
116
+                },
117
+            },
118
+            '',
119
+        ),
120
+        (
121
+            {
122
+                'global': {'key': '...', 'unicode_normalization_form': 'NFC'},
123
+                'services': {
124
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
125
+                    'sv2': {
126
+                        'length': 10,
127
+                        'repeat': 1,
128
+                        'lower': 1,
129
+                        'unknown_key': None,
130
+                    },
131
+                },
132
+            },
133
+            '',
134
+        ),
89 135
     ],
90 136
 )
91 137
 def test_200_is_vault_config(obj: Any, comment: str) -> None:
... ...
@@ -95,3 +141,91 @@ def test_200_is_vault_config(obj: Any, comment: str) -> None:
95 141
         if comment
96 142
         else 'failed on valid example'
97 143
     )
144
+
145
+
146
+@pytest.mark.parametrize(
147
+    ['obj', 'allow_unknown_settings', 'allow_derivepassphrase_extensions'],
148
+    [
149
+        (
150
+            {
151
+                'global': {'key': '...', 'unicode_normalization_form': 'NFC'},
152
+                'services': {
153
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
154
+                    'sv2': {'length': 10, 'repeat': 1, 'lower': 1},
155
+                },
156
+            },
157
+            False,
158
+            False,
159
+        ),
160
+        (
161
+            {
162
+                'global': {'key': '...', 'unknown_key': None},
163
+                'services': {
164
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
165
+                    'sv2': {'length': 10, 'repeat': 1, 'lower': 1},
166
+                },
167
+            },
168
+            False,
169
+            False,
170
+        ),
171
+        (
172
+            {
173
+                'global': {'key': '...', 'unicode_normalization_form': 'NFC'},
174
+                'services': {
175
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
176
+                    'sv2': {
177
+                        'length': 10,
178
+                        'repeat': 1,
179
+                        'lower': 1,
180
+                        'unknown_key': None,
181
+                    },
182
+                },
183
+            },
184
+            False,
185
+            False,
186
+        ),
187
+        (
188
+            {
189
+                'global': {'key': '...', 'unicode_normalization_form': 'NFC'},
190
+                'services': {
191
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
192
+                    'sv2': {
193
+                        'length': 10,
194
+                        'repeat': 1,
195
+                        'lower': 1,
196
+                        'unknown_key': None,
197
+                    },
198
+                },
199
+            },
200
+            False,
201
+            True,
202
+        ),
203
+        (
204
+            {
205
+                'global': {'key': '...', 'unicode_normalization_form': 'NFC'},
206
+                'services': {
207
+                    'sv1': {'phrase': 'abc', 'length': 10, 'upper': 1},
208
+                    'sv2': {
209
+                        'length': 10,
210
+                        'repeat': 1,
211
+                        'lower': 1,
212
+                        'unknown_key': None,
213
+                    },
214
+                },
215
+            },
216
+            True,
217
+            False,
218
+        ),
219
+    ],
220
+)
221
+def test_400_validate_vault_config(
222
+    obj: Any,
223
+    allow_unknown_settings: bool,
224
+    allow_derivepassphrase_extensions: bool,
225
+) -> None:
226
+    with pytest.raises((TypeError, ValueError), match='vault config '):
227
+        _types.validate_vault_config(
228
+            obj,
229
+            allow_unknown_settings=allow_unknown_settings,
230
+            allow_derivepassphrase_extensions=allow_derivepassphrase_extensions,
231
+        )
98 232