/ tests / test_exceptions.py
test_exceptions.py
  1  import json
  2  import pickle
  3  
  4  import pytest
  5  
  6  from mlflow.exceptions import MlflowException, RestException
  7  from mlflow.protos.databricks_pb2 import (
  8      ENDPOINT_NOT_FOUND,
  9      INTERNAL_ERROR,
 10      INVALID_PARAMETER_VALUE,
 11      INVALID_STATE,
 12      IO_ERROR,
 13      RESOURCE_ALREADY_EXISTS,
 14  )
 15  
 16  
 17  def test_error_code_constructor():
 18      assert (
 19          MlflowException("test", error_code=INVALID_PARAMETER_VALUE).error_code
 20          == "INVALID_PARAMETER_VALUE"
 21      )
 22  
 23  
 24  def test_default_error_code():
 25      assert MlflowException("test").error_code == "INTERNAL_ERROR"
 26  
 27  
 28  def test_serialize_to_json():
 29      mlflow_exception = MlflowException("test")
 30      deserialized = json.loads(mlflow_exception.serialize_as_json())
 31      assert deserialized["message"] == "test"
 32      assert deserialized["error_code"] == "INTERNAL_ERROR"
 33  
 34  
 35  def test_get_http_status_code():
 36      assert MlflowException("test default").get_http_status_code() == 500
 37      assert MlflowException("code not in map", error_code=IO_ERROR).get_http_status_code() == 500
 38      assert MlflowException("test", error_code=INVALID_STATE).get_http_status_code() == 500
 39      assert MlflowException("test", error_code=ENDPOINT_NOT_FOUND).get_http_status_code() == 404
 40      assert MlflowException("test", error_code=INVALID_PARAMETER_VALUE).get_http_status_code() == 400
 41      assert MlflowException("test", error_code=INTERNAL_ERROR).get_http_status_code() == 500
 42      assert MlflowException("test", error_code=RESOURCE_ALREADY_EXISTS).get_http_status_code() == 400
 43  
 44  
 45  def test_invalid_parameter_value():
 46      mlflow_exception = MlflowException.invalid_parameter_value("test")
 47      assert mlflow_exception.error_code == "INVALID_PARAMETER_VALUE"
 48  
 49  
 50  def test_rest_exception():
 51      mlflow_exception = MlflowException("test", error_code=RESOURCE_ALREADY_EXISTS)
 52      json_exception = mlflow_exception.serialize_as_json()
 53      deserialized_rest_exception = RestException(json.loads(json_exception))
 54      assert deserialized_rest_exception.error_code == "RESOURCE_ALREADY_EXISTS"
 55      assert "test" in deserialized_rest_exception.message
 56  
 57  
 58  def test_rest_exception_with_unrecognized_error_code():
 59      # Test that we can create a RestException with a convertible error code.
 60      exception = RestException({"error_code": "403", "messages": "something important."})
 61      assert "something important." in str(exception)
 62      assert exception.error_code == "PERMISSION_DENIED"
 63      json.loads(exception.serialize_as_json())
 64  
 65      # Test that we can create a RestException with an unrecognized error code.
 66      exception = RestException({"error_code": "weird error", "messages": "something important."})
 67      assert "something important." in str(exception)
 68      json.loads(exception.serialize_as_json())
 69  
 70  
 71  def test_rest_exception_pickleable():
 72      e1 = RestException({"error_code": "INTERNAL_ERROR", "message": "abc"})
 73      e2 = pickle.loads(pickle.dumps(e1))
 74  
 75      assert e1.error_code == e2.error_code
 76      assert e1.message == e2.message
 77  
 78  
 79  def test_rest_exception_with_null_error_code():
 80      exception = RestException({"error_code": None, "message": "test message"})
 81      assert exception.error_code == "INTERNAL_ERROR"
 82      assert "test message" in str(exception)
 83  
 84  
 85  def test_rest_exception_with_missing_error_code():
 86      exception = RestException({"message": "test message"})
 87      assert exception.error_code == "INTERNAL_ERROR"
 88      assert "test message" in str(exception)
 89  
 90  
 91  # --- sqlstate / error_class auto-derive tests ---
 92  
 93  
 94  def test_sqlstate_auto_derived_from_error_code():
 95      exc = MlflowException("test")
 96      assert exc.sqlstate == "XXM00"
 97      assert exc.error_class == "CLIENT_INTERNAL_ERROR"
 98  
 99      exc = MlflowException("test", error_code=INVALID_PARAMETER_VALUE)
