test_model_config.py
1 import os 2 from unittest import mock 3 4 import pytest 5 6 from mlflow.exceptions import MlflowException 7 from mlflow.models import ModelConfig 8 9 dir_path = os.path.dirname(os.path.abspath(__file__)) 10 VALID_CONFIG_PATH = os.path.join(dir_path, "configs/config.yaml") 11 VALID_CONFIG_PATH_2 = os.path.join(dir_path, "configs/config_2.yaml") 12 13 14 def test_config_not_set(): 15 with pytest.raises( 16 FileNotFoundError, match="Config file is not provided which is needed to load the model." 17 ): 18 ModelConfig() 19 20 21 def test_config_not_found(): 22 with pytest.raises(FileNotFoundError, match="Config file 'nonexistent.yaml' not found."): 23 ModelConfig(development_config="nonexistent.yaml") 24 25 26 def test_config_invalid_yaml(tmp_path): 27 tmp_file = tmp_path / "invalid_config.yaml" 28 tmp_file.write_text("invalid_yaml: \n - this is not valid \n-yaml") 29 config = ModelConfig(development_config=str(tmp_file)) 30 with pytest.raises(MlflowException, match="Error parsing YAML file: "): 31 config.get("key") 32 33 34 def test_config_key_not_found(): 35 config = ModelConfig(development_config=VALID_CONFIG_PATH) 36 with pytest.raises(KeyError, match="Key 'key' not found in configuration: "): 37 config.get("key") 38 39 40 def test_config_setup_correctly(): 41 config = ModelConfig(development_config=VALID_CONFIG_PATH) 42 assert config.get("llm_parameters").get("temperature") == 0.01 43 44 45 def test_config_setup_correctly_with_mlflow_langchain(): 46 with mock.patch("mlflow.models.model_config.__mlflow_model_config__", new=VALID_CONFIG_PATH): 47 config = ModelConfig(development_config="nonexistent.yaml") 48 assert config.get("llm_parameters").get("temperature") == 0.01 49 50 51 def test_config_setup_with_mlflow_langchain_path(): 52 with mock.patch("mlflow.models.model_config.__mlflow_model_config__", new=VALID_CONFIG_PATH_2): 53 # here the config.yaml has the max_tokens set to 500 54 # where as the config_2.yaml has it set to 200. 55 # Here we give preference to the __mlflow_model_config__. 56 config = ModelConfig(development_config=VALID_CONFIG_PATH) 57 assert config.get("llm_parameters").get("max_tokens") == 200 58 59 60 def test_config_development_config_must_be_specified_with_keyword(): 61 with pytest.raises(TypeError, match="1 positional argument but 2 were given"): 62 ModelConfig(VALID_CONFIG_PATH_2) 63 64 65 def test_config_development_config_is_a_dict(): 66 config = ModelConfig(development_config={"llm_parameters": {"temperature": 0.01}}) 67 assert config.get("llm_parameters").get("temperature") == 0.01 68 69 70 def test_config_setup_correctly_errors_with_no_config_path(): 71 with mock.patch("mlflow.models.model_config.__mlflow_model_config__", new=""): 72 with pytest.raises( 73 FileNotFoundError, 74 match="Config file is not provided which is needed to load the model.", 75 ): 76 ModelConfig(development_config=VALID_CONFIG_PATH) 77 78 79 def test_config_development_config_to_dict(): 80 config = ModelConfig(development_config={"llm_parameters": {"temperature": 0.01}}) 81 assert config.to_dict() == {"llm_parameters": {"temperature": 0.01}} 82 83 config = ModelConfig(development_config=VALID_CONFIG_PATH) 84 assert config.to_dict() == { 85 "embedding_model_query_instructions": "Represent this sentence for searching " 86 "relevant passages:", 87 "llm_model": "databricks-dbrx-instruct", 88 "llm_prompt_template": "You are a trustful assistant.", 89 "retriever_config": {"k": 5, "use_mmr": False}, 90 "llm_parameters": {"temperature": 0.01, "max_tokens": 500}, 91 "llm_prompt_template_variables": ["chat_history", "context", "question"], 92 }