/ tests / test_crypto.py
test_crypto.py
  1  """Tests for crypto utility functions."""
  2  from restai.utils.crypto import (
  3      LLM_SENSITIVE_KEYS,
  4      PROJECT_SENSITIVE_KEYS,
  5      SYNC_SOURCE_SENSITIVE_KEYS,
  6      decrypt_field,
  7      decrypt_sensitive_options,
  8      encrypt_field,
  9      encrypt_sensitive_options,
 10      generate_recovery_codes,
 11      hash_api_key,
 12      hash_recovery_code,
 13      verify_api_key_hash,
 14      verify_recovery_code,
 15  )
 16  
 17  
 18  def test_encrypt_decrypt_field_round_trip():
 19      plaintext = "my-secret-value"
 20      encrypted = encrypt_field(plaintext)
 21      assert encrypted != plaintext
 22      assert encrypted.startswith("$ENC$")
 23      decrypted = decrypt_field(encrypted)
 24      assert decrypted == plaintext
 25  
 26  
 27  def test_decrypt_field_on_plaintext_returns_as_is():
 28      """Backward compatibility: plaintext without $ENC$ prefix is returned unchanged."""
 29      raw = "legacy-plaintext-key"
 30      assert decrypt_field(raw) == raw
 31  
 32  
 33  def test_encrypt_field_is_idempotent():
 34      """Calling encrypt_field twice should not double-encrypt."""
 35      plaintext = "idempotent-test"
 36      once = encrypt_field(plaintext)
 37      twice = encrypt_field(once)
 38      assert once == twice
 39      assert decrypt_field(twice) == plaintext
 40  
 41  
 42  def test_encrypt_decrypt_sensitive_options_with_project_keys():
 43      opts = {
 44          "telegram_token": "tok_123",
 45          "connection": "postgresql://user:pass@host/db",
 46          "unrelated_key": "should-stay-plain",
 47      }
 48      encrypted = encrypt_sensitive_options(opts, PROJECT_SENSITIVE_KEYS)
 49  
 50      assert encrypted["telegram_token"].startswith("$ENC$")
 51      assert encrypted["connection"].startswith("$ENC$")
 52      assert encrypted["unrelated_key"] == "should-stay-plain"
 53  
 54      decrypted = decrypt_sensitive_options(encrypted, PROJECT_SENSITIVE_KEYS)
 55      assert decrypted["telegram_token"] == "tok_123"
 56      assert decrypted["connection"] == "postgresql://user:pass@host/db"
 57      assert decrypted["unrelated_key"] == "should-stay-plain"
 58  
 59  
 60  def test_encrypt_sensitive_options_with_nested_sync_sources():
 61      opts = {
 62          "sync_sources": [
 63              {
 64                  "s3_secret_key": "secret123",
 65                  "confluence_api_token": "conf-tok",
 66                  "bucket": "my-bucket",
 67              }
 68          ]
 69      }
 70      encrypted = encrypt_sensitive_options(opts, PROJECT_SENSITIVE_KEYS)
 71      src = encrypted["sync_sources"][0]
 72      assert src["s3_secret_key"].startswith("$ENC$")
 73      assert src["confluence_api_token"].startswith("$ENC$")
 74      assert src["bucket"] == "my-bucket"
 75  
 76      decrypted = decrypt_sensitive_options(encrypted, PROJECT_SENSITIVE_KEYS)
 77      dsrc = decrypted["sync_sources"][0]
 78      assert dsrc["s3_secret_key"] == "secret123"
 79      assert dsrc["confluence_api_token"] == "conf-tok"
 80  
 81  
 82  def test_llm_sensitive_keys_contains_expected():
 83      expected = {"api_key", "key", "password", "secret"}
 84      for k in expected:
 85          assert k in LLM_SENSITIVE_KEYS, f"{k} missing from LLM_SENSITIVE_KEYS"
 86  
 87  
 88  def test_hash_api_key_salted():
 89      """Each call produces a different hash (random salt), but verify works."""
 90      key = "sk-test-1234567890"
 91      h1 = hash_api_key(key)
 92      h2 = hash_api_key(key)
 93      assert h1 != h2, "Salted hashes should differ"
 94      assert h1.startswith("$pbkdf2$")
 95      assert verify_api_key_hash(key, h1)
 96      assert verify_api_key_hash(key, h2)
 97      assert not verify_api_key_hash("wrong-key", h1)
 98  
 99  
100  def test_hash_api_key_legacy_sha256_fallback():
101      """Verify works with legacy unsalted SHA256 hashes."""
102      import hashlib
103      key = "legacy-key"
104      legacy_hash = hashlib.sha256(key.encode()).hexdigest()
105      assert verify_api_key_hash(key, legacy_hash)
106      assert not verify_api_key_hash("wrong", legacy_hash)
107  
108  
109  def test_generate_recovery_codes_count_and_uniqueness():
110      codes = generate_recovery_codes(count=10)
111      assert len(codes) == 10
112      assert len(set(codes)) == 10, "Recovery codes should be unique"
113      for code in codes:
114          assert len(code) == 8
115          assert code.isalnum()
116  
117  
118  def test_hash_recovery_code_salted():
119      """Each call produces a different hash, but verify works."""
120      code = "abcd1234"
121      h1 = hash_recovery_code(code)
122      h2 = hash_recovery_code(code)
123      assert h1 != h2, "Salted hashes should differ"
124      assert h1.startswith("$pbkdf2$")
125      assert verify_recovery_code(code, h1)
126      assert verify_recovery_code("ABCD1234", h1), "Case-insensitive"
127      assert not verify_recovery_code("wrong", h1)
128  
129  
130  def test_hash_recovery_code_legacy_sha256_fallback():
131      import hashlib
132      code = "mycode"
133      legacy = hashlib.sha256(code.encode()).hexdigest()
134      assert verify_recovery_code(code, legacy)
135      assert not verify_recovery_code("wrong", legacy)