Make all command/format/feature enums self-testing
Marco Ricci

Marco Ricci commited on 2025-08-03 11:42:44
Zeige 3 geänderte Dateien mit 179 Einfügungen und 74 Löschungen.


Make the enums `_types.Subcommand`, `_types.DerivationScheme`, etc. all
capable of testing whether they are active/enabled/supported, or not.
Specifically, make each enum definition include a name and an "is
enabled?" test function.  Then all the command/format/feature support
information is bundled at the same site in the code, is harder to get
out of sync, and is also nicer to read at the call sites.

We choose to implement a single test function that dispatches to the
enum value-specific test functions manually.  I originally considered
using a tuple of name and test function as the enum value, and then
automatically dispatching on the test function value, akin to how one
might implement this e.g. in Java.  However, the corresponding testing
code is written generically, permitting arbitrary texts as values for
the categories that these enums represent.  Having to explicitly convert
between strings and (non-string) enums makes the testing code much more
complicated (and thus harder itself to test), extremely brittle, and
the sample test cases markedly less readable.  Having a hand-written
dispatch function for each enum class, with relatively few cases per
class, seems like an acceptable price to pay if it makes for more easily
and more readably testable code.
... ...
@@ -18,21 +18,16 @@ import collections
18 18
 import importlib.metadata
19 19
 import inspect
20 20
 import logging
21
-import sys
22 21
 import warnings
23 22
 from typing import TYPE_CHECKING, Callable, Literal, TextIO, TypeVar
24 23
 
25 24
 import click
26 25
 import click.shell_completion
27
-import exceptiongroup
28 26
 from typing_extensions import Any, ParamSpec, override
29 27
 
30
-from derivepassphrase import _internals, _types, ssh_agent
28
+from derivepassphrase import _internals, _types
31 29
 from derivepassphrase._internals import cli_messages as _msg
32 30
 
33
-if sys.version_info < (3, 11):
34
-    from exceptiongroup import BaseExceptionGroup
35
-
36 31
 if TYPE_CHECKING:
37 32
     import types
