test_log_image.py
1 import json 2 import os 3 import posixpath 4 5 import numpy as np 6 import pytest 7 8 import mlflow 9 from mlflow.utils.file_utils import local_file_uri_to_path 10 from mlflow.utils.time import get_current_time_millis 11 12 13 @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."]) 14 def test_log_image_numpy(subdir): 15 import numpy as np 16 from PIL import Image 17 18 filename = "image.png" 19 artifact_file = filename if subdir is None else posixpath.join(subdir, filename) 20 21 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 22 23 with mlflow.start_run(): 24 mlflow.log_image(image, artifact_file) 25 26 artifact_path = None if subdir is None else posixpath.normpath(subdir) 27 artifact_uri = mlflow.get_artifact_uri(artifact_path) 28 run_artifact_dir = local_file_uri_to_path(artifact_uri) 29 assert os.listdir(run_artifact_dir) == [filename] 30 31 logged_path = os.path.join(run_artifact_dir, filename) 32 loaded_image = np.asarray(Image.open(logged_path), dtype=np.uint8) 33 np.testing.assert_array_equal(loaded_image, image) 34 35 36 @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."]) 37 def test_log_image_pillow(subdir): 38 from PIL import Image, ImageChops 39 40 filename = "image.png" 41 artifact_file = filename if subdir is None else posixpath.join(subdir, filename) 42 43 image = Image.new("RGB", (100, 100)) 44 45 with mlflow.start_run(): 46 mlflow.log_image(image, artifact_file) 47 48 artifact_path = None if subdir is None else posixpath.normpath(subdir) 49 artifact_uri = mlflow.get_artifact_uri(artifact_path) 50 run_artifact_dir = local_file_uri_to_path(artifact_uri) 51 assert os.listdir(run_artifact_dir) == [filename] 52 53 logged_path = os.path.join(run_artifact_dir, filename) 54 loaded_image = Image.open(logged_path) 55 # How to check Pillow image equality: https://stackoverflow.com/a/6204954/6943581 56 assert ImageChops.difference(loaded_image, image).getbbox() is None 57 58 59 def test_log_image_raises_for_unsupported_objects(): 60 with mlflow.start_run(): 61 with pytest.raises(TypeError, match="Unsupported image object type"): 62 mlflow.log_image("not_image", "image.png") 63 64 65 @pytest.mark.parametrize( 66 "size", 67 [ 68 (100, 100), # Grayscale (2D) 69 (100, 100, 1), # Grayscale (3D) 70 (100, 100, 3), # RGB 71 (100, 100, 4), # RGBA 72 ], 73 ) 74 def test_log_image_numpy_shape(size): 75 import numpy as np 76 77 filename = "image.png" 78 image = np.random.randint(0, 256, size=size, dtype=np.uint8) 79 80 with mlflow.start_run(): 81 mlflow.log_image(image, filename) 82 artifact_uri = mlflow.get_artifact_uri() 83 run_artifact_dir = local_file_uri_to_path(artifact_uri) 84 assert os.listdir(run_artifact_dir) == [filename] 85 86 87 @pytest.mark.parametrize( 88 "dtype", 89 [ 90 # Ref.: https://numpy.org/doc/stable/user/basics.types.html#array-types-and-conversions-between-types 91 "int8", 92 "int16", 93 "int32", 94 "int64", 95 "uint8", 96 "uint16", 97 "uint32", 98 "uint64", 99 "float16", 100 "float32", 101 "float64", 102 "bool", 103 ], 104 ) 105 def test_log_image_numpy_dtype(dtype): 106 import numpy as np 107 108 filename = "image.png" 109 image = np.random.randint(0, 2, size=(100, 100, 3)).astype(np.dtype(dtype)) 110 111 with mlflow.start_run(): 112 mlflow.log_image(image, filename) 113 artifact_uri = mlflow.get_artifact_uri() 114 run_artifact_dir = local_file_uri_to_path(artifact_uri) 115 assert os.listdir(run_artifact_dir) == [filename] 116 117 118 @pytest.mark.parametrize( 119 "array", 120 # 1 pixel images with out-of-range values 121 [[[-1]], [[256]], [[-0.1]], [[1.1]]], 122 ) 123 def test_log_image_numpy_emits_warning_for_out_of_range_values(array): 124 import numpy as np 125 126 image = np.array(array).astype(type(array[0][0])) 127 if isinstance(array[0][0], int): 128 with ( 129 mlflow.start_run(), 130 pytest.raises(ValueError, match="Integer pixel values out of acceptable range"), 131 ): 132 mlflow.log_image(image, "image.png") 133 else: 134 with ( 135 mlflow.start_run(), 136 pytest.warns(UserWarning, match="Float pixel values out of acceptable range"), 137 ): 138 mlflow.log_image(image, "image.png") 139 140 141 def test_log_image_numpy_raises_exception_for_invalid_array_data_type(): 142 import numpy as np 143 144 with mlflow.start_run(), pytest.raises(TypeError, match="Invalid array data type"): 145 mlflow.log_image(np.tile("a", (1, 1, 3)), "image.png") 146 147 148 def test_log_image_numpy_raises_exception_for_invalid_array_shape(): 149 import numpy as np 150 151 with mlflow.start_run(), pytest.raises(ValueError, match="`image` must be a 2D or 3D array"): 152 mlflow.log_image(np.zeros((1,), dtype=np.uint8), "image.png") 153 154 155 def test_log_image_numpy_raises_exception_for_invalid_channel_length(): 156 import numpy as np 157 158 with mlflow.start_run(), pytest.raises(ValueError, match="Invalid channel length"): 159 mlflow.log_image(np.zeros((1, 1, 5), dtype=np.uint8), "image.png") 160 161 162 def test_log_image_raises_exception_for_unsupported_image_object_type(): 163 with mlflow.start_run(), pytest.raises(TypeError, match="Unsupported image object type"): 164 mlflow.log_image("not_image", "image.png") 165 166 167 def test_log_image_with_steps(): 168 import numpy as np 169 from PIL import Image 170 171 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 172 173 with mlflow.start_run(): 174 mlflow.log_image(image, key="dog", step=0, synchronous=True) 175 176 logged_path = "images/" 177 artifact_uri = mlflow.get_artifact_uri(logged_path) 178 run_artifact_dir = local_file_uri_to_path(artifact_uri) 179 files = os.listdir(run_artifact_dir) 180 181 # .png file for the image and .webp file for compressed image 182 assert len(files) == 2 183 for file in files: 184 assert file.startswith("dog+step+0") 185 logged_path = os.path.join(run_artifact_dir, file) 186 if file.endswith(".png"): 187 loaded_image = np.asarray(Image.open(logged_path), dtype=np.uint8) 188 np.testing.assert_array_equal(loaded_image, image) 189 elif file.endswith(".json"): 190 with open(logged_path) as f: 191 metadata = json.load(f) 192 assert metadata["filepath"].startswith("images/dog+step+0") 193 assert metadata["key"] == "dog" 194 assert metadata["step"] == 0 195 assert metadata["timestamp"] <= get_current_time_millis() 196 197 198 @pytest.mark.parametrize("step", [20, 26, 27]) 199 def test_log_image_with_url_encoding_prone_steps(step): 200 """Regression test: steps like 20, 26, 27 previously created %20, %26, %27 patterns 201 in filenames that got URL-decoded, corrupting the artifact path. 202 See https://github.com/mlflow/mlflow/issues/21085 203 """ 204 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 205 206 with mlflow.start_run(): 207 mlflow.log_image(image, key="dog", step=step, synchronous=True) 208 209 artifact_uri = mlflow.get_artifact_uri("images/") 210 run_artifact_dir = local_file_uri_to_path(artifact_uri) 211 files = os.listdir(run_artifact_dir) 212 213 assert len(files) == 2 214 for file in files: 215 assert file.startswith(f"dog+step+{step}+timestamp+") 216 assert "%" not in file 217 218 219 def test_log_image_with_timestamp(): 220 import numpy as np 221 from PIL import Image 222 223 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 224 225 with mlflow.start_run(): 226 mlflow.log_image(image, key="dog", timestamp=100, synchronous=True) 227 228 logged_path = "images/" 229 artifact_uri = mlflow.get_artifact_uri(logged_path) 230 run_artifact_dir = local_file_uri_to_path(artifact_uri) 231 files = os.listdir(run_artifact_dir) 232 233 # .png file for the image, and .webp file for compressed image 234 assert len(files) == 2 235 for file in files: 236 assert file.startswith("dog+step+0") 237 logged_path = os.path.join(run_artifact_dir, file) 238 if file.endswith(".png"): 239 loaded_image = np.asarray(Image.open(logged_path), dtype=np.uint8) 240 np.testing.assert_array_equal(loaded_image, image) 241 elif file.endswith(".json"): 242 with open(logged_path) as f: 243 metadata = json.load(f) 244 assert metadata["filepath"].startswith("images/dog+step+0") 245 assert metadata["key"] == "dog" 246 assert metadata["step"] == 0 247 assert metadata["timestamp"] == 100 248 249 250 def test_duplicated_log_image_with_step(): 251 """ 252 MLflow will save both files if there are multiple calls to log_image 253 with the same key and step. 254 """ 255 import numpy as np 256 257 image1 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 258 image2 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 259 260 with mlflow.start_run(): 261 mlflow.log_image(image1, key="dog", step=100, synchronous=True) 262 mlflow.log_image(image2, key="dog", step=100, synchronous=True) 263 264 logged_path = "images/" 265 artifact_uri = mlflow.get_artifact_uri(logged_path) 266 run_artifact_dir = local_file_uri_to_path(artifact_uri) 267 files = os.listdir(run_artifact_dir) 268 assert len(files) == 2 * 2 # 2 images and 2 files per image 269 270 271 def test_duplicated_log_image_with_timestamp(): 272 """ 273 MLflow will save both files if there are multiple calls to log_image 274 with the same key, step, and timestamp. 275 """ 276 import numpy as np 277 278 image1 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 279 image2 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 280 281 with mlflow.start_run(): 282 mlflow.log_image(image1, key="dog", step=100, timestamp=100, synchronous=True) 283 mlflow.log_image(image2, key="dog", step=100, timestamp=100, synchronous=True) 284 285 logged_path = "images/" 286 artifact_uri = mlflow.get_artifact_uri(logged_path) 287 run_artifact_dir = local_file_uri_to_path(artifact_uri) 288 files = os.listdir(run_artifact_dir) 289 assert len(files) == 2 * 2 290 291 292 @pytest.mark.parametrize( 293 "args", 294 [ 295 {"key": "image"}, 296 {"step": 0}, 297 {"timestamp": 0}, 298 {"timestamp": 0, "step": 0}, 299 ["image"], 300 ["image", 0], 301 ], 302 ) 303 def test_log_image_raises_exception_for_unexpected_arguments_used(args): 304 # It will overwrite if the user wants the exact same timestamp for the logged images 305 import numpy as np 306 307 exception = "The `artifact_file` parameter cannot be used in conjunction" 308 if isinstance(args, dict): 309 with mlflow.start_run(), pytest.raises(TypeError, match=exception): 310 mlflow.log_image(np.zeros((1,), dtype=np.uint8), "image.png", **args) 311 elif isinstance(args, list): 312 with mlflow.start_run(), pytest.raises(TypeError, match=exception): 313 mlflow.log_image(np.zeros((1,), dtype=np.uint8), "image.png", *args) 314 315 316 def test_log_image_raises_exception_for_missing_arguments(): 317 import numpy as np 318 319 exception = "Invalid arguments: Please specify exactly one of `artifact_file` or `key`" 320 with mlflow.start_run(), pytest.raises(TypeError, match=exception): 321 mlflow.log_image(np.zeros((1,), dtype=np.uint8)) 322 323 324 def test_async_log_image_flush(): 325 import numpy as np 326 327 image1 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 328 with mlflow.start_run(): 329 for i in range(100): 330 mlflow.log_image(image1, key="dog", step=i, timestamp=i, synchronous=False) 331 332 mlflow.flush_artifact_async_logging() 333 334 logged_path = "images/" 335 artifact_uri = mlflow.get_artifact_uri(logged_path) 336 run_artifact_dir = local_file_uri_to_path(artifact_uri) 337 files = os.listdir(run_artifact_dir) 338 assert len(files) == 100 * 2 339 340 341 def test_log_image_with_slash_in_key(): 342 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 343 344 with mlflow.start_run(): 345 mlflow.log_image(image, key="category/name", step=5, synchronous=True) 346 347 logged_path = "images/" 348 artifact_uri = mlflow.get_artifact_uri(logged_path) 349 run_artifact_dir = local_file_uri_to_path(artifact_uri) 350 files = os.listdir(run_artifact_dir) 351 352 assert len(files) == 2 353 for file in files: 354 # '~' must be used instead of '#' as the separator 355 assert "category~name" in file 356 assert "#" not in file 357 358 run_id = mlflow.active_run().info.run_id 359 360 client = mlflow.MlflowClient() 361 artifacts = client.list_artifacts(run_id, path="images") 362 assert len(artifacts) == 2 363 for artifact in artifacts: 364 # download_artifacts must not raise MlflowException about '#' in path 365 local_path = client.download_artifacts(run_id, artifact.path) 366 assert os.path.exists(local_path)