model_config.py
1 import os 2 from typing import Any 3 4 import yaml 5 6 from mlflow.exceptions import MlflowException 7 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 8 9 __mlflow_model_config__ = None 10 11 12 class ModelConfig: 13 """ 14 ModelConfig used in code to read a YAML configuration file or a dictionary. 15 16 Args: 17 development_config: Path to the YAML configuration file or a dictionary containing the 18 configuration. If the configuration is not provided, an error is raised 19 20 .. code-block:: python 21 :caption: Example usage in model code 22 23 from mlflow.models import ModelConfig 24 25 # Load the configuration from a dictionary 26 config = ModelConfig(development_config={"key1": "value1"}) 27 print(config.get("key1")) 28 29 30 .. code-block:: yaml 31 :caption: yaml file for model configuration 32 33 key1: value1 34 another_key: 35 - value2 36 - value3 37 38 .. code-block:: python 39 :caption: Example yaml usage in model code 40 41 from mlflow.models import ModelConfig 42 43 # Load the configuration from a file 44 config = ModelConfig(development_config="config.yaml") 45 print(config.get("key1")) 46 47 48 When invoking the ModelConfig locally in a model file, development_config can be passed in 49 which would be used as configuration for the model. 50 51 52 .. code-block:: python 53 :caption: Example to use ModelConfig when logging model as code: agent.py 54 55 import mlflow 56 from mlflow.models import ModelConfig 57 58 config = ModelConfig(development_config={"key1": "value1"}) 59 60 61 class TestModel(mlflow.pyfunc.PythonModel): 62 def predict(self, context, model_input, params=None): 63 return config.get("key1") 64 65 66 mlflow.models.set_model(TestModel()) 67 68 69 But this development_config configuration file will be overridden when logging a model. 70 When no model_config is passed in while logging the model, an error will be raised when 71 trying to load the model using ModelConfig. 72 Note: development_config is not used when logging the model. 73 74 75 .. code-block:: python 76 :caption: Example to use agent.py to log the model: deploy.py 77 78 model_config = {"key1": "value2"} 79 with mlflow.start_run(): 80 model_info = mlflow.pyfunc.log_model( 81 name="model", python_model="agent.py", model_config=model_config 82 ) 83 84 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 85 86 # This will print "value2" as the model_config passed in while logging the model 87 print(loaded_model.predict(None)) 88 """ 89 90 def __init__(self, *, development_config: str | dict[str, Any] | None = None): 91 config = globals().get("__mlflow_model_config__", None) 92 # Here mlflow_model_config have 3 states: 93 # 1. None, this means if the mlflow_model_config is None, use development_config if 94 # available 95 # 2. "", Empty string, this means the users explicitly didn't set the model config 96 # while logging the model so if ModelConfig is used, it should throw an error 97 # 3. A valid path, this means the users have set the model config while logging the 98 # model so use that path 99 if config is not None: 100 self.config = config 101 else: 102 self.config = development_config 103 104 if not self.config: 105 raise FileNotFoundError( 106 "Config file is not provided which is needed to load the model. " 107 "Please provide a valid path." 108 ) 109 110 if not isinstance(self.config, dict) and not os.path.isfile(self.config): 111 raise FileNotFoundError(f"Config file '{self.config}' not found.") 112 113 def _read_config(self): 114 """Reads the YAML configuration file and returns its contents. 115 116 Raises: 117 FileNotFoundError: If the configuration file does not exist. 118 yaml.YAMLError: If there is an error parsing the YAML content. 119 120 Returns: 121 dict or None: The content of the YAML file as a dictionary, or None if the 122 config path is not set. 123 """ 124 if isinstance(self.config, dict): 125 return self.config 126 127 with open(self.config) as file: 128 try: 129 return yaml.safe_load(file) 130 except yaml.YAMLError as e: 131 raise MlflowException( 132 f"Error parsing YAML file: {e}", error_code=INVALID_PARAMETER_VALUE 133 ) 134 135 def to_dict(self): 136 """Returns the configuration as a dictionary.""" 137 return self._read_config() 138 139 def get(self, key): 140 """Gets the value of a top-level parameter in the configuration.""" 141 config_data = self._read_config() 142 143 if config_data and key in config_data: 144 return config_data[key] 145 else: 146 raise KeyError(f"Key '{key}' not found in configuration: {config_data}.") 147 148 149 def _set_model_config(model_config): 150 globals()["__mlflow_model_config__"] = model_config