Add missing `warning_callback` argument to `key_to_phrase` call
Marco Ricci

Marco Ricci commited on 2025-12-13 15:47:45
Zeige 4 geänderte Dateien mit 60 Einfügungen und 9 Löschungen.


The callback parameters in `cli_helpers.key_to_phrase` for handling
warning and error messages are now mandatory to specify.  All calls in
production code and in test code(!) now explicitly handle warnings and
errors.

Since changing `cli_helpers.key_to_phrase` in
2413d9dc10ede315c295ab7520a19b21d597a668 to support exception groups,
the function takes additional callback parameters to do its warning and
error handling/reporting.  For compatibility reasons, the signature
included default values which suppressed warning messages and exited the
process on error messages.

However, these default values made it too
easy to *forget* proper handling of warning and error messages in the
new implementation.  In particular, two call sites in production code,
just several lines apart, had differing warning handling, and the test
suite had two major areas where warning handling was completely absent,
relying on the "suppressed warning messages" default.  All these
behaviors were unwanted and wrong, but difficult to spot, because the
test code too was wrong.  By making the error and warning handling
callbacks mandatory to specify and removing the implicit suppression
behavior, inadvertent suppression of warning and error messages becomes
much more difficult.

To further ensure this doesn't happen again accidentally, the tests now
also assert that certain expected warning messages are emitted, i.e.,
that the callback is actually exercised.
... ...
@@ -828,8 +828,8 @@ def select_ssh_key(
828 828
     /,
829 829
     *,
830 830
     ctx: click.Context | None = None,
831
-    error_callback: Callable[..., NoReturn] = default_error_callback,
832
-    warning_callback: Callable[..., None] = lambda *_args: None,
831
+    error_callback: Callable[..., NoReturn],
832
+    warning_callback: Callable[..., None],
833 833
 ) -> bytes | bytearray:
834 834
     """Interactively select an SSH key for passphrase derivation.
835 835
 
... ...
@@ -1068,8 +1068,8 @@ def key_to_phrase(
1068 1068
     key: str | Buffer,
1069 1069
     /,
1070 1070
     *,
1071
-    error_callback: Callable[..., NoReturn] = default_error_callback,
1072
-    warning_callback: Callable[..., None] = lambda *_args: None,
1071
+    error_callback: Callable[..., NoReturn],
1072
+    warning_callback: Callable[..., None],
1073 1073
     conn: ssh_agent.SSHAgentClient
1074 1074
     | _types.SSHAgentSocket
1075 1075
     | Sequence[str]
... ...
@@ -1082,6 +1082,26 @@ def key_to_phrase(
1082 1082
     obtained from the key, because this is the first point of contact
1083 1083
     with the SSH agent.
1084 1084
 
1085
+    Args:
1086
+        key:
1087
+            The SSH key for which the equivalent master passphrase shall
1088
+            be computed, encoded in base64.
1089
+        conn:
1090
+            An optional connection hint to the SSH agent.  See
1091
+            [`ssh_agent.SSHAgentClient.ensure_agent_subcontext`][].
1092
+        error_callback:
1093
+            A callback function for an error message, if any.  The
1094
+            callback function is responsible for aborting this function
1095
+            call after acknowledging, formatting and/or forwarding the
1096
+            error message; it would typically call [`sys.exit`][] or
1097
+            raise an exception of its own, based on the provided error
1098
+            message.
1099
+        warning_callback:
1100
+            A callback function for a warning message, if any.  The
1101
+            callback function is responsible for formatting the warning
1102
+            message and dispatching it into the warning system, if so
1103
+            desired.
1104
+
1085 1105
     """
1086 1106
     key = base64.standard_b64decode(key)