38 33
     from collections.abc import (
... ...
@@ -1059,17 +1054,19 @@ def derivepassphrase_version_option_callback(
1059 1054
 ) -> None:
1060 1055
     if value and not ctx.resilient_parsing:
1061 1056
         common_version_output(ctx, param, value)
1062
-        derivation_schemes = dict.fromkeys(_types.DerivationScheme, True)
1057
+        derivation_schemes = set(_types.DerivationScheme)
1063 1058
         supported_subcommands = set(_types.Subcommand)
1064 1059
         click.echo()
1065 1060
         version_info_types: dict[_msg.Label, list[str]] = {
1066 1061
             _msg.Label.SUPPORTED_DERIVATION_SCHEMES: [
1067
-                k for k, v in derivation_schemes.items() if v
1062
+                k for k in derivation_schemes if k.test()
1068 1063
             ],
1069 1064
             _msg.Label.UNAVAILABLE_DERIVATION_SCHEMES: [
1070
-                k for k, v in derivation_schemes.items() if not v
1065
+                k for k in derivation_schemes if not k.test()
1066
+            ],
1067
+            _msg.Label.SUPPORTED_SUBCOMMANDS: [
1068
+                k for k in supported_subcommands if k.test()
1071 1069
             ],
1072
-            _msg.Label.SUPPORTED_SUBCOMMANDS: sorted(supported_subcommands),
1073 1070
         }
1074 1071
         print_version_info_types(version_info_types, ctx=ctx)
1075 1072
         ctx.exit()
... ...
@@ -1083,17 +1080,22 @@ def export_version_option_callback(
1083 1080
     if value and not ctx.resilient_parsing:
1084 1081
         common_version_output(ctx, param, value)
1085 1082
         supported_subcommands = set(_types.ExportSubcommand)
1086
-        foreign_configuration_formats = {
1087
-            _types.ForeignConfigurationFormat.VAULT_STOREROOM: False,
1088
-            _types.ForeignConfigurationFormat.VAULT_V02: False,
1089
-            _types.ForeignConfigurationFormat.VAULT_V03: False,
1090
-        }
1083
+        foreign_configuration_formats = [
1084
+            _types.ForeignConfigurationFormat.VAULT_STOREROOM,
1085
+            _types.ForeignConfigurationFormat.VAULT_V02,
1086
+            _types.ForeignConfigurationFormat.VAULT_V03,
1087
+        ]
1091 1088
         click.echo()
1092 1089
         version_info_types: dict[_msg.Label, list[str]] = {
1090
+            # Marked as known-but-unavailable, because they are used by
1091
+            # subcommands of `derivepassphrase export`, not by the
1092
+            # command itself.
1093 1093
             _msg.Label.UNAVAILABLE_FOREIGN_CONFIGURATION_FORMATS: [
1094
-                k for k, v in foreign_configuration_formats.items() if not v
1094
+                k for k in foreign_configuration_formats
1095
+            ],
1096
+            _msg.Label.SUPPORTED_SUBCOMMANDS: [
1097
+                k for k in supported_subcommands if k.test()
1095 1098
             ],
1096
-            _msg.Label.SUPPORTED_SUBCOMMANDS: sorted(supported_subcommands),
1097 1099
         }
1098 1100
         print_version_info_types(version_info_types, ctx=ctx)
1099 1101
         ctx.exit()
... ...
@@ -1106,67 +1108,30 @@ def export_vault_version_option_callback(
1106 1108
 ) -> None:
1107 1109
     if value and not ctx.resilient_parsing:
1108 1110
         common_version_output(ctx, param, value)
1109
-        foreign_configuration_formats = {
1110
-            _types.ForeignConfigurationFormat.VAULT_STOREROOM: False,
1111
-            _types.ForeignConfigurationFormat.VAULT_V02: False,
1112
-            _types.ForeignConfigurationFormat.VAULT_V03: False,
1113
-        }
1114
-        known_extras = {
1115
-            _types.PEP508Extra.EXPORT: False,
1116
-        }
1117
-        from derivepassphrase.exporter import storeroom, vault_native  # noqa: I001,PLC0415
1118
-
1119
-        foreign_configuration_formats[
1120
-            _types.ForeignConfigurationFormat.VAULT_STOREROOM
1121
-        ] = not storeroom.STUBBED
1122
-        foreign_configuration_formats[
1123
-            _types.ForeignConfigurationFormat.VAULT_V02
1124
-        ] = not vault_native.STUBBED
1125
-        foreign_configuration_formats[
1126
-            _types.ForeignConfigurationFormat.VAULT_V03
1127
-        ] = not vault_native.STUBBED
1128
-        known_extras[_types.PEP508Extra.EXPORT] = (
1129
-            not storeroom.STUBBED and not vault_native.STUBBED
1130
-        )
1111
+        foreign_configuration_formats = [
1112
+            _types.ForeignConfigurationFormat.VAULT_STOREROOM,
1113
+            _types.ForeignConfigurationFormat.VAULT_V02,
1114
+            _types.ForeignConfigurationFormat.VAULT_V03,
1115
+        ]
1116
+        known_extras = [
1117
+            _types.PEP508Extra.EXPORT,
1118
+        ]
1131 1119
         click.echo()
1132 1120
         version_info_types: dict[_msg.Label, list[str]] = {
1133 1121
             _msg.Label.SUPPORTED_FOREIGN_CONFIGURATION_FORMATS: [
1134
-                k for k, v in foreign_configuration_formats.items() if v
1122
+                k for k in foreign_configuration_formats if k.test()
1135 1123
             ],
1136 1124
             _msg.Label.UNAVAILABLE_FOREIGN_CONFIGURATION_FORMATS: [
1137
-                k for k, v in foreign_configuration_formats.items() if not v
1125
+                k for k in foreign_configuration_formats if not k.test()
1138 1126
             ],
1139 1127
             _msg.Label.ENABLED_PEP508_EXTRAS: [
1140
-                k for k, v in known_extras.items() if v
1128
+                k for k in known_extras if k.test()
1141 1129
             ],
1142 1130
         }
1143 1131
         print_version_info_types(version_info_types, ctx=ctx)
1144 1132
         ctx.exit()
1145 1133
 
1146 1134
 
1147
-def _test_for_ssh_key_feature() -> bool:
1148
-    """Return true if we support SSH keys.
1149
-
1150
-    This is the feature test for [`_types.Feature.SSH_KEY`][].  We test
1151
-    this by attempting to construct an SSH agent client, reporting
1152
-    whether this can principally work, or not.
1153
-
1154
-    """
1155
-    ret = True
1156
-
1157
-    def handle_notimplementederror(_exc: BaseExceptionGroup) -> None:
1158
-        nonlocal ret
1159
-        ret = False
1160
-
1161
-    with exceptiongroup.catch({  # noqa: SIM117
1162
-        NotImplementedError: handle_notimplementederror,
1163
-        Exception: lambda _exc: None,
1164
-    }):
1165
-        with ssh_agent.SSHAgentClient.ensure_agent_subcontext():
1166
-            pass
1167
-    return ret
1168
-
1169
-
1170 1135
 def vault_version_option_callback(
1171 1136
     ctx: click.Context,
1172 1137
     param: click.Parameter,
... ...
@@ -1174,16 +1139,14 @@ def vault_version_option_callback(
1174 1139
 ) -> None:
1175 1140
     if value and not ctx.resilient_parsing:
1176 1141
         common_version_output(ctx, param, value)
1177
-        features = {
1178
-            _types.Feature.SSH_KEY: _test_for_ssh_key_feature(),
1179
-        }
1142
+        features = [
1143
+            _types.Feature.SSH_KEY,
1144
+        ]
1180 1145
         click.echo()
1181 1146
         version_info_types: dict[_msg.Label, list[str]] = {
1182
-            _msg.Label.SUPPORTED_FEATURES: [
1183
-                k for k, v in features.items() if v
1184
-            ],
1147
+            _msg.Label.SUPPORTED_FEATURES: [k for k in features if k.test()],
1185 1148
             _msg.Label.UNAVAILABLE_FEATURES: [
1186
-                k for k, v in features.items() if not v
1149
+                k for k in features if not k.test()
1187 1150
             ],
1188 1151
         }
1189 1152
         print_version_info_types(version_info_types, ctx=ctx)
... ...
@@ -765,6 +765,36 @@ class StoreroomMasterKeys(NamedTuple, Generic[T_Buffer]):
765 765
         )
766 766
 
767 767
 
768
+class FeatureTestEnum(Protocol):
769
+    """An [`enum.Enum`][] subclass supporting feature tests.
770
+
771
+    Each value of the enum supports the `test` method, which tests
772
+    whether the feature is enabled or not.  (The specific test function
773
+    may in general require arguments to actually execute the test,
774
+    though no function currently does.)
775
+
776
+    """
777
+
778
+    def test(self, *args: Any, **kwargs: Any) -> bool: ...  # noqa: ANN401
779
+
780
+
781
+def _feature_test_function(f: Callable[..., bool]) -> Callable[..., bool]:
782
+    """Mark a function as a feature test function.
783
+
784
+    This decorator exists purely to hold some shared commentary.
785
+
786
+    Test functions may accept arbitrary arguments, dependent on which
787
+    feature they test for.  They should make sure to raise no
788
+    exceptions.  They may use delayed imports to avoid the import cost
789
+    until it is actually warranted.
790
+
791
+    Returns:
792
+        The callable.
793
+
794
+    """
795
+    return f
796
+
797
+
768 798
 class PEP508Extra(str, enum.Enum):
769 799
     """PEP 508 extras supported by `derivepassphrase`.
770 800
 
... ...
@@ -778,6 +808,20 @@ class PEP508Extra(str, enum.Enum):
778 808
     EXPORT = 'export'
779 809
     """"""
780 810
 
811
+    @_feature_test_function
812
+    @staticmethod
813
+    def _test_export() -> bool:
814
+        """Return true if [`PEP508Extra.EXPORT`][] is currently supported."""
815
+        import importlib.util  # noqa: PLC0415
816
+
817
+        return importlib.util.find_spec('cryptography') is not None
818
+
819
+    def test(self, *_args: Any, **_kwargs: Any) -> bool:  # noqa: ANN401
820
+        """Return true if this extra is enabled."""
821
+        if self == PEP508Extra.EXPORT:
822
+            return self._test_export()
823
+        return False  # pragma: no cover [unused]
824
+
781 825
     __str__ = str.__str__
782 826
     __format__ = str.__format__  # type: ignore[assignment]
783 827
 
... ...
@@ -799,6 +843,46 @@ class Feature(str, enum.Enum):
799 843
     SSH_KEY = 'master SSH key'
800 844
     """"""
801 845
 
846
+    @_feature_test_function
847
+    @staticmethod
848
+    def _test_ssh_key() -> bool:
849
+        """Return true if [`_types.Feature.SSH_KEY`][] is currently supported.
850
+
851
+        We test this by attempting to construct an SSH agent client,
852
+        reporting whether this can principally work, or not.
853
+
854
+        """
855
+        import sys  # noqa: PLC0415
856
+
857
+        import exceptiongroup  # noqa: PLC0415
858
+
859
+        from derivepassphrase import ssh_agent  # noqa: PLC0415
860
+
861
+        if sys.version_info < (3, 11):
862
+            from exceptiongroup import BaseExceptionGroup  # noqa: PLC0415
863
+
864
+        ret = True
865
+
866
+        def handle_notimplementederror(
867
+            _exc: BaseExceptionGroup,
868
+        ) -> None:  # pragma: no cover [unused]
869
+            nonlocal ret
870
+            ret = False
871
+
872
+        with exceptiongroup.catch({  # noqa: SIM117
873
+            NotImplementedError: handle_notimplementederror,
874
+            Exception: lambda _exc: None,
875
+        }):
876
+            with ssh_agent.SSHAgentClient.ensure_agent_subcontext():
877
+                pass
878
+        return ret
879
+
880
+    def test(self, *_args: Any, **_kwargs: Any) -> bool:  # noqa: ANN401
881
+        """Return true if this feature is enabled."""
882
+        if self == Feature.SSH_KEY:
883
+            return self._test_ssh_key()
884
+        return False  # pragma: no cover [unused]
885
+
802 886
     __str__ = str.__str__
803 887
     __format__ = str.__format__  # type: ignore[assignment]
804 888
 
... ...
@@ -815,6 +899,10 @@ class DerivationScheme(str, enum.Enum):
815 899
     VAULT = 'vault'
816 900
     """"""
817 901
 
902
+    def test(self, *_args: Any, **_kwargs: Any) -> bool:  # noqa: ANN401
903
+        """Return true if this derivation scheme is enabled."""
904
+        return self == DerivationScheme.VAULT
905
+
818 906
     __str__ = str.__str__
819 907
     __format__ = str.__format__  # type: ignore[assignment]
820 908
 
... ...
@@ -837,11 +925,39 @@ class ForeignConfigurationFormat(str, enum.Enum):
837 925
 
838 926
     VAULT_STOREROOM = 'vault storeroom'
839 927
     """"""
928
+
929
+    @_feature_test_function
930
+    @staticmethod
931
+    def _test_vault_storeroom() -> bool:
932
+        """Return true if [`ForeignConfigurationFormat.VAULT_STOREROOM`][] is currently supported."""  # noqa: E501
933
+        from derivepassphrase.exporter import storeroom  # noqa: PLC0415
934
+
935
+        return not storeroom.STUBBED
936
+
840 937
     VAULT_V02 = 'vault v0.2'
841 938
     """"""
842 939
     VAULT_V03 = 'vault v0.3'
843 940
     """"""
844 941
 
942
+    @_feature_test_function
943
+    @staticmethod
944
+    def _test_vault_v02_v03() -> bool:
945
+        """Return true if [`ForeignConfigurationFormat.VAULT_V02`][] and [`ForeignConfigurationFormat.VAULT_V03`][] is currently supported."""  # noqa: E501
946
+        from derivepassphrase.exporter import vault_native  # noqa: PLC0415
947
+
948
+        return not vault_native.STUBBED
949
+
950
+    def test(self, *_args: Any, **_kwargs: Any) -> bool:  # noqa: ANN401
951
+        """Return true if this foreign configuration format is enabled."""
952
+        if self == ForeignConfigurationFormat.VAULT_STOREROOM:
953
+            return self._test_vault_storeroom()
954
+        if self in {
955
+            ForeignConfigurationFormat.VAULT_V02,
956
+            ForeignConfigurationFormat.VAULT_V03,
957
+        }:
958
+            return self._test_vault_v02_v03()
959
+        return False  # pragma: no cover [unused]
960
+
845 961
     __str__ = str.__str__
846 962
     __format__ = str.__format__  # type: ignore[assignment]
847 963
 
... ...
@@ -858,6 +974,10 @@ class ExportSubcommand(str, enum.Enum):
858 974
     VAULT = 'vault'
859 975
     """"""
860 976
 
977
+    def test(self, *_args: Any, **_kwargs: Any) -> bool:  # noqa: ANN401
978
+        """Return true if this `export` subcommand is enabled."""
979
+        return self == ExportSubcommand.VAULT
980
+
861 981
     __str__ = str.__str__
862 982
     __format__ = str.__format__  # type: ignore[assignment]
863 983
 
... ...
@@ -878,6 +998,10 @@ class Subcommand(str, enum.Enum):
878 998
     VAULT = 'vault'
879 999
     """"""
880 1000
 
1001
+    def test(self, *_args: Any, **_kwargs: Any) -> bool:  # noqa: ANN401
1002
+        """Return true if this subcommand is enabled."""
1003
+        return self in {Subcommand.VAULT, Subcommand.EXPORT}
1004
+
881 1005
     __str__ = str.__str__
882 1006
     __format__ = str.__format__  # type: ignore[assignment]
883 1007
 
... ...
@@ -28,6 +28,7 @@ import warnings
28 28
 from typing import TYPE_CHECKING, cast
29 29
 
30 30
 import click.testing
31
+import exceptiongroup
31 32
 import hypothesis
32 33
 import pytest
33 34
 from hypothesis import stateful, strategies
... ...
@@ -2333,7 +2334,9 @@ class TestAllCLI:
2333 2334
                 _types.ForeignConfigurationFormat.VAULT_V02: not vault_native.STUBBED,
2334 2335
                 _types.ForeignConfigurationFormat.VAULT_V03: not vault_native.STUBBED,
2335 2336
             })
2336
-            if not storeroom.STUBBED and not vault_native.STUBBED:
2337
+        with contextlib.suppress(ModuleNotFoundError):
2338
+            import cryptography  # noqa: F401,PLC0415
2339
+
2337 2340
             actually_enabled_extras.add(_types.PEP508Extra.EXPORT)
2338 2341
         assert not version_data.derivation_schemes
2339 2342
         assert (
... ...
@@ -2376,8 +2379,23 @@ class TestAllCLI:
2376 2379
         assert result.clean_exit(empty_stderr=True), 'expected clean exit'
2377 2380
         assert result.stdout.strip(), 'expected version output'
2378 2381
         version_data = parse_version_output(result.stdout)
2382
+
2383
+        ssh_key_supported = True
2384
+
2385
+        def react_to_notimplementederror(
2386
+            _exc: BaseException,
2387
+        ) -> None:  # pragma: no cover[unused]
2388
+            nonlocal ssh_key_supported
2389
+            ssh_key_supported = False
2390
+
2391
+        with exceptiongroup.catch({  # noqa: SIM117
2392
+            NotImplementedError: react_to_notimplementederror,
2393
+            Exception: lambda *_args: None,
2394
+        }):
2395
+            with ssh_agent.SSHAgentClient.ensure_agent_subcontext():
2396
+                pass
2379 2397
         features: dict[str, bool] = {
2380
-            _types.Feature.SSH_KEY: hasattr(socket, 'AF_UNIX'),
2398
+            _types.Feature.SSH_KEY: ssh_key_supported,
2381 2399
         }
2382 2400
         assert not version_data.derivation_schemes
2383 2401
         assert not version_data.foreign_configuration_formats
2384 2402