Lock the derivepassphrase CLI against concurrent updating
Marco Ricci

Marco Ricci commited on 2025-03-02 16:27:06
Zeige 3 geänderte Dateien mit 275 Einfügungen und 7 Löschungen.


In a very coarse manner, after command-line parsing, detect whether the
CLI operation is a read-write operation or a read-only one, and if it is
read-write, run the whole operation while holding a lock.

The lock is held on a temporary file whose basename is dependent on the
full path to the `derivepassphrase` configuration directory.  The file
is stored in the system temporary directory if possible, and will be
synchronized using `msvcrt.locking` or `fcntl.flock`, whichever is
appropriate.

Coverage checking for the Windows-specific application code is disabled
because we have no Windows runners for the test suite.  The test suite
itself merely adds tests for the temporary directory function, but is
otherwise unchanged.
... ...
@@ -16,8 +16,10 @@ Warning:
16 16
 from __future__ import annotations
17 17
 
18 18
 import base64
19
+import contextlib
19 20
 import copy
20 21
 import enum
22
+import hashlib
21 23
 import json
22 24
 import logging
23 25
 import os
... ...
@@ -159,6 +161,7 @@ def shell_complete_service(
159 161
 
160 162
 config_filename_table = {
161 163
     None: '.',
164
+    'write lock': '',
162 165
     'vault': 'vault.json',
163 166
     'user configuration': 'config.toml',
164 167
     # TODO(the-13th-letter): Remove the old settings.json file.
... ...
@@ -167,6 +170,129 @@ config_filename_table = {
167 170
     'notes backup': 'old-notes.txt',
168 171
 }
169 172
 
173
+LOCK_SIZE = 4096
174
+"""
175
+The size of the record to lock at the beginning of the file, for locking
176
+implementations that lock byte ranges instead of whole files.
177
+
178
+While POSIX specifies that [`fcntl`][] locks shall support a size of zero to
179
+denote "any conceivable file size", the locking system available in
180
+[`msvcrt`][] does not support this, and requires an explicit size.
181
+"""
182
+
183
+
184
+@contextlib.contextmanager
185
+def configuration_mutex() -> Iterator[None]:
186
+    """Enter a mutually exclusive context for configuration writes.
187
+
188
+    Within this context, no other cooperating instance of
189
+    `derivepassphrase` will attempt to write to its configuration
190
+    directory.  We achieve this by locking a specific temporary file
191
+    (whose name depends on the location of the configuration directory)
192
+    for the duration of the context.
193
+
194
+    Note: Locking specifics
195
+        The directory for the lock file is determined via
196
+        [`get_tempdir`][].  The lock filename is
197
+        `derivepassphrase-lock-<hash>.txt`, where `<hash>` is computed
198
+        as follows.  First, canonicalize the path to the configuration
199
+        directory with [`pathlib.Path.resolve`][].  Then encode the
200
+        result as per the filesystem encoding ([`os.fsencode`][]), and
201
+        hash it with SHA256.  Finally, convert the result to standard
202
+        base32 and use the first twelve characters, in lowercase, as
203
+        `<hash>`.
204
+
205
+        We use [`msvcrt.locking`][] on Windows platforms (`sys.platform
206
+        == "win32"`) and [`fcntl.flock`][] on all others.  All locks are
207
+        exclusive locks.  If the locking system requires a byte range,
208
+        we lock the first [`LOCK_SIZE`][] bytes.  For maximum
209
+        portability between locking implementations, we first open the
210
+        lock file for writing, which is sometimes necessary to lock
211
+        a file exclusively.  Thus locking will fail if we lack
212
+        permission to write to an already-existing lockfile.
213
+
214
+    """
215
+    lock_func: Callable[[int], None]
216
+    unlock_func: Callable[[int], None]
217
+    if sys.platform == 'win32':  # pragma: no cover
218
+        import msvcrt  # noqa: PLC0415
219
+
220
+        locking = msvcrt.locking
221
+        LK_LOCK = msvcrt.LK_LOCK  # noqa: N806
222
+        LK_UNLCK = msvcrt.LK_UNLCK  # noqa: N806
223
+
224
+        def lock_func(fileobj: int) -> None:
225
+            locking(fileobj, LK_LOCK, LOCK_SIZE)
226
+
227
+        def unlock_func(fileobj: int) -> None:
228
+            locking(fileobj, LK_UNLCK, LOCK_SIZE)
229
+
230
+    else:
231
+        import fcntl  # noqa: PLC0415
232
+
233
+        flock = fcntl.flock
234
+        LOCK_EX = fcntl.LOCK_EX  # noqa: N806
235
+        LOCK_UN = fcntl.LOCK_UN  # noqa: N806
236
+
237
+        def lock_func(fileobj: int) -> None:
238
+            flock(fileobj, LOCK_EX)
239
+
240
+        def unlock_func(fileobj: int) -> None:
241
+            flock(fileobj, LOCK_UN)
242
+
243
+    write_lock_file = config_filename('write lock')
244
+    write_lock_file.touch()
245
+    with write_lock_file.open('wb') as lock_fileobj:
246
+        lock_func(lock_fileobj.fileno())
247
+        try:
248
+            yield
249
+        finally:
250
+            unlock_func(lock_fileobj.fileno())
251
+
252
+
253
+def get_tempdir() -> pathlib.Path:
254
+    """Return a suitable temporary directory.
255
+
256
+    We implement the same algorithm as [`tempfile.gettempdir`][], except
257
+    that we default to the `derivepassphrase` configuration directory
258
+    instead of the current directory if no other choice is suitable, and
259
+    that we return [`pathlib.Path`][] objects directly.
260
+
261
+    """
262
+    paths_to_try: list[pathlib.PurePath] = []
263
+    env_paths_to_try = [
264
+        os.getenv('TMPDIR'),
265
+        os.getenv('TEMP'),
266
+        os.getenv('TMP'),
267
+    ]
268
+    paths_to_try.extend(
269
+        pathlib.PurePath(p) for p in env_paths_to_try if p is not None
270
+    )
271
+    posix_paths_to_try = [
272
+        pathlib.PurePosixPath('/tmp'),  # noqa: S108
273
+        pathlib.PurePosixPath('/var/tmp'),  # noqa: S108
274
+        pathlib.PurePosixPath('/usr/tmp'),
275
+    ]
276
+    windows_paths_to_try = [
277
+        pathlib.PureWindowsPath(r'C:\TEMP'),
278
+        pathlib.PureWindowsPath(r'C:\TMP'),
279
+        pathlib.PureWindowsPath(r'\TEMP'),
280
+        pathlib.PureWindowsPath(r'\TMP'),
281
+    ]
282
+    paths_to_try.extend(
283
+        windows_paths_to_try if sys.platform == 'win32' else posix_paths_to_try
284
+    )
285
+    for p in paths_to_try:
286
+        path = pathlib.Path(p)
287
+        try:
288
+            points_to_dir = path.is_dir()
289
+        except OSError:
290
+            continue
291
+        else:
292
+            if points_to_dir:
293
+                return path.resolve(strict=True)
294
+    return config_filename(subsystem=None)
295
+
170 296
 
171 297
 def config_filename(
172 298
     subsystem: str | None = 'old settings.json',
... ...
@@ -201,6 +327,14 @@ def config_filename(
201 327
         os.getenv(PROG_NAME.upper() + '_PATH')
202 328
         or click.get_app_dir(PROG_NAME, force_posix=True)
203 329
     )
330
+    if subsystem == 'write lock':
331
+        path_hash = base64.b32encode(
332
+            hashlib.sha256(os.fsencode(path.resolve())).digest()
333
+        )
334
+        path_hash_text = path_hash[:12].lower().decode('ASCII')
335
+        temp_path = get_tempdir()
336
+        filename_ = f'derivepassphrase-lock-{path_hash_text}.txt'
337
+        return temp_path / filename_
204 338
     try:
205 339
         filename = config_filename_table[subsystem]
206 340
     except (KeyError, TypeError):  # pragma: no cover
... ...
@@ -10,6 +10,7 @@ from __future__ import annotations
10 10
 
11 11
 import base64
12 12
 import collections
13
+import contextlib
13 14
 import functools
14 15
 import json
15 16
 import logging
... ...
@@ -34,6 +35,7 @@ from derivepassphrase._internals import cli_messages as _msg
34 35
 
35 36
 if TYPE_CHECKING:
36 37
     from collections.abc import (
38
+        Callable,
37 39
         Sequence,
38 40
     )
39 41
 
... ...
@@ -1029,7 +1031,21 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
1029 1031
             extra={'color': ctx.color},
1030 1032
         )
1031 1033
 
1032
-    if delete_service_settings:  # noqa: PLR1702
1034
+    readwrite_ops = [
1035
+        delete_service_settings,
1036
+        delete_globals,
1037
+        clear_all_settings,
1038
+        import_settings,
1039
+        store_config_only,
1040
+    ]
1041
+    mutex: Callable[[], contextlib.AbstractContextManager[None]] = (
1042
+        cli_helpers.configuration_mutex
1043
+        if any(readwrite_ops)
1044
+        else contextlib.nullcontext
1045
+    )
1046
+
1047
+    with mutex():  # noqa: PLR1702
1048
+        if delete_service_settings:
1033 1049
             assert service is not None
1034 1050
             configuration = get_config()
1035 1051
             if service in configuration['services']:
... ...
@@ -1231,7 +1247,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
1231 1247
                             this_ctx.parent is not None
1232 1248
                             and this_ctx.parent.info_name is not None
1233 1249
                         ):
1234
-                        prog_name_pieces.appendleft(this_ctx.parent.info_name)
1250
+                            prog_name_pieces.appendleft(
1251
+                                this_ctx.parent.info_name
1252
+                            )
1235 1253
                             this_ctx = this_ctx.parent
1236 1254
                         cli_helpers.print_config_as_sh_script(
1237 1255
                             configuration,
... ...
@@ -1281,7 +1299,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
1281 1299
                 },
1282 1300
                 cast(
1283 1301
                     'dict[str, Any]',
1284
-                configuration['services'].get(service, {}) if service else {},
1302
+                    configuration['services'].get(service, {})
1303
+                    if service
1304
+                    else {},
1285 1305
                 ),
1286 1306
                 cast('dict[str, Any]', configuration.get('global', {})),
1287 1307
             )
