stdin_server.py
1 import argparse 2 import inspect 3 import json 4 import logging 5 import sys 6 7 from mlflow.pyfunc import scoring_server 8 from mlflow.pyfunc.model import _log_warning_if_params_not_in_predict_signature 9 10 _logger = logging.getLogger(__name__) 11 logging.basicConfig(level=logging.INFO) 12 13 parser = argparse.ArgumentParser() 14 parser.add_argument("--model-uri") 15 args = parser.parse_args() 16 17 _logger.info("Loading model from %s", args.model_uri) 18 19 model = scoring_server.load_model_with_mlflow_config(args.model_uri) 20 input_schema = model.metadata.get_input_schema() 21 _logger.info("Loaded model") 22 23 _logger.info("Waiting for request") 24 for line in sys.stdin: 25 _logger.info("Received request") 26 request = json.loads(line) 27 28 _logger.info("Parsing input data") 29 data = request["data"] 30 data, params = scoring_server._split_data_and_params(data) 31 data = scoring_server.infer_and_parse_data(data, input_schema) 32 33 _logger.info("Making predictions") 34 if "params" in inspect.signature(model.predict).parameters: 35 preds = model.predict(data, params=params) 36 else: 37 _log_warning_if_params_not_in_predict_signature(_logger, params) 38 preds = model.predict(data) 39 40 _logger.info("Writing predictions") 41 with open(request["output_file"], "a") as f: 42 scoring_server.predictions_to_json(preds, f, {"id": request["id"]}) 43 44 _logger.info("Done")