backend.py
1 import logging 2 import os 3 import re 4 import subprocess 5 import sys 6 7 from mlflow.exceptions import MlflowException 8 from mlflow.models import FlavorBackend 9 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 10 from mlflow.utils.string_utils import quote 11 12 _logger = logging.getLogger(__name__) 13 14 15 class RFuncBackend(FlavorBackend): 16 """ 17 Flavor backend implementation for the generic R models. 18 Predict and serve locally models with 'crate' flavor. 19 """ 20 21 def build_image( 22 self, 23 model_uri, 24 image_name, 25 install_java=False, 26 install_mlflow=False, 27 mlflow_home=None, 28 enable_mlserver=False, 29 base_image=None, 30 ): 31 pass 32 33 def generate_dockerfile( 34 self, 35 model_uri, 36 output_dir, 37 install_java=False, 38 install_mlflow=False, 39 mlflow_home=None, 40 enable_mlserver=False, 41 base_image=None, 42 ): 43 pass 44 45 version_pattern = re.compile(r"version ([0-9]+\.[0-9]+\.[0-9]+)") 46 47 def predict( 48 self, 49 model_uri, 50 input_path, 51 output_path, 52 content_type, 53 pip_requirements_override=None, 54 extra_envs=None, 55 ): 56 """ 57 Generate predictions using R model saved with MLflow. 58 Return the prediction results as a JSON. 59 """ 60 if pip_requirements_override is not None: 61 raise MlflowException("pip_requirements_override is not supported in the R backend.") 62 model_path = _download_artifact_from_uri(model_uri) 63 str_cmd = ( 64 "mlflow:::mlflow_rfunc_predict(model_path = '{0}', input_path = {1}, " 65 "output_path = {2}, content_type = {3})" 66 ) 67 command = str_cmd.format( 68 quote(model_path), 69 _str_optional(input_path), 70 _str_optional(output_path), 71 _str_optional(content_type), 72 ) 73 _execute(command, extra_envs=extra_envs) 74 75 def serve( 76 self, 77 model_uri, 78 port, 79 host, 80 timeout, 81 enable_mlserver, 82 synchronous=True, 83 stdout=None, 84 stderr=None, 85 ): 86 """ 87 Generate R model locally. 88 89 NOTE: The `enable_mlserver` parameter is there to comply with the 90 FlavorBackend interface but is not supported by MLServer yet. 91 https://github.com/SeldonIO/MLServer/issues/183 92 """ 93 if enable_mlserver: 94 raise Exception("The MLServer inference server is not yet supported in the R backend.") 95 96 if timeout: 97 _logger.warning("Timeout is not yet supported in the R backend.") 98 99 if not synchronous: 100 raise Exception("RBackend does not support call with synchronous=False") 101 102 if stdout is not None or stderr is not None: 103 raise Exception("RBackend does not support redirect stdout/stderr.") 104 105 model_path = _download_artifact_from_uri(model_uri) 106 command = "mlflow::mlflow_rfunc_serve('{}', port = {}, host = '{}')".format( 107 quote(model_path), port, host 108 ) 109 _execute(command) 110 111 def can_score_model(self): 112 # `Rscript --version` writes to stderr in R < 4.2.0 but stdout in R >= 4.2.0. 113 process = subprocess.Popen( 114 ["Rscript", "--version"], 115 close_fds=True, 116 stdout=subprocess.PIPE, 117 stderr=subprocess.STDOUT, 118 ) 119 stdout, _ = process.communicate() 120 if process.wait() != 0: 121 return False 122 123 version = self.version_pattern.search(stdout.decode("utf-8")) 124 if not version: 125 return False 126 version = [int(x) for x in version.group(1).split(".")] 127 return version[0] > 3 or version[0] == 3 and version[1] >= 3 128 129 130 def _execute(command, extra_envs=None): 131 env = os.environ.copy() 132 if extra_envs: 133 env.update(extra_envs) 134 135 process = subprocess.Popen( 136 ["Rscript", "-e", command], 137 env=env, 138 close_fds=False, 139 stdin=sys.stdin, 140 stdout=sys.stdout, 141 stderr=sys.stderr, 142 ) 143 if process.wait() != 0: 144 raise Exception("Command returned non zero exit code.") 145 146 147 def _str_optional(s): 148 return "NULL" if s is None else f"'{quote(str(s))}'"