Support suppressing or forcing color output
Marco Ricci

Marco Ricci commited on 2025-01-01 00:47:57
Zeige 2 geänderte Dateien mit 98 Einfügungen und 19 Löschungen.


We implement this with an eager, hidden "pseudo" option that checks for
the `NO_COLOR` and `FORCE_COLOR` options and sets the context's color
setting appropriately.

To actually make use of this color setting, all output must query the
context's color setting, as must all log output.  For the log output
specifically, we arrange to pass the `click` context in the extra dict
of each logging call, so that the handler has access to the color
setting too.
... ...
@@ -95,7 +95,11 @@ class ClickEchoStderrHandler(logging.Handler):
95 95
         [`sys.stderr`][].
96 96
 
97 97
         """
98
-        click.echo(self.format(record), err=True)
98
+        click.echo(
99
+            self.format(record),
100
+            err=True,
101
+            color=getattr(record, 'color', None),
102
+        )
99 103
 
100 104
 
101 105
 class CLIofPackageFormatter(logging.Formatter):
... ...
@@ -1016,6 +1020,31 @@ def version_option(f: Callable[P, R]) -> Callable[P, R]:
1016 1020
     )(f)
1017 1021
 
1018 1022
 
1023
+def color_forcing_callback(
1024
+    ctx: click.Context,
1025
+    param: click.Parameter,
1026
+    value: Any,  # noqa: ANN401
1027
+) -> None:
1028
+    """Force the `click` context to honor `NO_COLOR` and `FORCE_COLOR`."""
1029
+    del param, value
1030
+    if os.environ.get('NO_COLOR'):  # pragma: no cover
1031
+        ctx.color = False
1032
+    if os.environ.get('FORCE_COLOR'):  # pragma: no cover
1033
+        ctx.color = True
1034
+
1035
+
1036
+color_forcing_pseudo_option = click.option(
1037
+    '--_pseudo-option-color-forcing',
1038
+    '_color_forcing',
1039
+    is_flag=True,
1040
+    is_eager=True,
1041
+    expose_value=False,
1042
+    hidden=True,
1043
+    callback=color_forcing_callback,
1044
+    help='(pseudo-option)',
1045
+)
1046
+
1047
+
1019 1048
 class LoggingOption(OptionGroupOption):
1020 1049
     """Logging options for the CLI."""
1021 1050
 
... ...
@@ -1223,6 +1252,7 @@ class _TopLevelCLIEntryPoint(_DefaultToVaultGroup):
1223 1252
     ),
1224 1253
 )
1225 1254
 @version_option
1255
+@color_forcing_pseudo_option
1226 1256
 @standard_logging_options
1227 1257
 @click.pass_context
1228 1258
 def derivepassphrase(ctx: click.Context, /) -> None:
... ...
@@ -1240,7 +1270,10 @@ def derivepassphrase(ctx: click.Context, /) -> None:
1240 1270
     deprecation = logging.getLogger(f'{PROG_NAME}.deprecation')
1241 1271
     if ctx.invoked_subcommand is None:
1242 1272
         deprecation.warning(
1243
-            _msg.TranslatedString(_msg.WarnMsgTemplate.V10_SUBCOMMAND_REQUIRED)
1273
+            _msg.TranslatedString(
1274
+                _msg.WarnMsgTemplate.V10_SUBCOMMAND_REQUIRED
1275
+            ),
1276
+            extra={'color': ctx.color},
1244 1277
         )
1245 1278
         # See definition of click.Group.invoke, non-chained case.
1246 1279
         with ctx:
... ...
@@ -1272,6 +1305,7 @@ def derivepassphrase(ctx: click.Context, /) -> None:
1272 1305
     ),
1273 1306
 )
1274 1307
 @version_option
1308
+@color_forcing_pseudo_option
1275 1309
 @standard_logging_options
1276 1310
 @click.pass_context
1277 1311
 def derivepassphrase_export(ctx: click.Context, /) -> None:
... ...
@@ -1289,7 +1323,10 @@ def derivepassphrase_export(ctx: click.Context, /) -> None:
1289 1323
     deprecation = logging.getLogger(f'{PROG_NAME}.deprecation')
1290 1324
     if ctx.invoked_subcommand is None:
1291 1325
         deprecation.warning(
1292
-            _msg.TranslatedString(_msg.WarnMsgTemplate.V10_SUBCOMMAND_REQUIRED)
1326
+            _msg.TranslatedString(
1327
+                _msg.WarnMsgTemplate.V10_SUBCOMMAND_REQUIRED
1328
+            ),
1329
+            extra={'color': ctx.color},
1293 1330
         )
1294 1331
         # See definition of click.Group.invoke, non-chained case.
1295 1332
         with ctx:
... ...
@@ -1413,6 +1450,7 @@ def _shell_complete_vault_path(  # pragma: no cover
1413 1450
     cls=StandardOption,
1414 1451
 )
1415 1452
 @version_option
1453
+@color_forcing_pseudo_option
1416 1454
 @standard_logging_options
1417 1455
 @click.argument(
1418 1456
     'path',
... ...
@@ -1462,6 +1500,7 @@ def derivepassphrase_export_vault(
1462 1500
                     path=path,
1463 1501
                     fmt=fmt,
1464 1502
                 ),
1503
+                extra={'color': ctx.color},
1465 1504
             )
1466 1505
             continue
1467 1506
         except OSError as exc:
... ...
@@ -1472,6 +1511,7 @@ def derivepassphrase_export_vault(
1472 1511
                     error=exc.strerror,
1473 1512
                     filename=exc.filename,
1474 1513
                 ).maybe_without_filename(),
1514
+                extra={'color': ctx.color},
1475 1515
             )
1476 1516
             ctx.exit(1)
1477 1517
         except ModuleNotFoundError:
... ...
@@ -1480,12 +1520,14 @@ def derivepassphrase_export_vault(
1480 1520
                     _msg.ErrMsgTemplate.MISSING_MODULE,
1481 1521
                     module='cryptography',
1482 1522
                 ),
1523
+                extra={'color': ctx.color},
1483 1524
             )
1484 1525
             logger.info(
1485 1526
                 _msg.TranslatedString(
1486 1527
                     _msg.InfoMsgTemplate.PIP_INSTALL_EXTRA,
1487 1528
                     extra_name='export',
1488 1529
                 ),
1530
+                extra={'color': ctx.color},
1489 1531
             )
1490 1532
             ctx.exit(1)
1491 1533
         else:
... ...
@@ -1495,9 +1537,13 @@ def derivepassphrase_export_vault(
1495 1537
                         _msg.ErrMsgTemplate.INVALID_VAULT_CONFIG,
1496 1538
                         config=config,
1497 1539
                     ),
1540
+                    extra={'color': ctx.color},
1498 1541
                 )
1499 1542
                 ctx.exit(1)
1500
-            click.echo(json.dumps(config, indent=2, sort_keys=True))
1543
+            click.echo(
1544
+                json.dumps(config, indent=2, sort_keys=True),
1545
+                color=ctx.color,
1546
+            )
1501 1547
             break
1502 1548
     else:
1503 1549
         logger.error(
... ...
@@ -1505,6 +1551,7 @@ def derivepassphrase_export_vault(
1505 1551
                 _msg.ErrMsgTemplate.CANNOT_PARSE_AS_VAULT_CONFIG,
1506 1552
                 path=path,
1507 1553
             ).maybe_without_filename(),
1554
+            extra={'color': ctx.color},
1508 1555
         )
1509 1556
         ctx.exit(1)
1510 1557
 
... ...
@@ -1728,6 +1775,7 @@ def _prompt_for_selection(
1728 1775
     items: Sequence[str | bytes],
1729 1776
     heading: str = 'Possible choices:',
1730 1777
     single_choice_prompt: str = 'Confirm this choice?',
1778
+    ctx: click.Context | None = None,
1731 1779
 ) -> int:
1732 1780
     """Prompt user for a choice among the given items.