1087 1107
     with exceptiongroup.catch({  # noqa: SIM117
... ...
@@ -1379,7 +1379,9 @@ class _VaultContext:  # noqa: PLR0904
1379 1379
             phrase = cast("str", overrides["phrase"])
1380 1380
         elif settings.get("key"):
1381 1381
             phrase = cli_helpers.key_to_phrase(
1382
-                cast("str", settings["key"]), error_callback=self.err
1382
+                cast("str", settings["key"]),
1383
+                error_callback=self.err,
1384
+                warning_callback=self.warning,
1383 1385
             )
1384 1386
         elif settings.get("phrase"):
1385 1387
             phrase = cast("str", settings["phrase"])
... ...
@@ -445,6 +445,7 @@ class Parametrize(types.SimpleNamespace):
445 445
             "system_support_action",
446 446
             "sign_action",
447 447
             "pattern",
448
+            "warnings_patterns",
448 449
         ],
449 450
         [
450 451
             pytest.param(
... ...
@@ -453,6 +454,7 @@ class Parametrize(types.SimpleNamespace):
453 454
                 None,
454 455
                 SignAction.FAIL,
455 456
                 "not loaded into the agent",
457
+                [],
456 458
                 id="key-not-loaded",
457 459
             ),
458 460
             pytest.param(
... ...
@@ -461,6 +463,7 @@ class Parametrize(types.SimpleNamespace):
461 463
                 None,
462 464
                 SignAction.FAIL,
463 465
                 "SSH agent failed to or refused to",
466
+                [],
464 467
                 id="list-keys-refused",
465 468
             ),
466 469
             pytest.param(
... ...
@@ -469,6 +472,7 @@ class Parametrize(types.SimpleNamespace):
469 472
                 None,
470 473
                 SignAction.FAIL,
471 474
                 "SSH agent failed to or refused to",
475
+                [],
472 476
                 id="list-keys-protocol-error",
473 477
             ),
474 478
             pytest.param(
... ...
@@ -477,6 +481,7 @@ class Parametrize(types.SimpleNamespace):
477 481
                 None,
478 482
                 SignAction.FAIL,
479 483
                 "Cannot find any running SSH agent",
484
+                [],
480 485
                 id="agent-address-missing",
481 486
             ),
482 487
             pytest.param(
... ...
@@ -485,6 +490,7 @@ class Parametrize(types.SimpleNamespace):
485 490
                 None,
486 491
                 SignAction.FAIL,
487 492
                 "Cannot connect to the SSH agent",
493
+                [],
488 494
                 id="agent-address-mangled",
489 495
             ),
490 496
             pytest.param(
... ...
@@ -493,6 +499,7 @@ class Parametrize(types.SimpleNamespace):
493 499
                 SystemSupportAction.UNSET_NATIVE,
494 500
                 SignAction.FAIL,
495 501
                 "does not support communicating with it",
502
+                [],
496 503
                 id="no-agent-support",
497 504
             ),
498 505
             pytest.param(
... ...
@@ -501,6 +508,7 @@ class Parametrize(types.SimpleNamespace):
501 508
                 SystemSupportAction.UNSET_PROVIDER_LIST,
502 509
                 SignAction.FAIL,
503 510
                 "does not support communicating with it",
511
+                [],
504 512
                 id="no-agent-support",
505 513
             ),
506 514
             pytest.param(
... ...
@@ -509,6 +517,7 @@ class Parametrize(types.SimpleNamespace):
509 517
                 SystemSupportAction.UNSET_AF_UNIX_AND_ENSURE_USE,
510 518
                 SignAction.FAIL,
511 519
                 "does not support communicating with it",
520
+                ["Cannot connect to an SSH agent via UNIX domain sockets"],
512 521
                 id="no-agent-support",
513 522
             ),
514 523
             pytest.param(
... ...
@@ -517,6 +526,7 @@ class Parametrize(types.SimpleNamespace):
517 526
                 SystemSupportAction.UNSET_WINDLL_AND_ENSURE_USE,
518 527
                 SignAction.FAIL,
519 528
                 "does not support communicating with it",
529
+                ["Cannot connect to an SSH agent via Windows named pipes"],
520 530
                 id="no-agent-support",
521 531
             ),
522 532
             pytest.param(
... ...
@@ -525,6 +535,7 @@ class Parametrize(types.SimpleNamespace):
525 535
                 None,
526 536
                 SignAction.FAIL_RUNTIME,
527 537
                 "violates the communication protocol",
538
+                [],
528 539
                 id="sign-violates-protocol",
529 540
             ),
530 541
         ],
... ...
@@ -1427,17 +1438,24 @@ class TestMisc:
1427 1438
         address_action: SocketAddressAction | None,
1428 1439
         sign_action: SignAction,
1429 1440
         pattern: str,
1441
+        warnings_patterns: list[str],
1430 1442
     ) -> None:
1431 1443
         """All errors in [`cli_helpers.key_to_phrase`][] are handled."""
1432 1444
 
1445
+        captured_warnings: list[str] = []
1446
+
1433 1447
         class ErrCallback(BaseException):
1434 1448
             def __init__(self, *args: Any, **kwargs: Any) -> None:
1435 1449
                 super().__init__(*args[:1])
1436 1450
                 self.args = args
1437 1451
                 self.kwargs = kwargs
1438 1452
 
1439
-        def err(*args: Any, **_kwargs: Any) -> NoReturn:
1440
-            raise ErrCallback(*args, **_kwargs)
1453
+        def err(*args: Any, **kwargs: Any) -> NoReturn:
1454
+            raise ErrCallback(*args, **kwargs)
1455
+
1456
+        def warn(*args: Any) -> None:
1457
+            if args:  # pragma: no branch
1458
+                captured_warnings.append(str(args[0]))
1441 1459
 
1442 1460
         with pytest.MonkeyPatch.context() as monkeypatch:
1443 1461
             loaded_keys = list(
... ...
@@ -1454,7 +1472,14 @@ class TestMisc:
1454 1472
             if system_support_action:
1455 1473
                 system_support_action(monkeypatch)
1456 1474
             with pytest.raises(ErrCallback, match=pattern) as excinfo:
1457
-                cli_helpers.key_to_phrase(loaded_key, error_callback=err)
1475
+                cli_helpers.key_to_phrase(
1476
+                    loaded_key, error_callback=err, warning_callback=warn
1477
+                )
1478
+
1479
+            for pat in warnings_patterns:
1480
+                assert any([pat in string for string in captured_warnings]), (
1481
+                    f"expected some warning message to match {pat}"
1482
+                )
1458 1483
             if list_keys_action == ListKeysAction.FAIL_RUNTIME:
1459 1484
                 assert excinfo.value.kwargs
1460 1485
                 assert isinstance(
... ...
@@ -1029,7 +1029,11 @@ class TestSuitableKeys:
1029 1029
         @click.command()
1030 1030
         def driver() -> None:
1031 1031
             """Call [`cli_helpers.select_ssh_key`][] directly, as a command."""
1032
-            key = cli_helpers.select_ssh_key(client)
1032
+            key = cli_helpers.select_ssh_key(
1033
+                client,
1034
+                error_callback=lambda s, *_args: pytest.fail(str(s)),
1035
+                warning_callback=lambda s, *_args: pytest.fail(str(s)),
1036
+            )
1033 1037
             click.echo(base64.standard_b64encode(key).decode("ASCII"))
1034 1038
 
1035 1039
         runner = machinery.CliRunner(mix_stderr=True)
1036 1040