test_exceptions.py
1 import json 2 import pickle 3 4 import pytest 5 6 from mlflow.exceptions import MlflowException, RestException 7 from mlflow.protos.databricks_pb2 import ( 8 ENDPOINT_NOT_FOUND, 9 INTERNAL_ERROR, 10 INVALID_PARAMETER_VALUE, 11 INVALID_STATE, 12 IO_ERROR, 13 RESOURCE_ALREADY_EXISTS, 14 ) 15 16 17 def test_error_code_constructor(): 18 assert ( 19 MlflowException("test", error_code=INVALID_PARAMETER_VALUE).error_code 20 == "INVALID_PARAMETER_VALUE" 21 ) 22 23 24 def test_default_error_code(): 25 assert MlflowException("test").error_code == "INTERNAL_ERROR" 26 27 28 def test_serialize_to_json(): 29 mlflow_exception = MlflowException("test") 30 deserialized = json.loads(mlflow_exception.serialize_as_json()) 31 assert deserialized["message"] == "test" 32 assert deserialized["error_code"] == "INTERNAL_ERROR" 33 34 35 def test_get_http_status_code(): 36 assert MlflowException("test default").get_http_status_code() == 500 37 assert MlflowException("code not in map", error_code=IO_ERROR).get_http_status_code() == 500 38 assert MlflowException("test", error_code=INVALID_STATE).get_http_status_code() == 500 39 assert MlflowException("test", error_code=ENDPOINT_NOT_FOUND).get_http_status_code() == 404 40 assert MlflowException("test", error_code=INVALID_PARAMETER_VALUE).get_http_status_code() == 400 41 assert MlflowException("test", error_code=INTERNAL_ERROR).get_http_status_code() == 500 42 assert MlflowException("test", error_code=RESOURCE_ALREADY_EXISTS).get_http_status_code() == 400 43 44 45 def test_invalid_parameter_value(): 46 mlflow_exception = MlflowException.invalid_parameter_value("test") 47 assert mlflow_exception.error_code == "INVALID_PARAMETER_VALUE" 48 49 50 def test_rest_exception(): 51 mlflow_exception = MlflowException("test", error_code=RESOURCE_ALREADY_EXISTS) 52 json_exception = mlflow_exception.serialize_as_json() 53 deserialized_rest_exception = RestException(json.loads(json_exception)) 54 assert deserialized_rest_exception.error_code == "RESOURCE_ALREADY_EXISTS" 55 assert "test" in deserialized_rest_exception.message 56 57 58 def test_rest_exception_with_unrecognized_error_code(): 59 # Test that we can create a RestException with a convertible error code. 60 exception = RestException({"error_code": "403", "messages": "something important."}) 61 assert "something important." in str(exception) 62 assert exception.error_code == "PERMISSION_DENIED" 63 json.loads(exception.serialize_as_json()) 64 65 # Test that we can create a RestException with an unrecognized error code. 66 exception = RestException({"error_code": "weird error", "messages": "something important."}) 67 assert "something important." in str(exception) 68 json.loads(exception.serialize_as_json()) 69 70 71 def test_rest_exception_pickleable(): 72 e1 = RestException({"error_code": "INTERNAL_ERROR", "message": "abc"}) 73 e2 = pickle.loads(pickle.dumps(e1)) 74 75 assert e1.error_code == e2.error_code 76 assert e1.message == e2.message 77 78 79 def test_rest_exception_with_null_error_code(): 80 exception = RestException({"error_code": None, "message": "test message"}) 81 assert exception.error_code == "INTERNAL_ERROR" 82 assert "test message" in str(exception) 83 84 85 def test_rest_exception_with_missing_error_code(): 86 exception = RestException({"message": "test message"}) 87 assert exception.error_code == "INTERNAL_ERROR" 88 assert "test message" in str(exception) 89 90 91 # --- sqlstate / error_class auto-derive tests --- 92 93 94 def test_sqlstate_auto_derived_from_error_code(): 95 exc = MlflowException("test") 96 assert exc.sqlstate == "XXM00" 97 assert exc.error_class == "CLIENT_INTERNAL_ERROR" 98 99 exc = MlflowException("test", error_code=INVALID_PARAMETER_VALUE) 100 assert exc.sqlstate == "KAM00" 101 assert exc.error_class == "INVALID_PARAMETER_VALUE" 102 103 exc = MlflowException("test", error_code=INTERNAL_ERROR) 104 assert exc.sqlstate == "XXM00" 105 assert exc.error_class == "CLIENT_INTERNAL_ERROR" 106 107 108 def test_sqlstate_explicit_overrides_auto_derive(): 109 exc = MlflowException( 110 "test", 111 error_code=INVALID_PARAMETER_VALUE, 112 sqlstate="KAM01", 113 error_class="SCHEMA_ENFORCEMENT_FAILED", 114 ) 115 assert exc.sqlstate == "KAM01" 116 assert exc.error_class == "SCHEMA_ENFORCEMENT_FAILED" 117 118 119 def test_sqlstate_serialize_as_json_includes_auto_derived(): 120 exc = MlflowException("test", error_code=INVALID_PARAMETER_VALUE) 121 deserialized = json.loads(exc.serialize_as_json()) 122 assert deserialized["sqlstate"] == "KAM00" 123 assert deserialized["error_class"] == "INVALID_PARAMETER_VALUE" 124 125 126 def test_sqlstate_serialize_as_json_includes_explicit(): 127 exc = MlflowException("test", sqlstate="KAM01", error_class="SCHEMA_ENFORCEMENT_FAILED") 128 deserialized = json.loads(exc.serialize_as_json()) 129 assert deserialized["sqlstate"] == "KAM01" 130 assert deserialized["error_class"] == "SCHEMA_ENFORCEMENT_FAILED" 131 132 133 def test_sqlstate_none_for_unknown_error_code(): 134 exc = MlflowException("test", error_code=IO_ERROR) 135 assert exc.sqlstate is None 136 assert exc.error_class is None 137 deserialized = json.loads(exc.serialize_as_json()) 138 assert "sqlstate" not in deserialized 139 assert "error_class" not in deserialized 140 141 142 def test_invalid_parameter_value_auto_derives_sqlstate(): 143 exc = MlflowException.invalid_parameter_value("bad input") 144 assert exc.error_code == "INVALID_PARAMETER_VALUE" 145 assert exc.sqlstate == "KAM00" 146 assert exc.error_class == "INVALID_PARAMETER_VALUE" 147 148 149 def test_invalid_parameter_value_with_explicit_override(): 150 exc = MlflowException.invalid_parameter_value( 151 "bad input", sqlstate="KAM02", error_class="PREDICTION_FUNCTION_FAILED" 152 ) 153 assert exc.error_code == "INVALID_PARAMETER_VALUE" 154 assert exc.sqlstate == "KAM02" 155 assert exc.error_class == "PREDICTION_FUNCTION_FAILED" 156 157 158 # --- RestException CP mapping tests --- 159 160 161 @pytest.mark.parametrize( 162 ("error_code", "expected_sqlstate", "expected_error_class"), 163 [ 164 ("PERMISSION_DENIED", "KAMC1", "CP_PERMISSION_DENIED"), 165 ("RESOURCE_DOES_NOT_EXIST", "KAMC2", "CP_RESOURCE_NOT_FOUND"), 166 ("REQUEST_LIMIT_EXCEEDED", "KAMC3", "CP_REQUEST_RATE_LIMITED"), 167 ("INVALID_PARAMETER_VALUE", "KAMC4", "CP_INVALID_PARAMETER_VALUE"), 168 ("INTERNAL_ERROR", "XXMC0", "CP_INTERNAL_ERROR"), 169 ("TEMPORARILY_UNAVAILABLE", "XXMC1", "CP_TEMPORARILY_UNAVAILABLE"), 170 ("INVALID_STATE", "XXMC2", "CP_INVALID_STATE"), 171 ], 172 ) 173 def test_rest_exception_cp_sqlstate_mapping(error_code, expected_sqlstate, expected_error_class): 174 exc = RestException({"error_code": error_code, "message": "test"}) 175 assert exc.sqlstate == expected_sqlstate 176 assert exc.error_class == expected_error_class 177 178 179 def test_rest_exception_preserves_sqlstate_from_json(): 180 exc = RestException({ 181 "error_code": "PERMISSION_DENIED", 182 "message": "no access", 183 "sqlstate": "CUSTOM", 184 "error_class": "CUSTOM_CLASS", 185 }) 186 assert exc.sqlstate == "CUSTOM" 187 assert exc.error_class == "CUSTOM_CLASS" 188 189 190 def test_rest_exception_ignores_null_sqlstate_from_json(): 191 exc = RestException({ 192 "error_code": "PERMISSION_DENIED", 193 "message": "no access", 194 "sqlstate": None, 195 }) 196 assert exc.sqlstate == "KAMC1" 197 198 199 def test_rest_exception_pickle_with_sqlstate(): 200 e1 = RestException({"error_code": "PERMISSION_DENIED", "message": "no access"}) 201 e2 = pickle.loads(pickle.dumps(e1)) 202 assert e1.error_code == e2.error_code 203 assert e1.message == e2.message 204 assert e1.sqlstate == e2.sqlstate 205 assert e1.error_class == e2.error_class 206 207 208 def test_rest_exception_unrecognized_error_code(): 209 exc = RestException({"error_code": "weird error", "messages": "something"}) 210 # Unrecognized error codes fall back to INTERNAL_ERROR, which maps to XXMC0 211 assert exc.sqlstate == "XXMC0" 212 assert exc.error_class == "CP_INTERNAL_ERROR"