... ...
@@ -1356,7 +1376,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
1356 1376
                 view = (
1357 1377
                     collections.ChainMap(*settings.maps[:2])
1358 1378
                     if service
1359
-                else collections.ChainMap(settings.maps[0], settings.maps[2])
1379
+                    else collections.ChainMap(
1380
+                        settings.maps[0], settings.maps[2]
1381
+                    )
1360 1382
                 )
1361 1383
                 if use_key:
1362 1384
                     view['key'] = key
... ...
@@ -1446,7 +1468,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
1446 1468
                             old_notes_value,
1447 1469
                         ])
1448 1470
                     else:
1449
-                    text = old_notes_value or str(notes_legacy_instructions)
1471
+                        text = old_notes_value or str(
1472
+                            notes_legacy_instructions
1473
+                        )
1450 1474
                     notes_value = click.edit(text=text, require_save=False)
1451 1475
                     assert notes_value is not None
1452 1476
                     if (
... ...
@@ -1456,7 +1480,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
1456 1480
                         backup_file = cli_helpers.config_filename(
1457 1481
                             subsystem='notes backup'
1458 1482
                         )
1459
-                    backup_file.write_text(old_notes_value, encoding='UTF-8')
1483
+                        backup_file.write_text(
1484
+                            old_notes_value, encoding='UTF-8'
1485
+                        )
1460 1486
                         logger.warning(
1461 1487
                             _msg.TranslatedString(
1462 1488
                                 _msg.WarnMsgTemplate.LEGACY_EDITOR_INTERFACE_NOTES_BACKUP,
... ...
@@ -1539,7 +1565,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
1539 1565
                     click.echo(f'{service_notes}\n', err=True, color=ctx.color)
1540 1566
                 click.echo(result.decode('ASCII'), color=ctx.color)
1541 1567
                 if not print_notes_before and service_notes.strip():
1542
-                click.echo(f'\n{service_notes}\n', err=True, color=ctx.color)
1568
+                    click.echo(
1569
+                        f'\n{service_notes}\n', err=True, color=ctx.color
1570
+                    )
1543 1571
 
1544 1572
 
1545 1573
 if __name__ == '__main__':
... ...
@@ -18,6 +18,7 @@ import re
18 18
 import shlex
19 19
 import shutil
20 20
 import socket
21
+import tempfile
21 22
 import textwrap
22 23
 import types
23 24
 import warnings
... ...
@@ -4886,6 +4887,111 @@ Boo.
4886 4887
         assert _types.is_vault_config(config)
4887 4888
         return self.export_as_sh_helper(config)
4888 4889
 
4890
+    @hypothesis.given(
4891
+        env_var=strategies.sampled_from(['TMPDIR', 'TEMP', 'TMP']),
4892
+        suffix=strategies.text(
4893
+            tuple(' 0123456789abcdefghijklmnopqrstuvwxyz'),
4894
+            min_size=12,
4895
+            max_size=12,
4896
+        ),
4897
+    )
4898
+    @hypothesis.example(env_var='', suffix='.')
4899
+    def test_140a_get_tempdir(
4900
+        self,
4901
+        env_var: str,
4902
+        suffix: str,
4903
+    ) -> None:
4904
+        """[`cli_helpers.get_tempdir`][] returns a temporary directory.
4905
+
4906
+        If it is not the same as the temporary directory determined by
4907
+        [`tempfile.gettempdir`][], then assert that
4908
+        `tempfile.gettempdir` returned the current directory and
4909
+        `cli_helpers.get_tempdir` returned the configuration directory.
4910
+
4911
+        """
4912
+        runner = click.testing.CliRunner(mix_stderr=False)
4913
+        # TODO(the-13th-letter): Rewrite using parenthesized
4914
+        # with-statements.
4915
+        # https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9
4916
+        with contextlib.ExitStack() as stack:
4917
+            monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
4918
+            stack.enter_context(
4919
+                tests.isolated_vault_config(
4920
+                    monkeypatch=monkeypatch,
4921
+                    runner=runner,
4922
+                    vault_config={'services': {}},
4923
+                )
4924
+            )
4925
+            monkeypatch.delenv('TMPDIR', raising=False)
4926
+            monkeypatch.delenv('TEMP', raising=False)
4927
+            monkeypatch.delenv('TMP', raising=False)
4928
+            if env_var:
4929
+                monkeypatch.setenv(env_var, str(pathlib.Path.cwd() / suffix))
4930
+            system_tempdir = os.fsdecode(tempfile.gettempdir())
4931
+            our_tempdir = cli_helpers.get_tempdir()
4932
+            assert system_tempdir == os.fsdecode(our_tempdir) or (
4933
+                # TODO(the-13th-letter): `tests.isolated_config`
4934
+                # guarantees that `Path.cwd() == config_filename(None)`.
4935
+                # So this sub-branch ought to never trigger in our
4936
+                # tests.
4937
+                system_tempdir == os.getcwd()  # noqa: PTH109
4938
+                and our_tempdir == cli_helpers.config_filename(subsystem=None)
4939
+            )
4940
+
4941
+    def test_140b_get_tempdir_force_default(self) -> None:
4942
+        """[`cli_helpers.get_tempdir`][] returns a temporary directory.
4943
+
4944
+        If all candidates are mocked to fail for the standard temporary
4945
+        directory choices, then we return the `derivepassphrase`
4946
+        configuration directory.
4947
+
4948
+        """
4949
+        runner = click.testing.CliRunner(mix_stderr=False)
4950
+        # TODO(the-13th-letter): Rewrite using parenthesized
4951
+        # with-statements.
4952
+        # https://the13thletter.info/derivepassphrase/latest/pycompatibility/#after-eol-py3.9
4953
+        with contextlib.ExitStack() as stack:
4954
+            monkeypatch = stack.enter_context(pytest.MonkeyPatch.context())
4955
+            stack.enter_context(
4956
+                tests.isolated_vault_config(
4957
+                    monkeypatch=monkeypatch,
4958
+                    runner=runner,
4959
+                    vault_config={'services': {}},
4960
+                )
4961
+            )
4962
+            monkeypatch.delenv('TMPDIR', raising=False)
4963
+            monkeypatch.delenv('TEMP', raising=False)
4964
+            monkeypatch.delenv('TMP', raising=False)
4965
+            config_dir = cli_helpers.config_filename(subsystem=None)
4966
+
4967
+            def is_dir_false(
4968
+                self: pathlib.Path,
4969
+                /,
4970
+                *,
4971
+                follow_symlinks: bool = False,
4972
+            ) -> bool:
4973
+                del self, follow_symlinks
4974
+                return False
4975
+
4976
+            def is_dir_error(
4977
+                self: pathlib.Path,
4978
+                /,
4979
+                *,
4980
+                follow_symlinks: bool = False,
4981
+            ) -> bool:
4982
+                del follow_symlinks
4983
+                raise OSError(
4984
+                    errno.EACCES,
4985
+                    os.strerror(errno.EACCES),
4986
+                    str(self),
4987
+                )
4988
+
4989
+            monkeypatch.setattr(pathlib.Path, 'is_dir', is_dir_false)
4990
+            assert cli_helpers.get_tempdir() == config_dir
4991
+
4992
+            monkeypatch.setattr(pathlib.Path, 'is_dir', is_dir_error)
4993
+            assert cli_helpers.get_tempdir() == config_dir
4994
+
4889 4995
     @Parametrize.DELETE_CONFIG_INPUT
4890 4996
     def test_203_repeated_config_deletion(
4891 4997
         self,
4892 4998