test_auth_policy.py
1 from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy 2 from mlflow.models.resources import ( 3 DatabricksFunction, 4 DatabricksServingEndpoint, 5 DatabricksUCConnection, 6 DatabricksVectorSearchIndex, 7 ) 8 9 10 def test_complete_auth_policy(): 11 system_auth_policy = SystemAuthPolicy( 12 resources=[ 13 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 14 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 15 DatabricksFunction(function_name="rag.studio.test_function_a"), 16 DatabricksUCConnection(connection_name="test_connection_1"), 17 ] 18 ) 19 20 user_auth_policy = UserAuthPolicy( 21 api_scopes=[ 22 "catalog.catalogs", 23 "vectorsearch.vector-search-indexes", 24 "workspace.workspace", 25 ] 26 ) 27 28 auth_policy = AuthPolicy( 29 user_auth_policy=user_auth_policy, system_auth_policy=system_auth_policy 30 ) 31 32 serialized_auth_policy = auth_policy.to_dict() 33 34 expected_serialized_auth_policy = { 35 "user_auth_policy": { 36 "api_scopes": [ 37 "catalog.catalogs", 38 "vectorsearch.vector-search-indexes", 39 "workspace.workspace", 40 ] 41 }, 42 "system_auth_policy": { 43 "resources": { 44 "databricks": { 45 "serving_endpoint": [{"name": "databricks-mixtral-8x7b-instruct"}], 46 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 47 "function": [{"name": "rag.studio.test_function_a"}], 48 "uc_connection": [{"name": "test_connection_1"}], 49 }, 50 "api_version": "1", 51 } 52 }, 53 } 54 assert serialized_auth_policy == expected_serialized_auth_policy 55 56 57 def test_user_auth_policy(): 58 user_auth_policy = UserAuthPolicy( 59 api_scopes=[ 60 "catalog.catalogs", 61 "vectorsearch.vector-search-indexes", 62 "workspace.workspace", 63 ] 64 ) 65 66 auth_policy = AuthPolicy(user_auth_policy=user_auth_policy) 67 68 serialized_auth_policy = auth_policy.to_dict() 69 70 expected_serialized_auth_policy = { 71 "system_auth_policy": {}, 72 "user_auth_policy": { 73 "api_scopes": [ 74 "catalog.catalogs", 75 "vectorsearch.vector-search-indexes", 76 "workspace.workspace", 77 ] 78 }, 79 } 80 assert serialized_auth_policy == expected_serialized_auth_policy 81 82 83 def test_system_auth_policy(): 84 system_auth_policy = SystemAuthPolicy( 85 resources=[ 86 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 87 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 88 DatabricksFunction(function_name="rag.studio.test_function_a"), 89 DatabricksUCConnection(connection_name="test_connection_1"), 90 ] 91 ) 92 93 auth_policy = AuthPolicy(system_auth_policy=system_auth_policy) 94 95 serialized_auth_policy = auth_policy.to_dict() 96 97 expected_serialized_auth_policy = { 98 "system_auth_policy": { 99 "resources": { 100 "databricks": { 101 "serving_endpoint": [{"name": "databricks-mixtral-8x7b-instruct"}], 102 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 103 "function": [{"name": "rag.studio.test_function_a"}], 104 "uc_connection": [{"name": "test_connection_1"}], 105 }, 106 "api_version": "1", 107 } 108 }, 109 "user_auth_policy": {}, 110 } 111 assert serialized_auth_policy == expected_serialized_auth_policy 112 113 114 def test_empty_auth_policy(): 115 auth_policy = AuthPolicy() 116 117 serialized_auth_policy = auth_policy.to_dict() 118 119 expected_serialized_auth_policy = {"system_auth_policy": {}, "user_auth_policy": {}} 120 assert serialized_auth_policy == expected_serialized_auth_policy