/ tests / test_settings.py
test_settings.py
  1  """Unit tests for Settings — get_llm_config semantics and model_pricing validation."""
  2  
  3  import logging
  4  
  5  import pytest
  6  from pydantic import ValidationError
  7  
  8  from config import LLMTaskConfig, Settings
  9  
 10  
 11  def _base() -> Settings:
 12      """Return a Settings instance with only required fields and defaults, no task overrides."""
 13      return Settings(brightdata_token="x", llm_tasks={})
 14  
 15  
 16  def _with_task(**kwargs: object) -> Settings:
 17      """Return a Settings instance with a score_jobs task override."""
 18      return Settings(
 19          brightdata_token="x",
 20          llm_tasks={"score_jobs": LLMTaskConfig(**kwargs)},
 21      )
 22  
 23  
 24  class TestGetLLMConfig:
 25      def test_no_task_entry_returns_globals(self) -> None:
 26          s = _base()
 27          # Use a task name that cannot be present in any config file.
 28          provider, model, temperature, seed, max_tokens = s.get_llm_config("__nonexistent__")
 29          assert provider == s.llm_provider
 30          assert model == s.llm_model
 31          assert temperature == s.llm_temperature
 32          assert seed == s.llm_seed
 33          assert max_tokens == s.llm_max_tokens
 34  
 35      def test_task_overrides_provider_and_model(self) -> None:
 36          s = _with_task(provider="openai", model="gpt-4o")
 37          provider, model, *_ = s.get_llm_config("score_jobs")
 38          assert provider == "openai"
 39          assert model == "gpt-4o"
 40  
 41      def test_task_partial_override_falls_back(self) -> None:
 42          # Only model is overridden; provider must fall back to global.
 43          s = _with_task(model="gpt-4o")
 44          provider, model, *_ = s.get_llm_config("score_jobs")
 45          assert provider == s.llm_provider
 46          assert model == "gpt-4o"
 47  
 48      def test_temperature_zero_not_overridden(self) -> None:
 49          # 0.0 is a valid override — must not fall back to global None.
 50          s = _with_task(temperature=0.0)
 51          _, _, temperature, _, _ = s.get_llm_config("score_jobs")
 52          assert temperature == 0.0
 53  
 54      def test_seed_zero_not_overridden(self) -> None:
 55          # 0 is a valid override — must not fall back to global None.
 56          s = _with_task(seed=0)
 57          _, _, _, seed, _ = s.get_llm_config("score_jobs")
 58          assert seed == 0
 59  
 60      def test_max_tokens_override(self) -> None:
 61          s = _with_task(max_tokens=512)
 62          _, _, _, _, max_tokens = s.get_llm_config("score_jobs")
 63          assert max_tokens == 512
 64  
 65      def test_max_tokens_zero_not_overridden(self) -> None:
 66          # 0 is a valid override — must not fall back to global.
 67          s = _with_task(max_tokens=0)
 68          _, _, _, _, max_tokens = s.get_llm_config("score_jobs")
 69          assert max_tokens == 0
 70  
 71  
 72  class TestModelPricingValidator:
 73      """Validate model_pricing field — length, sign, and order checks."""
 74  
 75      def test_pricing_validator_rejects_wrong_length(self) -> None:
 76          with pytest.raises(ValidationError, match="expected \\[input, output\\]"):
 77              Settings(brightdata_token="x", model_pricing={"model-a": [3.0]})
 78  
 79      def test_pricing_validator_rejects_negative_price(self) -> None:
 80          with pytest.raises(ValidationError, match="non-negative"):
 81              Settings(brightdata_token="x", model_pricing={"model-a": [-1.0, 15.0]})
 82  
 83      def test_pricing_validator_warns_on_swapped_prices(
 84          self, caplog: pytest.LogCaptureFixture
 85      ) -> None:
 86          # output < input → likely swapped, should warn but not raise
 87          with caplog.at_level(logging.WARNING):
 88              s = Settings(
 89                  brightdata_token="x", model_pricing={"model-a": [15.0, 3.0]}
 90              )
 91          assert s.model_pricing["model-a"] == [15.0, 3.0]
 92          assert "values may be swapped" in caplog.text
 93  
 94      def test_valid_pricing_accepted(self) -> None:
 95          s = Settings(
 96              brightdata_token="x",
 97              model_pricing={"model-a": [3.0, 15.0], "model-b": [0.8, 4.0]},
 98          )
 99          assert s.model_pricing["model-a"] == [3.0, 15.0]
100          assert s.model_pricing["model-b"] == [0.8, 4.0]