/ mlflow / rfunc / backend.py
backend.py
  1  import logging
  2  import os
  3  import re
  4  import subprocess
  5  import sys
  6  
  7  from mlflow.exceptions import MlflowException
  8  from mlflow.models import FlavorBackend
  9  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 10  from mlflow.utils.string_utils import quote
 11  
 12  _logger = logging.getLogger(__name__)
 13  
 14  
 15  class RFuncBackend(FlavorBackend):
 16      """
 17      Flavor backend implementation for the generic R models.
 18      Predict and serve locally models with 'crate' flavor.
 19      """
 20  
 21      def build_image(
 22          self,
 23          model_uri,
 24          image_name,
 25          install_java=False,
 26          install_mlflow=False,
 27          mlflow_home=None,
 28          enable_mlserver=False,
 29          base_image=None,
 30      ):
 31          pass
 32  
 33      def generate_dockerfile(
 34          self,
 35          model_uri,
 36          output_dir,
 37          install_java=False,
 38          install_mlflow=False,
 39          mlflow_home=None,
 40          enable_mlserver=False,
 41          base_image=None,
 42      ):
 43          pass
 44  
 45      version_pattern = re.compile(r"version ([0-9]+\.[0-9]+\.[0-9]+)")
 46  
 47      def predict(
 48          self,
 49          model_uri,
 50          input_path,
 51          output_path,
 52          content_type,
 53          pip_requirements_override=None,
 54          extra_envs=None,
 55      ):
 56          """
 57          Generate predictions using R model saved with MLflow.
 58          Return the prediction results as a JSON.
 59          """
 60          if pip_requirements_override is not None:
 61              raise MlflowException("pip_requirements_override is not supported in the R backend.")
 62          model_path = _download_artifact_from_uri(model_uri)
 63          str_cmd = (
 64              "mlflow:::mlflow_rfunc_predict(model_path = '{0}', input_path = {1}, "
 65              "output_path = {2}, content_type = {3})"
 66          )
 67          command = str_cmd.format(
 68              quote(model_path),
 69              _str_optional(input_path),
 70              _str_optional(output_path),
 71              _str_optional(content_type),
 72          )
 73          _execute(command, extra_envs=extra_envs)
 74  
 75      def serve(
 76          self,
 77          model_uri,
 78          port,
 79          host,
 80          timeout,
 81          enable_mlserver,
 82          synchronous=True,
 83          stdout=None,
 84          stderr=None,
 85      ):
 86          """
 87          Generate R model locally.
 88  
 89          NOTE: The `enable_mlserver` parameter is there to comply with the
 90          FlavorBackend interface but is not supported by MLServer yet.
 91          https://github.com/SeldonIO/MLServer/issues/183
 92          """
 93          if enable_mlserver:
 94              raise Exception("The MLServer inference server is not yet supported in the R backend.")
 95  
 96          if timeout:
 97              _logger.warning("Timeout is not yet supported in the R backend.")
 98  
 99          if not synchronous:
100              raise Exception("RBackend does not support call with synchronous=False")
101  
102          if stdout is not None or stderr is not None:
103              raise Exception("RBackend does not support redirect stdout/stderr.")
104  
105          model_path = _download_artifact_from_uri(model_uri)
106          command = "mlflow::mlflow_rfunc_serve('{}', port = {}, host = '{}')".format(
107              quote(model_path), port, host
108          )
109          _execute(command)
110  
111      def can_score_model(self):
112          # `Rscript --version` writes to stderr in R < 4.2.0 but stdout in R >= 4.2.0.
113          process = subprocess.Popen(
114              ["Rscript", "--version"],
115              close_fds=True,
116              stdout=subprocess.PIPE,
117              stderr=subprocess.STDOUT,
118          )
119          stdout, _ = process.communicate()
120          if process.wait() != 0:
121              return False
122  
123          version = self.version_pattern.search(stdout.decode("utf-8"))
124          if not version:
125              return False
126          version = [int(x) for x in version.group(1).split(".")]
127          return version[0] > 3 or version[0] == 3 and version[1] >= 3
128  
129  
130  def _execute(command, extra_envs=None):
131      env = os.environ.copy()
132      if extra_envs:
133          env.update(extra_envs)
134  
135      process = subprocess.Popen(
136          ["Rscript", "-e", command],
137          env=env,
138          close_fds=False,
139          stdin=sys.stdin,
140          stdout=sys.stdout,
141          stderr=sys.stderr,
142      )
143      if process.wait() != 0:
144          raise Exception("Command returned non zero exit code.")
145  
146  
147  def _str_optional(s):
148      return "NULL" if s is None else f"'{quote(str(s))}'"