Add tests for explicit SSH agent socket provider selection
Marco Ricci

Marco Ricci commited on 2026-02-08 16:02:18
Zeige 2 geänderte Dateien mit 157 Einfügungen und 11 Löschungen.


Extend the vault CLI tests for basic SSH key usage to also cover the SSH
agent socket provider choice case.  We split up the monolithic `_test`
helper function into separate `_setup_environment` and `_check_result`
functions, so that we can share common setup code (including fixtures
and parametrizations), but use different result testing code.  We also
extend the `_setup_environment` code to handle main user configuration
mocking and SSH agent socket provider registry mocking as well, even if
the basic key tests don't use this functionality.
... ...
@@ -226,6 +226,7 @@ class Parametrize(types.SimpleNamespace):
226 226
                     "--vault-legacy-editor-interface",
227 227
                     "--print-notes-before",
228 228
                     "--print-notes-after",
229
+                    "--ssh-agent-socket-provider",
229 230
                 }),
230 231
                 id="derivepassphrase-vault",
231 232
             ),
... ...
@@ -8,6 +8,7 @@ from __future__ import annotations
8 8
 
9 9
 import contextlib
10 10
 import json
11
+import textwrap
11 12
 import types
12 13
 from typing import TYPE_CHECKING
13 14
 
... ...
@@ -18,6 +19,7 @@ from derivepassphrase import _types, cli, ssh_agent, vault
18 19
 from derivepassphrase._internals import (
19 20
     cli_helpers,
20 21
 )
22
+from derivepassphrase.ssh_agent import socketprovider
21 23
 from tests import data, machinery
22 24
 from tests.data import callables
23 25
 from tests.machinery import pytest as pytest_machinery
... ...
@@ -329,6 +331,16 @@ class Parametrize(types.SimpleNamespace):
329 331
     KEY_INDEX = pytest.mark.parametrize(
330 332
         "key_index", [1, 2, 3], ids=lambda i: f"index{i}"
331 333
     )
334
+    EXPLICIT_SSH_AGENT_SOCKET_PROVIDER = pytest.mark.parametrize(
335
+        ["on_command_line", "in_config"],
336
+        [
337
+            pytest.param(False, True, id="configuration_only"),
338
+            pytest.param(True, False, id="command_line_only"),
339
+            pytest.param(
340
+                True, True, id="command_line_overrides_configuration"
341
+            ),
342
+        ],
343
+    )
332 344
     VAULT_CHARSET_OPTION = pytest.mark.parametrize(
333 345
         "option",
334 346
         [
... ...
@@ -543,17 +555,20 @@ class TestPhraseBasic:
543 555
 class TestKeyBasic:
544 556
     """Tests for SSH key configuration: basic."""
545 557
 
546
-    def _test(
558
+    @contextlib.contextmanager
559
+    def _setup_environment(
547 560
         self,
548
-        command_line: list[str],
549 561
         /,
550 562
         *,
551 563
         config: _types.VaultConfig = {  # noqa: B006
552
-            "services": {DUMMY_SERVICE: DUMMY_CONFIG_SETTINGS}
564
+            "services": {
565
+                DUMMY_SERVICE: {**DUMMY_CONFIG_SETTINGS},
553 566
             },
554
-        multiline: bool = False,
555
-        input: str | bytes | None = None,
556
-    ) -> None:
567
+        },
568
+        main_config_str: str | None = None,
569
+        registry: dict[str, _types.SSHAgentSocketProvider | str | None]
570
+        | None = None,
571
+    ) -> Iterator[machinery.CliRunner]:
557 572
         runner = machinery.CliRunner(mix_stderr=False)
558 573
         # TODO(the-13th-letter): Rewrite using parenthesized
559 574
         # with-statements.
... ...
@@ -565,6 +580,7 @@ class TestKeyBasic:
565 580
                     monkeypatch=monkeypatch,
566 581
                     runner=runner,
567 582
                     vault_config=config,
583
+                    main_config_str=main_config_str,
568 584
                 )
569 585
             )
570 586
             monkeypatch.setattr(
... ...
@@ -577,12 +593,20 @@ class TestKeyBasic:
577 593
                 "phrase_from_key",
578 594
                 callables.phrase_from_key,
579 595
             )
580
-            result = runner.invoke(
581
-                cli.derivepassphrase_vault,
582
-                command_line,
583
-                input=input,
584
-                catch_exceptions=False,
596
+            if registry is not None:
597
+                monkeypatch.setattr(
598
+                    socketprovider.SocketProvider, "registry", registry
599
+                )
600
+                monkeypatch.setattr(
601
+                    ssh_agent.SSHAgentClient,
602
+                    "SOCKET_PROVIDERS",
603
+                    ("incorrect",),
585 604
                 )
605
+            yield runner
606
+
607
+    def _check_result(
608
+        self, result: machinery.ReadableResult, /, *, multiline: bool = False
609
+    ) -> None:
586 610
         if multiline:
587 611
             assert result.clean_exit(), "expected clean exit"
588 612
         else:
... ...
@@ -602,6 +626,31 @@ class TestKeyBasic:
602 626
             "expected known output"
603 627
         )
604 628
 
