/ tests / models / test_dependencies_schema.py
test_dependencies_schema.py
  1  from unittest import mock
  2  
  3  from mlflow.models import dependencies_schemas
  4  from mlflow.models.dependencies_schemas import (
  5      DependenciesSchemas,
  6      DependenciesSchemasType,
  7      RetrieverSchema,
  8      _get_dependencies_schemas,
  9      _get_retriever_schema,
 10      set_retriever_schema,
 11  )
 12  
 13  
 14  def test_retriever_creation():
 15      vsi = RetrieverSchema(
 16          name="index-name",
 17          primary_key="primary-key",
 18          text_column="text-column",
 19          doc_uri="doc-uri",
 20          other_columns=["column1", "column2"],
 21      )
 22      assert vsi.name == "index-name"
 23      assert vsi.primary_key == "primary-key"
 24      assert vsi.text_column == "text-column"
 25      assert vsi.doc_uri == "doc-uri"
 26      assert vsi.other_columns == ["column1", "column2"]
 27  
 28  
 29  def test_retriever_to_dict():
 30      vsi = RetrieverSchema(
 31          name="index-name",
 32          primary_key="primary-key",
 33          text_column="text-column",
 34          doc_uri="doc-uri",
 35          other_columns=["column1", "column2"],
 36      )
 37      expected_dict = {
 38          DependenciesSchemasType.RETRIEVERS.value: [
 39              {
 40                  "name": "index-name",
 41                  "primary_key": "primary-key",
 42                  "text_column": "text-column",
 43                  "doc_uri": "doc-uri",
 44                  "other_columns": ["column1", "column2"],
 45              }
 46          ]
 47      }
 48      assert vsi.to_dict() == expected_dict
 49  
 50  
 51  def test_retriever_from_dict():
 52      data = {
 53          "name": "index-name",
 54          "primary_key": "primary-key",
 55          "text_column": "text-column",
 56          "doc_uri": "doc-uri",
 57          "other_columns": ["column1", "column2"],
 58      }
 59      vsi = RetrieverSchema.from_dict(data)
 60      assert vsi.name == "index-name"
 61      assert vsi.primary_key == "primary-key"
 62      assert vsi.text_column == "text-column"
 63      assert vsi.doc_uri == "doc-uri"
 64      assert vsi.other_columns == ["column1", "column2"]
 65  
 66  
 67  def test_dependencies_schemas_to_dict():
 68      vsi = RetrieverSchema(
 69          name="index-name",
 70          primary_key="primary-key",
 71          text_column="text-column",
 72          doc_uri="doc-uri",
 73          other_columns=["column1", "column2"],
 74      )
 75      schema = DependenciesSchemas(retriever_schemas=[vsi])
 76      expected_dict = {
 77          "dependencies_schemas": {
 78              DependenciesSchemasType.RETRIEVERS.value: [
 79                  {
 80                      "name": "index-name",
 81                      "primary_key": "primary-key",
 82                      "text_column": "text-column",
 83                      "doc_uri": "doc-uri",
 84                      "other_columns": ["column1", "column2"],
 85                  }
 86              ]
 87          }
 88      }
 89      assert schema.to_dict() == expected_dict
 90  
 91  
 92  def test_set_retriever_schema_creation():
 93      set_retriever_schema(
 94          primary_key="primary-key",
 95          text_column="text-column",
 96          doc_uri="doc-uri",
 97          other_columns=["column1", "column2"],
 98      )
 99      with _get_dependencies_schemas() as schema:
