Marco Ricci commited on 2024-07-22 13:37:03
Zeige 10 geänderte Dateien mit 141 Einfügungen und 73 Löschungen.
... | ... |
@@ -11,6 +11,8 @@ import collections |
11 | 11 |
import hashlib |
12 | 12 |
import math |
13 | 13 |
import unicodedata |
14 |
+from collections.abc import Callable |
|
15 |
+from typing import TypeAlias |
|
14 | 16 |
|
15 | 17 |
from typing_extensions import assert_type |
16 | 18 |
|
... | ... |
@@ -23,7 +25,8 @@ __version__ = '0.1.2' |
23 | 25 |
|
24 | 26 |
class AmbiguousByteRepresentationError(ValueError): |
25 | 27 |
"""The object has an ambiguous byte representation.""" |
26 |
- def __init__(self): |
|
28 |
+ |
|
29 |
+ def __init__(self) -> None: |
|
27 | 30 |
super().__init__('text string has ambiguous byte representation') |
28 | 31 |
|
29 | 32 |
|
... | ... |
@@ -425,6 +428,8 @@ class Vault: |
425 | 428 |
a passphrase deterministically. |
426 | 429 |
|
427 | 430 |
""" |
431 |
+ TestFunc: TypeAlias = Callable[[bytes | bytearray], bool] |
|
432 |
+ deterministic_signature_types: dict[str, TestFunc] |
|
428 | 433 |
deterministic_signature_types = { |
429 | 434 |
'ssh-ed25519': lambda k: k.startswith( |
430 | 435 |
b'\x00\x00\x00\x0bssh-ed25519' |
... | ... |
@@ -28,6 +28,7 @@ from typing_extensions import ( |
28 | 28 |
|
29 | 29 |
import derivepassphrase as dpp |
30 | 30 |
import ssh_agent_client |
31 |
+import ssh_agent_client.types |
|
31 | 32 |
from derivepassphrase import types as dpp_types |
32 | 33 |
|
33 | 34 |
if TYPE_CHECKING: |
... | ... |
@@ -159,7 +160,7 @@ def _get_suitable_ssh_keys( |
159 | 160 |
|
160 | 161 |
""" |
161 | 162 |
client: ssh_agent_client.SSHAgentClient |
162 |
- client_context: contextlib.AbstractContextManager |
|
163 |
+ client_context: contextlib.AbstractContextManager[Any] |
|
163 | 164 |
match conn: |
164 | 165 |
case ssh_agent_client.SSHAgentClient(): |
165 | 166 |
client = conn |
... | ... |
@@ -324,8 +325,15 @@ def _prompt_for_passphrase() -> str: |
324 | 325 |
The user input. |
325 | 326 |
|
326 | 327 |
""" |
327 |
- return click.prompt( |
|
328 |
- 'Passphrase', default='', hide_input=True, show_default=False, err=True |
|
328 |
+ return cast( |
|
329 |
+ str, |
|
330 |
+ click.prompt( |
|
331 |
+ 'Passphrase', |
|
332 |
+ default='', |
|
333 |
+ hide_input=True, |
|
334 |
+ show_default=False, |
|
335 |
+ err=True, |
|
336 |
+ ), |
|
329 | 337 |
) |
330 | 338 |
|
331 | 339 |
|
... | ... |
@@ -349,8 +357,8 @@ class OptionGroupOption(click.Option): |
349 | 357 |
option_group_name: str = '' |
350 | 358 |
epilog: str = '' |
351 | 359 |
|
352 |
- def __init__(self, *args, **kwargs): # type: ignore |
|
353 |
- if self.__class__ == __class__: |
|
360 |
+ def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
361 |
+ if self.__class__ == __class__: # type: ignore[name-defined] |
|
354 | 362 |
raise NotImplementedError |
355 | 363 |
super().__init__(*args, **kwargs) |
356 | 364 |
|
... | ... |
@@ -809,7 +817,7 @@ def derivepassphrase( |
809 | 817 |
for name in param.opts + param.secondary_opts: |
810 | 818 |
params_by_str[name] = param |
811 | 819 |
|
812 |
- def is_param_set(param: click.Parameter): |
|
820 |
+ def is_param_set(param: click.Parameter) -> bool: |
|
813 | 821 |
return bool(ctx.params.get(param.human_readable_name)) |
814 | 822 |
|
815 | 823 |
def check_incompatible_options( |
... | ... |
@@ -89,7 +89,7 @@ class Sequin: |
89 | 89 |
""" |
90 | 90 |
msg = 'sequence item out of range' |
91 | 91 |
|
92 |
- def uint8_to_bits(value): |
|
92 |
+ def uint8_to_bits(value: int) -> Iterator[int]: |
|
93 | 93 |
"""Yield individual bits of an 8-bit number, MSB first.""" |
94 | 94 |
for i in (0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01): |
95 | 95 |
yield 1 if value | i == value else 0 |
... | ... |
@@ -388,5 +388,5 @@ class SequinExhaustedError(Exception): |
388 | 388 |
|
389 | 389 |
""" |
390 | 390 |
|
391 |
- def __init__(self): |
|
391 |
+ def __init__(self) -> None: |
|
392 | 392 |
super().__init__('Sequin is exhausted') |
... | ... |
@@ -34,7 +34,7 @@ _socket = socket |
34 | 34 |
class TrailingDataError(RuntimeError): |
35 | 35 |
"""The result contained trailing data.""" |
36 | 36 |
|
37 |
- def __init__(self): |
|
37 |
+ def __init__(self) -> None: |
|
38 | 38 |
super().__init__('Overlong response from SSH agent') |
39 | 39 |
|
40 | 40 |
|
... | ... |
@@ -274,7 +274,7 @@ class SSHAgentClient: |
274 | 274 |
raise EOFError(msg) |
275 | 275 |
return response[0], response[1:] |
276 | 276 |
|
277 |
- def list_keys(self) -> Sequence[ssh_types.KeyCommentPair]: |
|
277 |
+ def list_keys(self) -> Sequence[types.KeyCommentPair]: |
|
278 | 278 |
"""Request a list of keys known to the SSH agent. |
279 | 279 |
|
280 | 280 |
Returns: |
... | ... |
@@ -290,9 +290,9 @@ class SSHAgentClient: |
290 | 290 |
|
291 | 291 |
""" |
292 | 292 |
response_code, response = self.request( |
293 |
- ssh_types.SSH_AGENTC.REQUEST_IDENTITIES.value, b'' |
|
293 |
+ types.SSH_AGENTC.REQUEST_IDENTITIES.value, b'' |
|
294 | 294 |
) |
295 |
- if response_code != ssh_types.SSH_AGENT.IDENTITIES_ANSWER.value: |
|
295 |
+ if response_code != types.SSH_AGENT.IDENTITIES_ANSWER.value: |
|
296 | 296 |
msg = ( |
297 | 297 |
f'error return from SSH agent: ' |
298 | 298 |
f'{response_code = }, {response = }' |
... | ... |
@@ -313,7 +313,7 @@ class SSHAgentClient: |
313 | 313 |
return bytes(buf) |
314 | 314 |
|
315 | 315 |
key_count = int.from_bytes(shift(4), 'big') |
316 |
- keys: collections.deque[ssh_types.KeyCommentPair] |
|
316 |
+ keys: collections.deque[types.KeyCommentPair] |
|
317 | 317 |
keys = collections.deque() |
318 | 318 |
for _ in range(key_count): |
319 | 319 |
key_size = int.from_bytes(shift(4), 'big') |
... | ... |
@@ -321,7 +321,7 @@ class SSHAgentClient: |
321 | 321 |
comment_size = int.from_bytes(shift(4), 'big') |
322 | 322 |
comment = shift(comment_size) |
323 | 323 |
# Both `key` and `comment` are not wrapped as SSH strings. |
324 |
- keys.append(ssh_types.KeyCommentPair(key, comment)) |
|
324 |
+ keys.append(types.KeyCommentPair(key, comment)) |
|
325 | 325 |
if response_stream: |
326 | 326 |
raise TrailingDataError |
327 | 327 |
return keys |
... | ... |
@@ -379,9 +379,9 @@ class SSHAgentClient: |
379 | 379 |
request_data.extend(self.string(payload)) |
380 | 380 |
request_data.extend(self.uint32(flags)) |
381 | 381 |
response_code, response = self.request( |
382 |
- ssh_types.SSH_AGENTC.SIGN_REQUEST.value, request_data |
|
382 |
+ types.SSH_AGENTC.SIGN_REQUEST.value, request_data |
|
383 | 383 |
) |
384 |
- if response_code != ssh_types.SSH_AGENT.SIGN_RESPONSE.value: |
|
384 |
+ if response_code != types.SSH_AGENT.SIGN_RESPONSE.value: |
|
385 | 385 |
msg = f'signing data failed: {response_code = }, {response = }' |
386 | 386 |
raise RuntimeError(msg) |
387 | 387 |
return self.unstring(response) |
... | ... |
@@ -402,7 +402,7 @@ def isolated_config( |
402 | 402 |
monkeypatch: Any, |
403 | 403 |
runner: click.testing.CliRunner, |
404 | 404 |
config: Any, |
405 |
-): |
|
405 |
+) -> Iterator[None]: |
|
406 | 406 |
prog_name = derivepassphrase.cli.PROG_NAME |
407 | 407 |
env_name = prog_name.replace(' ', '_').upper() + '_PATH' |
408 | 408 |
with runner.isolated_filesystem(): |
... | ... |
@@ -28,31 +28,33 @@ class TestVault: |
28 | 28 |
('twitter', twitter_phrase), |
29 | 29 |
], |
30 | 30 |
) |
31 |
- def test_200_basic_configuration(self, service, expected): |
|
31 |
+ def test_200_basic_configuration( |
|
32 |
+ self, service: bytes | str, expected: bytes |
|
33 |
+ ) -> None: |
|
32 | 34 |
assert Vault(phrase=self.phrase).generate(service) == expected |
33 | 35 |
|
34 |
- def test_201_phrase_dependence(self): |
|
36 |
+ def test_201_phrase_dependence(self) -> None: |
|
35 | 37 |
assert ( |
36 | 38 |
Vault(phrase=(self.phrase + b'X')).generate('google') |
37 | 39 |
== b'n+oIz6sL>K*lTEWYRO%7' |
38 | 40 |
) |
39 | 41 |
|
40 |
- def test_202_reproducibility_and_bytes_service_name(self): |
|
42 |
+ def test_202_reproducibility_and_bytes_service_name(self) -> None: |
|
41 | 43 |
assert Vault(phrase=self.phrase).generate(b'google') == Vault( |
42 | 44 |
phrase=self.phrase |
43 | 45 |
).generate('google') |
44 | 46 |
|
45 |
- def test_203_reproducibility_and_bytearray_service_name(self): |
|
47 |
+ def test_203_reproducibility_and_bytearray_service_name(self) -> None: |
|
46 | 48 |
assert Vault(phrase=self.phrase).generate(b'google') == Vault( |
47 | 49 |
phrase=self.phrase |
48 | 50 |
).generate(bytearray(b'google')) |
49 | 51 |
|
50 |
- def test_210_nonstandard_length(self): |
|
52 |
+ def test_210_nonstandard_length(self) -> None: |
|
51 | 53 |
assert ( |
52 | 54 |
Vault(phrase=self.phrase, length=4).generate('google') == b'xDFu' |
53 | 55 |
) |
54 | 56 |
|
55 |
- def test_211_repetition_limit(self): |
|
57 |
+ def test_211_repetition_limit(self) -> None: |
|
56 | 58 |
assert ( |
57 | 59 |
Vault( |
58 | 60 |
phrase=b'', length=24, symbol=0, number=0, repeat=1 |
... | ... |
@@ -60,37 +62,37 @@ class TestVault: |
60 | 62 |
== b'IVTDzACftqopUXqDHPkuCIhV' |
61 | 63 |
) |
62 | 64 |
|
63 |
- def test_212_without_symbols(self): |
|
65 |
+ def test_212_without_symbols(self) -> None: |
|
64 | 66 |
assert ( |
65 | 67 |
Vault(phrase=self.phrase, symbol=0).generate('google') |
66 | 68 |
== b'XZ4wRe0bZCazbljCaMqR' |
67 | 69 |
) |
68 | 70 |
|
69 |
- def test_213_no_numbers(self): |
|
71 |
+ def test_213_no_numbers(self) -> None: |
|
70 | 72 |
assert ( |
71 | 73 |
Vault(phrase=self.phrase, number=0).generate('google') |
72 | 74 |
== b'_*$TVH.%^aZl(LUeOT?>' |
73 | 75 |
) |
74 | 76 |
|
75 |
- def test_214_no_lowercase_letters(self): |
|
77 |
+ def test_214_no_lowercase_letters(self) -> None: |
|
76 | 78 |
assert ( |
77 | 79 |
Vault(phrase=self.phrase, lower=0).generate('google') |
78 | 80 |
== b':{?)+7~@OA:L]!0E$)(+' |
79 | 81 |
) |
80 | 82 |
|
81 |
- def test_215_at_least_5_digits(self): |
|
83 |
+ def test_215_at_least_5_digits(self) -> None: |
|
82 | 84 |
assert ( |
83 | 85 |
Vault(phrase=self.phrase, length=8, number=5).generate('songkick') |
84 | 86 |
== b'i0908.7[' |
85 | 87 |
) |
86 | 88 |
|
87 |
- def test_216_lots_of_spaces(self): |
|
89 |
+ def test_216_lots_of_spaces(self) -> None: |
|
88 | 90 |
assert ( |
89 | 91 |
Vault(phrase=self.phrase, space=12).generate('songkick') |
90 | 92 |
== b' c 6 Bq % 5fR ' |
91 | 93 |
) |
92 | 94 |
|
93 |
- def test_217_all_character_classes(self): |
|
95 |
+ def test_217_all_character_classes(self) -> None: |
|
94 | 96 |
assert ( |
95 | 97 |
Vault( |
96 | 98 |
phrase=self.phrase, |
... | ... |
@@ -104,7 +106,7 @@ class TestVault: |
104 | 106 |
== b': : fv_wqt>a-4w1S R' |
105 | 107 |
) |
106 | 108 |
|
107 |
- def test_218_only_numbers_and_very_high_repetition_limit(self): |
|
109 |
+ def test_218_only_numbers_and_very_high_repetition_limit(self) -> None: |
|
108 | 110 |
generated = Vault( |
109 | 111 |
phrase=b'', |
110 | 112 |
length=40, |
... | ... |
@@ -130,13 +132,13 @@ class TestVault: |
130 | 132 |
for substring in forbidden_substrings: |
131 | 133 |
assert substring not in generated |
132 | 134 |
|
133 |
- def test_219_very_limited_character_set(self): |
|
135 |
+ def test_219_very_limited_character_set(self) -> None: |
|
134 | 136 |
generated = Vault( |
135 | 137 |
phrase=b'', length=24, lower=0, upper=0, space=0, symbol=0 |
136 | 138 |
).generate('testing') |
137 | 139 |
assert generated == b'763252593304946694588866' |
138 | 140 |
|
139 |
- def test_220_character_set_subtraction(self): |
|
141 |
+ def test_220_character_set_subtraction(self) -> None: |
|
140 | 142 |
assert Vault._subtract(b'be', b'abcdef') == bytearray(b'acdf') |
141 | 143 |
|
142 | 144 |
@pytest.mark.parametrize( |
... | ... |
@@ -230,13 +232,13 @@ class TestVault: |
230 | 232 |
assert binstr(s) == bytes(s) |
231 | 233 |
assert binstr(binstr(s)) == bytes(s) |
232 | 234 |
|
233 |
- def test_310_too_many_symbols(self): |
|
235 |
+ def test_310_too_many_symbols(self) -> None: |
|
234 | 236 |
with pytest.raises( |
235 | 237 |
ValueError, match='requested passphrase length too short' |
236 | 238 |
): |
237 | 239 |
Vault(phrase=self.phrase, symbol=100) |
238 | 240 |
|
239 |
- def test_311_no_viable_characters(self): |
|
241 |
+ def test_311_no_viable_characters(self) -> None: |
|
240 | 242 |
with pytest.raises(ValueError, match='no allowed characters left'): |
241 | 243 |
Vault( |
242 | 244 |
phrase=self.phrase, |
... | ... |
@@ -248,7 +250,7 @@ class TestVault: |
248 | 250 |
symbol=0, |
249 | 251 |
) |
250 | 252 |
|
251 |
- def test_320_character_set_subtraction_duplicate(self): |
|
253 |
+ def test_320_character_set_subtraction_duplicate(self) -> None: |
|
252 | 254 |
with pytest.raises(ValueError, match='duplicate characters'): |
253 | 255 |
Vault._subtract(b'abcdef', b'aabbccddeeff') |
254 | 256 |
with pytest.raises(ValueError, match='duplicate characters'): |
... | ... |
@@ -8,7 +8,7 @@ import contextlib |
8 | 8 |
import json |
9 | 9 |
import os |
10 | 10 |
import socket |
11 |
-from typing import TYPE_CHECKING, cast |
|
11 |
+from typing import TYPE_CHECKING |
|
12 | 12 |
|
13 | 13 |
import click.testing |
14 | 14 |
import pytest |
... | ... |
@@ -196,7 +196,7 @@ for opt, config in SINGLES.items(): |
196 | 196 |
|
197 | 197 |
|
198 | 198 |
class TestCLI: |
199 |
- def test_200_help_output(self): |
|
199 |
+ def test_200_help_output(self) -> None: |
|
200 | 200 |
runner = click.testing.CliRunner(mix_stderr=False) |
201 | 201 |
result = runner.invoke( |
202 | 202 |
cli.derivepassphrase, ['--help'], catch_exceptions=False |
... | ... |
@@ -390,11 +390,10 @@ class TestCLI: |
390 | 390 |
) |
391 | 391 |
def test_210_invalid_argument_range(self, option: str) -> None: |
392 | 392 |
runner = click.testing.CliRunner(mix_stderr=False) |
393 |
- value: str | int |
|
394 | 393 |
for value in '-42', 'invalid': |
395 | 394 |
result = runner.invoke( |
396 | 395 |
cli.derivepassphrase, |
397 |
- [option, cast(str, value), '-p', DUMMY_SERVICE], |
|
396 |
+ [option, value, '-p', DUMMY_SERVICE], |
|
398 | 397 |
input=DUMMY_PASSPHRASE, |
399 | 398 |
catch_exceptions=False, |
400 | 399 |
) |
... | ... |
@@ -872,7 +871,7 @@ contents go here |
872 | 871 |
): |
873 | 872 |
custom_error = 'custom error message' |
874 | 873 |
|
875 |
- def raiser(): |
|
874 |
+ def raiser() -> None: |
|
876 | 875 |
raise RuntimeError(custom_error) |
877 | 876 |
|
878 | 877 |
monkeypatch.setattr(cli, '_select_ssh_key', raiser) |
... | ... |
@@ -925,7 +924,7 @@ class TestCLIUtils: |
925 | 924 |
@click.command() |
926 | 925 |
@click.option('--heading', default='Our menu:') |
927 | 926 |
@click.argument('items', nargs=-1) |
928 |
- def driver(heading, items): |
|
927 |
+ def driver(heading: str, items: list[str]) -> None: |
|
929 | 928 |
# from https://montypython.fandom.com/wiki/Spam#The_menu |
930 | 929 |
items = items or [ |
931 | 930 |
'Egg and bacon', |
... | ... |
@@ -1001,7 +1000,7 @@ Your selection? (1-10, leave empty to abort):\x20 |
1001 | 1000 |
@click.command() |
1002 | 1001 |
@click.option('--item', default='baked beans') |
1003 | 1002 |
@click.argument('prompt') |
1004 |
- def driver(item, prompt): |
|
1003 |
+ def driver(item: str, prompt: str) -> None: |
|
1005 | 1004 |
try: |
1006 | 1005 |
cli._prompt_for_selection( |
1007 | 1006 |
[item], heading='', single_choice_prompt=prompt |
... | ... |
@@ -27,7 +27,9 @@ class TestStaticFunctionality: |
27 | 27 |
([1, 7, 5, 5], 8, 0o1755), |
28 | 28 |
], |
29 | 29 |
) |
30 |
- def test_200_big_endian_number(self, sequence, base, expected): |
|
30 |
+ def test_200_big_endian_number( |
|
31 |
+ self, sequence: list[int], base: int, expected: int |
|
32 |
+ ) -> None: |
|
31 | 33 |
assert ( |
32 | 34 |
sequin.Sequin._big_endian_number(sequence, base=base) |
33 | 35 |
) == expected |
... | ... |
@@ -41,8 +43,12 @@ class TestStaticFunctionality: |
41 | 43 |
], |
42 | 44 |
) |
43 | 45 |
def test_300_big_endian_number_exceptions( |
44 |
- self, exc_type, exc_pattern, sequence, base |
|
45 |
- ): |
|
46 |
+ self, |
|
47 |
+ exc_type: type[Exception], |
|
48 |
+ exc_pattern: str, |
|
49 |
+ sequence: list[int], |
|
50 |
+ base: int, |
|
51 |
+ ) -> None: |
|
46 | 52 |
with pytest.raises(exc_type, match=exc_pattern): |
47 | 53 |
sequin.Sequin._big_endian_number(sequence, base=base) |
48 | 54 |
|
... | ... |
@@ -61,11 +67,16 @@ class TestSequin: |
61 | 67 |
('OK', False, bitseq('0100111101001011')), |
62 | 68 |
], |
63 | 69 |
) |
64 |
- def test_200_constructor(self, sequence, is_bitstring, expected): |
|
70 |
+ def test_200_constructor( |
|
71 |
+ self, |
|
72 |
+ sequence: str | bytes | bytearray | list[int], |
|
73 |
+ is_bitstring: bool, |
|
74 |
+ expected: list[int], |
|
75 |
+ ) -> None: |
|
65 | 76 |
seq = sequin.Sequin(sequence, is_bitstring=is_bitstring) |
66 | 77 |
assert seq.bases == {2: collections.deque(expected)} |
67 | 78 |
|
68 |
- def test_201_generating(self): |
|
79 |
+ def test_201_generating(self) -> None: |
|
69 | 80 |
seq = sequin.Sequin( |
70 | 81 |
[1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1], is_bitstring=True |
71 | 82 |
) |
... | ... |
@@ -83,7 +94,7 @@ class TestSequin: |
83 | 94 |
with pytest.raises(ValueError, match='invalid target range'): |
84 | 95 |
seq.generate(0) |
85 | 96 |
|
86 |
- def test_210_internal_generating(self): |
|
97 |
+ def test_210_internal_generating(self) -> None: |
|
87 | 98 |
seq = sequin.Sequin( |
88 | 99 |
[1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1], is_bitstring=True |
89 | 100 |
) |
... | ... |
@@ -101,7 +112,7 @@ class TestSequin: |
101 | 112 |
with pytest.raises(ValueError, match='invalid base:'): |
102 | 113 |
seq._generate_inner(16, base=1) |
103 | 114 |
|
104 |
- def test_211_shifting(self): |
|
115 |
+ def test_211_shifting(self) -> None: |
|
105 | 116 |
seq = sequin.Sequin([1, 0, 1, 0, 0, 1, 0, 0, 0, 1], is_bitstring=True) |
106 | 117 |
assert seq.bases == { |
107 | 118 |
2: collections.deque([1, 0, 1, 0, 0, 1, 0, 0, 0, 1]) |
... | ... |
@@ -130,7 +141,11 @@ class TestSequin: |
130 | 141 |
], |
131 | 142 |
) |
132 | 143 |
def test_300_constructor_exceptions( |
133 |
- self, sequence, is_bitstring, exc_type, exc_pattern |
|
134 |
- ): |
|
144 |
+ self, |
|
145 |
+ sequence: list[int] | str, |
|
146 |
+ is_bitstring: bool, |
|
147 |
+ exc_type: type[Exception], |
|
148 |
+ exc_pattern: str, |
|
149 |
+ ) -> None: |
|
135 | 150 |
with pytest.raises(exc_type, match=exc_pattern): |
136 | 151 |
sequin.Sequin(sequence, is_bitstring=is_bitstring) |
... | ... |
@@ -11,6 +11,7 @@ import io |
11 | 11 |
import os |
12 | 12 |
import socket |
13 | 13 |
import subprocess |
14 |
+from typing import TYPE_CHECKING |
|
14 | 15 |
|
15 | 16 |
import click |
16 | 17 |
import click.testing |
... | ... |
@@ -22,6 +23,9 @@ import derivepassphrase.cli |
22 | 23 |
import ssh_agent_client |
23 | 24 |
import tests |
24 | 25 |
|
26 |
+if TYPE_CHECKING: |
|
27 |
+ from collections.abc import Iterator |
|
28 |
+ |
|
25 | 29 |
|
26 | 30 |
class TestStaticFunctionality: |
27 | 31 |
@pytest.mark.parametrize( |
... | ... |
@@ -31,13 +35,15 @@ class TestStaticFunctionality: |
31 | 35 |
for val in tests.SUPPORTED_KEYS.values() |
32 | 36 |
], |
33 | 37 |
) |
34 |
- def test_100_key_decoding(self, public_key, public_key_data): |
|
38 |
+ def test_100_key_decoding( |
|
39 |
+ self, public_key: bytes, public_key_data: bytes |
|
40 |
+ ) -> None: |
|
35 | 41 |
keydata = base64.b64decode(public_key.split(None, 2)[1]) |
36 | 42 |
assert ( |
37 | 43 |
keydata == public_key_data |
38 | 44 |
), "recorded public key data doesn't match" |
39 | 45 |
|
40 |
- def test_200_constructor_no_running_agent(self, monkeypatch): |
|
46 |
+ def test_200_constructor_no_running_agent(self, monkeypatch: Any) -> None: |
|
41 | 47 |
monkeypatch.delenv('SSH_AUTH_SOCK', raising=False) |
42 | 48 |
sock = socket.socket(family=socket.AF_UNIX) |
43 | 49 |
with pytest.raises( |
... | ... |
@@ -51,7 +57,7 @@ class TestStaticFunctionality: |
51 | 57 |
(16777216, b'\x01\x00\x00\x00'), |
52 | 58 |
], |
53 | 59 |
) |
54 |
- def test_210_uint32(self, input, expected): |
|
60 |
+ def test_210_uint32(self, input: int, expected: bytes | bytearray) -> None: |
|
55 | 61 |
uint32 = ssh_agent_client.SSHAgentClient.uint32 |
56 | 62 |
assert uint32(input) == expected |
57 | 63 |
|
... | ... |
@@ -66,7 +72,9 @@ class TestStaticFunctionality: |
66 | 72 |
), |
67 | 73 |
], |
68 | 74 |
) |
69 |
- def test_211_string(self, input, expected): |
|
75 |
+ def test_211_string( |
|
76 |
+ self, input: bytes | bytearray, expected: bytes | bytearray |
|
77 |
+ ) -> None: |
|
70 | 78 |
string = ssh_agent_client.SSHAgentClient.string |
71 | 79 |
assert bytes(string(input)) == expected |
72 | 80 |
|
... | ... |
@@ -80,7 +88,9 @@ class TestStaticFunctionality: |
80 | 88 |
), |
81 | 89 |
], |
82 | 90 |
) |
83 |
- def test_212_unstring(self, input, expected): |
|
91 |
+ def test_212_unstring( |
|
92 |
+ self, input: bytes | bytearray, expected: bytes | bytearray |
|
93 |
+ ) -> None: |
|
84 | 94 |
unstring = ssh_agent_client.SSHAgentClient.unstring |
85 | 95 |
unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix |
86 | 96 |
assert bytes(unstring(input)) == expected |
... | ... |
@@ -96,7 +106,9 @@ class TestStaticFunctionality: |
96 | 106 |
(-1, OverflowError, "can't convert negative int to unsigned"), |
97 | 107 |
], |
98 | 108 |
) |
99 |
- def test_310_uint32_exceptions(self, value, exc_type, exc_pattern): |
|
109 |
+ def test_310_uint32_exceptions( |
|
110 |
+ self, value: int, exc_type: type[Exception], exc_pattern: str |
|
111 |
+ ) -> None: |
|
100 | 112 |
uint32 = ssh_agent_client.SSHAgentClient.uint32 |
101 | 113 |
with pytest.raises(exc_type, match=exc_pattern): |
102 | 114 |
uint32(value) |
... | ... |
@@ -107,7 +119,9 @@ class TestStaticFunctionality: |
107 | 119 |
('some string', TypeError, 'invalid payload type'), |
108 | 120 |
], |
109 | 121 |
) |
110 |
- def test_311_string_exceptions(self, input, exc_type, exc_pattern): |
|
122 |
+ def test_311_string_exceptions( |
|
123 |
+ self, input: Any, exc_type: type[Exception], exc_pattern: str |
|
124 |
+ ) -> None: |
|
111 | 125 |
string = ssh_agent_client.SSHAgentClient.string |
112 | 126 |
with pytest.raises(exc_type, match=exc_pattern): |
113 | 127 |
string(input) |
... | ... |
@@ -133,8 +147,13 @@ class TestStaticFunctionality: |
133 | 147 |
], |
134 | 148 |
) |
135 | 149 |
def test_312_unstring_exceptions( |
136 |
- self, input, exc_type, exc_pattern, has_trailer, parts |
|
137 |
- ): |
|
150 |
+ self, |
|
151 |
+ input: bytes | bytearray, |
|
152 |
+ exc_type: type[Exception], |
|
153 |
+ exc_pattern: str, |
|
154 |
+ has_trailer: bool, |
|
155 |
+ parts: tuple[bytes | bytearray, bytes | bytearray] | None, |
|
156 |
+ ) -> None: |
|
138 | 157 |
unstring = ssh_agent_client.SSHAgentClient.unstring |
139 | 158 |
unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix |
140 | 159 |
with pytest.raises(exc_type, match=exc_pattern): |
... | ... |
@@ -151,7 +170,9 @@ class TestAgentInteraction: |
151 | 170 |
@pytest.mark.parametrize( |
152 | 171 |
['keytype', 'data_dict'], list(tests.SUPPORTED_KEYS.items()) |
153 | 172 |
) |
154 |
- def test_200_sign_data_via_agent(self, keytype, data_dict): |
|
173 |
+ def test_200_sign_data_via_agent( |
|
174 |
+ self, keytype: str, data_dict: tests.SSHTestKey |
|
175 |
+ ) -> None: |
|
155 | 176 |
del keytype # Unused. |
156 | 177 |
private_key = data_dict['private_key'] |
157 | 178 |
try: |
... | ... |
@@ -200,7 +221,9 @@ class TestAgentInteraction: |
200 | 221 |
@pytest.mark.parametrize( |
201 | 222 |
['keytype', 'data_dict'], list(tests.UNSUITABLE_KEYS.items()) |
202 | 223 |
) |
203 |
- def test_201_sign_data_via_agent_unsupported(self, keytype, data_dict): |
|
224 |
+ def test_201_sign_data_via_agent_unsupported( |
|
225 |
+ self, keytype: str, data_dict: tests.SSHTestKey |
|
226 |
+ ) -> None: |
|
204 | 227 |
del keytype # Unused. |
205 | 228 |
private_key = data_dict['private_key'] |
206 | 229 |
try: |
... | ... |
@@ -243,7 +266,7 @@ class TestAgentInteraction: |
243 | 266 |
derivepassphrase.Vault.phrase_from_key(public_key_data) |
244 | 267 |
|
245 | 268 |
@staticmethod |
246 |
- def _params(): |
|
269 |
+ def _params() -> Iterator[tuple[bytes, bool]]: |
|
247 | 270 |
for value in tests.SUPPORTED_KEYS.values(): |
248 | 271 |
key = value['public_key_data'] |
249 | 272 |
yield (key, False) |
... | ... |
@@ -254,8 +277,10 @@ class TestAgentInteraction: |
254 | 277 |
yield (key, True) |
255 | 278 |
|
256 | 279 |
@pytest.mark.parametrize(['key', 'single'], list(_params())) |
257 |
- def test_210_ssh_key_selector(self, monkeypatch, key, single): |
|
258 |
- def key_is_suitable(key: bytes): |
|
280 |
+ def test_210_ssh_key_selector( |
|
281 |
+ self, monkeypatch: Any, key: bytes, single: bool |
|
282 |
+ ) -> None: |
|
283 |
+ def key_is_suitable(key: bytes) -> bool: |
|
259 | 284 |
return key in { |
260 | 285 |
v['public_key_data'] for v in tests.SUPPORTED_KEYS.values() |
261 | 286 |
} |
... | ... |
@@ -288,7 +313,7 @@ class TestAgentInteraction: |
288 | 313 |
b64_key = base64.standard_b64encode(key).decode('ASCII') |
289 | 314 |
|
290 | 315 |
@click.command() |
291 |
- def driver(): |
|
316 |
+ def driver() -> None: |
|
292 | 317 |
key = derivepassphrase.cli._select_ssh_key() |
293 | 318 |
click.echo(base64.standard_b64encode(key).decode('ASCII')) |
294 | 319 |
|
... | ... |
@@ -310,7 +335,7 @@ class TestAgentInteraction: |
310 | 335 |
|
311 | 336 |
del _params |
312 | 337 |
|
313 |
- def test_300_constructor_bad_running_agent(self, monkeypatch): |
|
338 |
+ def test_300_constructor_bad_running_agent(self, monkeypatch: Any) -> None: |
|
314 | 339 |
monkeypatch.setenv('SSH_AUTH_SOCK', os.environ['SSH_AUTH_SOCK'] + '~') |
315 | 340 |
sock = socket.socket(family=socket.AF_UNIX) |
316 | 341 |
with pytest.raises(OSError): # noqa: PT011 |
... | ... |
@@ -323,7 +348,9 @@ class TestAgentInteraction: |
323 | 348 |
b'\x00\x00\x00\x1f some bytes missing', |
324 | 349 |
], |
325 | 350 |
) |
326 |
- def test_310_truncated_server_response(self, monkeypatch, response): |
|
351 |
+ def test_310_truncated_server_response( |
|
352 |
+ self, monkeypatch: Any, response: bytes |
|
353 |
+ ) -> None: |
|
327 | 354 |
client = ssh_agent_client.SSHAgentClient() |
328 | 355 |
response_stream = io.BytesIO(response) |
329 | 356 |
|
... | ... |
@@ -354,8 +381,13 @@ class TestAgentInteraction: |
354 | 381 |
], |
355 | 382 |
) |
356 | 383 |
def test_320_list_keys_error_responses( |
357 |
- self, monkeypatch, response_code, response, exc_type, exc_pattern |
|
358 |
- ): |
|
384 |
+ self, |
|
385 |
+ monkeypatch: Any, |
|
386 |
+ response_code: int, |
|
387 |
+ response: bytes | bytearray, |
|
388 |
+ exc_type: type[Exception], |
|
389 |
+ exc_pattern: str, |
|
390 |
+ ) -> None: |
|
359 | 391 |
client = ssh_agent_client.SSHAgentClient() |
360 | 392 |
monkeypatch.setattr( |
361 | 393 |
client, |
... | ... |
@@ -386,8 +418,14 @@ class TestAgentInteraction: |
386 | 418 |
], |
387 | 419 |
) |
388 | 420 |
def test_330_sign_error_responses( |
389 |
- self, monkeypatch, key, check, response, exc_type, exc_pattern |
|
390 |
- ): |
|
421 |
+ self, |
|
422 |
+ monkeypatch: Any, |
|
423 |
+ key: bytes | bytearray, |
|
424 |
+ check: bool, |
|
425 |
+ response: tuple[int, bytes | bytearray], |
|
426 |
+ exc_type: type[Exception], |
|
427 |
+ exc_pattern: str, |
|
428 |
+ ) -> None: |
|
391 | 429 |
client = ssh_agent_client.SSHAgentClient() |
392 | 430 |
monkeypatch.setattr(client, 'request', lambda a, b: response) # noqa: ARG005 |
393 | 431 |
KeyCommentPair = ssh_agent_client.types.KeyCommentPair # noqa: N806 |