/ tests / llama_index / test_llama_index_pyfunc_wrapper.py
test_llama_index_pyfunc_wrapper.py
  1  import llama_index.core
  2  import numpy as np
  3  import pandas as pd
  4  import pytest
  5  from llama_index.core import QueryBundle
  6  from llama_index.core.llms import ChatMessage
  7  from packaging.version import Version
  8  
  9  import mlflow
 10  from mlflow.llama_index.pyfunc_wrapper import (
 11      _CHAT_MESSAGE_HISTORY_PARAMETER_NAME,
 12      CHAT_ENGINE_NAME,
 13      QUERY_ENGINE_NAME,
 14      RETRIEVER_ENGINE_NAME,
 15      create_pyfunc_wrapper,
 16  )
 17  
 18  
 19  ################## Inferece Input #################
 20  def test_format_predict_input_str_chat(single_index):
 21      wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME)
 22      formatted_data = wrapped_model._format_predict_input("string")
 23      assert formatted_data == "string"
 24  
 25  
 26  def test_format_predict_input_dict_chat(single_index):
 27      wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME)
 28      formatted_data = wrapped_model._format_predict_input({"query": "string"})
 29      assert isinstance(formatted_data, dict)
 30  
 31  
 32  def test_format_predict_input_message_history_chat(single_index):
 33      payload = {
 34          "message": "string",
 35          _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [{"role": "user", "content": "hi"}] * 3,
 36      }
 37      wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME)
 38      formatted_data = wrapped_model._format_predict_input(payload)
 39  
 40      assert isinstance(formatted_data, dict)
 41      assert formatted_data["message"] == payload["message"]
 42      assert isinstance(formatted_data[_CHAT_MESSAGE_HISTORY_PARAMETER_NAME], list)
 43      assert all(
 44          isinstance(x, ChatMessage) for x in formatted_data[_CHAT_MESSAGE_HISTORY_PARAMETER_NAME]
 45      )
 46  
 47  
 48  @pytest.mark.parametrize(
 49      "data",
 50      [
 51          [
 52              {
 53                  "query": "string",
 54                  _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [{"role": "user", "content": "hi"}] * 3,
 55              }
 56          ]
 57          * 3,
 58          pd.DataFrame(
 59              [
 60                  {
 61                      "query": "string",
 62                      _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [{"role": "user", "content": "hi"}] * 3,
 63                  }
 64              ]
 65              * 3
 66          ),
 67      ],
 68  )
 69  def test_format_predict_input_message_history_chat_iterable(single_index, data):
 70      wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME)
 71      formatted_data = wrapped_model._format_predict_input(data)
 72  
 73      if isinstance(data, pd.DataFrame):
 74          data = data.to_dict("records")
 75  
 76      assert isinstance(formatted_data, list)
 77      assert formatted_data[0]["query"] == data[0]["query"]
 78      assert isinstance(formatted_data[0][_CHAT_MESSAGE_HISTORY_PARAMETER_NAME], list)
 79      assert all(
 80          isinstance(x, ChatMessage) for x in formatted_data[0][_CHAT_MESSAGE_HISTORY_PARAMETER_NAME]
 81      )
 82  
 83  
 84  def test_format_predict_input_message_history_chat_invalid_type(single_index):
 85      payload = {
 86          "message": "string",
 87          _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: ["invalid history string", "user: hi"],
 88      }
 89      wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME)
 90      with pytest.raises(ValueError, match="It must be a list of dicts"):
 91          _ = wrapped_model._format_predict_input(payload)
 92  
 93  
 94  @pytest.mark.parametrize(
 95      "data",
 96      [
 97          "string",
 98          ["string"],  # iterables of length 1 should be treated non-iterables
 99          {"query_str": "string"},
100          {"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]},
101          pd.DataFrame({
102              "query_str": ["string"],
103              "custom_embedding_strs": [["string"]],
104              "embedding": [[1.0]],
105          }),
106      ],
107  )
108  def test_format_predict_input_no_iterable_query(single_index, data):
109      wrapped_model = create_pyfunc_wrapper(single_index, QUERY_ENGINE_NAME)
110      formatted_data = wrapped_model._format_predict_input(data)
111      assert isinstance(formatted_data, QueryBundle)
112  
113  
114  @pytest.mark.parametrize(
115      "data",
116      [
117          ["string", "string"],
118          [{"query_str": "string"}] * 4,
119          [{"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]}] * 4,
120          [
121              pd.DataFrame({
122                  "query_str": ["string"],
123                  "custom_embedding_strs": [["string"]],
124                  "embedding": [[1.0]],
125              })
126          ]
127          * 2,
128      ],
129  )
130  def test_format_predict_input_iterable_query(single_index, data):
131      wrapped_model = create_pyfunc_wrapper(single_index, QUERY_ENGINE_NAME)
132      formatted_data = wrapped_model._format_predict_input(data)
133  
134      assert isinstance(formatted_data, list)
135      assert all(isinstance(x, QueryBundle) for x in formatted_data)
136  
137  
138  @pytest.mark.parametrize(
139      "data",
140      [
141          "string",
142          ["string"],  # iterables of length 1 should be treated non-iterables
143          {"query_str": "string"},
144          {"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]},
145          pd.DataFrame({
146              "query_str": ["string"],
147              "custom_embedding_strs": [["string"]],
148              "embedding": [[1.0]],
149          }),
150      ],
151  )
152  def test_format_predict_input_no_iterable_retriever(single_index, data):
153      wrapped_model = create_pyfunc_wrapper(single_index, RETRIEVER_ENGINE_NAME)
154      formatted_data = wrapped_model._format_predict_input(data)
155      assert isinstance(formatted_data, QueryBundle)
156  
157  
158  @pytest.mark.parametrize(
159      "data",
160      [
161          ["string", "string"],
162          [{"query_str": "string"}] * 4,
163          [{"query_str": "string", "custom_embedding_strs": ["string"], "embedding": [1.0]}] * 4,
164          [
165              pd.DataFrame({
166                  "query_str": ["string"],
167                  "custom_embedding_strs": [["string"]],
168                  "embedding": [[1.0]],
169              })
170          ]
171          * 2,
172      ],
173  )
174  def test_format_predict_input_iterable_retriever(single_index, data):
175      wrapped_model = create_pyfunc_wrapper(single_index, RETRIEVER_ENGINE_NAME)
176      formatted_data = wrapped_model._format_predict_input(data)
177      assert isinstance(formatted_data, list)
178      assert all(isinstance(x, QueryBundle) for x in formatted_data)
179  
180  
181  @pytest.mark.parametrize(
182      "engine_type",
183      ["query", "retriever"],
184  )
185  def test_format_predict_input_correct(single_index, engine_type):
186      wrapped_model = create_pyfunc_wrapper(single_index, engine_type)
187  
188      assert isinstance(
189          wrapped_model._format_predict_input(pd.DataFrame({"query_str": ["hi"]})), QueryBundle
190      )
191      assert isinstance(wrapped_model._format_predict_input(np.array(["hi"])), QueryBundle)
192      assert isinstance(wrapped_model._format_predict_input({"query_str": ["hi"]}), QueryBundle)
193      assert isinstance(wrapped_model._format_predict_input({"query_str": "hi"}), QueryBundle)
194      assert isinstance(wrapped_model._format_predict_input(["hi"]), QueryBundle)
195      assert isinstance(wrapped_model._format_predict_input("hi"), QueryBundle)
196  
197  
198  @pytest.mark.parametrize(
199      "engine_type",
200      ["query", "retriever"],
201  )
202  def test_format_predict_input_correct_schema_complex(single_index, engine_type):
203      wrapped_model = create_pyfunc_wrapper(single_index, engine_type)
204  
205      payload = {
206          "query_str": "hi",
207          "image_path": "some/path",
208          "custom_embedding_strs": [["a"]],
209          "embedding": [[1.0]],
210      }
211      assert isinstance(wrapped_model._format_predict_input(pd.DataFrame(payload)), QueryBundle)
212      payload.update({
213          "custom_embedding_strs": ["a"],
214          "embedding": [1.0],
215      })
216      assert isinstance(wrapped_model._format_predict_input(payload), QueryBundle)
217  
218  
219  @pytest.mark.parametrize(
220      ("engine_type", "input"),
221      [
222          ("query", {"query_str": "hello!"}),
223          ("retriever", {"query_str": "hello!"}),
224      ],
225  )
226  def test_spark_udf_retriever_and_query_engine(model_path, spark, single_index, engine_type, input):
227      mlflow.llama_index.save_model(
228          llama_index_model=single_index,
229          engine_type=engine_type,
230          path=model_path,
231          input_example=input,
232      )
233      udf = mlflow.pyfunc.spark_udf(spark, model_path, result_type="string")
234      df = spark.createDataFrame([{"query_str": "hi"}])
235      df = df.withColumn("predictions", udf())
236      pdf = df.toPandas()
237      assert len(pdf["predictions"].tolist()) == 1
238      assert isinstance(pdf["predictions"].tolist()[0], str)
239  
240  
241  def test_spark_udf_chat(model_path, spark, single_index):
242      engine_type = "chat"
243      input = pd.DataFrame({
244          "message": ["string"],
245          _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [[{"role": "user", "content": "string"}]],
246      })
247      mlflow.llama_index.save_model(
248          llama_index_model=single_index,
249          engine_type=engine_type,
250          path=model_path,
251          input_example=input,
252      )
253      udf = mlflow.pyfunc.spark_udf(spark, model_path, result_type="string")
254      df = spark.createDataFrame(input)
255      df = df.withColumn("predictions", udf())
256      pdf = df.toPandas()
257      assert len(pdf["predictions"].tolist()) == 1
258      assert isinstance(pdf["predictions"].tolist()[0], str)
259  
260  
261  @pytest.mark.skipif(
262      Version(llama_index.core.__version__) < Version("0.11.0"),
263      reason="Workflow was introduced in 0.11.0",
264  )
265  @pytest.mark.asyncio
266  async def test_wrap_workflow():
267      from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step
268  
269      class MyWorkflow(Workflow):
270          @step
271          async def my_step(self, ev: StartEvent) -> StopEvent:
272              return StopEvent(result=f"Hi, {ev.name}!")
273  
274      w = MyWorkflow(timeout=10, verbose=False)
275      wrapper = create_pyfunc_wrapper(w)
276      assert wrapper.get_raw_model() == w
277  
278      result = wrapper.predict({"name": "Alice"})
279      assert result == "Hi, Alice!"
280  
281      results = wrapper.predict([
282          {"name": "Bob"},
283          {"name": "Charlie"},
284      ])
285      assert results == ["Hi, Bob!", "Hi, Charlie!"]
286  
287      results = wrapper.predict(pd.DataFrame({"name": ["David"]}))
288      assert results == "Hi, David!"
289  
290      results = wrapper.predict(pd.DataFrame({"name": ["Eve", "Frank"]}))
291      assert results == ["Hi, Eve!", "Hi, Frank!"]
292  
293  
294  @pytest.mark.skipif(
295      Version(llama_index.core.__version__) < Version("0.11.0"),
296      reason="Workflow was introduced in 0.11.0",
297  )
298  @pytest.mark.asyncio
299  async def test_wrap_workflow_raise_exception():
300      from llama_index.core.workflow import (
301          StartEvent,
302          StopEvent,
303          Workflow,
304          WorkflowRuntimeError,
305          step,
306      )
307  
308      class MyWorkflow(Workflow):
309          @step
310          async def my_step(self, ev: StartEvent) -> StopEvent:
311              raise ValueError("Expected error")
312  
313      w = MyWorkflow(timeout=10, verbose=False)
314      wrapper = create_pyfunc_wrapper(w)
315  
316      with pytest.raises(
317          (
318              ValueError,  # llama_index < 0.12.1
319              WorkflowRuntimeError,  # llama_index >= 0.12.1
320          ),
321          match="Expected error",
322      ):
323          wrapper.predict({"name": "Alice"})