test_pyfunc_input_converter.py
1 from dataclasses import asdict, dataclass 2 from typing import Optional 3 4 import pandas as pd 5 import pytest 6 7 from mlflow.models.rag_signatures import ChatCompletionRequest 8 from mlflow.pyfunc.utils.input_converter import _hydrate_dataclass 9 10 11 def test_hydrate_dataclass_input_no_dataclass(): 12 # Define a class that is not a dataclass 13 class NotADataclass: 14 pass 15 16 # Create some dummy data as a pandas df 17 data = {"a": 1, "b": 2} 18 df = pd.DataFrame(data, index=[0]) 19 20 # Check that an error is raised when trying to hydrate the dataclass 21 with pytest.raises(ValueError, match="NotADataclass is not a dataclass"): 22 _hydrate_dataclass(NotADataclass, df.iloc[0]) 23 24 25 def test_hydrate_dataclass_simple(): 26 # Define a dataclass 27 @dataclass 28 class MyDataclass: 29 a: int 30 b: int 31 32 # Create some dummy data as a pandas df 33 df = pd.DataFrame({"a": [1], "b": [2]}) 34 35 # Check that the dataclass is hydrated 36 result = _hydrate_dataclass(MyDataclass, df.iloc[0]) 37 assert result == MyDataclass(a=1, b=2) 38 39 40 def test_hydrate_dataclass_complex(): 41 # Define a more complex dataclass 42 @dataclass 43 class MyDataclass: 44 a: int 45 b: int 46 47 @dataclass 48 class MyListDataclass: 49 c: list[MyDataclass] 50 51 # Create some dummy data as a pandas df 52 df = pd.DataFrame({"c": [[{"a": 1, "b": 2}, {"a": 3, "b": 4}]]}) 53 54 # Check that the dataclass is hydrated 55 result = _hydrate_dataclass(MyListDataclass, df.iloc[0]) 56 assert result == MyListDataclass(c=[MyDataclass(a=1, b=2), MyDataclass(a=3, b=4)]) 57 58 59 @dataclass 60 class CustomInput: 61 id: int = 0 62 63 64 @dataclass 65 class FlexibleChatCompletionRequest(ChatCompletionRequest): 66 custom_input: Optional[CustomInput] = None # noqa: UP045 67 another_custom_input: CustomInput | None = None 68 69 70 def test_hydrate_child_dataclass(): 71 result = _hydrate_dataclass( 72 FlexibleChatCompletionRequest, 73 asdict( 74 FlexibleChatCompletionRequest( 75 custom_input=CustomInput(), another_custom_input=CustomInput() 76 ) 77 ), 78 ) 79 assert result == FlexibleChatCompletionRequest( 80 custom_input=CustomInput(), another_custom_input=CustomInput() 81 ) 82 83 84 def test_hydrate_optional_dataclass(): 85 result = _hydrate_dataclass( 86 FlexibleChatCompletionRequest, 87 asdict(FlexibleChatCompletionRequest(custom_input=None, another_custom_input=None)), 88 ) 89 assert result == FlexibleChatCompletionRequest(custom_input=None, another_custom_input=None)