/ tests / future / test_data_definition.py
test_data_definition.py
  1  import datetime
  2  import json
  3  import random
  4  
  5  import pandas as pd
  6  import pytest
  7  
  8  from evidently._pydantic_compat import parse_obj_as
  9  from evidently.core.datasets import DEFAULT_TRACE_LINK_COLUMN
 10  from evidently.core.datasets import DataDefinition
 11  from evidently.core.datasets import Dataset
 12  from evidently.core.datasets import ServiceColumns
 13  from evidently.core.datasets import infer_column_type
 14  from evidently.legacy.core import ColumnType
 15  
 16  
 17  @pytest.mark.parametrize(
 18      "data,expected",
 19      [
 20          (pd.Series(["a", "b", "a", "b", "a"]), ColumnType.Categorical),
 21          (pd.Series([0.1, 0.2, 0.3, 0.4, 0.5]), ColumnType.Numerical),
 22          (pd.Series([0.1, 0.1, 0.2, 0.2, 0.2]), ColumnType.Numerical),
 23          (pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), ColumnType.Numerical),
 24          (pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), ColumnType.Categorical),
 25          (pd.Series([True, False, True, False, True, False]), ColumnType.Categorical),
 26          (pd.Series([datetime.datetime.now(), datetime.datetime.now()]), ColumnType.Datetime),
 27          (pd.Series(["a", "b", "c", "d", "e", "f", "g"]), ColumnType.Text),
 28          (pd.Categorical(["a", "b", "c", "d", "e", "f", "g"]), ColumnType.Categorical),
 29          (pd.Series(pd.date_range("2025-01-01", periods=11, freq="D").values), ColumnType.Datetime),
 30      ],
 31  )
 32  def test_infer_column_type(data: pd.Series, expected: ColumnType):
 33      assert infer_column_type(data) == expected
 34  
 35  
 36  @pytest.mark.parametrize(
 37      "definition,numerical,categorical,datetime_cols,text,service_columns",
 38      [
 39          (
 40              None,
 41              ("num_1", "num_2", "num_3"),
 42              ("cat_1", "cat_2", "cat_3"),
 43              ("datetime", "datetime_2"),
 44              ("text_1", "text_2"),
 45              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
 46          ),
 47          (
 48              DataDefinition(numerical_columns=["num_1"]),
 49              ("num_1",),
 50              ("cat_1", "cat_2", "cat_3"),
 51              ("datetime", "datetime_2"),
 52              ("text_1", "text_2"),
 53              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
 54          ),
 55          (
 56              DataDefinition(categorical_columns=["cat_1"]),
 57              ("num_1", "num_2", "num_3"),
 58              ("cat_1",),
 59              ("datetime", "datetime_2"),
 60              ("text_1", "text_2"),
 61              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
 62          ),
 63          (
 64              DataDefinition(text_columns=["text_2"]),
 65              ("num_1", "num_2", "num_3"),
 66              ("cat_1", "cat_2", "cat_3"),
 67              ("datetime", "datetime_2"),
 68              ("text_2",),
 69              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
 70          ),
 71          (
 72              DataDefinition(datetime_columns=["datetime_2"]),
 73              ("num_1", "num_2", "num_3"),
 74              ("cat_1", "cat_2", "cat_3"),
 75              ("datetime_2",),
 76              ("text_1", "text_2"),
 77              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
 78          ),
 79          (
 80              DataDefinition(timestamp="datetime"),
 81              ("num_1", "num_2", "num_3"),
 82              ("cat_1", "cat_2", "cat_3"),
 83              ("datetime_2",),
 84              ("text_1", "text_2"),
 85              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
 86          ),
 87          (
 88              DataDefinition(numerical_columns=[]),
 89              tuple(),
 90              ("cat_1", "cat_2", "cat_3"),
 91              ("datetime", "datetime_2"),
 92              ("text_1", "text_2"),
 93              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
 94          ),
 95          (
 96              DataDefinition(id_column="num_1"),
 97              ("num_2", "num_3"),
 98              ("cat_1", "cat_2", "cat_3"),
 99              ("datetime", "datetime_2"),
100              ("text_1", "text_2"),
101              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
102          ),
103          (
104              DataDefinition(categorical_columns=["num_3"]),
105              ("num_1", "num_2"),
106              ("num_3",),
107              ("datetime", "datetime_2"),
108              ("text_1", "text_2"),
109              ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN),
110          ),
111          (
112              DataDefinition(service_columns=ServiceColumns(trace_link="another_trace_link")),
113              ("num_1", "num_2", "num_3"),
114              ("cat_1", "cat_2", "cat_3"),
115              ("datetime", "datetime_2"),
116              ("text_1", "text_2", "_evidently_trace_link"),
117              ServiceColumns(trace_link="another_trace_link"),
118          ),
119      ],
120  )
121  def test_data_definition(definition, numerical, categorical, datetime_cols, text, service_columns):
122      data = pd.DataFrame(
123          data=dict(
124              num_1=pd.Series([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1]),
125              num_2=pd.Series([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
126              num_3=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]),
127              cat_1=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]),
128              cat_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "a"]),
129              cat_3=pd.Series(random.choices([True, False], k=11)),
130              datetime=pd.Series([datetime.datetime.now()] * 11),
131              datetime_2=pd.date_range("2025-01-01", periods=11, freq="D"),
132              text_1=pd.Series(["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]),
133              text_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "f"]),
134              _evidently_trace_link=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "f"]),
135          )
136      )
137      dataset = Dataset.from_pandas(data, data_definition=definition)
138      assert set(numerical) == set(dataset.data_definition.get_numerical_columns())
139      assert set(categorical) == set(dataset.data_definition.get_categorical_columns())
140      assert set(datetime_cols) == set(dataset.data_definition.get_datetime_columns())
141      assert set(text) == set(dataset.data_definition.get_text_columns())
142      assert service_columns == dataset.data_definition.service_columns
143  
144  
145  @pytest.mark.parametrize(
146      "definition,numerical,categorical,datetime_cols,text,service_columns",
147      [
148          (
149              None,
150              ("num_1", "num_2", "num_3"),
151              ("cat_1", "cat_2", "cat_3"),
152              ("datetime", "datetime_2"),
153              ("text_1", "text_2"),
154              None,
155          ),
156          (
157              DataDefinition(numerical_columns=["num_1"]),
158              ("num_1",),
159              ("cat_1", "cat_2", "cat_3"),
160              ("datetime", "datetime_2"),
161              ("text_1", "text_2"),
162              None,
163          ),
164          (
165              DataDefinition(categorical_columns=["cat_1"]),
166              ("num_1", "num_2", "num_3"),
167              ("cat_1",),
168              ("datetime", "datetime_2"),
169              ("text_1", "text_2"),
170              None,
171          ),
172          (
173              DataDefinition(text_columns=["text_2"]),
174              ("num_1", "num_2", "num_3"),
175              ("cat_1", "cat_2", "cat_3"),
176              ("datetime", "datetime_2"),
177              ("text_2",),
178              None,
179          ),
180          (
181              DataDefinition(datetime_columns=["datetime_2"]),
182              ("num_1", "num_2", "num_3"),
183              ("cat_1", "cat_2", "cat_3"),
184              ("datetime_2",),
185              ("text_1", "text_2"),
186              None,
187          ),
188          (
189              DataDefinition(timestamp="datetime"),
190              ("num_1", "num_2", "num_3"),
191              ("cat_1", "cat_2", "cat_3"),
192              ("datetime_2",),
193              ("text_1", "text_2"),
194              None,
195          ),
196          (
197              DataDefinition(numerical_columns=[]),
198              tuple(),
199              ("cat_1", "cat_2", "cat_3"),
200              ("datetime", "datetime_2"),
201              ("text_1", "text_2"),
202              None,
203          ),
204          (
205              DataDefinition(id_column="num_1"),
206              ("num_2", "num_3"),
207              ("cat_1", "cat_2", "cat_3"),
208              ("datetime", "datetime_2"),
209              ("text_1", "text_2"),
210              None,
211          ),
212          (
213              DataDefinition(categorical_columns=["num_3"]),
214              ("num_1", "num_2"),
215              ("num_3",),
216              ("datetime", "datetime_2"),
217              ("text_1", "text_2"),
218              None,
219          ),
220          (
221              DataDefinition(service_columns=ServiceColumns(trace_link="another_trace_link")),
222              ("num_1", "num_2", "num_3"),
223              ("cat_1", "cat_2", "cat_3"),
224              ("datetime", "datetime_2"),
225              ("text_1", "text_2"),
226              ServiceColumns(trace_link="another_trace_link"),
227          ),
228      ],
229  )
230  def test_data_definition_without_service(definition, numerical, categorical, datetime_cols, text, service_columns):
231      data = pd.DataFrame(
232          data=dict(
233              num_1=pd.Series([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1]),
234              num_2=pd.Series([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
235              num_3=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]),
236              cat_1=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]),
237              cat_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "a"]),
238              cat_3=pd.Series(random.choices([True, False], k=11)),
239              datetime=pd.Series([datetime.datetime.now()] * 11),
240              datetime_2=pd.date_range("2025-01-01", periods=11, freq="D"),
241              text_1=pd.Series(["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]),
242              text_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "f"]),
243          )
244      )
245      dataset = Dataset.from_pandas(data, data_definition=definition)
246      assert set(numerical) == set(dataset.data_definition.get_numerical_columns())
247      assert set(categorical) == set(dataset.data_definition.get_categorical_columns())
248      assert set(datetime_cols) == set(dataset.data_definition.get_datetime_columns())
249      assert set(text) == set(dataset.data_definition.get_text_columns())
250      assert service_columns == dataset.data_definition.service_columns
251  
252  
253  def test_data_definition_serialization():
254      data_definition = DataDefinition(text_columns=["text_2"])
255      parsed = parse_obj_as(DataDefinition, json.loads(data_definition.json()))
256      assert parsed == data_definition
257  
258  
259  def test_data_definition_serialization_empty():
260      data_definition = DataDefinition()
261      parsed = parse_obj_as(DataDefinition, {})
262      assert parsed == data_definition