test_llama_index_pyfunc_wrapper.py
1 import llama_index.core 2 import numpy as np 3 import pandas as pd 4 import pytest 5 from llama_index.core import QueryBundle 6 from llama_index.core.llms import ChatMessage 7 from packaging.version import Version 8 9 import mlflow 10 from mlflow.llama_index.pyfunc_wrapper import ( 11 _CHAT_MESSAGE_HISTORY_PARAMETER_NAME, 12 CHAT_ENGINE_NAME, 13 QUERY_ENGINE_NAME, 14 RETRIEVER_ENGINE_NAME, 15 create_pyfunc_wrapper, 16 ) 17 18 19 ################## Inferece Input ################# 20 def test_format_predict_input_str_chat(single_index): 21 wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) 22 formatted_data = wrapped_model._format_predict_input("string") 23 assert formatted_data == "string" 24 25 26 def test_format_predict_input_dict_chat(single_index): 27 wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) 28 formatted_data = wrapped_model._format_predict_input({"query": "string"}) 29 assert isinstance(formatted_data, dict) 30 31 32 def test_format_predict_input_message_history_chat(single_index): 33 payload = { 34 "message": "string", 35 _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [{"role": "user", "content": "hi"}] * 3, 36 } 37 wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) 38 formatted_data = wrapped_model._format_predict_input(payload) 39 40 assert isinstance(formatted_data, dict) 41 assert formatted_data["message"] == payload["message"] 42 assert isinstance(formatted_data[_CHAT_MESSAGE_HISTORY_PARAMETER_NAME], list) 43 assert all( 44 isinstance(x, ChatMessage) for x in formatted_data[_CHAT_MESSAGE_HISTORY_PARAMETER_NAME] 45 ) 46 47 48 @pytest.mark.parametrize( 49 "data", 50 [ 51 [ 52 { 53 "query": "string", 54 _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [{"role": "user", "content": "hi"}] * 3, 55 } 56 ] 57 * 3, 58 pd.DataFrame( 59 [ 60 { 61 "query": "string", 62 _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [{"role": "user", "content": "hi"}] * 3, 63 } 64 ] 65 * 3 66 ), 67 ], 68 ) 69 def test_format_predict_input_message_history_chat_iterable(single_index, data): 70 wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) 71 formatted_data = wrapped_model._format_predict_input(data) 72 73 if isinstance(data, pd.DataFrame): 74 data = data.to_dict("records") 75 76 assert isinstance(formatted_data, list) 77 assert formatted_data[0]["query"] == data[0]["query"] 78 assert isinstance(formatted_data[0][_CHAT_MESSAGE_HISTORY_PARAMETER_NAME], list) 79 assert all( 80 isinstance(x, ChatMessage) for x in formatted_data[0][_CHAT_MESSAGE_HISTORY_PARAMETER_NAME] 81 ) 82 83 84 def test_format_predict_input_message_history_chat_invalid_type(single_index): 85 payload = { 86 "message": "string", 87 _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: ["invalid history string", "user: hi"], 88 } 89 wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) 90 with pytest.raises(ValueError, match="It must be a list of dicts"): 91 _ = wrapped_model._format_predict_input(payload) 92 93 94 @pytest.mark.parametrize( 95 "data", 96 [ 97 "string", 98 ["string"], # iterables of length 1 should be treated non-iterables 99 {"query_str": "string"}, 100 {"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]}, 101 pd.DataFrame({ 102 "query_str": ["string"], 103 "custom_embedding_strs": [["string"]], 104 "embedding": [[1.0]], 105 }), 106 ], 107 ) 108 def test_format_predict_input_no_iterable_query(single_index, data): 109 wrapped_model = create_pyfunc_wrapper(single_index, QUERY_ENGINE_NAME) 110 formatted_data = wrapped_model._format_predict_input(data) 111 assert isinstance(formatted_data, QueryBundle) 112 113 114 @pytest.mark.parametrize( 115 "data", 116 [ 117 ["string", "string"], 118 [{"query_str": "string"}] * 4, 119 [{"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]}] * 4, 120 [ 121 pd.DataFrame({ 122 "query_str": ["string"], 123 "custom_embedding_strs": [["string"]], 124 "embedding": [[1.0]], 125 }) 126 ] 127 * 2, 128 ], 129 ) 130 def test_format_predict_input_iterable_query(single_index, data): 131 wrapped_model = create_pyfunc_wrapper(single_index, QUERY_ENGINE_NAME) 132 formatted_data = wrapped_model._format_predict_input(data) 133 134 assert isinstance(formatted_data, list) 135 assert all(isinstance(x, QueryBundle) for x in formatted_data) 136 137 138 @pytest.mark.parametrize( 139 "data", 140 [ 141 "string", 142 ["string"], # iterables of length 1 should be treated non-iterables 143 {"query_str": "string"}, 144 {"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]}, 145 pd.DataFrame({ 146 "query_str": ["string"], 147 "custom_embedding_strs": [["string"]], 148 "embedding": [[1.0]], 149 }), 150 ], 151 ) 152 def test_format_predict_input_no_iterable_retriever(single_index, data): 153 wrapped_model = create_pyfunc_wrapper(single_index, RETRIEVER_ENGINE_NAME) 154 formatted_data = wrapped_model._format_predict_input(data) 155 assert isinstance(formatted_data, QueryBundle) 156 157 158 @pytest.mark.parametrize( 159 "data", 160 [ 161 ["string", "string"], 162 [{"query_str": "string"}] * 4, 163 [{"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]}] * 4, 164 [ 165 pd.DataFrame({ 166 "query_str": ["string"], 167 "custom_embedding_strs": [["string"]], 168 "embedding": [[1.0]], 169 }) 170 ] 171 * 2, 172 ], 173 ) 174 def test_format_predict_input_iterable_retriever(single_index, data): 175 wrapped_model = create_pyfunc_wrapper(single_index, RETRIEVER_ENGINE_NAME) 176 formatted_data = wrapped_model._format_predict_input(data) 177 assert isinstance(formatted_data, list) 178 assert all(isinstance(x, QueryBundle) for x in formatted_data) 179 180 181 @pytest.mark.parametrize( 182 "engine_type", 183 ["query", "retriever"], 184 ) 185 def test_format_predict_input_correct(single_index, engine_type): 186 wrapped_model = create_pyfunc_wrapper(single_index, engine_type) 187 188 assert isinstance( 189 wrapped_model._format_predict_input(pd.DataFrame({"query_str": ["hi"]})), QueryBundle 190 ) 191 assert isinstance(wrapped_model._format_predict_input(np.array(["hi"])), QueryBundle) 192 assert isinstance(wrapped_model._format_predict_input({"query_str": ["hi"]}), QueryBundle) 193 assert isinstance(wrapped_model._format_predict_input({"query_str": "hi"}), QueryBundle) 194 assert isinstance(wrapped_model._format_predict_input(["hi"]), QueryBundle) 195 assert isinstance(wrapped_model._format_predict_input("hi"), QueryBundle) 196 197 198 @pytest.mark.parametrize( 199 "engine_type", 200 ["query", "retriever"], 201 ) 202 def test_format_predict_input_correct_schema_complex(single_index, engine_type): 203 wrapped_model = create_pyfunc_wrapper(single_index, engine_type) 204 205 payload = { 206 "query_str": "hi", 207 "image_path": "some/path", 208 "custom_embedding_strs": [["a"]], 209 "embedding": [[1.0]], 210 } 211 assert isinstance(wrapped_model._format_predict_input(pd.DataFrame(payload)), QueryBundle) 212 payload.update({ 213 "custom_embedding_strs": ["a"], 214 "embedding": [1.0], 215 }) 216 assert isinstance(wrapped_model._format_predict_input(payload), QueryBundle) 217 218 219 @pytest.mark.parametrize( 220 ("engine_type", "input"), 221 [ 222 ("query", {"query_str": "hello!"}), 223 ("retriever", {"query_str": "hello!"}), 224 ], 225 ) 226 def test_spark_udf_retriever_and_query_engine(model_path, spark, single_index, engine_type, input): 227 mlflow.llama_index.save_model( 228 llama_index_model=single_index, 229 engine_type=engine_type, 230 path=model_path, 231 input_example=input, 232 ) 233 udf = mlflow.pyfunc.spark_udf(spark, model_path, result_type="string") 234 df = spark.createDataFrame([{"query_str": "hi"}]) 235 df = df.withColumn("predictions", udf()) 236 pdf = df.toPandas() 237 assert len(pdf["predictions"].tolist()) == 1 238 assert isinstance(pdf["predictions"].tolist()[0], str) 239 240 241 def test_spark_udf_chat(model_path, spark, single_index): 242 engine_type = "chat" 243 input = pd.DataFrame({ 244 "message": ["string"], 245 _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [[{"role": "user", "content": "string"}]], 246 }) 247 mlflow.llama_index.save_model( 248 llama_index_model=single_index, 249 engine_type=engine_type, 250 path=model_path, 251 input_example=input, 252 ) 253 udf = mlflow.pyfunc.spark_udf(spark, model_path, result_type="string") 254 df = spark.createDataFrame(input) 255 df = df.withColumn("predictions", udf()) 256 pdf = df.toPandas() 257 assert len(pdf["predictions"].tolist()) == 1 258 assert isinstance(pdf["predictions"].tolist()[0], str) 259 260 261 @pytest.mark.skipif( 262 Version(llama_index.core.__version__) < Version("0.11.0"), 263 reason="Workflow was introduced in 0.11.0", 264 ) 265 @pytest.mark.asyncio 266 async def test_wrap_workflow(): 267 from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step 268 269 class MyWorkflow(Workflow): 270 @step 271 async def my_step(self, ev: StartEvent) -> StopEvent: 272 return StopEvent(result=f"Hi, {ev.name}!") 273 274 w = MyWorkflow(timeout=10, verbose=False) 275 wrapper = create_pyfunc_wrapper(w) 276 assert wrapper.get_raw_model() == w 277 278 result = wrapper.predict({"name": "Alice"}) 279 assert result == "Hi, Alice!" 280 281 results = wrapper.predict([ 282 {"name": "Bob"}, 283 {"name": "Charlie"}, 284 ]) 285 assert results == ["Hi, Bob!", "Hi, Charlie!"] 286 287 results = wrapper.predict(pd.DataFrame({"name": ["David"]})) 288 assert results == "Hi, David!" 289 290 results = wrapper.predict(pd.DataFrame({"name": ["Eve", "Frank"]})) 291 assert results == ["Hi, Eve!", "Hi, Frank!"] 292 293 294 @pytest.mark.skipif( 295 Version(llama_index.core.__version__) < Version("0.11.0"), 296 reason="Workflow was introduced in 0.11.0", 297 ) 298 @pytest.mark.asyncio 299 async def test_wrap_workflow_raise_exception(): 300 from llama_index.core.workflow import ( 301 StartEvent, 302 StopEvent, 303 Workflow, 304 WorkflowRuntimeError, 305 step, 306 ) 307 308 class MyWorkflow(Workflow): 309 @step 310 async def my_step(self, ev: StartEvent) -> StopEvent: 311 raise ValueError("Expected error") 312 313 w = MyWorkflow(timeout=10, verbose=False) 314 wrapper = create_pyfunc_wrapper(w) 315 316 with pytest.raises( 317 ( 318 ValueError, # llama_index < 0.12.1 319 WorkflowRuntimeError, # llama_index >= 0.12.1 320 ), 321 match="Expected error", 322 ): 323 wrapper.predict({"name": "Alice"})