100      assert exc.sqlstate == "KAM00"
101      assert exc.error_class == "INVALID_PARAMETER_VALUE"
102  
103      exc = MlflowException("test", error_code=INTERNAL_ERROR)
104      assert exc.sqlstate == "XXM00"
105      assert exc.error_class == "CLIENT_INTERNAL_ERROR"
106  
107  
108  def test_sqlstate_explicit_overrides_auto_derive():
109      exc = MlflowException(
110          "test",
111          error_code=INVALID_PARAMETER_VALUE,
112          sqlstate="KAM01",
113          error_class="SCHEMA_ENFORCEMENT_FAILED",
114      )
115      assert exc.sqlstate == "KAM01"
116      assert exc.error_class == "SCHEMA_ENFORCEMENT_FAILED"
117  
118  
119  def test_sqlstate_serialize_as_json_includes_auto_derived():
120      exc = MlflowException("test", error_code=INVALID_PARAMETER_VALUE)
121      deserialized = json.loads(exc.serialize_as_json())
122      assert deserialized["sqlstate"] == "KAM00"
123      assert deserialized["error_class"] == "INVALID_PARAMETER_VALUE"
124  
125  
126  def test_sqlstate_serialize_as_json_includes_explicit():
127      exc = MlflowException("test", sqlstate="KAM01", error_class="SCHEMA_ENFORCEMENT_FAILED")
128      deserialized = json.loads(exc.serialize_as_json())
129      assert deserialized["sqlstate"] == "KAM01"
130      assert deserialized["error_class"] == "SCHEMA_ENFORCEMENT_FAILED"
131  
132  
133  def test_sqlstate_none_for_unknown_error_code():
134      exc = MlflowException("test", error_code=IO_ERROR)
135      assert exc.sqlstate is None
136      assert exc.error_class is None
137      deserialized = json.loads(exc.serialize_as_json())
138      assert "sqlstate" not in deserialized
139      assert "error_class" not in deserialized
140  
141  
142  def test_invalid_parameter_value_auto_derives_sqlstate():
143      exc = MlflowException.invalid_parameter_value("bad input")
144      assert exc.error_code == "INVALID_PARAMETER_VALUE"
145      assert exc.sqlstate == "KAM00"
146      assert exc.error_class == "INVALID_PARAMETER_VALUE"
147  
148  
149  def test_invalid_parameter_value_with_explicit_override():
150      exc = MlflowException.invalid_parameter_value(
151          "bad input", sqlstate="KAM02", error_class="PREDICTION_FUNCTION_FAILED"
152      )
153      assert exc.error_code == "INVALID_PARAMETER_VALUE"
154      assert exc.sqlstate == "KAM02"
155      assert exc.error_class == "PREDICTION_FUNCTION_FAILED"
156  
157  
158  # --- RestException CP mapping tests ---
159  
160  
161  @pytest.mark.parametrize(
162      ("error_code", "expected_sqlstate", "expected_error_class"),
163      [
164          ("PERMISSION_DENIED", "KAMC1", "CP_PERMISSION_DENIED"),
165          ("RESOURCE_DOES_NOT_EXIST", "KAMC2", "CP_RESOURCE_NOT_FOUND"),
166          ("REQUEST_LIMIT_EXCEEDED", "KAMC3", "CP_REQUEST_RATE_LIMITED"),
167          ("INVALID_PARAMETER_VALUE", "KAMC4", "CP_INVALID_PARAMETER_VALUE"),
168          ("INTERNAL_ERROR", "XXMC0", "CP_INTERNAL_ERROR"),
169          ("TEMPORARILY_UNAVAILABLE", "XXMC1", "CP_TEMPORARILY_UNAVAILABLE"),
170          ("INVALID_STATE", "XXMC2", "CP_INVALID_STATE"),
171      ],
172  )
173  def test_rest_exception_cp_sqlstate_mapping(error_code, expected_sqlstate, expected_error_class):
174      exc = RestException({"error_code": error_code, "message": "test"})
175      assert exc.sqlstate == expected_sqlstate
176      assert exc.error_class == expected_error_class
177  
178  
179  def test_rest_exception_preserves_sqlstate_from_json():
180      exc = RestException({
181          "error_code": "PERMISSION_DENIED",
182          "message": "no access",
183          "sqlstate": "CUSTOM",
184          "error_class": "CUSTOM_CLASS",
185      })
186      assert exc.sqlstate == "CUSTOM"
187      assert exc.error_class == "CUSTOM_CLASS"
188  
189  
190  def test_rest_exception_ignores_null_sqlstate_from_json():
191      exc = RestException({
192          "error_code": "PERMISSION_DENIED",
193          "message": "no access",
194          "sqlstate": None,
195      })
196      assert exc.sqlstate == "KAMC1"
197  
198  
199  def test_rest_exception_pickle_with_sqlstate():
200      e1 = RestException({"error_code": "PERMISSION_DENIED", "message": "no access"})
201      e2 = pickle.loads(pickle.dumps(e1))
202      assert e1.error_code == e2.error_code
203      assert e1.message == e2.message
204      assert e1.sqlstate == e2.sqlstate
205      assert e1.error_class == e2.error_class
206  
207  
208  def test_rest_exception_unrecognized_error_code():
209      exc = RestException({"error_code": "weird error", "messages": "something"})
210      # Unrecognized error codes fall back to INTERNAL_ERROR, which maps to XXMC0
211      assert exc.sqlstate == "XXMC0"
212      assert exc.error_class == "CP_INTERNAL_ERROR"