/ mlflow / pyfunc / stdin_server.py
stdin_server.py
 1  import argparse
 2  import inspect
 3  import json
 4  import logging
 5  import sys
 6  
 7  from mlflow.pyfunc import scoring_server
 8  from mlflow.pyfunc.model import _log_warning_if_params_not_in_predict_signature
 9  
10  _logger = logging.getLogger(__name__)
11  logging.basicConfig(level=logging.INFO)
12  
13  parser = argparse.ArgumentParser()
14  parser.add_argument("--model-uri")
15  args = parser.parse_args()
16  
17  _logger.info("Loading model from %s", args.model_uri)
18  
19  model = scoring_server.load_model_with_mlflow_config(args.model_uri)
20  input_schema = model.metadata.get_input_schema()
21  _logger.info("Loaded model")
22  
23  _logger.info("Waiting for request")
24  for line in sys.stdin:
25      _logger.info("Received request")
26      request = json.loads(line)
27  
28      _logger.info("Parsing input data")
29      data = request["data"]
30      data, params = scoring_server._split_data_and_params(data)
31      data = scoring_server.infer_and_parse_data(data, input_schema)
32  
33      _logger.info("Making predictions")
34      if "params" in inspect.signature(model.predict).parameters:
35          preds = model.predict(data, params=params)
36      else:
37          _log_warning_if_params_not_in_predict_signature(_logger, params)
38          preds = model.predict(data)
39  
40      _logger.info("Writing predictions")
41      with open(request["output_file"], "a") as f:
42          scoring_server.predictions_to_json(preds, f, {"id": request["id"]})
43  
44      _logger.info("Done")