/ tests / models / test_resources.py
test_resources.py
  1  import pytest
  2  
  3  from mlflow.models.resources import (
  4      DEFAULT_API_VERSION,
  5      DatabricksApp,
  6      DatabricksFunction,
  7      DatabricksGenieSpace,
  8      DatabricksLakebase,
  9      DatabricksServingEndpoint,
 10      DatabricksSQLWarehouse,
 11      DatabricksTable,
 12      DatabricksUCConnection,
 13      DatabricksVectorSearchIndex,
 14      _ResourceBuilder,
 15  )
 16  
 17  
 18  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
 19  def test_serving_endpoint(on_behalf_of_user):
 20      endpoint = DatabricksServingEndpoint(
 21          endpoint_name="llm_server", on_behalf_of_user=on_behalf_of_user
 22      )
 23      expected = (
 24          {"serving_endpoint": [{"name": "llm_server"}]}
 25          if on_behalf_of_user is None
 26          else {"serving_endpoint": [{"name": "llm_server", "on_behalf_of_user": on_behalf_of_user}]}
 27      )
 28      assert endpoint.to_dict() == expected
 29      assert _ResourceBuilder.from_resources([endpoint]) == {
 30          "api_version": DEFAULT_API_VERSION,
 31          "databricks": expected,
 32      }
 33  
 34  
 35  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
 36  def test_index_name(on_behalf_of_user):
 37      index = DatabricksVectorSearchIndex(index_name="index1", on_behalf_of_user=on_behalf_of_user)
 38      expected = (
 39          {"vector_search_index": [{"name": "index1"}]}
 40          if on_behalf_of_user is None
 41          else {"vector_search_index": [{"name": "index1", "on_behalf_of_user": on_behalf_of_user}]}
 42      )
 43      assert index.to_dict() == expected
 44      assert _ResourceBuilder.from_resources([index]) == {
 45          "api_version": DEFAULT_API_VERSION,
 46          "databricks": expected,
 47      }
 48  
 49  
 50  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
 51  def test_sql_warehouse(on_behalf_of_user):
 52      sql_warehouse = DatabricksSQLWarehouse(warehouse_id="id1", on_behalf_of_user=on_behalf_of_user)
 53      expected = (
 54          {"sql_warehouse": [{"name": "id1"}]}
 55          if on_behalf_of_user is None
 56          else {"sql_warehouse": [{"name": "id1", "on_behalf_of_user": on_behalf_of_user}]}
 57      )
 58      assert sql_warehouse.to_dict() == expected
 59      assert _ResourceBuilder.from_resources([sql_warehouse]) == {
 60          "api_version": DEFAULT_API_VERSION,
 61          "databricks": expected,
 62      }
 63  
 64  
 65  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
 66  def test_uc_function(on_behalf_of_user):
 67      uc_function = DatabricksFunction(function_name="function", on_behalf_of_user=on_behalf_of_user)
 68      expected = (
 69          {"function": [{"name": "function"}]}
 70          if on_behalf_of_user is None
 71          else {"function": [{"name": "function", "on_behalf_of_user": on_behalf_of_user}]}
 72      )
 73      assert uc_function.to_dict() == expected
 74      assert _ResourceBuilder.from_resources([uc_function]) == {
 75          "api_version": DEFAULT_API_VERSION,
 76          "databricks": expected,
 77      }
 78  
 79  
 80  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
 81  def test_genie_space(on_behalf_of_user):
 82      genie_space = DatabricksGenieSpace(genie_space_id="id1", on_behalf_of_user=on_behalf_of_user)
 83      expected = (
 84          {"genie_space": [{"name": "id1"}]}
 85          if on_behalf_of_user is None
 86          else {"genie_space": [{"name": "id1", "on_behalf_of_user": on_behalf_of_user}]}
 87      )
 88  
 89      assert genie_space.to_dict() == expected
 90      assert _ResourceBuilder.from_resources([genie_space]) == {
 91          "api_version": DEFAULT_API_VERSION,
 92          "databricks": expected,
 93      }
 94  
 95  
 96  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
 97  def test_uc_connection(on_behalf_of_user):
 98      uc_function = DatabricksUCConnection(
 99          connection_name="slack_connection", on_behalf_of_user=on_behalf_of_user
100      )
101      expected = (
102          {"uc_connection": [{"name": "slack_connection"}]}
103          if on_behalf_of_user is None
104          else {
105              "uc_connection": [{"name": "slack_connection", "on_behalf_of_user": on_behalf_of_user}]
106          }
107      )
108      assert uc_function.to_dict() == expected
109      assert _ResourceBuilder.from_resources([uc_function]) == {
110          "api_version": DEFAULT_API_VERSION,
111          "databricks": expected,
112      }
113  
114  
115  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
116  def test_table(on_behalf_of_user):
117      table = DatabricksTable(table_name="tableName", on_behalf_of_user=on_behalf_of_user)
118      expected = (
119          {"table": [{"name": "tableName"}]}
120          if on_behalf_of_user is None
121          else {"table": [{"name": "tableName", "on_behalf_of_user": on_behalf_of_user}]}
122      )
123  
124      assert table.to_dict() == expected
125      assert _ResourceBuilder.from_resources([table]) == {
126          "api_version": DEFAULT_API_VERSION,
127          "databricks": expected,
128      }
129  
130  
131  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
132  def test_app(on_behalf_of_user):
133      app = DatabricksApp(app_name="id1", on_behalf_of_user=on_behalf_of_user)
134      expected = (
135          {"app": [{"name": "id1"}]}
136          if on_behalf_of_user is None
137          else {"app": [{"name": "id1", "on_behalf_of_user": on_behalf_of_user}]}
138      )
139      assert app.to_dict() == expected
140      assert _ResourceBuilder.from_resources([app]) == {
141          "api_version": DEFAULT_API_VERSION,
142          "databricks": expected,
143      }
144  
145  
146  @pytest.mark.parametrize("on_behalf_of_user", [True, False, None])
147  def test_lakebase(on_behalf_of_user):
148      lakebase = DatabricksLakebase(
149          database_instance_name="lakebase_name", on_behalf_of_user=on_behalf_of_user
150      )
151      expected = (
152          {"lakebase": [{"name": "lakebase_name"}]}
153          if on_behalf_of_user is None
154          else {"lakebase": [{"name": "lakebase_name", "on_behalf_of_user": on_behalf_of_user}]}
155      )
156      assert lakebase.to_dict() == expected
157      assert _ResourceBuilder.from_resources([lakebase]) == {
158          "api_version": DEFAULT_API_VERSION,
159          "databricks": expected,
160      }
161  
162  
163  def test_resources():
164      resources = [
165          DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
166          DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
167          DatabricksServingEndpoint(endpoint_name="databricks-llama-8x7b-instruct"),
168          DatabricksSQLWarehouse(warehouse_id="id123"),
169          DatabricksFunction(function_name="rag.studio.test_function_1"),
170          DatabricksFunction(function_name="rag.studio.test_function_2"),
171          DatabricksUCConnection(connection_name="slack_connection"),
172          DatabricksApp(app_name="test_databricks_app"),
173          DatabricksLakebase(database_instance_name="test_databricks_lakebase"),
174      ]
175      expected = {
176          "api_version": DEFAULT_API_VERSION,
177          "databricks": {
178              "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
179              "serving_endpoint": [
180                  {"name": "databricks-mixtral-8x7b-instruct"},
181                  {"name": "databricks-llama-8x7b-instruct"},
182              ],
183              "sql_warehouse": [{"name": "id123"}],
184              "function": [
185                  {"name": "rag.studio.test_function_1"},
186                  {"name": "rag.studio.test_function_2"},
187              ],
188              "uc_connection": [{"name": "slack_connection"}],
189              "app": [{"name": "test_databricks_app"}],
190              "lakebase": [{"name": "test_databricks_lakebase"}],
191          },
192      }
193  
194      assert _ResourceBuilder.from_resources(resources) == expected
195  
196  
197  def test_invoker_resources():
198      resources = [
199          DatabricksVectorSearchIndex(
200              index_name="rag.studio_bugbash.databricks_docs_index", on_behalf_of_user=True
201          ),
202          DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
203          DatabricksServingEndpoint(
204              endpoint_name="databricks-llama-8x7b-instruct", on_behalf_of_user=True
205          ),
206          DatabricksSQLWarehouse(warehouse_id="id123"),
207          DatabricksFunction(function_name="rag.studio.test_function_1"),
208          DatabricksFunction(function_name="rag.studio.test_function_2", on_behalf_of_user=True),
209          DatabricksUCConnection(connection_name="slack_connection"),
210      ]
211      expected = {
212          "api_version": DEFAULT_API_VERSION,
213          "databricks": {
214              "vector_search_index": [
215                  {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True}
216              ],
217              "serving_endpoint": [
218                  {"name": "databricks-mixtral-8x7b-instruct"},
219                  {"name": "databricks-llama-8x7b-instruct", "on_behalf_of_user": True},
220              ],
221              "sql_warehouse": [{"name": "id123"}],
222              "function": [
223                  {"name": "rag.studio.test_function_1"},
224                  {"name": "rag.studio.test_function_2", "on_behalf_of_user": True},
225              ],
226              "uc_connection": [{"name": "slack_connection"}],
227          },
228      }
229  
230      assert _ResourceBuilder.from_resources(resources) == expected
231  
232  
233  def test_resources_from_yaml(tmp_path):
234      yaml_file = tmp_path.joinpath("resources.yaml")
235      with open(yaml_file, "w") as f:
236          f.write(
237              """
238              api_version: "1"
239              databricks:
240                  vector_search_index:
241                  - name: rag.studio_bugbash.databricks_docs_index
242                  serving_endpoint:
243                  - name: databricks-mixtral-8x7b-instruct
244                  - name: databricks-llama-8x7b-instruct
245                  sql_warehouse:
246                  - name: id123
247                  function:
248                  - name: rag.studio.test_function_1
249                  - name: rag.studio.test_function_2
250                  lakebase:
251                  - name: test_databricks_lakebase
252                  uc_connection:
253                  - name: slack_connection
254                  app:
255                  - name: test_databricks_app
256              """
257          )
258  
259      assert _ResourceBuilder.from_yaml_file(yaml_file) == {
260          "api_version": DEFAULT_API_VERSION,
261          "databricks": {
262              "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
263              "serving_endpoint": [
264                  {"name": "databricks-mixtral-8x7b-instruct"},
265                  {"name": "databricks-llama-8x7b-instruct"},
266              ],
267              "sql_warehouse": [{"name": "id123"}],
268              "function": [
269                  {"name": "rag.studio.test_function_1"},
270                  {"name": "rag.studio.test_function_2"},
271              ],
272              "uc_connection": [{"name": "slack_connection"}],
273              "app": [{"name": "test_databricks_app"}],
274              "lakebase": [{"name": "test_databricks_lakebase"}],
275          },
276      }
277  
278      with pytest.raises(OSError, match="No such file or directory: 'no-file.yaml'"):
279          _ResourceBuilder.from_yaml_file("no-file.yaml")
280  
281      incorrect_version = tmp_path.joinpath("incorrect_file.yaml")
282      with open(incorrect_version, "w") as f:
283          f.write(
284              """
285              api_version: "v1"
286              """
287          )
288  
289      with pytest.raises(ValueError, match="Unsupported API version: v1"):
290          _ResourceBuilder.from_yaml_file(incorrect_version)
291  
292      incorrect_target_uri = tmp_path.joinpath("incorrect_target_uri.yaml")
293      with open(incorrect_target_uri, "w") as f:
294          f.write(
295              """
296              api_version: "1"
297              databricks-aa:
298                  vector_search_index_name:
299                  - name: rag.studio_bugbash.databricks_docs_index
300              """
301          )
302  
303      with pytest.raises(ValueError, match="Unsupported target URI: databricks-aa"):
304          _ResourceBuilder.from_yaml_file(incorrect_target_uri)
305  
306      incorrect_resource = tmp_path.joinpath("incorrect_resource.yaml")
307      with open(incorrect_resource, "w") as f:
308          f.write(
309              """
310              api_version: "1"
311              databricks:
312                  vector_search_index_name:
313                  - name: rag.studio_bugbash.databricks_docs_index
314              """
315          )
316  
317      with pytest.raises(ValueError, match="Unsupported resource type: vector_search_index_name"):
318          _ResourceBuilder.from_yaml_file(incorrect_resource)
319  
320      invokers_yaml_file = tmp_path.joinpath("invokers_resources.yaml")
321      with open(invokers_yaml_file, "w") as f:
322          f.write(
323              """
324              api_version: "1"
325              databricks:
326                  vector_search_index:
327                  - name: rag.studio_bugbash.databricks_docs_index
328                    on_behalf_of_user: true
329                  serving_endpoint:
330                  - name: databricks-mixtral-8x7b-instruct
331                  - name: databricks-llama-8x7b-instruct
332                    on_behalf_of_user: true
333                  sql_warehouse:
334                  - name: id123
335                  function:
336                  - name: rag.studio.test_function_1
337                    on_behalf_of_user: true
338                  - name: rag.studio.test_function_2
339                  uc_connection:
340                  - name: slack_connection
341                    on_behalf_of_user: true
342              """
343          )
344  
345      assert _ResourceBuilder.from_yaml_file(invokers_yaml_file) == {
346          "api_version": DEFAULT_API_VERSION,
347          "databricks": {
348              "vector_search_index": [
349                  {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True}
350              ],
351              "serving_endpoint": [
352                  {"name": "databricks-mixtral-8x7b-instruct"},
353                  {"name": "databricks-llama-8x7b-instruct", "on_behalf_of_user": True},
354              ],
355              "sql_warehouse": [{"name": "id123"}],
356              "function": [
357                  {"name": "rag.studio.test_function_1", "on_behalf_of_user": True},
358                  {"name": "rag.studio.test_function_2"},
359              ],
360              "uc_connection": [{"name": "slack_connection", "on_behalf_of_user": True}],
361          },
362      }