/ examples / rest_api / mlflow_tracking_rest_api.py
mlflow_tracking_rest_api.py
  1  """
  2  This simple example shows how you could use MLflow REST API to create new
  3  runs inside an experiment to log parameters/metrics.  Using MLflow REST API
  4  instead of MLflow library might be useful to embed in an application where
  5  you don't want to depend on the whole MLflow library, or to make
  6  your own HTTP requests in another programming language (not Python).
  7  For more details on MLflow REST API endpoints check the following page:
  8  
  9  https://www.mlflow.org/docs/latest/rest-api.html
 10  """
 11  
 12  import argparse
 13  import os
 14  import pwd
 15  
 16  import requests
 17  
 18  from mlflow.utils.time import get_current_time_millis
 19  
 20  _DEFAULT_USER_ID = "unknown"
 21  
 22  
 23  class MlflowTrackingRestApi:
 24      def __init__(self, hostname, port, experiment_id):
 25          self.base_url = "http://" + hostname + ":" + str(port) + "/api/2.0/mlflow"
 26          self.experiment_id = experiment_id
 27          self.run_id = self.create_run()
 28  
 29      def create_run(self):
 30          """Create a new run for tracking."""
 31          url = self.base_url + "/runs/create"
 32          # user_id is deprecated and will be removed from the API in a future release
 33          payload = {
 34              "experiment_id": self.experiment_id,
 35              "start_time": get_current_time_millis(),
 36              "user_id": _get_user_id(),
 37          }
 38          r = requests.post(url, json=payload)
 39          run_id = None
 40          if r.status_code == 200:
 41              run_id = r.json()["run"]["info"]["run_uuid"]
 42          else:
 43              print("Creating run failed!")
 44          return run_id
 45  
 46      def search_experiments(self):
 47          """Get all experiments."""
 48          url = self.base_url + "/experiments/search"
 49          r = requests.get(url)
 50          experiments = None
 51          if r.status_code == 200:
 52              experiments = r.json()["experiments"]
 53          return experiments
 54  
 55      def log_param(self, param):
 56          """Log a parameter dict for the given run."""
 57          url = self.base_url + "/runs/log-parameter"
 58          payload = {"run_id": self.run_id, "key": param["key"], "value": param["value"]}
 59          r = requests.post(url, json=payload)
 60          return r.status_code
 61  
 62      def log_metric(self, metric):
 63          """Log a metric dict for the given run."""
 64          url = self.base_url + "/runs/log-metric"
 65          payload = {
 66              "run_id": self.run_id,
 67              "key": metric["key"],
 68              "value": metric["value"],
 69              "timestamp": metric["timestamp"],
 70              "step": metric["step"],
 71          }
 72          r = requests.post(url, json=payload)
 73          return r.status_code
 74  
 75  
 76  def _get_user_id():
 77      """Get the ID of the user for the current run."""
 78      try:
 79          return pwd.getpwuid(os.getuid())[0]
 80      except ImportError:
 81          return _DEFAULT_USER_ID
 82  
 83  
 84  if __name__ == "__main__":
 85      # Command-line arguments
 86      parser = argparse.ArgumentParser(description="MLflow REST API Example")
 87  
 88      parser.add_argument(
 89          "--hostname",
 90          type=str,
 91          default="localhost",
 92          dest="hostname",
 93          help="MLflow server hostname/ip (default: localhost)",
 94      )
 95  
 96      parser.add_argument(
 97          "--port",
 98          type=int,
 99          default=5000,
100          dest="port",
101          help="MLflow server port number (default: 5000)",
102      )
103  
104      parser.add_argument(
105          "--experiment-id",
106          type=int,
107          default=0,
108          dest="experiment_id",
109          help="Experiment ID (default: 0)",
110      )
111  
112      print("Running mlflow_tracking_rest_api.py")
113  
114      args = parser.parse_args()
115  
116      mlflow_rest = MlflowTrackingRestApi(args.hostname, args.port, args.experiment_id)
117      # Parameter is a key/val pair (str types)
118      param = {"key": "alpha", "value": "0.1980"}
119      status_code = mlflow_rest.log_param(param)
120      if status_code == 200:
121          print(
122              "Successfully logged parameter: {} with value: {}".format(param["key"], param["value"])
123          )
124      else:
125          print("Logging parameter failed!")
126      # Metric is a key/val pair (key/val have str/float types)
127      metric = {
128          "key": "precision",
129          "value": 0.769,
130          "timestamp": get_current_time_millis(),
131          "step": 1,
132      }
133      status_code = mlflow_rest.log_metric(metric)
134      if status_code == 200:
135          print(
136              "Successfully logged parameter: {} with value: {}".format(
137                  metric["key"], metric["value"]
138              )
139          )
140      else:
141          print("Logging metric failed!")