1733 1781
 
... ...
@@ -1747,6 +1795,9 @@ def _prompt_for_selection(
1747 1795
         single_choice_prompt:
1748 1796
             The confirmation prompt if there is only a single possible
1749 1797
             choice.  Defaults to a reasonable standard prompt.
1798
+        ctx:
1799
+            An optional `click` context, from which output device
1800
+            properties and color preferences will be queried.
1750 1801
 
1751 1802
     Returns:
1752 1803
         An index into the items sequence, indicating the user's
... ...
@@ -1759,12 +1810,13 @@ def _prompt_for_selection(
1759 1810
 
1760 1811
     """
1761 1812
     n = len(items)
1813
+    color = ctx.color if ctx is not None else None
1762 1814
     if heading:
1763
-        click.echo(click.style(heading, bold=True))
1815
+        click.echo(click.style(heading, bold=True), color=color)
1764 1816
     for i, x in enumerate(items, start=1):
1765
-        click.echo(click.style(f'[{i}]', bold=True), nl=False)
1766
-        click.echo(' ', nl=False)
1767
-        click.echo(x)
1817
+        click.echo(click.style(f'[{i}]', bold=True), nl=False, color=color)
1818
+        click.echo(' ', nl=False, color=color)
1819
+        click.echo(x, color=color)
1768 1820
     if n > 1:
1769 1821
         choices = click.Choice([''] + [str(i) for i in range(1, n + 1)])
1770 1822
         choice = click.prompt(
... ...
@@ -1796,7 +1848,10 @@ def _prompt_for_selection(
1796 1848
 
1797 1849
 
1798 1850
 def _select_ssh_key(
1799
-    conn: ssh_agent.SSHAgentClient | socket.socket | None = None, /
1851
+    conn: ssh_agent.SSHAgentClient | socket.socket | None = None,
1852
+    /,
1853
+    *,
1854
+    ctx: click.Context | None = None,
1800 1855
 ) -> bytes | bytearray:
1801 1856
     """Interactively select an SSH key for passphrase derivation.
1802 1857
 
... ...
@@ -1808,6 +1863,9 @@ def _select_ssh_key(
1808 1863
         conn:
1809 1864
             An optional connection hint to the SSH agent.  See
1810 1865
             [`ssh_agent.SSHAgentClient.ensure_agent_subcontext`][].
1866
+        ctx:
1867
+            An `click` context, queried for output device properties and
1868
+            color preferences when issuing the prompt.
1811 1869
 
1812 1870
     Returns:
1813 1871
         The selected SSH key.
... ...
@@ -1852,6 +1910,7 @@ def _select_ssh_key(
1852 1910
         key_listing,
1853 1911
         heading='Suitable SSH keys:',
1854 1912
         single_choice_prompt='Use this key?',
1913
+        ctx=ctx,
1855 1914
     )
1856 1915
     return suitable_keys[choice].key
1857 1916
 
... ...
@@ -1917,6 +1976,7 @@ def _check_for_misleading_passphrase(
1917 1976
     value: dict[str, Any],
1918 1977
     *,
1919 1978
     main_config: dict[str, Any],
1979
+    ctx: click.Context | None = None,
1920 1980
 ) -> None:
1921 1981
     form_key = 'unicode-normalization-form'
1922 1982
     default_form: str = main_config.get('vault', {}).get(
... ...
@@ -1954,6 +2014,7 @@ def _check_for_misleading_passphrase(
1954 2014
                 formatted_key,
1955 2015
                 form,
1956 2016
                 stacklevel=2,
2017
+                extra={'color': ctx.color if ctx is not None else None},
1957 2018
             )
1958 2019
 
1959 2020
 
... ...
@@ -2522,6 +2583,7 @@ DEFAULT_NOTES_MARKER = '# - - - - - >8 - - - - -'
2522 2583
     cls=CompatibilityOption,
2523 2584
 )
2524 2585
 @version_option
2586
+@color_forcing_pseudo_option
2525 2587
 @standard_logging_options
2526 2588
 @click.argument(
2527 2589
     'service',
... ...
@@ -2725,7 +2787,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2725 2787
     def err(msg: Any, /, **kwargs: Any) -> NoReturn:  # noqa: ANN401
2726 2788
         stacklevel = kwargs.pop('stacklevel', 1)
2727 2789
         stacklevel += 1
2728
-        logger.error(msg, stacklevel=stacklevel, **kwargs)
2790
+        extra = kwargs.pop('extra', {})
2791
+        extra.setdefault('color', ctx.color)
2792
+        logger.error(msg, stacklevel=stacklevel, extra=extra, **kwargs)
2729 2793
         ctx.exit(1)
2730 2794
 
2731 2795
     def get_config() -> _types.VaultConfig:
... ...
@@ -2746,6 +2810,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2746 2810
                     old=old_name,
2747 2811
                     new=new_name,
2748 2812
                 ),
2813
+                extra={'color': ctx.color},
2749 2814
             )
2750 2815
             if isinstance(exc, OSError):
2751 2816
                 logger.warning(
... ...
@@ -2755,6 +2820,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2755 2820
                         error=exc.strerror,
2756 2821
                         filename=exc.filename,
2757 2822
                     ).maybe_without_filename(),
2823
+                    extra={'color': ctx.color},
2758 2824
                 )
2759 2825
             else:
2760 2826
                 deprecation.info(
... ...
@@ -2762,6 +2828,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2762 2828
                         _msg.InfoMsgTemplate.SUCCESSFULLY_MIGRATED,
2763 2829
                         path=new_name,
2764 2830
                     ),
2831
+                    extra={'color': ctx.color},
2765 2832
                 )