100          assert schema.to_dict()["dependencies_schemas"] == {
101              DependenciesSchemasType.RETRIEVERS.value: [
102                  {
103                      "doc_uri": "doc-uri",
104                      "name": "retriever",
105                      "other_columns": ["column1", "column2"],
106                      "primary_key": "primary-key",
107                      "text_column": "text-column",
108                  }
109              ]
110          }
111  
112      # Schema is automatically reset
113      with _get_dependencies_schemas() as schema:
114          assert schema.to_dict() is None
115      assert _get_retriever_schema() == []
116  
117  
118  def test_set_retriever_schema_creation_with_name():
119      set_retriever_schema(
120          name="my_ret_2",
121          primary_key="primary-key",
122          text_column="text-column",
123          doc_uri="doc-uri",
124          other_columns=["column1", "column2"],
125      )
126      with _get_dependencies_schemas() as schema:
127          assert schema.to_dict()["dependencies_schemas"] == {
128              DependenciesSchemasType.RETRIEVERS.value: [
129                  {
130                      "doc_uri": "doc-uri",
131                      "name": "my_ret_2",
132                      "other_columns": ["column1", "column2"],
133                      "primary_key": "primary-key",
134                      "text_column": "text-column",
135                  }
136              ]
137          }
138  
139      # Schema is automatically reset
140      with _get_dependencies_schemas() as schema:
141          assert schema.to_dict() is None
142      assert _get_retriever_schema() == []
143  
144  
145  def test_set_retriever_schema_empty_creation():
146      with _get_dependencies_schemas() as schema:
147          assert schema.to_dict() is None
148  
149  
150  def test_multiple_set_retriever_schema_creation_with_name():
151      set_retriever_schema(
152          name="my_ret_1",
153          primary_key="primary-key-2",
154          text_column="text-column-1",
155          doc_uri="doc-uri-3",
156          other_columns=["column1", "column2"],
157      )
158  
159      set_retriever_schema(
160          name="my_ret_2",
161          primary_key="primary-key",
162          text_column="text-column",
163          doc_uri="doc-uri",
164          other_columns=["column1", "column2"],
165      )
166      with _get_dependencies_schemas() as schema:
167          assert schema.to_dict()["dependencies_schemas"] == {
168              DependenciesSchemasType.RETRIEVERS.value: [
169                  {
170                      "doc_uri": "doc-uri-3",
171                      "name": "my_ret_1",
172                      "other_columns": ["column1", "column2"],
173                      "primary_key": "primary-key-2",
174                      "text_column": "text-column-1",
175                  },
176                  {
177                      "doc_uri": "doc-uri",
178                      "name": "my_ret_2",
179                      "other_columns": ["column1", "column2"],
180                      "primary_key": "primary-key",
181                      "text_column": "text-column",
182                  },
183              ]
184          }
185  
186      # Schema is automatically reset
187      with _get_dependencies_schemas() as schema:
188          assert schema.to_dict() is None
189      assert _get_retriever_schema() == []
190  
191  
192  def test_multiple_set_retriever_schema_with_same_name_with_different_schemas():
193      set_retriever_schema(
194          name="my_ret_1",
195          primary_key="primary-key-2",
196          text_column="text-column-1",
197          doc_uri="doc-uri-3",
198          other_columns=["column1", "column2"],
199      )
200      set_retriever_schema(
201          name="my_ret_2",
202          primary_key="primary-key",
203          text_column="text-column",
204          doc_uri="doc-uri",
205          other_columns=["column1", "column2"],
206      )
207  
208      with mock.patch.object(dependencies_schemas, "_logger") as mock_logger:
209          set_retriever_schema(
210              name="my_ret_1",
211              primary_key="primary-key",
212              text_column="text-column",
213              doc_uri="doc-uri",
214              other_columns=["column1", "column2"],
215          )
216          mock_logger.warning.assert_called_once_with(
217              "A retriever schema with the name 'my_ret_1' already exists. "
218              "Overriding the existing schema."
219          )
220  
221      with _get_dependencies_schemas() as schema:
222          assert schema.to_dict()["dependencies_schemas"] == {
223              DependenciesSchemasType.RETRIEVERS.value: [
224                  {
225                      "doc_uri": "doc-uri",
226                      "name": "my_ret_1",
227                      "other_columns": ["column1", "column2"],
228                      "primary_key": "primary-key",
229                      "text_column": "text-column",
230                  },
231                  {
232                      "doc_uri": "doc-uri",
233                      "name": "my_ret_2",
234                      "other_columns": ["column1", "column2"],
235                      "primary_key": "primary-key",
236                      "text_column": "text-column",
237                  },
238              ]
239          }
240  
241  
242  def test_multiple_set_retriever_schema_with_same_name_with_same_schema():
243      set_retriever_schema(
244          name="my_ret_1",
245          primary_key="primary-key",
246          text_column="text-column",
247          doc_uri="doc-uri",
248          other_columns=["column1", "column2"],
249      )
250      set_retriever_schema(
251          name="my_ret_2",
252          primary_key="primary-key",
253          text_column="text-column",
254          doc_uri="doc-uri",
255          other_columns=["column1", "column2"],
256      )
257  
258      with mock.patch.object(dependencies_schemas, "_logger") as mock_logger:
259          set_retriever_schema(
260              name="my_ret_1",
261              primary_key="primary-key",
262              text_column="text-column",
263              doc_uri="doc-uri",
264              other_columns=["column1", "column2"],
265          )
266          mock_logger.warning.assert_not_called()
267  
268      with _get_dependencies_schemas() as schema:
269          assert schema.to_dict()["dependencies_schemas"] == {
270              DependenciesSchemasType.RETRIEVERS.value: [
271                  {
272                      "doc_uri": "doc-uri",
273                      "name": "my_ret_1",
274                      "other_columns": ["column1", "column2"],
275                      "primary_key": "primary-key",
276                      "text_column": "text-column",
277                  },
278                  {
279                      "doc_uri": "doc-uri",
280                      "name": "my_ret_2",
281                      "other_columns": ["column1", "column2"],
282                      "primary_key": "primary-key",
283                      "text_column": "text-column",
284                  },
285              ]
286          }