databricks_request_header_provider.py
1 from mlflow.tracking.request_header.abstract_request_header_provider import RequestHeaderProvider 2 from mlflow.utils import databricks_utils 3 4 5 class DatabricksRequestHeaderProvider(RequestHeaderProvider): 6 """ 7 Provides request headers indicating the type of Databricks environment from which a request 8 was made. 9 """ 10 11 def in_context(self): 12 return ( 13 databricks_utils.is_in_cluster() 14 or databricks_utils.is_in_databricks_notebook() 15 or databricks_utils.is_in_databricks_job() 16 ) 17 18 def request_headers(self): 19 request_headers = {} 20 if databricks_utils.is_in_databricks_notebook(): 21 request_headers["notebook_id"] = databricks_utils.get_notebook_id() 22 if databricks_utils.is_in_databricks_job(): 23 request_headers["job_id"] = databricks_utils.get_job_id() 24 request_headers["job_run_id"] = databricks_utils.get_job_run_id() 25 request_headers["job_type"] = databricks_utils.get_job_type() 26 if databricks_utils.is_in_cluster(): 27 request_headers["cluster_id"] = databricks_utils.get_cluster_id() 28 workload_id = databricks_utils.get_workload_id() 29 workload_class = databricks_utils.get_workload_class() 30 if workload_id is not None: 31 request_headers["workload_id"] = workload_id 32 if workload_class is not None: 33 request_headers["workload_class"] = workload_class 34 35 return request_headers