2766 2833
             return backup_config
2767 2834
         except OSError as exc:
... ...
@@ -2882,7 +2949,8 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2882 2949
             _msg.TranslatedString(
2883 2950
                 _msg.WarnMsgTemplate.EMPTY_SERVICE_NOT_SUPPORTED,
2884 2951
                 service_metavar=service_metavar,
2885
-            )
2952
+            ),
2953
+            extra={'color': ctx.color},
2886 2954
         )
2887 2955
 
2888 2956
     if edit_notes:
... ...
@@ -2982,6 +3050,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2982 3050
                         path=_types.json_path(step.path),
2983 3051
                         new=json.dumps(step.new_value),
2984 3052
                     ),
3053
+                    extra={'color': ctx.color},
2985 3054
                 )
2986 3055
             else:
2987 3056
                 logger.warning(
... ...
@@ -2990,6 +3059,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2990 3059
                         path=_types.json_path(step.path),
2991 3060
                         old=json.dumps(step.old_value),
2992 3061
                     ),
3062
+                    extra={'color': ctx.color},
2993 3063
                 )
2994 3064
         if '' in maybe_config['services']:
2995 3065
             logger.warning(
... ...
@@ -2998,18 +3068,21 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
2998 3068
                     service_metavar=service_metavar,
2999 3069
                     PROG_NAME=PROG_NAME,
3000 3070
                 ),
