test_crypto.py
1 """Tests for crypto utility functions.""" 2 from restai.utils.crypto import ( 3 LLM_SENSITIVE_KEYS, 4 PROJECT_SENSITIVE_KEYS, 5 SYNC_SOURCE_SENSITIVE_KEYS, 6 decrypt_field, 7 decrypt_sensitive_options, 8 encrypt_field, 9 encrypt_sensitive_options, 10 generate_recovery_codes, 11 hash_api_key, 12 hash_recovery_code, 13 verify_api_key_hash, 14 verify_recovery_code, 15 ) 16 17 18 def test_encrypt_decrypt_field_round_trip(): 19 plaintext = "my-secret-value" 20 encrypted = encrypt_field(plaintext) 21 assert encrypted != plaintext 22 assert encrypted.startswith("$ENC$") 23 decrypted = decrypt_field(encrypted) 24 assert decrypted == plaintext 25 26 27 def test_decrypt_field_on_plaintext_returns_as_is(): 28 """Backward compatibility: plaintext without $ENC$ prefix is returned unchanged.""" 29 raw = "legacy-plaintext-key" 30 assert decrypt_field(raw) == raw 31 32 33 def test_encrypt_field_is_idempotent(): 34 """Calling encrypt_field twice should not double-encrypt.""" 35 plaintext = "idempotent-test" 36 once = encrypt_field(plaintext) 37 twice = encrypt_field(once) 38 assert once == twice 39 assert decrypt_field(twice) == plaintext 40 41 42 def test_encrypt_decrypt_sensitive_options_with_project_keys(): 43 opts = { 44 "telegram_token": "tok_123", 45 "connection": "postgresql://user:pass@host/db", 46 "unrelated_key": "should-stay-plain", 47 } 48 encrypted = encrypt_sensitive_options(opts, PROJECT_SENSITIVE_KEYS) 49 50 assert encrypted["telegram_token"].startswith("$ENC$") 51 assert encrypted["connection"].startswith("$ENC$") 52 assert encrypted["unrelated_key"] == "should-stay-plain" 53 54 decrypted = decrypt_sensitive_options(encrypted, PROJECT_SENSITIVE_KEYS) 55 assert decrypted["telegram_token"] == "tok_123" 56 assert decrypted["connection"] == "postgresql://user:pass@host/db" 57 assert decrypted["unrelated_key"] == "should-stay-plain" 58 59 60 def test_encrypt_sensitive_options_with_nested_sync_sources(): 61 opts = { 62 "sync_sources": [ 63 { 64 "s3_secret_key": "secret123", 65 "confluence_api_token": "conf-tok", 66 "bucket": "my-bucket", 67 } 68 ] 69 } 70 encrypted = encrypt_sensitive_options(opts, PROJECT_SENSITIVE_KEYS) 71 src = encrypted["sync_sources"][0] 72 assert src["s3_secret_key"].startswith("$ENC$") 73 assert src["confluence_api_token"].startswith("$ENC$") 74 assert src["bucket"] == "my-bucket" 75 76 decrypted = decrypt_sensitive_options(encrypted, PROJECT_SENSITIVE_KEYS) 77 dsrc = decrypted["sync_sources"][0] 78 assert dsrc["s3_secret_key"] == "secret123" 79 assert dsrc["confluence_api_token"] == "conf-tok" 80 81 82 def test_llm_sensitive_keys_contains_expected(): 83 expected = {"api_key", "key", "password", "secret"} 84 for k in expected: 85 assert k in LLM_SENSITIVE_KEYS, f"{k} missing from LLM_SENSITIVE_KEYS" 86 87 88 def test_hash_api_key_salted(): 89 """Each call produces a different hash (random salt), but verify works.""" 90 key = "sk-test-1234567890" 91 h1 = hash_api_key(key) 92 h2 = hash_api_key(key) 93 assert h1 != h2, "Salted hashes should differ" 94 assert h1.startswith("$pbkdf2$") 95 assert verify_api_key_hash(key, h1) 96 assert verify_api_key_hash(key, h2) 97 assert not verify_api_key_hash("wrong-key", h1) 98 99 100 def test_hash_api_key_legacy_sha256_fallback(): 101 """Verify works with legacy unsalted SHA256 hashes.""" 102 import hashlib 103 key = "legacy-key" 104 legacy_hash = hashlib.sha256(key.encode()).hexdigest() 105 assert verify_api_key_hash(key, legacy_hash) 106 assert not verify_api_key_hash("wrong", legacy_hash) 107 108 109 def test_generate_recovery_codes_count_and_uniqueness(): 110 codes = generate_recovery_codes(count=10) 111 assert len(codes) == 10 112 assert len(set(codes)) == 10, "Recovery codes should be unique" 113 for code in codes: 114 assert len(code) == 8 115 assert code.isalnum() 116 117 118 def test_hash_recovery_code_salted(): 119 """Each call produces a different hash, but verify works.""" 120 code = "abcd1234" 121 h1 = hash_recovery_code(code) 122 h2 = hash_recovery_code(code) 123 assert h1 != h2, "Salted hashes should differ" 124 assert h1.startswith("$pbkdf2$") 125 assert verify_recovery_code(code, h1) 126 assert verify_recovery_code("ABCD1234", h1), "Case-insensitive" 127 assert not verify_recovery_code("wrong", h1) 128 129 130 def test_hash_recovery_code_legacy_sha256_fallback(): 131 import hashlib 132 code = "mycode" 133 legacy = hashlib.sha256(code.encode()).hexdigest() 134 assert verify_recovery_code(code, legacy) 135 assert not verify_recovery_code("wrong", legacy)