promptlab_utils.py
1 import json 2 import os 3 import tempfile 4 import time 5 from datetime import datetime, timezone 6 7 from mlflow.entities.param import Param 8 from mlflow.entities.run_status import RunStatus 9 from mlflow.entities.run_tag import RunTag 10 from mlflow.utils.file_utils import make_containing_dirs, write_to 11 from mlflow.utils.mlflow_tags import MLFLOW_LOGGED_ARTIFACTS, MLFLOW_RUN_SOURCE_TYPE 12 from mlflow.version import VERSION as __version__ 13 14 15 def create_eval_results_json(prompt_parameters, model_input, model_output_parameters, model_output): 16 columns = [param.key for param in prompt_parameters] + ["prompt", "output"] 17 data = [param.value for param in prompt_parameters] + [model_input, model_output] 18 19 updated_columns = columns + [param.key for param in model_output_parameters] 20 updated_data = data + [param.value for param in model_output_parameters] 21 22 eval_results = {"columns": updated_columns, "data": [updated_data]} 23 24 return json.dumps(eval_results) 25 26 27 def _create_promptlab_run_impl( 28 store, 29 experiment_id: str, 30 run_name: str, 31 tags: list[RunTag], 32 prompt_template: str, 33 prompt_parameters: list[Param], 34 model_route: str, 35 model_parameters: list[Param], 36 model_input: str, 37 model_output_parameters: list[Param], 38 model_output: str, 39 mlflow_version: str, 40 user_id: str, 41 start_time: str, 42 ): 43 run = store.create_run(experiment_id, user_id, start_time, tags, run_name) 44 run_id = run.info.run_id 45 46 try: 47 prompt_parameters = [ 48 Param(key=param.key, value=str(param.value)) for param in prompt_parameters 49 ] 50 model_parameters = [ 51 Param(key=param.key, value=str(param.value)) for param in model_parameters 52 ] 53 model_output_parameters = [ 54 Param(key=param.key, value=str(param.value)) for param in model_output_parameters 55 ] 56 57 # log model parameters 58 parameters_to_log = [ 59 *model_parameters, 60 Param("model_route", model_route), 61 Param("prompt_template", prompt_template), 62 ] 63 64 tags_to_log = [ 65 RunTag( 66 MLFLOW_LOGGED_ARTIFACTS, 67 json.dumps([{"path": "eval_results_table.json", "type": "table"}]), 68 ), 69 RunTag(MLFLOW_RUN_SOURCE_TYPE, "PROMPT_ENGINEERING"), 70 ] 71 72 store.log_batch(run_id, [], parameters_to_log, tags_to_log) 73 74 # log model 75 from mlflow.models import Model 76 77 artifact_dir = store.get_run(run_id).info.artifact_uri 78 79 utc_time_created = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") 80 promptlab_model = Model( 81 artifact_path="model", 82 run_id=run_id, 83 utc_time_created=utc_time_created, 84 ) 85 store.record_logged_model(run_id, promptlab_model) 86 87 try: 88 from mlflow.models.signature import ModelSignature 89 from mlflow.types.schema import ColSpec, DataType, Schema 90 except ImportError: 91 signature = None 92 else: 93 inputs_colspecs = [ColSpec(DataType.string, param.key) for param in prompt_parameters] 94 outputs_colspecs = [ColSpec(DataType.string, "output")] 95 signature = ModelSignature( 96 inputs=Schema(inputs_colspecs), 97 outputs=Schema(outputs_colspecs), 98 ) 99 100 from mlflow.prompt.promptlab_model import save_model 101 from mlflow.server.handlers import ( 102 _get_artifact_repo_mlflow_artifacts, 103 _get_proxied_run_artifact_destination_path, 104 _is_servable_proxied_run_artifact_root, 105 ) 106 107 # write artifact files 108 from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository 109 110 with tempfile.TemporaryDirectory() as local_dir: 111 save_model( 112 mlflow_model=promptlab_model, 113 path=os.path.join(local_dir, "model"), 114 signature=signature, 115 input_example={"inputs": [param.value for param in prompt_parameters]}, 116 prompt_template=prompt_template, 117 prompt_parameters=prompt_parameters, 118 model_parameters=model_parameters, 119 model_route=model_route, 120 pip_requirements=[f"mlflow[gateway]=={__version__}"], 121 ) 122 123 eval_results_json = create_eval_results_json( 124 prompt_parameters, model_input, model_output_parameters, model_output 125 ) 126 eval_results_json_file_path = os.path.join(local_dir, "eval_results_table.json") 127 make_containing_dirs(eval_results_json_file_path) 128 write_to(eval_results_json_file_path, eval_results_json) 129 130 if _is_servable_proxied_run_artifact_root(run.info.artifact_uri): 131 artifact_repo = _get_artifact_repo_mlflow_artifacts() 132 artifact_path = _get_proxied_run_artifact_destination_path( 133 proxied_artifact_root=run.info.artifact_uri, 134 ) 135 artifact_repo.log_artifacts(local_dir, artifact_path=artifact_path) 136 else: 137 artifact_repo = get_artifact_repository(artifact_dir) 138 artifact_repo.log_artifacts(local_dir) 139 140 except Exception: 141 store.update_run_info(run_id, RunStatus.FAILED, int(time.time() * 1000), run_name) 142 else: 143 # end time is the current number of milliseconds since the UNIX epoch. 144 store.update_run_info(run_id, RunStatus.FINISHED, int(time.time() * 1000), run_name) 145 146 return store.get_run(run_id=run_id)