3071
+                extra={'color': ctx.color},
3001 3072
             )
3002 3073
         try:
3003 3074
             _check_for_misleading_passphrase(
3004 3075
                 ('global',),
3005 3076
                 cast(dict[str, Any], maybe_config.get('global', {})),
3006 3077
                 main_config=user_config,
3078
+                ctx=ctx,
3007 3079
             )
3008 3080
             for key, value in maybe_config['services'].items():
3009 3081
                 _check_for_misleading_passphrase(
3010 3082
                     ('services', key),
3011 3083
                     cast(dict[str, Any], value),
3012 3084
                     main_config=user_config,
3085
+                    ctx=ctx,
3013 3086
                 )
3014 3087
         except AssertionError as exc:
3015 3088
             err(
... ...
@@ -3026,7 +3099,8 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
3026 3099
             logger.warning(
3027 3100
                 _msg.TranslatedString(
3028 3101
                     _msg.WarnMsgTemplate.GLOBAL_PASSPHRASE_INEFFECTIVE,
3029
-                )
3102
+                ),
3103
+                extra={'color': ctx.color},
3030 3104
             )
3031 3105
         for service_name, service_obj in maybe_config['services'].items():
3032 3106
             has_key = _types.js_truthiness(
... ...
@@ -3041,6 +3115,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
3041 3115
                         _msg.WarnMsgTemplate.SERVICE_PASSPHRASE_INEFFECTIVE,
3042 3116
                         service=json.dumps(service_name),
3043 3117
                     ),
