/ mlflow / utils / promptlab_utils.py
promptlab_utils.py
  1  import json
  2  import os
  3  import tempfile
  4  import time
  5  from datetime import datetime, timezone
  6  
  7  from mlflow.entities.param import Param
  8  from mlflow.entities.run_status import RunStatus
  9  from mlflow.entities.run_tag import RunTag
 10  from mlflow.utils.file_utils import make_containing_dirs, write_to
 11  from mlflow.utils.mlflow_tags import MLFLOW_LOGGED_ARTIFACTS, MLFLOW_RUN_SOURCE_TYPE
 12  from mlflow.version import VERSION as __version__
 13  
 14  
 15  def create_eval_results_json(prompt_parameters, model_input, model_output_parameters, model_output):
 16      columns = [param.key for param in prompt_parameters] + ["prompt", "output"]
 17      data = [param.value for param in prompt_parameters] + [model_input, model_output]
 18  
 19      updated_columns = columns + [param.key for param in model_output_parameters]
 20      updated_data = data + [param.value for param in model_output_parameters]
 21  
 22      eval_results = {"columns": updated_columns, "data": [updated_data]}
 23  
 24      return json.dumps(eval_results)
 25  
 26  
 27  def _create_promptlab_run_impl(
 28      store,
 29      experiment_id: str,
 30      run_name: str,
 31      tags: list[RunTag],
 32      prompt_template: str,
 33      prompt_parameters: list[Param],
 34      model_route: str,
 35      model_parameters: list[Param],
 36      model_input: str,
 37      model_output_parameters: list[Param],
 38      model_output: str,
 39      mlflow_version: str,
 40      user_id: str,
 41      start_time: str,
 42  ):
 43      run = store.create_run(experiment_id, user_id, start_time, tags, run_name)
 44      run_id = run.info.run_id
 45  
 46      try:
 47          prompt_parameters = [
 48              Param(key=param.key, value=str(param.value)) for param in prompt_parameters
 49          ]
 50          model_parameters = [
 51              Param(key=param.key, value=str(param.value)) for param in model_parameters
 52          ]
 53          model_output_parameters = [
 54              Param(key=param.key, value=str(param.value)) for param in model_output_parameters
 55          ]
 56  
 57          # log model parameters
 58          parameters_to_log = [
 59              *model_parameters,
 60              Param("model_route", model_route),
 61              Param("prompt_template", prompt_template),
 62          ]
 63  
 64          tags_to_log = [
 65              RunTag(
 66                  MLFLOW_LOGGED_ARTIFACTS,
 67                  json.dumps([{"path": "eval_results_table.json", "type": "table"}]),
 68              ),
 69              RunTag(MLFLOW_RUN_SOURCE_TYPE, "PROMPT_ENGINEERING"),
 70          ]
 71  
 72          store.log_batch(run_id, [], parameters_to_log, tags_to_log)
 73  
 74          # log model
 75          from mlflow.models import Model
 76  
 77          artifact_dir = store.get_run(run_id).info.artifact_uri
 78  
 79          utc_time_created = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
 80          promptlab_model = Model(
 81              artifact_path="model",
 82              run_id=run_id,
 83              utc_time_created=utc_time_created,
 84          )
 85          store.record_logged_model(run_id, promptlab_model)
 86  
 87          try:
 88              from mlflow.models.signature import ModelSignature
 89              from mlflow.types.schema import ColSpec, DataType, Schema
 90          except ImportError:
 91              signature = None
 92          else:
 93              inputs_colspecs = [ColSpec(DataType.string, param.key) for param in prompt_parameters]
 94              outputs_colspecs = [ColSpec(DataType.string, "output")]
 95              signature = ModelSignature(
 96                  inputs=Schema(inputs_colspecs),
 97                  outputs=Schema(outputs_colspecs),
 98              )
 99  
100          from mlflow.prompt.promptlab_model import save_model
101          from mlflow.server.handlers import (
102              _get_artifact_repo_mlflow_artifacts,
103              _get_proxied_run_artifact_destination_path,
104              _is_servable_proxied_run_artifact_root,
105          )
106  
107          # write artifact files
108          from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
109  
110          with tempfile.TemporaryDirectory() as local_dir:
111              save_model(
112                  mlflow_model=promptlab_model,
113                  path=os.path.join(local_dir, "model"),
114                  signature=signature,
115                  input_example={"inputs": [param.value for param in prompt_parameters]},
116                  prompt_template=prompt_template,
117                  prompt_parameters=prompt_parameters,
118                  model_parameters=model_parameters,
119                  model_route=model_route,
120                  pip_requirements=[f"mlflow[gateway]=={__version__}"],
121              )
122  
123              eval_results_json = create_eval_results_json(
124                  prompt_parameters, model_input, model_output_parameters, model_output
125              )
126              eval_results_json_file_path = os.path.join(local_dir, "eval_results_table.json")
127              make_containing_dirs(eval_results_json_file_path)
128              write_to(eval_results_json_file_path, eval_results_json)
129  
130              if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
131                  artifact_repo = _get_artifact_repo_mlflow_artifacts()
132                  artifact_path = _get_proxied_run_artifact_destination_path(
133                      proxied_artifact_root=run.info.artifact_uri,
134                  )
135                  artifact_repo.log_artifacts(local_dir, artifact_path=artifact_path)
136              else:
137                  artifact_repo = get_artifact_repository(artifact_dir)
138                  artifact_repo.log_artifacts(local_dir)
139  
140      except Exception:
141          store.update_run_info(run_id, RunStatus.FAILED, int(time.time() * 1000), run_name)
142      else:
143          # end time is the current number of milliseconds since the UNIX epoch.
144          store.update_run_info(run_id, RunStatus.FINISHED, int(time.time() * 1000), run_name)
145  
146      return store.get_run(run_id=run_id)