test_util.py
1 from typing import Any, Union 2 3 import pytest 4 from pydantic import BaseModel 5 6 from mlflow.entities.assessment import Feedback 7 from mlflow.exceptions import MlflowException 8 from mlflow.genai.judges import CategoricalRating 9 from mlflow.genai.optimize.util import ( 10 create_metric_from_scorers, 11 infer_type_from_value, 12 validate_train_data, 13 ) 14 from mlflow.genai.scorers import scorer 15 16 17 @pytest.mark.parametrize( 18 ("input_value", "expected_type"), 19 [ 20 (None, type(None)), 21 (True, bool), 22 (42, int), 23 (3.14, float), 24 ("hello", str), 25 ], 26 ) 27 def test_infer_primitive_types(input_value, expected_type): 28 assert infer_type_from_value(input_value) == expected_type 29 30 31 @pytest.mark.parametrize( 32 ("input_list", "expected_type"), 33 [ 34 ([], list[Any]), 35 ([1, 2, 3], list[int]), 36 (["a", "b", "c"], list[str]), 37 ([1.0, 2.0, 3.0], list[float]), 38 ([True, False, True], list[bool]), 39 ([1, "hello", True], list[Union[int, str, bool]]), # noqa: UP007 40 ([1, "hello", True], list[int | str | bool]), 41 ([1, 2.0], list[int | float]), 42 ([[1, 2], [3, 4]], list[list[int]]), 43 ([["a"], ["b", "c"]], list[list[str]]), 44 ], 45 ) 46 def test_infer_list_types(input_list, expected_type): 47 assert infer_type_from_value(input_list) == expected_type 48 49 50 @pytest.mark.parametrize( 51 ("input_dict", "expected_fields"), 52 [ 53 ({"name": "John", "age": 30, "active": True}, {"name": str, "age": int, "active": bool}), 54 ({"score": 95.5, "passed": True}, {"score": float, "passed": bool}), 55 ], 56 ) 57 def test_infer_simple_dict(input_dict, expected_fields): 58 result = infer_type_from_value(input_dict) 59 60 assert isinstance(result, type) 61 assert issubclass(result, BaseModel) 62 63 for field_name, expected_type in expected_fields.items(): 64 assert result.__annotations__[field_name] == expected_type 65 66 67 def test_infer_nested_dict(): 68 data = { 69 "user": {"name": "John", "scores": [85, 90, 95]}, 70 "settings": {"enabled": True, "theme": "dark"}, 71 } 72 result = infer_type_from_value(data) 73 74 assert isinstance(result, type) 75 assert issubclass(result, BaseModel) 76 77 # Check nested model types 78 user_model = result.__annotations__["user"] 79 settings_model = result.__annotations__["settings"] 80 81 assert issubclass(user_model, BaseModel) 82 assert issubclass(settings_model, BaseModel) 83 84 # Check nested field types 85 assert user_model.__annotations__["name"] == str 86 assert user_model.__annotations__["scores"] == list[int] 87 assert settings_model.__annotations__["enabled"] == bool 88 assert settings_model.__annotations__["theme"] == str 89 90 91 @pytest.mark.parametrize( 92 ("model_class", "model_data"), 93 [ 94 ( 95 type("UserModel", (BaseModel,), {"__annotations__": {"name": str, "age": int}}), 96 {"name": "John", "age": 30}, 97 ), 98 ( 99 type("ProductModel", (BaseModel,), {"__annotations__": {"id": int, "price": float}}), 100 {"id": 1, "price": 99.99}, 101 ), 102 ], 103 ) 104 def test_infer_pydantic_model(model_class, model_data): 105 model = model_class(**model_data) 106 result = infer_type_from_value(model) 107 assert result == model_class 108 109 110 @pytest.mark.parametrize( 111 "type_to_infer", 112 [ 113 type("CustomClass", (), {}), 114 type("AnotherClass", (), {"custom_attr": 42}), 115 ], 116 ) 117 def test_infer_unsupported_type(type_to_infer): 118 obj = type_to_infer() 119 assert infer_type_from_value(obj) == Any 120 121 122 @pytest.mark.parametrize( 123 ("input_dict", "model_name"), 124 [ 125 ({"name": "John", "age": 30}, "UserData"), 126 ({"id": 1, "value": "test"}, "TestModel"), 127 ], 128 ) 129 def test_model_name_parameter(input_dict, model_name): 130 result = infer_type_from_value(input_dict, model_name=model_name) 131 assert result.__name__ == model_name 132 133 134 @pytest.mark.parametrize( 135 ("score", "expected_score"), 136 [ 137 (CategoricalRating.YES, 1.0), 138 (CategoricalRating.NO, 0.0), 139 ("yes", 1.0), 140 ("no", 0.0), 141 (True, 1.0), 142 (False, 0.0), 143 (1, 1.0), 144 (0, 0.0), 145 (1.0, 1.0), 146 (0.0, 0.0), 147 ], 148 ) 149 def test_create_metric_from_scorers_with_single_score(score, expected_score): 150 @scorer(name="test_scorer") 151 def test_scorer(inputs, outputs): 152 return Feedback(name="test_scorer", value=score, rationale="test rationale") 153 154 metric = create_metric_from_scorers([test_scorer]) 155 156 result = metric({"input": "test"}, {"output": "result"}, {}, None) 157 assert result[0] == expected_score 158 assert result[1] == {"test_scorer": "test rationale"} 159 assert result[2] == {"test_scorer": expected_score} 160 161 162 def test_create_metric_from_scorers_with_multiple_categorical_ratings(): 163 @scorer(name="scorer1") 164 def scorer1(inputs, outputs): 165 return Feedback(name="scorer1", value=CategoricalRating.YES, rationale="rationale1") 166 167 @scorer(name="scorer2") 168 def scorer2(inputs, outputs): 169 return Feedback(name="scorer2", value=CategoricalRating.YES, rationale="rationale2") 170 171 metric = create_metric_from_scorers([scorer1, scorer2]) 172 173 # Should average: (1.0 + 1.0) / 2 = 1.0 174 result = metric({"input": "test"}, {"output": "result"}, {}, None) 175 assert result[0] == 1.0 176 assert result[1] == {"scorer1": "rationale1", "scorer2": "rationale2"} 177 assert result[2] == {"scorer1": 1.0, "scorer2": 1.0} 178 179 180 @pytest.mark.parametrize( 181 ("train_data", "scorers", "expected_error"), 182 [ 183 # Empty inputs 184 ( 185 [{"inputs": {}, "outputs": "result"}], 186 [], 187 "Record 0 is missing required 'inputs' field or it is empty", 188 ), 189 # Missing inputs 190 ( 191 [{"outputs": "result"}], 192 [], 193 "Record 0 is missing required 'inputs' field or it is empty", 194 ), 195 ], 196 ) 197 def test_validate_train_data_errors(train_data, scorers, expected_error): 198 import pandas as pd 199 200 with pytest.raises(MlflowException, match=expected_error): 201 validate_train_data(pd.DataFrame(train_data), scorers, lambda **kwargs: None) 202 203 204 @pytest.mark.parametrize( 205 "train_data", 206 [ 207 # Valid with outputs 208 [{"inputs": {"text": "hello"}, "outputs": "result"}], 209 # Valid with expectations 210 [{"inputs": {"text": "hello"}, "expectations": {"expected": "result"}}], 211 # Multiple valid records 212 [ 213 {"inputs": {"text": "hello"}, "outputs": "result1"}, 214 {"inputs": {"text": "world"}, "expectations": {"expected": "result2"}}, 215 ], 216 # Falsy but valid values: False as output 217 [{"inputs": {"text": "hello"}, "outputs": False}], 218 ], 219 ) 220 def test_validate_train_data_success(train_data): 221 import pandas as pd 222 223 validate_train_data(pd.DataFrame(train_data), [], lambda **kwargs: None)