test_resources.py
1 import pytest 2 3 from mlflow.models.resources import ( 4 DEFAULT_API_VERSION, 5 DatabricksApp, 6 DatabricksFunction, 7 DatabricksGenieSpace, 8 DatabricksLakebase, 9 DatabricksServingEndpoint, 10 DatabricksSQLWarehouse, 11 DatabricksTable, 12 DatabricksUCConnection, 13 DatabricksVectorSearchIndex, 14 _ResourceBuilder, 15 ) 16 17 18 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 19 def test_serving_endpoint(on_behalf_of_user): 20 endpoint = DatabricksServingEndpoint( 21 endpoint_name="llm_server", on_behalf_of_user=on_behalf_of_user 22 ) 23 expected = ( 24 {"serving_endpoint": [{"name": "llm_server"}]} 25 if on_behalf_of_user is None 26 else {"serving_endpoint": [{"name": "llm_server", "on_behalf_of_user": on_behalf_of_user}]} 27 ) 28 assert endpoint.to_dict() == expected 29 assert _ResourceBuilder.from_resources([endpoint]) == { 30 "api_version": DEFAULT_API_VERSION, 31 "databricks": expected, 32 } 33 34 35 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 36 def test_index_name(on_behalf_of_user): 37 index = DatabricksVectorSearchIndex(index_name="index1", on_behalf_of_user=on_behalf_of_user) 38 expected = ( 39 {"vector_search_index": [{"name": "index1"}]} 40 if on_behalf_of_user is None 41 else {"vector_search_index": [{"name": "index1", "on_behalf_of_user": on_behalf_of_user}]} 42 ) 43 assert index.to_dict() == expected 44 assert _ResourceBuilder.from_resources([index]) == { 45 "api_version": DEFAULT_API_VERSION, 46 "databricks": expected, 47 } 48 49 50 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 51 def test_sql_warehouse(on_behalf_of_user): 52 sql_warehouse = DatabricksSQLWarehouse(warehouse_id="id1", on_behalf_of_user=on_behalf_of_user) 53 expected = ( 54 {"sql_warehouse": [{"name": "id1"}]} 55 if on_behalf_of_user is None 56 else {"sql_warehouse": [{"name": "id1", "on_behalf_of_user": on_behalf_of_user}]} 57 ) 58 assert sql_warehouse.to_dict() == expected 59 assert _ResourceBuilder.from_resources([sql_warehouse]) == { 60 "api_version": DEFAULT_API_VERSION, 61 "databricks": expected, 62 } 63 64 65 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 66 def test_uc_function(on_behalf_of_user): 67 uc_function = DatabricksFunction(function_name="function", on_behalf_of_user=on_behalf_of_user) 68 expected = ( 69 {"function": [{"name": "function"}]} 70 if on_behalf_of_user is None 71 else {"function": [{"name": "function", "on_behalf_of_user": on_behalf_of_user}]} 72 ) 73 assert uc_function.to_dict() == expected 74 assert _ResourceBuilder.from_resources([uc_function]) == { 75 "api_version": DEFAULT_API_VERSION, 76 "databricks": expected, 77 } 78 79 80 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 81 def test_genie_space(on_behalf_of_user): 82 genie_space = DatabricksGenieSpace(genie_space_id="id1", on_behalf_of_user=on_behalf_of_user) 83 expected = ( 84 {"genie_space": [{"name": "id1"}]} 85 if on_behalf_of_user is None 86 else {"genie_space": [{"name": "id1", "on_behalf_of_user": on_behalf_of_user}]} 87 ) 88 89 assert genie_space.to_dict() == expected 90 assert _ResourceBuilder.from_resources([genie_space]) == { 91 "api_version": DEFAULT_API_VERSION, 92 "databricks": expected, 93 } 94 95 96 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 97 def test_uc_connection(on_behalf_of_user): 98 uc_function = DatabricksUCConnection( 99 connection_name="slack_connection", on_behalf_of_user=on_behalf_of_user 100 ) 101 expected = ( 102 {"uc_connection": [{"name": "slack_connection"}]} 103 if on_behalf_of_user is None 104 else { 105 "uc_connection": [{"name": "slack_connection", "on_behalf_of_user": on_behalf_of_user}] 106 } 107 ) 108 assert uc_function.to_dict() == expected 109 assert _ResourceBuilder.from_resources([uc_function]) == { 110 "api_version": DEFAULT_API_VERSION, 111 "databricks": expected, 112 } 113 114 115 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 116 def test_table(on_behalf_of_user): 117 table = DatabricksTable(table_name="tableName", on_behalf_of_user=on_behalf_of_user) 118 expected = ( 119 {"table": [{"name": "tableName"}]} 120 if on_behalf_of_user is None 121 else {"table": [{"name": "tableName", "on_behalf_of_user": on_behalf_of_user}]} 122 ) 123 124 assert table.to_dict() == expected 125 assert _ResourceBuilder.from_resources([table]) == { 126 "api_version": DEFAULT_API_VERSION, 127 "databricks": expected, 128 } 129 130 131 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 132 def test_app(on_behalf_of_user): 133 app = DatabricksApp(app_name="id1", on_behalf_of_user=on_behalf_of_user) 134 expected = ( 135 {"app": [{"name": "id1"}]} 136 if on_behalf_of_user is None 137 else {"app": [{"name": "id1", "on_behalf_of_user": on_behalf_of_user}]} 138 ) 139 assert app.to_dict() == expected 140 assert _ResourceBuilder.from_resources([app]) == { 141 "api_version": DEFAULT_API_VERSION, 142 "databricks": expected, 143 } 144 145 146 @pytest.mark.parametrize("on_behalf_of_user", [True, False, None]) 147 def test_lakebase(on_behalf_of_user): 148 lakebase = DatabricksLakebase( 149 database_instance_name="lakebase_name", on_behalf_of_user=on_behalf_of_user 150 ) 151 expected = ( 152 {"lakebase": [{"name": "lakebase_name"}]} 153 if on_behalf_of_user is None 154 else {"lakebase": [{"name": "lakebase_name", "on_behalf_of_user": on_behalf_of_user}]} 155 ) 156 assert lakebase.to_dict() == expected 157 assert _ResourceBuilder.from_resources([lakebase]) == { 158 "api_version": DEFAULT_API_VERSION, 159 "databricks": expected, 160 } 161 162 163 def test_resources(): 164 resources = [ 165 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 166 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 167 DatabricksServingEndpoint(endpoint_name="databricks-llama-8x7b-instruct"), 168 DatabricksSQLWarehouse(warehouse_id="id123"), 169 DatabricksFunction(function_name="rag.studio.test_function_1"), 170 DatabricksFunction(function_name="rag.studio.test_function_2"), 171 DatabricksUCConnection(connection_name="slack_connection"), 172 DatabricksApp(app_name="test_databricks_app"), 173 DatabricksLakebase(database_instance_name="test_databricks_lakebase"), 174 ] 175 expected = { 176 "api_version": DEFAULT_API_VERSION, 177 "databricks": { 178 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 179 "serving_endpoint": [ 180 {"name": "databricks-mixtral-8x7b-instruct"}, 181 {"name": "databricks-llama-8x7b-instruct"}, 182 ], 183 "sql_warehouse": [{"name": "id123"}], 184 "function": [ 185 {"name": "rag.studio.test_function_1"}, 186 {"name": "rag.studio.test_function_2"}, 187 ], 188 "uc_connection": [{"name": "slack_connection"}], 189 "app": [{"name": "test_databricks_app"}], 190 "lakebase": [{"name": "test_databricks_lakebase"}], 191 }, 192 } 193 194 assert _ResourceBuilder.from_resources(resources) == expected 195 196 197 def test_invoker_resources(): 198 resources = [ 199 DatabricksVectorSearchIndex( 200 index_name="rag.studio_bugbash.databricks_docs_index", on_behalf_of_user=True 201 ), 202 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 203 DatabricksServingEndpoint( 204 endpoint_name="databricks-llama-8x7b-instruct", on_behalf_of_user=True 205 ), 206 DatabricksSQLWarehouse(warehouse_id="id123"), 207 DatabricksFunction(function_name="rag.studio.test_function_1"), 208 DatabricksFunction(function_name="rag.studio.test_function_2", on_behalf_of_user=True), 209 DatabricksUCConnection(connection_name="slack_connection"), 210 ] 211 expected = { 212 "api_version": DEFAULT_API_VERSION, 213 "databricks": { 214 "vector_search_index": [ 215 {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True} 216 ], 217 "serving_endpoint": [ 218 {"name": "databricks-mixtral-8x7b-instruct"}, 219 {"name": "databricks-llama-8x7b-instruct", "on_behalf_of_user": True}, 220 ], 221 "sql_warehouse": [{"name": "id123"}], 222 "function": [ 223 {"name": "rag.studio.test_function_1"}, 224 {"name": "rag.studio.test_function_2", "on_behalf_of_user": True}, 225 ], 226 "uc_connection": [{"name": "slack_connection"}], 227 }, 228 } 229 230 assert _ResourceBuilder.from_resources(resources) == expected 231 232 233 def test_resources_from_yaml(tmp_path): 234 yaml_file = tmp_path.joinpath("resources.yaml") 235 with open(yaml_file, "w") as f: 236 f.write( 237 """ 238 api_version: "1" 239 databricks: 240 vector_search_index: 241 - name: rag.studio_bugbash.databricks_docs_index 242 serving_endpoint: 243 - name: databricks-mixtral-8x7b-instruct 244 - name: databricks-llama-8x7b-instruct 245 sql_warehouse: 246 - name: id123 247 function: 248 - name: rag.studio.test_function_1 249 - name: rag.studio.test_function_2 250 lakebase: 251 - name: test_databricks_lakebase 252 uc_connection: 253 - name: slack_connection 254 app: 255 - name: test_databricks_app 256 """ 257 ) 258 259 assert _ResourceBuilder.from_yaml_file(yaml_file) == { 260 "api_version": DEFAULT_API_VERSION, 261 "databricks": { 262 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 263 "serving_endpoint": [ 264 {"name": "databricks-mixtral-8x7b-instruct"}, 265 {"name": "databricks-llama-8x7b-instruct"}, 266 ], 267 "sql_warehouse": [{"name": "id123"}], 268 "function": [ 269 {"name": "rag.studio.test_function_1"}, 270 {"name": "rag.studio.test_function_2"}, 271 ], 272 "uc_connection": [{"name": "slack_connection"}], 273 "app": [{"name": "test_databricks_app"}], 274 "lakebase": [{"name": "test_databricks_lakebase"}], 275 }, 276 } 277 278 with pytest.raises(OSError, match="No such file or directory: 'no-file.yaml'"): 279 _ResourceBuilder.from_yaml_file("no-file.yaml") 280 281 incorrect_version = tmp_path.joinpath("incorrect_file.yaml") 282 with open(incorrect_version, "w") as f: 283 f.write( 284 """ 285 api_version: "v1" 286 """ 287 ) 288 289 with pytest.raises(ValueError, match="Unsupported API version: v1"): 290 _ResourceBuilder.from_yaml_file(incorrect_version) 291 292 incorrect_target_uri = tmp_path.joinpath("incorrect_target_uri.yaml") 293 with open(incorrect_target_uri, "w") as f: 294 f.write( 295 """ 296 api_version: "1" 297 databricks-aa: 298 vector_search_index_name: 299 - name: rag.studio_bugbash.databricks_docs_index 300 """ 301 ) 302 303 with pytest.raises(ValueError, match="Unsupported target URI: databricks-aa"): 304 _ResourceBuilder.from_yaml_file(incorrect_target_uri) 305 306 incorrect_resource = tmp_path.joinpath("incorrect_resource.yaml") 307 with open(incorrect_resource, "w") as f: 308 f.write( 309 """ 310 api_version: "1" 311 databricks: 312 vector_search_index_name: 313 - name: rag.studio_bugbash.databricks_docs_index 314 """ 315 ) 316 317 with pytest.raises(ValueError, match="Unsupported resource type: vector_search_index_name"): 318 _ResourceBuilder.from_yaml_file(incorrect_resource) 319 320 invokers_yaml_file = tmp_path.joinpath("invokers_resources.yaml") 321 with open(invokers_yaml_file, "w") as f: 322 f.write( 323 """ 324 api_version: "1" 325 databricks: 326 vector_search_index: 327 - name: rag.studio_bugbash.databricks_docs_index 328 on_behalf_of_user: true 329 serving_endpoint: 330 - name: databricks-mixtral-8x7b-instruct 331 - name: databricks-llama-8x7b-instruct 332 on_behalf_of_user: true 333 sql_warehouse: 334 - name: id123 335 function: 336 - name: rag.studio.test_function_1 337 on_behalf_of_user: true 338 - name: rag.studio.test_function_2 339 uc_connection: 340 - name: slack_connection 341 on_behalf_of_user: true 342 """ 343 ) 344 345 assert _ResourceBuilder.from_yaml_file(invokers_yaml_file) == { 346 "api_version": DEFAULT_API_VERSION, 347 "databricks": { 348 "vector_search_index": [ 349 {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True} 350 ], 351 "serving_endpoint": [ 352 {"name": "databricks-mixtral-8x7b-instruct"}, 353 {"name": "databricks-llama-8x7b-instruct", "on_behalf_of_user": True}, 354 ], 355 "sql_warehouse": [{"name": "id123"}], 356 "function": [ 357 {"name": "rag.studio.test_function_1", "on_behalf_of_user": True}, 358 {"name": "rag.studio.test_function_2"}, 359 ], 360 "uc_connection": [{"name": "slack_connection", "on_behalf_of_user": True}], 361 }, 362 }