test_pydantic_aliases.py
1 import glob 2 import os 3 from collections import defaultdict 4 from importlib import import_module 5 from inspect import isabstract 6 from typing import Dict 7 from typing import Set 8 from typing import Type 9 from typing import TypeVar 10 11 import pytest 12 13 import evidently 14 from evidently._pydantic_compat import import_string 15 from evidently.core import registries 16 from evidently.core.container import MetricContainer 17 from evidently.core.datasets import ColumnCondition 18 from evidently.core.datasets import Descriptor 19 from evidently.core.datasets import SpecialColumnInfo 20 from evidently.core.metric_types import BoundTest 21 from evidently.core.metric_types import Metric as MetricV2 22 from evidently.core.metric_types import MetricResult as MetricResultV2 23 from evidently.core.metric_types import MetricTest 24 from evidently.legacy.base_metric import BasePreset 25 from evidently.legacy.base_metric import ColumnName 26 from evidently.legacy.base_metric import Metric 27 from evidently.legacy.base_metric import MetricResult 28 from evidently.legacy.collector.config import CollectorTrigger 29 from evidently.legacy.collector.storage import CollectorStorage 30 from evidently.legacy.features.generated_features import BaseDescriptor 31 from evidently.legacy.features.generated_features import FeatureDescriptor 32 from evidently.legacy.features.generated_features import GeneratedFeatures 33 from evidently.legacy.features.llm_judge import BaseLLMPromptTemplate 34 from evidently.legacy.metric_preset.metric_preset import MetricPreset 35 from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod 36 from evidently.legacy.test_preset.test_preset import TestPreset 37 from evidently.legacy.tests.base_test import Test 38 from evidently.legacy.tests.base_test import TestParameters 39 from evidently.legacy.ui.components.base import Component as ComponentLegacy 40 from evidently.legacy.ui.dashboards.base import DashboardPanel 41 from evidently.legacy.utils.llm.prompts import PromptBlock 42 from evidently.legacy.utils.llm.prompts import PromptTemplate 43 from evidently.llm.datagen.base import BaseDatasetGenerator 44 from evidently.llm.optimization.optimizer import OptimizerConfig 45 from evidently.llm.optimization.optimizer import OptimizerLog 46 from evidently.llm.optimization.prompts import OptimizationScorer 47 from evidently.llm.optimization.prompts import PromptExecutor 48 from evidently.llm.optimization.prompts import PromptOptimizerStrategy 49 from evidently.llm.prompts.content import PromptContent 50 from evidently.llm.rag.index import DataCollectionProvider 51 from evidently.llm.rag.splitter import Splitter 52 from evidently.pydantic_utils import TYPE_ALIASES 53 from evidently.pydantic_utils import EvidentlyBaseModel 54 from evidently.pydantic_utils import PolymorphicModel 55 from evidently.pydantic_utils import WithTestAndMetricDependencies 56 from evidently.pydantic_utils import get_base_class 57 from evidently.pydantic_utils import is_not_abstract 58 from evidently.sdk.configs import ConfigContent 59 from evidently.ui.service.components import DataStorageComponent 60 from evidently.ui.service.components import MetadataStorageComponent 61 from evidently.ui.service.components.base import Component 62 from evidently.ui.service.components.snapshot_links import SnapshotDatasetLinksComponent 63 from evidently.ui.service.components.storage import BlobStorageComponent 64 from evidently.ui.service.components.storage import DatasetFileStorageComponent 65 from evidently.ui.service.components.storage import DatasetMetadataComponent 66 from evidently.ui.service.components.storage import StorageComponent 67 from evidently.ui.service.datasets.data_source import DataSource 68 from evidently.ui.service.datasets.data_source import DataSourceDTO 69 from evidently.ui.service.datasets.filters import FilterBy 70 from evidently.ui.service.datasets.filters import FilterByNumber 71 from evidently.ui.service.datasets.filters import FilterByString 72 73 T = TypeVar("T") 74 75 76 # todo: deduplicate code 77 def find_all_subclasses( 78 base: Type[T], 79 base_module: str = "evidently", 80 path: str = os.path.dirname(evidently.__file__), 81 include_abstract: bool = False, 82 ) -> Set[Type[T]]: 83 classes = set() 84 for mod in glob.glob(path + "/**/*.py", recursive=True): 85 mod_path = os.path.relpath(mod, path)[:-3] 86 mod_name = f"{base_module}." + mod_path.replace("/", ".").replace("\\", ".") 87 if mod_name.endswith("__"): 88 continue 89 module = import_module(mod_name) 90 for key, value in module.__dict__.items(): 91 if isinstance(value, type) and value is not base and issubclass(value, base): 92 if not isabstract(value) or include_abstract: 93 classes.add(value) 94 95 return classes 96 97 98 REGISTRY_MAPPING: Dict[Type[PolymorphicModel], str] = { 99 # legacy 100 Test: "evidently.legacy.tests._registry", 101 TestParameters: "evidently.legacy.tests._registry", 102 TestPreset: "evidently.legacy.test_preset._registry", 103 MetricResult: "evidently.legacy.metrics._registry", 104 Metric: "evidently.legacy.metrics._registry", 105 MetricPreset: "evidently.legacy.metric_preset._registry", 106 FeatureDescriptor: "evidently.legacy.descriptors._registry", 107 GeneratedFeatures: "evidently.legacy.features._registry", 108 # new api 109 MetricTest: registries.metric_tests.__name__, 110 MetricV2: registries.metrics.__name__, 111 MetricContainer: registries.presets.__name__, 112 MetricResultV2: registries.metric_results.__name__, 113 BoundTest: registries.bound_tests.__name__, 114 Descriptor: registries.descriptors.__name__, 115 ColumnCondition: registries.column_conditions.__name__, 116 PromptContent: registries.prompts.__name__, 117 ConfigContent: registries.configs.__name__, 118 OptimizerConfig: registries.optimizers.__name__, 119 OptimizerLog: registries.optimizers.__name__, 120 OptimizationScorer: registries.optimizers.__name__, 121 PromptExecutor: registries.optimizers.__name__, 122 PromptOptimizerStrategy: registries.optimizers.__name__, 123 PromptBlock: registries.prompt_blocks.__name__, 124 PromptTemplate: registries.prompt_templates.__name__, 125 BaseLLMPromptTemplate: registries.prompts.__name__, 126 DataCollectionProvider: registries.rag.__name__, 127 Splitter: registries.rag.__name__, 128 BaseDatasetGenerator: registries.datagen.__name__, 129 SpecialColumnInfo: registries.descriptors.__name__, 130 BlobStorageComponent: registries.components.__name__, 131 DataStorageComponent: registries.components.__name__, 132 MetadataStorageComponent: registries.components.__name__, 133 StorageComponent: registries.components.__name__, 134 DatasetFileStorageComponent: registries.components.__name__, 135 DatasetMetadataComponent: registries.components.__name__, 136 SnapshotDatasetLinksComponent: registries.components.__name__, 137 DataSource: registries.dataset_models.__name__, 138 FilterByNumber: registries.dataset_models.__name__, 139 FilterByString: registries.dataset_models.__name__, 140 DataSourceDTO: registries.dataset_models.__name__, 141 } 142 143 144 def test_all_aliases_registered(): 145 not_registered = [] 146 147 for cls in find_all_subclasses(PolymorphicModel, include_abstract=True): 148 if cls.__is_base_type__(): 149 continue 150 classpath = cls.__get_classpath__() 151 typename = cls.__get_type__() 152 if classpath == typename: 153 # no typename 154 continue 155 key = (get_base_class(cls), typename) 156 if key not in TYPE_ALIASES or TYPE_ALIASES[key] != classpath: 157 not_registered.append(cls) 158 159 register_msgs = [] 160 file_to_type = defaultdict(list) 161 for cls in sorted(not_registered, key=lambda c: get_base_class(c).__name__ + " " + c.__get_classpath__()): 162 base_class = get_base_class(cls) 163 msg = f'register_type_alias({base_class.__name__}, "{cls.__get_classpath__()}", "{cls.__get_type__()}")' 164 if base_class not in REGISTRY_MAPPING: 165 register_msgs.append(msg) 166 continue 167 file_to_type[REGISTRY_MAPPING[base_class]].append(msg) 168 169 for file, msgs in file_to_type.items(): 170 mod = import_string(file) 171 with open(mod.__file__, "a") as f: 172 f.write("\n") 173 f.write("\n".join(msgs)) 174 print("\n".join(register_msgs)) 175 assert len(not_registered) == 0, "Not all aliases registered" 176 177 178 @pytest.mark.parametrize( 179 "base_class,classpath", [(base_class, classpath) for (base_class, _), classpath in TYPE_ALIASES.items()] 180 ) 181 def test_all_registered_classpath_exist(base_class: Type[PolymorphicModel], classpath): 182 try: 183 base_class.load_alias(classpath) 184 except ImportError: 185 assert False, f"wrong classpath registered '{classpath}'" 186 187 188 def test_all_aliases_correct(): 189 base_class_type_mapping = { 190 Metric: "metric", 191 Test: "test", 192 GeneratedFeatures: "feature", 193 BaseDescriptor: "descriptor", 194 MetricPreset: "metric_preset", 195 TestPreset: "test_preset", 196 MetricResult: "metric_result", 197 DriftMethod: "drift_method", 198 TestParameters: "test_parameters", 199 ColumnName: "base", 200 CollectorTrigger: "collector_trigger", 201 CollectorStorage: "collector_storage", 202 BaseLLMPromptTemplate: "prompt_template", 203 DashboardPanel: "dashboard_panel", 204 PromptBlock: "prompt_block", 205 PromptTemplate: "prompt_template", 206 MetricV2: MetricV2.__alias_type__, 207 MetricResultV2: MetricResultV2.__alias_type__, 208 MetricTest: MetricTest.__alias_type__, 209 BoundTest: BoundTest.__alias_type__, 210 Descriptor: Descriptor.__alias_type__, 211 MetricContainer: MetricContainer.__alias_type__, 212 ColumnCondition: ColumnCondition.__alias_type__, 213 PromptContent: PromptContent.__alias_type__, 214 ConfigContent: ConfigContent.__alias_type__, 215 OptimizerConfig: OptimizerConfig.__alias_type__, 216 OptimizerLog: OptimizerLog.__alias_type__, 217 OptimizationScorer: OptimizationScorer.__alias_type__, 218 PromptExecutor: PromptExecutor.__alias_type__, 219 PromptOptimizerStrategy: PromptOptimizerStrategy.__alias_type__, 220 DataCollectionProvider: DataCollectionProvider.__alias_type__, 221 Splitter: Splitter.__alias_type__, 222 BaseDatasetGenerator: BaseDatasetGenerator.__alias_type__, 223 SpecialColumnInfo: SpecialColumnInfo.__alias_type__, 224 DataSourceDTO: DataSourceDTO.__alias_type__, 225 DataSource: DataSource.__alias_type__, 226 } 227 skip = [ 228 Component, 229 ComponentLegacy, 230 FilterBy, 231 ] 232 skip_literal = [EvidentlyBaseModel, WithTestAndMetricDependencies, BasePreset] 233 for cls in find_all_subclasses(PolymorphicModel, include_abstract=True): 234 if cls in skip_literal or any(issubclass(cls, s) for s in skip) or not is_not_abstract(cls): 235 continue 236 for base_class, base_type in base_class_type_mapping.items(): 237 if issubclass(cls, base_class): 238 # alias = getattr(cls.__config__, "type_alias") 239 alias = cls.__get_type__() 240 assert alias is not None, f"{cls.__name__} has no alias ({alias})" 241 assert alias == f"evidently:{base_type}:{cls.__name__}", f"wrong alias for {cls.__name__}" 242 break 243 else: 244 assert False, f"No base class type mapping for {cls}"