/ tests / tracking / test_log_image.py
test_log_image.py
  1  import json
  2  import os
  3  import posixpath
  4  
  5  import numpy as np
  6  import pytest
  7  
  8  import mlflow
  9  from mlflow.utils.file_utils import local_file_uri_to_path
 10  from mlflow.utils.time import get_current_time_millis
 11  
 12  
 13  @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."])
 14  def test_log_image_numpy(subdir):
 15      import numpy as np
 16      from PIL import Image
 17  
 18      filename = "image.png"
 19      artifact_file = filename if subdir is None else posixpath.join(subdir, filename)
 20  
 21      image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
 22  
 23      with mlflow.start_run():
 24          mlflow.log_image(image, artifact_file)
 25  
 26          artifact_path = None if subdir is None else posixpath.normpath(subdir)
 27          artifact_uri = mlflow.get_artifact_uri(artifact_path)
 28          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 29          assert os.listdir(run_artifact_dir) == [filename]
 30  
 31          logged_path = os.path.join(run_artifact_dir, filename)
 32          loaded_image = np.asarray(Image.open(logged_path), dtype=np.uint8)
 33          np.testing.assert_array_equal(loaded_image, image)
 34  
 35  
 36  @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."])
 37  def test_log_image_pillow(subdir):
 38      from PIL import Image, ImageChops
 39  
 40      filename = "image.png"
 41      artifact_file = filename if subdir is None else posixpath.join(subdir, filename)
 42  
 43      image = Image.new("RGB", (100, 100))
 44  
 45      with mlflow.start_run():
 46          mlflow.log_image(image, artifact_file)
 47  
 48          artifact_path = None if subdir is None else posixpath.normpath(subdir)
 49          artifact_uri = mlflow.get_artifact_uri(artifact_path)
 50          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 51          assert os.listdir(run_artifact_dir) == [filename]
 52  
 53          logged_path = os.path.join(run_artifact_dir, filename)
 54          loaded_image = Image.open(logged_path)
 55          # How to check Pillow image equality: https://stackoverflow.com/a/6204954/6943581
 56          assert ImageChops.difference(loaded_image, image).getbbox() is None
 57  
 58  
 59  def test_log_image_raises_for_unsupported_objects():
 60      with mlflow.start_run():
 61          with pytest.raises(TypeError, match="Unsupported image object type"):
 62              mlflow.log_image("not_image", "image.png")
 63  
 64  
 65  @pytest.mark.parametrize(
 66      "size",
 67      [
 68          (100, 100),  # Grayscale (2D)
 69          (100, 100, 1),  # Grayscale (3D)
 70          (100, 100, 3),  # RGB
 71          (100, 100, 4),  # RGBA
 72      ],
 73  )
 74  def test_log_image_numpy_shape(size):
 75      import numpy as np
 76  
 77      filename = "image.png"
 78      image = np.random.randint(0, 256, size=size, dtype=np.uint8)
 79  
 80      with mlflow.start_run():
 81          mlflow.log_image(image, filename)
 82          artifact_uri = mlflow.get_artifact_uri()
 83          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 84          assert os.listdir(run_artifact_dir) == [filename]
 85  
 86  
 87  @pytest.mark.parametrize(
 88      "dtype",
 89      [
 90          # Ref.: https://numpy.org/doc/stable/user/basics.types.html#array-types-and-conversions-between-types
 91          "int8",
 92          "int16",
 93          "int32",
 94          "int64",
 95          "uint8",
 96          "uint16",
 97          "uint32",
 98          "uint64",
 99          "float16",
100          "float32",
101          "float64",
102          "bool",
103      ],
104  )
105  def test_log_image_numpy_dtype(dtype):
106      import numpy as np
107  
108      filename = "image.png"
109      image = np.random.randint(0, 2, size=(100, 100, 3)).astype(np.dtype(dtype))
110  
111      with mlflow.start_run():
112          mlflow.log_image(image, filename)
113          artifact_uri = mlflow.get_artifact_uri()
114          run_artifact_dir = local_file_uri_to_path(artifact_uri)
115          assert os.listdir(run_artifact_dir) == [filename]
116  
117  
118  @pytest.mark.parametrize(
119      "array",
120      # 1 pixel images with out-of-range values
121      [[[-1]], [[256]], [[-0.1]], [[1.1]]],
122  )
123  def test_log_image_numpy_emits_warning_for_out_of_range_values(array):
124      import numpy as np
125  
126      image = np.array(array).astype(type(array[0][0]))
127      if isinstance(array[0][0], int):
128          with (
129              mlflow.start_run(),
130              pytest.raises(ValueError, match="Integer pixel values out of acceptable range"),
131          ):
132              mlflow.log_image(image, "image.png")
133      else:
134          with (
135              mlflow.start_run(),
136              pytest.warns(UserWarning, match="Float pixel values out of acceptable range"),
137          ):
138              mlflow.log_image(image, "image.png")
139  
140  
141  def test_log_image_numpy_raises_exception_for_invalid_array_data_type():
142      import numpy as np
143  
144      with mlflow.start_run(), pytest.raises(TypeError, match="Invalid array data type"):
145          mlflow.log_image(np.tile("a", (1, 1, 3)), "image.png")
146  
147  
148  def test_log_image_numpy_raises_exception_for_invalid_array_shape():
149      import numpy as np
150  
151      with mlflow.start_run(), pytest.raises(ValueError, match="`image` must be a 2D or 3D array"):
152          mlflow.log_image(np.zeros((1,), dtype=np.uint8), "image.png")
153  
154  
155  def test_log_image_numpy_raises_exception_for_invalid_channel_length():
156      import numpy as np
157  
158      with mlflow.start_run(), pytest.raises(ValueError, match="Invalid channel length"):
159          mlflow.log_image(np.zeros((1, 1, 5), dtype=np.uint8), "image.png")
160  
161  
162  def test_log_image_raises_exception_for_unsupported_image_object_type():
163      with mlflow.start_run(), pytest.raises(TypeError, match="Unsupported image object type"):
164          mlflow.log_image("not_image", "image.png")
165  
166  
167  def test_log_image_with_steps():
168      import numpy as np
169      from PIL import Image
170  
171      image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
172  
173      with mlflow.start_run():
174          mlflow.log_image(image, key="dog", step=0, synchronous=True)
175  
176          logged_path = "images/"
177          artifact_uri = mlflow.get_artifact_uri(logged_path)
178          run_artifact_dir = local_file_uri_to_path(artifact_uri)
179          files = os.listdir(run_artifact_dir)
180  
181          # .png file for the image and .webp file for compressed image
182          assert len(files) == 2
183          for file in files:
184              assert file.startswith("dog+step+0")
185              logged_path = os.path.join(run_artifact_dir, file)
186              if file.endswith(".png"):
187                  loaded_image = np.asarray(Image.open(logged_path), dtype=np.uint8)
188                  np.testing.assert_array_equal(loaded_image, image)
189              elif file.endswith(".json"):
190                  with open(logged_path) as f:
191                      metadata = json.load(f)
192                      assert metadata["filepath"].startswith("images/dog+step+0")
193                      assert metadata["key"] == "dog"
194                      assert metadata["step"] == 0
195                      assert metadata["timestamp"] <= get_current_time_millis()
196  
197  
198  @pytest.mark.parametrize("step", [20, 26, 27])
199  def test_log_image_with_url_encoding_prone_steps(step):
200      """Regression test: steps like 20, 26, 27 previously created %20, %26, %27 patterns
201      in filenames that got URL-decoded, corrupting the artifact path.
202      See https://github.com/mlflow/mlflow/issues/21085
203      """
204      image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
205  
206      with mlflow.start_run():
207          mlflow.log_image(image, key="dog", step=step, synchronous=True)
208  
209          artifact_uri = mlflow.get_artifact_uri("images/")
210          run_artifact_dir = local_file_uri_to_path(artifact_uri)
211          files = os.listdir(run_artifact_dir)
212  
213          assert len(files) == 2
214          for file in files:
215              assert file.startswith(f"dog+step+{step}+timestamp+")
216              assert "%" not in file
217  
218  
219  def test_log_image_with_timestamp():
220      import numpy as np
221      from PIL import Image
222  
223      image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
224  
225      with mlflow.start_run():
226          mlflow.log_image(image, key="dog", timestamp=100, synchronous=True)
227  
228          logged_path = "images/"
229          artifact_uri = mlflow.get_artifact_uri(logged_path)
230          run_artifact_dir = local_file_uri_to_path(artifact_uri)
231          files = os.listdir(run_artifact_dir)
232  
233          # .png file for the image, and .webp file for compressed image
234          assert len(files) == 2
235          for file in files:
236              assert file.startswith("dog+step+0")
237              logged_path = os.path.join(run_artifact_dir, file)
238              if file.endswith(".png"):
239                  loaded_image = np.asarray(Image.open(logged_path), dtype=np.uint8)
240                  np.testing.assert_array_equal(loaded_image, image)
241              elif file.endswith(".json"):
242                  with open(logged_path) as f:
243                      metadata = json.load(f)
244                      assert metadata["filepath"].startswith("images/dog+step+0")
245                      assert metadata["key"] == "dog"
246                      assert metadata["step"] == 0
247                      assert metadata["timestamp"] == 100
248  
249  
250  def test_duplicated_log_image_with_step():
251      """
252      MLflow will save both files if there are multiple calls to log_image
253      with the same key and step.
254      """
255      import numpy as np
256  
257      image1 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
258      image2 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
259  
260      with mlflow.start_run():
261          mlflow.log_image(image1, key="dog", step=100, synchronous=True)
262          mlflow.log_image(image2, key="dog", step=100, synchronous=True)
263  
264          logged_path = "images/"
265          artifact_uri = mlflow.get_artifact_uri(logged_path)
266          run_artifact_dir = local_file_uri_to_path(artifact_uri)
267          files = os.listdir(run_artifact_dir)
268          assert len(files) == 2 * 2  # 2 images and 2 files per image
269  
270  
271  def test_duplicated_log_image_with_timestamp():
272      """
273      MLflow will save both files if there are multiple calls to log_image
274      with the same key, step, and timestamp.
275      """
276      import numpy as np
277  
278      image1 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
279      image2 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
280  
281      with mlflow.start_run():
282          mlflow.log_image(image1, key="dog", step=100, timestamp=100, synchronous=True)
283          mlflow.log_image(image2, key="dog", step=100, timestamp=100, synchronous=True)
284  
285          logged_path = "images/"
286          artifact_uri = mlflow.get_artifact_uri(logged_path)
287          run_artifact_dir = local_file_uri_to_path(artifact_uri)
288          files = os.listdir(run_artifact_dir)
289          assert len(files) == 2 * 2
290  
291  
292  @pytest.mark.parametrize(
293      "args",
294      [
295          {"key": "image"},
296          {"step": 0},
297          {"timestamp": 0},
298          {"timestamp": 0, "step": 0},
299          ["image"],
300          ["image", 0],
301      ],
302  )
303  def test_log_image_raises_exception_for_unexpected_arguments_used(args):
304      # It will overwrite if the user wants the exact same timestamp for the logged images
305      import numpy as np
306  
307      exception = "The `artifact_file` parameter cannot be used in conjunction"
308      if isinstance(args, dict):
309          with mlflow.start_run(), pytest.raises(TypeError, match=exception):
310              mlflow.log_image(np.zeros((1,), dtype=np.uint8), "image.png", **args)
311      elif isinstance(args, list):
312          with mlflow.start_run(), pytest.raises(TypeError, match=exception):
313              mlflow.log_image(np.zeros((1,), dtype=np.uint8), "image.png", *args)
314  
315  
316  def test_log_image_raises_exception_for_missing_arguments():
317      import numpy as np
318  
319      exception = "Invalid arguments: Please specify exactly one of `artifact_file` or `key`"
320      with mlflow.start_run(), pytest.raises(TypeError, match=exception):
321          mlflow.log_image(np.zeros((1,), dtype=np.uint8))
322  
323  
324  def test_async_log_image_flush():
325      import numpy as np
326  
327      image1 = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
328      with mlflow.start_run():
329          for i in range(100):
330              mlflow.log_image(image1, key="dog", step=i, timestamp=i, synchronous=False)
331  
332          mlflow.flush_artifact_async_logging()
333  
334          logged_path = "images/"
335          artifact_uri = mlflow.get_artifact_uri(logged_path)
336          run_artifact_dir = local_file_uri_to_path(artifact_uri)
337          files = os.listdir(run_artifact_dir)
338          assert len(files) == 100 * 2
339  
340  
341  def test_log_image_with_slash_in_key():
342      image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
343  
344      with mlflow.start_run():
345          mlflow.log_image(image, key="category/name", step=5, synchronous=True)
346  
347          logged_path = "images/"
348          artifact_uri = mlflow.get_artifact_uri(logged_path)
349          run_artifact_dir = local_file_uri_to_path(artifact_uri)
350          files = os.listdir(run_artifact_dir)
351  
352          assert len(files) == 2
353          for file in files:
354              # '~' must be used instead of '#' as the separator
355              assert "category~name" in file
356              assert "#" not in file
357  
358          run_id = mlflow.active_run().info.run_id
359  
360      client = mlflow.MlflowClient()
361      artifacts = client.list_artifacts(run_id, path="images")
362      assert len(artifacts) == 2
363      for artifact in artifacts:
364          # download_artifacts must not raise MlflowException about '#' in path
365          local_path = client.download_artifacts(run_id, artifact.path)
366          assert os.path.exists(local_path)