/ tests / test_pydantic_aliases.py
test_pydantic_aliases.py
  1  import glob
  2  import os
  3  from collections import defaultdict
  4  from importlib import import_module
  5  from inspect import isabstract
  6  from typing import Dict
  7  from typing import Set
  8  from typing import Type
  9  from typing import TypeVar
 10  
 11  import pytest
 12  
 13  import evidently
 14  from evidently._pydantic_compat import import_string
 15  from evidently.core import registries
 16  from evidently.core.container import MetricContainer
 17  from evidently.core.datasets import ColumnCondition
 18  from evidently.core.datasets import Descriptor
 19  from evidently.core.datasets import SpecialColumnInfo
 20  from evidently.core.metric_types import BoundTest
 21  from evidently.core.metric_types import Metric as MetricV2
 22  from evidently.core.metric_types import MetricResult as MetricResultV2
 23  from evidently.core.metric_types import MetricTest
 24  from evidently.legacy.base_metric import BasePreset
 25  from evidently.legacy.base_metric import ColumnName
 26  from evidently.legacy.base_metric import Metric
 27  from evidently.legacy.base_metric import MetricResult
 28  from evidently.legacy.collector.config import CollectorTrigger
 29  from evidently.legacy.collector.storage import CollectorStorage
 30  from evidently.legacy.features.generated_features import BaseDescriptor
 31  from evidently.legacy.features.generated_features import FeatureDescriptor
 32  from evidently.legacy.features.generated_features import GeneratedFeatures
 33  from evidently.legacy.features.llm_judge import BaseLLMPromptTemplate
 34  from evidently.legacy.metric_preset.metric_preset import MetricPreset
 35  from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod
 36  from evidently.legacy.test_preset.test_preset import TestPreset
 37  from evidently.legacy.tests.base_test import Test
 38  from evidently.legacy.tests.base_test import TestParameters
 39  from evidently.legacy.ui.components.base import Component as ComponentLegacy
 40  from evidently.legacy.ui.dashboards.base import DashboardPanel
 41  from evidently.legacy.utils.llm.prompts import PromptBlock
 42  from evidently.legacy.utils.llm.prompts import PromptTemplate
 43  from evidently.llm.datagen.base import BaseDatasetGenerator
 44  from evidently.llm.optimization.optimizer import OptimizerConfig
 45  from evidently.llm.optimization.optimizer import OptimizerLog
 46  from evidently.llm.optimization.prompts import OptimizationScorer
 47  from evidently.llm.optimization.prompts import PromptExecutor
 48  from evidently.llm.optimization.prompts import PromptOptimizerStrategy
 49  from evidently.llm.prompts.content import PromptContent
 50  from evidently.llm.rag.index import DataCollectionProvider
 51  from evidently.llm.rag.splitter import Splitter
 52  from evidently.pydantic_utils import TYPE_ALIASES
 53  from evidently.pydantic_utils import EvidentlyBaseModel
 54  from evidently.pydantic_utils import PolymorphicModel
 55  from evidently.pydantic_utils import WithTestAndMetricDependencies
 56  from evidently.pydantic_utils import get_base_class
 57  from evidently.pydantic_utils import is_not_abstract
 58  from evidently.sdk.configs import ConfigContent
 59  from evidently.ui.service.components import DataStorageComponent
 60  from evidently.ui.service.components import MetadataStorageComponent
 61  from evidently.ui.service.components.base import Component
 62  from evidently.ui.service.components.snapshot_links import SnapshotDatasetLinksComponent
 63  from evidently.ui.service.components.storage import BlobStorageComponent
 64  from evidently.ui.service.components.storage import DatasetFileStorageComponent
 65  from evidently.ui.service.components.storage import DatasetMetadataComponent
 66  from evidently.ui.service.components.storage import StorageComponent
 67  from evidently.ui.service.datasets.data_source import DataSource
 68  from evidently.ui.service.datasets.data_source import DataSourceDTO
 69  from evidently.ui.service.datasets.filters import FilterBy
 70  from evidently.ui.service.datasets.filters import FilterByNumber
 71  from evidently.ui.service.datasets.filters import FilterByString
 72  
 73  T = TypeVar("T")
 74  
 75  
 76  # todo: deduplicate code
 77  def find_all_subclasses(
 78      base: Type[T],
 79      base_module: str = "evidently",
 80      path: str = os.path.dirname(evidently.__file__),
 81      include_abstract: bool = False,
 82  ) -> Set[Type[T]]:
 83      classes = set()
 84      for mod in glob.glob(path + "/**/*.py", recursive=True):
 85          mod_path = os.path.relpath(mod, path)[:-3]
 86          mod_name = f"{base_module}." + mod_path.replace("/", ".").replace("\\", ".")
 87          if mod_name.endswith("__"):
 88              continue
 89          module = import_module(mod_name)
 90          for key, value in module.__dict__.items():
 91              if isinstance(value, type) and value is not base and issubclass(value, base):
 92                  if not isabstract(value) or include_abstract:
 93                      classes.add(value)
 94  
 95      return classes
 96  
 97  
 98  REGISTRY_MAPPING: Dict[Type[PolymorphicModel], str] = {
 99      # legacy
100      Test: "evidently.legacy.tests._registry",
101      TestParameters: "evidently.legacy.tests._registry",
102      TestPreset: "evidently.legacy.test_preset._registry",
103      MetricResult: "evidently.legacy.metrics._registry",
104      Metric: "evidently.legacy.metrics._registry",
105      MetricPreset: "evidently.legacy.metric_preset._registry",
106      FeatureDescriptor: "evidently.legacy.descriptors._registry",
107      GeneratedFeatures: "evidently.legacy.features._registry",
108      # new api
109      MetricTest: registries.metric_tests.__name__,
110      MetricV2: registries.metrics.__name__,
111      MetricContainer: registries.presets.__name__,
112      MetricResultV2: registries.metric_results.__name__,
113      BoundTest: registries.bound_tests.__name__,
114      Descriptor: registries.descriptors.__name__,
115      ColumnCondition: registries.column_conditions.__name__,
116      PromptContent: registries.prompts.__name__,
117      ConfigContent: registries.configs.__name__,
118      OptimizerConfig: registries.optimizers.__name__,
119      OptimizerLog: registries.optimizers.__name__,
120      OptimizationScorer: registries.optimizers.__name__,
121      PromptExecutor: registries.optimizers.__name__,
122      PromptOptimizerStrategy: registries.optimizers.__name__,
123      PromptBlock: registries.prompt_blocks.__name__,
124      PromptTemplate: registries.prompt_templates.__name__,
125      BaseLLMPromptTemplate: registries.prompts.__name__,
126      DataCollectionProvider: registries.rag.__name__,
127      Splitter: registries.rag.__name__,
128      BaseDatasetGenerator: registries.datagen.__name__,
129      SpecialColumnInfo: registries.descriptors.__name__,
130      BlobStorageComponent: registries.components.__name__,
131      DataStorageComponent: registries.components.__name__,
132      MetadataStorageComponent: registries.components.__name__,
133      StorageComponent: registries.components.__name__,
134      DatasetFileStorageComponent: registries.components.__name__,
135      DatasetMetadataComponent: registries.components.__name__,
136      SnapshotDatasetLinksComponent: registries.components.__name__,
137      DataSource: registries.dataset_models.__name__,
138      FilterByNumber: registries.dataset_models.__name__,
139      FilterByString: registries.dataset_models.__name__,
140      DataSourceDTO: registries.dataset_models.__name__,
141  }
142  
143  
144  def test_all_aliases_registered():
145      not_registered = []
146  
147      for cls in find_all_subclasses(PolymorphicModel, include_abstract=True):
148          if cls.__is_base_type__():
149              continue
150          classpath = cls.__get_classpath__()
151          typename = cls.__get_type__()
152          if classpath == typename:
153              # no typename
154              continue
155          key = (get_base_class(cls), typename)
156          if key not in TYPE_ALIASES or TYPE_ALIASES[key] != classpath:
157              not_registered.append(cls)
158  
159      register_msgs = []
160      file_to_type = defaultdict(list)
161      for cls in sorted(not_registered, key=lambda c: get_base_class(c).__name__ + " " + c.__get_classpath__()):
162          base_class = get_base_class(cls)
163          msg = f'register_type_alias({base_class.__name__}, "{cls.__get_classpath__()}", "{cls.__get_type__()}")'
164          if base_class not in REGISTRY_MAPPING:
165              register_msgs.append(msg)
166              continue
167          file_to_type[REGISTRY_MAPPING[base_class]].append(msg)
168  
169      for file, msgs in file_to_type.items():
170          mod = import_string(file)
171          with open(mod.__file__, "a") as f:
172              f.write("\n")
173              f.write("\n".join(msgs))
174      print("\n".join(register_msgs))
175      assert len(not_registered) == 0, "Not all aliases registered"
176  
177  
178  @pytest.mark.parametrize(
179      "base_class,classpath", [(base_class, classpath) for (base_class, _), classpath in TYPE_ALIASES.items()]
180  )
181  def test_all_registered_classpath_exist(base_class: Type[PolymorphicModel], classpath):
182      try:
183          base_class.load_alias(classpath)
184      except ImportError:
185          assert False, f"wrong classpath registered '{classpath}'"
186  
187  
188  def test_all_aliases_correct():
189      base_class_type_mapping = {
190          Metric: "metric",
191          Test: "test",
192          GeneratedFeatures: "feature",
193          BaseDescriptor: "descriptor",
194          MetricPreset: "metric_preset",
195          TestPreset: "test_preset",
196          MetricResult: "metric_result",
197          DriftMethod: "drift_method",
198          TestParameters: "test_parameters",
199          ColumnName: "base",
200          CollectorTrigger: "collector_trigger",
201          CollectorStorage: "collector_storage",
202          BaseLLMPromptTemplate: "prompt_template",
203          DashboardPanel: "dashboard_panel",
204          PromptBlock: "prompt_block",
205          PromptTemplate: "prompt_template",
206          MetricV2: MetricV2.__alias_type__,
207          MetricResultV2: MetricResultV2.__alias_type__,
208          MetricTest: MetricTest.__alias_type__,
209          BoundTest: BoundTest.__alias_type__,
210          Descriptor: Descriptor.__alias_type__,
211          MetricContainer: MetricContainer.__alias_type__,
212          ColumnCondition: ColumnCondition.__alias_type__,
213          PromptContent: PromptContent.__alias_type__,
214          ConfigContent: ConfigContent.__alias_type__,
215          OptimizerConfig: OptimizerConfig.__alias_type__,
216          OptimizerLog: OptimizerLog.__alias_type__,
217          OptimizationScorer: OptimizationScorer.__alias_type__,
218          PromptExecutor: PromptExecutor.__alias_type__,
219          PromptOptimizerStrategy: PromptOptimizerStrategy.__alias_type__,
220          DataCollectionProvider: DataCollectionProvider.__alias_type__,
221          Splitter: Splitter.__alias_type__,
222          BaseDatasetGenerator: BaseDatasetGenerator.__alias_type__,
223          SpecialColumnInfo: SpecialColumnInfo.__alias_type__,
224          DataSourceDTO: DataSourceDTO.__alias_type__,
225          DataSource: DataSource.__alias_type__,
226      }
227      skip = [
228          Component,
229          ComponentLegacy,
230          FilterBy,
231      ]
232      skip_literal = [EvidentlyBaseModel, WithTestAndMetricDependencies, BasePreset]
233      for cls in find_all_subclasses(PolymorphicModel, include_abstract=True):
234          if cls in skip_literal or any(issubclass(cls, s) for s in skip) or not is_not_abstract(cls):
235              continue
236          for base_class, base_type in base_class_type_mapping.items():
237              if issubclass(cls, base_class):
238                  # alias = getattr(cls.__config__, "type_alias")
239                  alias = cls.__get_type__()
240                  assert alias is not None, f"{cls.__name__} has no alias ({alias})"
241                  assert alias == f"evidently:{base_type}:{cls.__name__}", f"wrong alias for {cls.__name__}"
242                  break
243          else:
244              assert False, f"No base class type mapping for {cls}"