/ tests / models / test_model_config.py
test_model_config.py
 1  import os
 2  from unittest import mock
 3  
 4  import pytest
 5  
 6  from mlflow.exceptions import MlflowException
 7  from mlflow.models import ModelConfig
 8  
 9  dir_path = os.path.dirname(os.path.abspath(__file__))
10  VALID_CONFIG_PATH = os.path.join(dir_path, "configs/config.yaml")
11  VALID_CONFIG_PATH_2 = os.path.join(dir_path, "configs/config_2.yaml")
12  
13  
14  def test_config_not_set():
15      with pytest.raises(
16          FileNotFoundError, match="Config file is not provided which is needed to load the model."
17      ):
18          ModelConfig()
19  
20  
21  def test_config_not_found():
22      with pytest.raises(FileNotFoundError, match="Config file 'nonexistent.yaml' not found."):
23          ModelConfig(development_config="nonexistent.yaml")
24  
25  
26  def test_config_invalid_yaml(tmp_path):
27      tmp_file = tmp_path / "invalid_config.yaml"
28      tmp_file.write_text("invalid_yaml: \n  - this is not valid \n-yaml")
29      config = ModelConfig(development_config=str(tmp_file))
30      with pytest.raises(MlflowException, match="Error parsing YAML file: "):
31          config.get("key")
32  
33  
34  def test_config_key_not_found():
35      config = ModelConfig(development_config=VALID_CONFIG_PATH)
36      with pytest.raises(KeyError, match="Key 'key' not found in configuration: "):
37          config.get("key")
38  
39  
40  def test_config_setup_correctly():
41      config = ModelConfig(development_config=VALID_CONFIG_PATH)
42      assert config.get("llm_parameters").get("temperature") == 0.01
43  
44  
45  def test_config_setup_correctly_with_mlflow_langchain():
46      with mock.patch("mlflow.models.model_config.__mlflow_model_config__", new=VALID_CONFIG_PATH):
47          config = ModelConfig(development_config="nonexistent.yaml")
48          assert config.get("llm_parameters").get("temperature") == 0.01
49  
50  
51  def test_config_setup_with_mlflow_langchain_path():
52      with mock.patch("mlflow.models.model_config.__mlflow_model_config__", new=VALID_CONFIG_PATH_2):
53          # here the config.yaml has the max_tokens set to 500
54          # where as the config_2.yaml has it set to 200.
55          # Here we give preference to the __mlflow_model_config__.
56          config = ModelConfig(development_config=VALID_CONFIG_PATH)
57          assert config.get("llm_parameters").get("max_tokens") == 200
58  
59  
60  def test_config_development_config_must_be_specified_with_keyword():
61      with pytest.raises(TypeError, match="1 positional argument but 2 were given"):
62          ModelConfig(VALID_CONFIG_PATH_2)
63  
64  
65  def test_config_development_config_is_a_dict():
66      config = ModelConfig(development_config={"llm_parameters": {"temperature": 0.01}})
67      assert config.get("llm_parameters").get("temperature") == 0.01
68  
69  
70  def test_config_setup_correctly_errors_with_no_config_path():
71      with mock.patch("mlflow.models.model_config.__mlflow_model_config__", new=""):
72          with pytest.raises(
73              FileNotFoundError,
74              match="Config file is not provided which is needed to load the model.",
75          ):
76              ModelConfig(development_config=VALID_CONFIG_PATH)
77  
78  
79  def test_config_development_config_to_dict():
80      config = ModelConfig(development_config={"llm_parameters": {"temperature": 0.01}})
81      assert config.to_dict() == {"llm_parameters": {"temperature": 0.01}}
82  
83      config = ModelConfig(development_config=VALID_CONFIG_PATH)
84      assert config.to_dict() == {
85          "embedding_model_query_instructions": "Represent this sentence for searching "
86          "relevant passages:",
87          "llm_model": "databricks-dbrx-instruct",
88          "llm_prompt_template": "You are a trustful assistant.",
89          "retriever_config": {"k": 5, "use_mmr": False},
90          "llm_parameters": {"temperature": 0.01, "max_tokens": 500},
91          "llm_prompt_template_variables": ["chat_history", "context", "question"],
92      }