/ tests / models / test_auth_policy.py
test_auth_policy.py
  1  from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
  2  from mlflow.models.resources import (
  3      DatabricksFunction,
  4      DatabricksServingEndpoint,
  5      DatabricksUCConnection,
  6      DatabricksVectorSearchIndex,
  7  )
  8  
  9  
 10  def test_complete_auth_policy():
 11      system_auth_policy = SystemAuthPolicy(
 12          resources=[
 13              DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
 14              DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
 15              DatabricksFunction(function_name="rag.studio.test_function_a"),
 16              DatabricksUCConnection(connection_name="test_connection_1"),
 17          ]
 18      )
 19  
 20      user_auth_policy = UserAuthPolicy(
 21          api_scopes=[
 22              "catalog.catalogs",
 23              "vectorsearch.vector-search-indexes",
 24              "workspace.workspace",
 25          ]
 26      )
 27  
 28      auth_policy = AuthPolicy(
 29          user_auth_policy=user_auth_policy, system_auth_policy=system_auth_policy
 30      )
 31  
 32      serialized_auth_policy = auth_policy.to_dict()
 33  
 34      expected_serialized_auth_policy = {
 35          "user_auth_policy": {
 36              "api_scopes": [
 37                  "catalog.catalogs",
 38                  "vectorsearch.vector-search-indexes",
 39                  "workspace.workspace",
 40              ]
 41          },
 42          "system_auth_policy": {
 43              "resources": {
 44                  "databricks": {
 45                      "serving_endpoint": [{"name": "databricks-mixtral-8x7b-instruct"}],
 46                      "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
 47                      "function": [{"name": "rag.studio.test_function_a"}],
 48                      "uc_connection": [{"name": "test_connection_1"}],
 49                  },
 50                  "api_version": "1",
 51              }
 52          },
 53      }
 54      assert serialized_auth_policy == expected_serialized_auth_policy
 55  
 56  
 57  def test_user_auth_policy():
 58      user_auth_policy = UserAuthPolicy(
 59          api_scopes=[
 60              "catalog.catalogs",
 61              "vectorsearch.vector-search-indexes",
 62              "workspace.workspace",
 63          ]
 64      )
 65  
 66      auth_policy = AuthPolicy(user_auth_policy=user_auth_policy)
 67  
 68      serialized_auth_policy = auth_policy.to_dict()
 69  
 70      expected_serialized_auth_policy = {
 71          "system_auth_policy": {},
 72          "user_auth_policy": {
 73              "api_scopes": [
 74                  "catalog.catalogs",
 75                  "vectorsearch.vector-search-indexes",
 76                  "workspace.workspace",
 77              ]
 78          },
 79      }
 80      assert serialized_auth_policy == expected_serialized_auth_policy
 81  
 82  
 83  def test_system_auth_policy():
 84      system_auth_policy = SystemAuthPolicy(
 85          resources=[
 86              DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
 87              DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
 88              DatabricksFunction(function_name="rag.studio.test_function_a"),
 89              DatabricksUCConnection(connection_name="test_connection_1"),
 90          ]
 91      )
 92  
 93      auth_policy = AuthPolicy(system_auth_policy=system_auth_policy)
 94  
 95      serialized_auth_policy = auth_policy.to_dict()
 96  
 97      expected_serialized_auth_policy = {
 98          "system_auth_policy": {
 99              "resources": {
100                  "databricks": {
101                      "serving_endpoint": [{"name": "databricks-mixtral-8x7b-instruct"}],
102                      "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
103                      "function": [{"name": "rag.studio.test_function_a"}],
104                      "uc_connection": [{"name": "test_connection_1"}],
105                  },
106                  "api_version": "1",
107              }
108          },
109          "user_auth_policy": {},
110      }
111      assert serialized_auth_policy == expected_serialized_auth_policy
112  
113  
114  def test_empty_auth_policy():
115      auth_policy = AuthPolicy()
116  
117      serialized_auth_policy = auth_policy.to_dict()
118  
119      expected_serialized_auth_policy = {"system_auth_policy": {}, "user_auth_policy": {}}
120      assert serialized_auth_policy == expected_serialized_auth_policy