629
+    def _test(
630
+        self,
631
+        command_line: list[str],
632
+        /,
633
+        *,
634
+        config: _types.VaultConfig = {  # noqa: B006
635
+            "services": {DUMMY_SERVICE: DUMMY_CONFIG_SETTINGS}
636
+        },
637
+        main_config_str: str | None = None,
638
+        registry: dict[str, _types.SSHAgentSocketProvider | str | None]
639
+        | None = None,
640
+        multiline: bool = False,
641
+        input: str | bytes | None = None,
642
+    ) -> None:
643
+        with self._setup_environment(
644
+            config=config, main_config_str=main_config_str, registry=registry
645
+        ) as runner:
646
+            result = runner.invoke(
647
+                cli.derivepassphrase_vault,
648
+                command_line,
649
+                input=input,
650
+                catch_exceptions=False,
651
+            )
652
+        self._check_result(result, multiline=multiline)
653
+
605 654
     @Parametrize.CONFIG_WITH_KEY
606 655
     def test_key_from_config(
607 656
         self,
... ...
@@ -621,6 +670,102 @@ class TestKeyBasic:
621 670
         self._test(["-k", "--", DUMMY_SERVICE], input="1\n", multiline=True)
622 671
 
623 672
 
673
+@pytest.fixture
674
+def provider_registry() -> dict[
675
+    str, _types.SSHAgentSocketProvider | str | None
676
+]:
677
+    """Set up a controlled SSH agent socket provider registry."""
678
+
679
+    def err() -> _types.SSHAgentSocket:
680
+        pytest.fail("Attempting to use the wrong SSH agent socket provider!")
681
+
682
+    return {
683
+        "correct": machinery.StubbedSSHAgentSocket,
684
+        "incorrect": err,
685
+    }
686
+
687
+
688
+class TestKeyExplicitSSHAgentSocketProvider(TestKeyBasic):
689
+    """Tests for SSH key configuration: explicit SSH agent socket providers."""
690
+
691
+    ARG = "--ssh-agent-socket-provider=correct"
692
+    ARG_NONEXISTANT = "--ssh-agent-socket-provider=nonexistant"
693
+    MAIN_CONFIG_STR = textwrap.dedent(r"""
694
+    [vault]
695
+    ssh-agent-socket-provider = "correct"
696
+    """)
697
+    WRONG_CONFIG_STR = textwrap.dedent(r"""
698
+    [vault]
699
+    ssh-agent-socket-provider = "incorrect"
700
+    """)
701
+    NONEXISTANT_CONFIG_STR = textwrap.dedent(r"""
702
+    [vault]
703
+    ssh-agent-socket-provider = "nonexistant"
704
+    """)
705
+
706
+    @Parametrize.EXPLICIT_SSH_AGENT_SOCKET_PROVIDER
707
+    @Parametrize.CONFIG_WITH_KEY
708
+    def test_explicit_ssh_agent_socket_provider(
709
+        self,
710
+        provider_registry: dict[
711
+            str, _types.SSHAgentSocketProvider | str | None
712
+        ],
713
+        config: _types.VaultConfig,
714
+        on_command_line: bool,
715
+        in_config: bool,
716
+    ) -> None:
717
+        args = [self.ARG] if on_command_line else []
718
+        main_config_str = (
719
+            self.WRONG_CONFIG_STR
720
+            if in_config and on_command_line
721
+            else self.MAIN_CONFIG_STR
722
+            if in_config
723
+            else None
724
+        )
725
+        self._test(
726
+            [*args, "--", DUMMY_SERVICE],
727
+            config=config,
728
+            registry=provider_registry,
729
+            main_config_str=main_config_str,
730
+        )
731
+
732
+    @Parametrize.EXPLICIT_SSH_AGENT_SOCKET_PROVIDER
733
+    @Parametrize.CONFIG_WITH_KEY
734
+    def test_explicit_ssh_agent_socket_provider_not_found(
735
+        self,
736
+        provider_registry: dict[
737
+            str, _types.SSHAgentSocketProvider | str | None
738
+        ],
739
+        config: _types.VaultConfig,
740
+        on_command_line: bool,
741
+        in_config: bool,
742
+    ) -> None:
743
+        assert "nonexistant" not in provider_registry, (
744
+            '"nonexistant" name actually found in registry?!'
745
+        )
746
+        args = [self.ARG_NONEXISTANT] if on_command_line else []
747
+        main_config_str = (
748
+            self.MAIN_CONFIG_STR
749
+            if in_config and on_command_line
750
+            else self.NONEXISTANT_CONFIG_STR
751
+            if in_config
752
+            else None
753
+        )
754
+        with self._setup_environment(
755
+            config=config,
756
+            registry=provider_registry,
757
+            main_config_str=main_config_str,
758
+        ) as runner:
759
+            result = runner.invoke(
760
+                cli.derivepassphrase_vault,
761
+                [*args, "--", DUMMY_SERVICE],
762
+                catch_exceptions=False,
763
+            )
764
+        assert result.error_exit(
765
+            error=" is not in derivepassphrase's provider registry."
766
+        )
767
+
768
+
624 769
 class TestPhraseAndKeyOverriding:
625 770
     """Tests for master passphrase and SSH key configuration: overriding."""
626 771
 
627 772