3118
+                    extra={'color': ctx.color},
3044 3119
                 )
3045 3120
         if overwrite_config:
3046 3121
             put_config(maybe_config)
... ...
@@ -3146,9 +3221,9 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
3146 3221
         )
3147 3222
         if use_key:
3148 3223
             try:
3149
-                key = base64.standard_b64encode(_select_ssh_key()).decode(
3150
-                    'ASCII'
3151
-                )
3224
+                key = base64.standard_b64encode(
3225
+                    _select_ssh_key(ctx=ctx)
3226
+                ).decode('ASCII')
3152 3227
             except IndexError:
3153 3228
                 err(
3154 3229
                     _msg.TranslatedString(
... ...
@@ -3219,6 +3294,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
3219 3294
                         ('services', service) if service else ('global',),
3220 3295
                         {'phrase': phrase},
3221 3296
                         main_config=user_config,
3297
+                        ctx=ctx,
3222 3298
                     )
3223 3299
                 except AssertionError as exc:
3224 3300
                     err(
... ...
@@ -3234,13 +3310,15 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
3234 3310
                             _msg.TranslatedString(
3235 3311
                                 _msg.WarnMsgTemplate.SERVICE_PASSPHRASE_INEFFECTIVE,
3236 3312
                                 service=json.dumps(service),
3237
-                            )
3313
+                            ),
3314
+                            extra={'color': ctx.color},
3238 3315
                         )
3239 3316
                     else:
3240 3317
                         logger.warning(
3241 3318
                             _msg.TranslatedString(
3242 3319
                                 _msg.WarnMsgTemplate.GLOBAL_PASSPHRASE_INEFFECTIVE
3243
-                            )
3320
+                            ),
3321
+                            extra={'color': ctx.color},
3244 3322
                         )
3245 3323
             if not view.maps[0] and not unset_settings:
3246 3324
                 settings_type = 'service' if service else 'global'
... ...
@@ -3292,6 +3370,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
3292 3370
                         _ORIGIN.INTERACTIVE,
3293 3371
                         {'phrase': phrase},
3294 3372
                         main_config=user_config,
3373
+                        ctx=ctx,
3295 3374
                     )
3296 3375
                 except AssertionError as exc:
3297 3376
                     err(
... ...
@@ -3327,7 +3406,7 @@ def derivepassphrase_vault(  # noqa: C901,PLR0912,PLR0913,PLR0914,PLR0915
3327 3406
                 raise click.UsageError(str(err_msg))
3328 3407
             kwargs.pop('key', '')
3329 3408
             result = vault.Vault(**kwargs).generate(service)
3330
-            click.echo(result.decode('ASCII'))
3409
+            click.echo(result.decode('ASCII'), color=ctx.color)
3331 3410
 
3332 3411
 
3333 3412
 if __name__ == '__main__':
... ...
@@ -1294,7 +1294,7 @@ contents go here
1294 1294
         ):
1295 1295
             custom_error = 'custom error message'
1296 1296
 
1297
-            def raiser() -> None:
1297
+            def raiser(*_args: Any, **_kwargs: Any) -> None:
1298 1298
                 raise RuntimeError(custom_error)
1299 1299
 
1300 1300
             monkeypatch.setattr(cli, '_select_ssh_key', raiser)
1301 1301