test_data_definition.py
1 import datetime 2 import json 3 import random 4 5 import pandas as pd 6 import pytest 7 8 from evidently._pydantic_compat import parse_obj_as 9 from evidently.core.datasets import DEFAULT_TRACE_LINK_COLUMN 10 from evidently.core.datasets import DataDefinition 11 from evidently.core.datasets import Dataset 12 from evidently.core.datasets import ServiceColumns 13 from evidently.core.datasets import infer_column_type 14 from evidently.legacy.core import ColumnType 15 16 17 @pytest.mark.parametrize( 18 "data,expected", 19 [ 20 (pd.Series(["a", "b", "a", "b", "a"]), ColumnType.Categorical), 21 (pd.Series([0.1, 0.2, 0.3, 0.4, 0.5]), ColumnType.Numerical), 22 (pd.Series([0.1, 0.1, 0.2, 0.2, 0.2]), ColumnType.Numerical), 23 (pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), ColumnType.Numerical), 24 (pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), ColumnType.Categorical), 25 (pd.Series([True, False, True, False, True, False]), ColumnType.Categorical), 26 (pd.Series([datetime.datetime.now(), datetime.datetime.now()]), ColumnType.Datetime), 27 (pd.Series(["a", "b", "c", "d", "e", "f", "g"]), ColumnType.Text), 28 (pd.Categorical(["a", "b", "c", "d", "e", "f", "g"]), ColumnType.Categorical), 29 (pd.Series(pd.date_range("2025-01-01", periods=11, freq="D").values), ColumnType.Datetime), 30 ], 31 ) 32 def test_infer_column_type(data: pd.Series, expected: ColumnType): 33 assert infer_column_type(data) == expected 34 35 36 @pytest.mark.parametrize( 37 "definition,numerical,categorical,datetime_cols,text,service_columns", 38 [ 39 ( 40 None, 41 ("num_1", "num_2", "num_3"), 42 ("cat_1", "cat_2", "cat_3"), 43 ("datetime", "datetime_2"), 44 ("text_1", "text_2"), 45 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 46 ), 47 ( 48 DataDefinition(numerical_columns=["num_1"]), 49 ("num_1",), 50 ("cat_1", "cat_2", "cat_3"), 51 ("datetime", "datetime_2"), 52 ("text_1", "text_2"), 53 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 54 ), 55 ( 56 DataDefinition(categorical_columns=["cat_1"]), 57 ("num_1", "num_2", "num_3"), 58 ("cat_1",), 59 ("datetime", "datetime_2"), 60 ("text_1", "text_2"), 61 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 62 ), 63 ( 64 DataDefinition(text_columns=["text_2"]), 65 ("num_1", "num_2", "num_3"), 66 ("cat_1", "cat_2", "cat_3"), 67 ("datetime", "datetime_2"), 68 ("text_2",), 69 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 70 ), 71 ( 72 DataDefinition(datetime_columns=["datetime_2"]), 73 ("num_1", "num_2", "num_3"), 74 ("cat_1", "cat_2", "cat_3"), 75 ("datetime_2",), 76 ("text_1", "text_2"), 77 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 78 ), 79 ( 80 DataDefinition(timestamp="datetime"), 81 ("num_1", "num_2", "num_3"), 82 ("cat_1", "cat_2", "cat_3"), 83 ("datetime_2",), 84 ("text_1", "text_2"), 85 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 86 ), 87 ( 88 DataDefinition(numerical_columns=[]), 89 tuple(), 90 ("cat_1", "cat_2", "cat_3"), 91 ("datetime", "datetime_2"), 92 ("text_1", "text_2"), 93 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 94 ), 95 ( 96 DataDefinition(id_column="num_1"), 97 ("num_2", "num_3"), 98 ("cat_1", "cat_2", "cat_3"), 99 ("datetime", "datetime_2"), 100 ("text_1", "text_2"), 101 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 102 ), 103 ( 104 DataDefinition(categorical_columns=["num_3"]), 105 ("num_1", "num_2"), 106 ("num_3",), 107 ("datetime", "datetime_2"), 108 ("text_1", "text_2"), 109 ServiceColumns(trace_link=DEFAULT_TRACE_LINK_COLUMN), 110 ), 111 ( 112 DataDefinition(service_columns=ServiceColumns(trace_link="another_trace_link")), 113 ("num_1", "num_2", "num_3"), 114 ("cat_1", "cat_2", "cat_3"), 115 ("datetime", "datetime_2"), 116 ("text_1", "text_2", "_evidently_trace_link"), 117 ServiceColumns(trace_link="another_trace_link"), 118 ), 119 ], 120 ) 121 def test_data_definition(definition, numerical, categorical, datetime_cols, text, service_columns): 122 data = pd.DataFrame( 123 data=dict( 124 num_1=pd.Series([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1]), 125 num_2=pd.Series([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), 126 num_3=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), 127 cat_1=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]), 128 cat_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "a"]), 129 cat_3=pd.Series(random.choices([True, False], k=11)), 130 datetime=pd.Series([datetime.datetime.now()] * 11), 131 datetime_2=pd.date_range("2025-01-01", periods=11, freq="D"), 132 text_1=pd.Series(["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]), 133 text_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "f"]), 134 _evidently_trace_link=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "f"]), 135 ) 136 ) 137 dataset = Dataset.from_pandas(data, data_definition=definition) 138 assert set(numerical) == set(dataset.data_definition.get_numerical_columns()) 139 assert set(categorical) == set(dataset.data_definition.get_categorical_columns()) 140 assert set(datetime_cols) == set(dataset.data_definition.get_datetime_columns()) 141 assert set(text) == set(dataset.data_definition.get_text_columns()) 142 assert service_columns == dataset.data_definition.service_columns 143 144 145 @pytest.mark.parametrize( 146 "definition,numerical,categorical,datetime_cols,text,service_columns", 147 [ 148 ( 149 None, 150 ("num_1", "num_2", "num_3"), 151 ("cat_1", "cat_2", "cat_3"), 152 ("datetime", "datetime_2"), 153 ("text_1", "text_2"), 154 None, 155 ), 156 ( 157 DataDefinition(numerical_columns=["num_1"]), 158 ("num_1",), 159 ("cat_1", "cat_2", "cat_3"), 160 ("datetime", "datetime_2"), 161 ("text_1", "text_2"), 162 None, 163 ), 164 ( 165 DataDefinition(categorical_columns=["cat_1"]), 166 ("num_1", "num_2", "num_3"), 167 ("cat_1",), 168 ("datetime", "datetime_2"), 169 ("text_1", "text_2"), 170 None, 171 ), 172 ( 173 DataDefinition(text_columns=["text_2"]), 174 ("num_1", "num_2", "num_3"), 175 ("cat_1", "cat_2", "cat_3"), 176 ("datetime", "datetime_2"), 177 ("text_2",), 178 None, 179 ), 180 ( 181 DataDefinition(datetime_columns=["datetime_2"]), 182 ("num_1", "num_2", "num_3"), 183 ("cat_1", "cat_2", "cat_3"), 184 ("datetime_2",), 185 ("text_1", "text_2"), 186 None, 187 ), 188 ( 189 DataDefinition(timestamp="datetime"), 190 ("num_1", "num_2", "num_3"), 191 ("cat_1", "cat_2", "cat_3"), 192 ("datetime_2",), 193 ("text_1", "text_2"), 194 None, 195 ), 196 ( 197 DataDefinition(numerical_columns=[]), 198 tuple(), 199 ("cat_1", "cat_2", "cat_3"), 200 ("datetime", "datetime_2"), 201 ("text_1", "text_2"), 202 None, 203 ), 204 ( 205 DataDefinition(id_column="num_1"), 206 ("num_2", "num_3"), 207 ("cat_1", "cat_2", "cat_3"), 208 ("datetime", "datetime_2"), 209 ("text_1", "text_2"), 210 None, 211 ), 212 ( 213 DataDefinition(categorical_columns=["num_3"]), 214 ("num_1", "num_2"), 215 ("num_3",), 216 ("datetime", "datetime_2"), 217 ("text_1", "text_2"), 218 None, 219 ), 220 ( 221 DataDefinition(service_columns=ServiceColumns(trace_link="another_trace_link")), 222 ("num_1", "num_2", "num_3"), 223 ("cat_1", "cat_2", "cat_3"), 224 ("datetime", "datetime_2"), 225 ("text_1", "text_2"), 226 ServiceColumns(trace_link="another_trace_link"), 227 ), 228 ], 229 ) 230 def test_data_definition_without_service(definition, numerical, categorical, datetime_cols, text, service_columns): 231 data = pd.DataFrame( 232 data=dict( 233 num_1=pd.Series([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1]), 234 num_2=pd.Series([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), 235 num_3=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), 236 cat_1=pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]), 237 cat_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "a"]), 238 cat_3=pd.Series(random.choices([True, False], k=11)), 239 datetime=pd.Series([datetime.datetime.now()] * 11), 240 datetime_2=pd.date_range("2025-01-01", periods=11, freq="D"), 241 text_1=pd.Series(["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]), 242 text_2=pd.Series(["a", "b", "c", "d", "e", "a", "b", "c", "d", "e", "f"]), 243 ) 244 ) 245 dataset = Dataset.from_pandas(data, data_definition=definition) 246 assert set(numerical) == set(dataset.data_definition.get_numerical_columns()) 247 assert set(categorical) == set(dataset.data_definition.get_categorical_columns()) 248 assert set(datetime_cols) == set(dataset.data_definition.get_datetime_columns()) 249 assert set(text) == set(dataset.data_definition.get_text_columns()) 250 assert service_columns == dataset.data_definition.service_columns 251 252 253 def test_data_definition_serialization(): 254 data_definition = DataDefinition(text_columns=["text_2"]) 255 parsed = parse_obj_as(DataDefinition, json.loads(data_definition.json())) 256 assert parsed == data_definition 257 258 259 def test_data_definition_serialization_empty(): 260 data_definition = DataDefinition() 261 parsed = parse_obj_as(DataDefinition, {}) 262 assert parsed == data_definition