/ tests / utils / test_logging_utils.py
test_logging_utils.py
  1  import logging
  2  import os
  3  import re
  4  import subprocess
  5  import sys
  6  import uuid
  7  from io import StringIO
  8  
  9  import pytest
 10  
 11  import mlflow
 12  from mlflow.utils import logging_utils
 13  from mlflow.utils.logging_utils import LOGGING_LINE_FORMAT, eprint, suppress_logs
 14  
 15  logger = logging.getLogger(mlflow.__name__)
 16  
 17  LOGGING_FNS_TO_TEST = [logger.info, logger.warning, logger.critical, eprint]
 18  
 19  
 20  @pytest.fixture(autouse=True)
 21  def reset_stderr():
 22      prev_stderr = sys.stderr
 23      yield
 24      sys.stderr = prev_stderr
 25  
 26  
 27  @pytest.fixture(autouse=True)
 28  def reset_logging_enablement():
 29      yield
 30      logging_utils.enable_logging()
 31  
 32  
 33  @pytest.fixture(autouse=True)
 34  def reset_logging_level():
 35      level_before = logger.level
 36      yield
 37      logger.setLevel(level_before)
 38  
 39  
 40  class SampleStream:
 41      def __init__(self):
 42          self.content = None
 43          self.flush_count = 0
 44  
 45      def write(self, text):
 46          self.content = (self.content or "") + text
 47  
 48      def flush(self):
 49          self.flush_count += 1
 50  
 51      def reset(self):
 52          self.content = None
 53          self.flush_count = 0
 54  
 55  
 56  @pytest.mark.parametrize("logging_fn", LOGGING_FNS_TO_TEST)
 57  def test_event_logging_apis_respect_stderr_reassignment(logging_fn):
 58      stream1 = SampleStream()
 59      stream2 = SampleStream()
 60      message_content = "test message"
 61  
 62      sys.stderr = stream1
 63      assert stream1.content is None
 64      logging_fn(message_content)
 65      assert message_content in stream1.content
 66      assert stream2.content is None
 67      stream1.reset()
 68  
 69      sys.stderr = stream2
 70      assert stream2.content is None
 71      logging_fn(message_content)
 72      assert message_content in stream2.content
 73      assert stream1.content is None
 74  
 75  
 76  @pytest.mark.parametrize("logging_fn", LOGGING_FNS_TO_TEST)
 77  def test_event_logging_apis_respect_stream_disablement_enablement(logging_fn):
 78      stream = SampleStream()
 79      sys.stderr = stream
 80      message_content = "test message"
 81  
 82      assert stream.content is None
 83      logging_fn(message_content)
 84      assert message_content in stream.content
 85      stream.reset()
 86  
 87      logging_utils.disable_logging()
 88      logging_fn(message_content)
 89      assert stream.content is None
 90      stream.reset()
 91  
 92      logging_utils.enable_logging()
 93      assert stream.content is None
 94      logging_fn(message_content)
 95      assert message_content in stream.content
 96  
 97  
 98  def test_event_logging_stream_flushes_properly():
 99      stream = SampleStream()
100      sys.stderr = stream
101  
102      eprint("foo", flush=True)
103      assert "foo" in stream.content
104      assert stream.flush_count > 0
105  
106  
107  def test_debug_logs_emitted_correctly_when_configured():
108      stream = SampleStream()
109      sys.stderr = stream
110  
111      logger.setLevel(logging.DEBUG)
112      logger.debug("test debug")
113      assert "test debug" in stream.content
114  
115  
116  def test_suppress_logs():
117      module = "test_logger"
118      logger = logging.getLogger(module)
119  
120      message = "This message should be suppressed."
121  
122      capture_stream = StringIO()
123      stream_handler = logging.StreamHandler(capture_stream)
124      logger.addHandler(stream_handler)
125  
126      logger.error(message)
127      assert message in capture_stream.getvalue()
128  
129      capture_stream.truncate(0)
130      with suppress_logs(module, re.compile(r"This .* be suppressed.")):
131          logger.error(message)
132      assert len(capture_stream.getvalue()) == 0
133  
134  
135  @pytest.mark.parametrize(
136      ("log_level", "expected"),
137      [
138          ("DEBUG", True),
139          ("INFO", False),
140          ("NOTSET", False),
141      ],
142  )
143  def test_logging_level(log_level: str, expected: bool) -> None:
144      random_str = str(uuid.uuid4())
145      stdout = subprocess.check_output(
146          [
147              sys.executable,
148              "-c",
149              f"from mlflow.utils.logging_utils import _debug; _debug({random_str!r})",
150          ],
151          env=os.environ.copy() | {"MLFLOW_LOGGING_LEVEL": log_level},
152          stderr=subprocess.STDOUT,
153          text=True,
154      )
155  
156      assert (random_str in stdout) is expected
157  
158  
159  @pytest.mark.parametrize(
160      "env_var_name",
161      ["MLFLOW_CONFIGURE_LOGGING", "MLFLOW_LOGGING_CONFIGURE_LOGGING"],
162  )
163  @pytest.mark.parametrize(
164      "value",
165      ["0", "1"],
166  )
167  def test_mlflow_configure_logging_env_var(env_var_name: str, value: str) -> None:
168      expected_level = logging.INFO if value == "1" else logging.WARNING
169      subprocess.check_call(
170          [
171              sys.executable,
172              "-c",
173              f"""
174  import logging
175  import mlflow
176  
177  assert logging.getLogger("mlflow").isEnabledFor({expected_level})
178  """,
179          ],
180          env=os.environ.copy() | {env_var_name: value},
181      )
182  
183  
184  @pytest.mark.parametrize("configure_logging", ["0", "1"])
185  def test_alembic_logging_respects_configure_flag(configure_logging: str, tmp_sqlite_uri: str):
186      user_specified_format = "CUSTOM: %(name)s - %(message)s"
187      actual_format = user_specified_format if configure_logging == "0" else LOGGING_LINE_FORMAT
188      code = f"""
189  import logging
190  
191  # user-specified format, this should only take effect if configure_logging is 0
192  logging.basicConfig(level=logging.INFO, format={user_specified_format!r})
193  
194  import mlflow
195  
196  # Check the alembic logger format, which is now configured in _configure_mlflow_loggers
197  alembic_logger = logging.getLogger("alembic")
198  if {configure_logging!r} == "1":
199      # When MLFLOW_CONFIGURE_LOGGING is enabled, alembic logger has its own handler
200      assert len(alembic_logger.handlers) > 0
201      actual_format = alembic_logger.handlers[0].formatter._fmt
202  else:
203      # When MLFLOW_CONFIGURE_LOGGING is disabled, alembic logger propagates to root
204      assert alembic_logger.propagate
205      root_logger = logging.getLogger()
206      actual_format = root_logger.handlers[0].formatter._fmt
207  
208  assert actual_format == {actual_format!r}, actual_format
209  """
210      subprocess.check_call(
211          [sys.executable, "-c", code],
212          env={
213              **os.environ,
214              "MLFLOW_TRACKING_URI": tmp_sqlite_uri,
215              "MLFLOW_CONFIGURE_LOGGING": configure_logging,
216          },
217      )