/ tests / pyfunc / test_pyfunc_input_converter.py
test_pyfunc_input_converter.py
 1  from dataclasses import asdict, dataclass
 2  from typing import Optional
 3  
 4  import pandas as pd
 5  import pytest
 6  
 7  from mlflow.models.rag_signatures import ChatCompletionRequest
 8  from mlflow.pyfunc.utils.input_converter import _hydrate_dataclass
 9  
10  
11  def test_hydrate_dataclass_input_no_dataclass():
12      # Define a class that is not a dataclass
13      class NotADataclass:
14          pass
15  
16      # Create some dummy data as a pandas df
17      data = {"a": 1, "b": 2}
18      df = pd.DataFrame(data, index=[0])
19  
20      # Check that an error is raised when trying to hydrate the dataclass
21      with pytest.raises(ValueError, match="NotADataclass is not a dataclass"):
22          _hydrate_dataclass(NotADataclass, df.iloc[0])
23  
24  
25  def test_hydrate_dataclass_simple():
26      # Define a dataclass
27      @dataclass
28      class MyDataclass:
29          a: int
30          b: int
31  
32      # Create some dummy data as a pandas df
33      df = pd.DataFrame({"a": [1], "b": [2]})
34  
35      # Check that the dataclass is hydrated
36      result = _hydrate_dataclass(MyDataclass, df.iloc[0])
37      assert result == MyDataclass(a=1, b=2)
38  
39  
40  def test_hydrate_dataclass_complex():
41      # Define a more complex dataclass
42      @dataclass
43      class MyDataclass:
44          a: int
45          b: int
46  
47      @dataclass
48      class MyListDataclass:
49          c: list[MyDataclass]
50  
51      # Create some dummy data as a pandas df
52      df = pd.DataFrame({"c": [[{"a": 1, "b": 2}, {"a": 3, "b": 4}]]})
53  
54      # Check that the dataclass is hydrated
55      result = _hydrate_dataclass(MyListDataclass, df.iloc[0])
56      assert result == MyListDataclass(c=[MyDataclass(a=1, b=2), MyDataclass(a=3, b=4)])
57  
58  
59  @dataclass
60  class CustomInput:
61      id: int = 0
62  
63  
64  @dataclass
65  class FlexibleChatCompletionRequest(ChatCompletionRequest):
66      custom_input: Optional[CustomInput] = None  # noqa: UP045
67      another_custom_input: CustomInput | None = None
68  
69  
70  def test_hydrate_child_dataclass():
71      result = _hydrate_dataclass(
72          FlexibleChatCompletionRequest,
73          asdict(
74              FlexibleChatCompletionRequest(
75                  custom_input=CustomInput(), another_custom_input=CustomInput()
76              )
77          ),
78      )
79      assert result == FlexibleChatCompletionRequest(
80          custom_input=CustomInput(), another_custom_input=CustomInput()
81      )
82  
83  
84  def test_hydrate_optional_dataclass():
85      result = _hydrate_dataclass(
86          FlexibleChatCompletionRequest,
87          asdict(FlexibleChatCompletionRequest(custom_input=None, another_custom_input=None)),
88      )
89      assert result == FlexibleChatCompletionRequest(custom_input=None, another_custom_input=None)