/ tests / telemetry / test_client.py
test_client.py
   1  import json
   2  import logging
   3  import threading
   4  import time
   5  import warnings
   6  from unittest import mock
   7  
   8  import pytest
   9  
  10  import mlflow
  11  from mlflow.environment_variables import _MLFLOW_TELEMETRY_SESSION_ID
  12  from mlflow.telemetry.client import (
  13      BATCH_SIZE,
  14      BATCH_TIME_INTERVAL_SECONDS,
  15      MAX_QUEUE_SIZE,
  16      MAX_WORKERS,
  17      RETRYABLE_ERRORS,
  18      UNRECOVERABLE_ERRORS,
  19      TelemetryClient,
  20      _is_localhost_uri,
  21      get_telemetry_client,
  22  )
  23  from mlflow.telemetry.events import CreateLoggedModelEvent, CreateRunEvent
  24  from mlflow.telemetry.schemas import Record, SourceSDK, Status
  25  from mlflow.utils.os import is_windows
  26  from mlflow.version import IS_TRACING_SDK_ONLY, VERSION
  27  
  28  from tests.telemetry.helper_functions import validate_telemetry_record
  29  
  30  if not IS_TRACING_SDK_ONLY:
  31      from mlflow.tracking._tracking_service.utils import _use_tracking_uri
  32  
  33  
  34  def test_telemetry_client_initialization(mock_telemetry_client: TelemetryClient, mock_requests):
  35      assert mock_telemetry_client.info is not None
  36      assert mock_telemetry_client._queue.maxsize == MAX_QUEUE_SIZE
  37      assert mock_telemetry_client._max_workers == MAX_WORKERS
  38      assert mock_telemetry_client._batch_size == BATCH_SIZE
  39      assert mock_telemetry_client._batch_time_interval == BATCH_TIME_INTERVAL_SECONDS
  40  
  41  
  42  def test_telemetry_client_session_id(
  43      mock_telemetry_client: TelemetryClient, mock_requests, monkeypatch
  44  ):
  45      monkeypatch.setenv(_MLFLOW_TELEMETRY_SESSION_ID.name, "test_session_id")
  46      with TelemetryClient() as telemetry_client:
  47          assert telemetry_client.info["session_id"] == "test_session_id"
  48      monkeypatch.delenv(_MLFLOW_TELEMETRY_SESSION_ID.name, raising=False)
  49      with TelemetryClient() as telemetry_client:
  50          assert telemetry_client.info["session_id"] != "test_session_id"
  51  
  52  
  53  def test_add_record_and_send(mock_telemetry_client: TelemetryClient, mock_requests):
  54      # Create a test record
  55      record = Record(
  56          event_name="test_event",
  57          timestamp_ns=time.time_ns(),
  58          status=Status.SUCCESS,
  59      )
  60  
  61      # Add record and wait for processing
  62      mock_telemetry_client.add_record(record)
  63      mock_telemetry_client.flush()
  64      received_record = next(
  65          req for req in mock_requests if req["data"]["event_name"] == "test_event"
  66      )
  67  
  68      assert "data" in received_record
  69      assert "partition-key" in received_record
  70  
  71      data = received_record["data"]
  72      assert data["event_name"] == "test_event"
  73      assert data["status"] == "success"
  74  
  75  
  76  def test_add_records_and_send(mock_telemetry_client: TelemetryClient, mock_requests):
  77      # Pre-populate pending_records with 200 records
  78      initial_records = [
  79          Record(
  80              event_name=f"initial_{i}",
  81              timestamp_ns=time.time_ns(),
  82              status=Status.SUCCESS,
  83          )
  84          for i in range(200)
  85      ]
  86      mock_telemetry_client.add_records(initial_records)
  87  
  88      # We haven't hit the batch size limit yet, so expect no records to be sent
  89      assert len(mock_telemetry_client._pending_records) == 200
  90      assert len(mock_requests) == 0
  91  
  92      # Add 1000 more records
  93      # Expected behavior:
  94      # - First 300 records fill to 500 -> send batch (200 + 300) to queue
  95      # - Next 500 records -> send batch to queue
  96      # - Last 200 records remain in pending (200 < 500)
  97      additional_records = [
  98          Record(
  99              event_name=f"additional_{i}",
 100              timestamp_ns=time.time_ns(),
 101              status=Status.SUCCESS,
 102          )
 103          for i in range(1000)
 104      ]
 105      mock_telemetry_client.add_records(additional_records)
 106  
 107      # Verify batching logic:
 108      # - 2 batches should be in the queue
 109      # - 200 records should remain in pending
 110      assert mock_telemetry_client._queue.qsize() == 2
 111      assert len(mock_telemetry_client._pending_records) == 200
 112  
 113      # Flush to process queue and send the remaining partial batch
 114      mock_telemetry_client.flush()
 115  
 116      # Verify all 1200 records were sent
 117      assert len(mock_requests) == 1200
 118      event_names = {req["data"]["event_name"] for req in mock_requests}
 119      assert all(f"initial_{i}" in event_names for i in range(200))
 120      assert all(f"additional_{i}" in event_names for i in range(1000))
 121  
 122  
 123  def test_record_with_session_and_installation_id(
 124      mock_telemetry_client: TelemetryClient, mock_requests
 125  ):
 126      record = Record(
 127          event_name="test_event",
 128          timestamp_ns=time.time_ns(),
 129          status=Status.SUCCESS,
 130          session_id="session_id_override",
 131          installation_id="installation_id_override",
 132      )
 133      mock_telemetry_client.add_record(record)
 134      mock_telemetry_client.flush()
 135      assert mock_requests[0]["data"]["session_id"] == "session_id_override"
 136      assert mock_requests[0]["data"]["installation_id"] == "installation_id_override"
 137  
 138      record = Record(
 139          event_name="test_event",
 140          timestamp_ns=time.time_ns(),
 141          status=Status.SUCCESS,
 142      )
 143      mock_telemetry_client.add_record(record)
 144      mock_telemetry_client.flush()
 145      assert mock_requests[1]["data"]["session_id"] == mock_telemetry_client.info["session_id"]
 146      assert (
 147          mock_requests[1]["data"]["installation_id"] == mock_telemetry_client.info["installation_id"]
 148      )
 149  
 150  
 151  def test_batch_processing(mock_telemetry_client: TelemetryClient, mock_requests):
 152      mock_telemetry_client._batch_size = 3  # Set small batch size for testing
 153  
 154      # Add multiple records
 155      for i in range(5):
 156          record = Record(
 157              event_name=f"test_event_{i}",
 158              timestamp_ns=time.time_ns(),
 159              status=Status.SUCCESS,
 160          )
 161          mock_telemetry_client.add_record(record)
 162  
 163      mock_telemetry_client.flush()
 164      events = {req["data"]["event_name"] for req in mock_requests}
 165      assert all(event_name in events for event_name in [f"test_event_{i}" for i in range(5)])
 166  
 167  
 168  def test_flush_functionality(mock_telemetry_client: TelemetryClient, mock_requests):
 169      record = Record(
 170          event_name="test_event",
 171          timestamp_ns=time.time_ns(),
 172          status=Status.SUCCESS,
 173      )
 174      mock_telemetry_client.add_record(record)
 175  
 176      mock_telemetry_client.flush()
 177      events = {req["data"]["event_name"] for req in mock_requests}
 178      assert record.event_name in events
 179  
 180  
 181  def test_record_sent(mock_telemetry_client: TelemetryClient, mock_requests):
 182      record_1 = Record(
 183          event_name="test_event_1",
 184          timestamp_ns=time.time_ns(),
 185          status=Status.SUCCESS,
 186      )
 187      mock_telemetry_client.add_record(record_1)
 188      mock_telemetry_client.flush()
 189  
 190      assert len(mock_requests) == 1
 191      data = mock_requests[0]["data"]
 192      assert data["event_name"] == record_1.event_name
 193      assert data["status"] == "success"
 194  
 195      session_id = data.get("session_id")
 196      installation_id = data.get("installation_id")
 197      assert session_id is not None
 198      assert installation_id is not None
 199  
 200      record_2 = Record(
 201          event_name="test_event_2",
 202          timestamp_ns=time.time_ns(),
 203          status=Status.FAILURE,
 204      )
 205      record_3 = Record(
 206          event_name="test_event_3",
 207          timestamp_ns=time.time_ns(),
 208          status=Status.SUCCESS,
 209      )
 210      mock_telemetry_client.add_record(record_2)
 211      mock_telemetry_client.add_record(record_3)
 212      mock_telemetry_client.flush()
 213      assert len(mock_requests) == 3
 214  
 215      # all record should have the same session id and installation id
 216      assert {req["data"].get("session_id") for req in mock_requests} == {session_id}
 217      assert {req["data"].get("installation_id") for req in mock_requests} == {installation_id}
 218  
 219  
 220  def test_client_shutdown(mock_telemetry_client: TelemetryClient, mock_requests):
 221      for _ in range(100):
 222          record = Record(
 223              event_name="test_event",
 224              timestamp_ns=time.time_ns(),
 225              status=Status.SUCCESS,
 226          )
 227          mock_telemetry_client.add_record(record)
 228  
 229      start_time = time.time()
 230      mock_telemetry_client.flush(terminate=True)
 231      end_time = time.time()
 232      assert end_time - start_time < 0.1
 233      events = {req["data"]["event_name"] for req in mock_requests}
 234      assert "test_event" not in events
 235      assert not mock_telemetry_client.is_active
 236  
 237  
 238  @pytest.mark.parametrize(
 239      "url",
 240      [
 241          "http://127.0.0.1:9999/nonexistent",
 242          "http://127.0.0.1:9999/unauthorized",
 243          "http://127.0.0.1:9999/forbidden",
 244          "http://127.0.0.1:9999/bad_request",
 245      ],
 246  )
 247  def test_telemetry_collection_stopped_on_error(mock_requests, mock_telemetry_client, url):
 248      mock_telemetry_client.config.ingestion_url = url
 249  
 250      # Add a record - should not crash
 251      record = Record(
 252          event_name="test_event",
 253          timestamp_ns=time.time_ns(),
 254          status=Status.SUCCESS,
 255      )
 256      mock_telemetry_client.add_record(record)
 257  
 258      mock_telemetry_client.flush(terminate=True)
 259  
 260      assert mock_telemetry_client._is_stopped is True
 261      assert mock_telemetry_client.is_active is False
 262      requests_count = len(mock_requests)
 263      assert requests_count <= 1
 264  
 265      # add record after stopping should be no-op
 266      mock_telemetry_client.add_record(record)
 267      mock_telemetry_client.flush(terminate=True)
 268      assert len(mock_requests) == requests_count
 269  
 270  
 271  @pytest.mark.parametrize("error_code", [429, 500])
 272  @pytest.mark.parametrize("terminate", [True, False])
 273  def test_telemetry_retry_on_error(error_code, terminate):
 274      record = Record(
 275          event_name="test_event",
 276          timestamp_ns=time.time_ns(),
 277          status=Status.SUCCESS,
 278      )
 279  
 280      class MockPostTracker:
 281          def __init__(self):
 282              self.count = 0
 283              self.responses = []
 284  
 285          def mock_post(self, url, json=None, **kwargs):
 286              self.count += 1
 287              if self.count < 3:
 288                  return mock.Mock(status_code=error_code)
 289              else:
 290                  self.responses.extend(json["records"])
 291                  return mock.Mock(status_code=200)
 292  
 293      tracker = MockPostTracker()
 294  
 295      with (
 296          mock.patch("requests.post", side_effect=tracker.mock_post),
 297          TelemetryClient() as telemetry_client,
 298      ):
 299          telemetry_client.add_record(record)
 300          start_time = time.time()
 301          telemetry_client.flush(terminate=terminate)
 302          duration = time.time() - start_time
 303          if terminate:
 304              assert duration < 1.5
 305          else:
 306              assert duration < 2.5
 307  
 308          if terminate:
 309              assert tracker.responses == []
 310          else:
 311              assert record.event_name in [resp["data"]["event_name"] for resp in tracker.responses]
 312  
 313  
 314  @pytest.mark.parametrize("error_type", [ConnectionError, TimeoutError])
 315  @pytest.mark.parametrize("terminate", [True, False])
 316  def test_telemetry_retry_on_request_error(error_type, terminate):
 317      record = Record(
 318          event_name="test_event",
 319          timestamp_ns=time.time_ns(),
 320          status=Status.SUCCESS,
 321      )
 322  
 323      class MockPostTracker:
 324          def __init__(self):
 325              self.count = 0
 326              self.responses = []
 327  
 328          def mock_post(self, url, json=None, **kwargs):
 329              self.count += 1
 330              if self.count < 3:
 331                  raise error_type()
 332              else:
 333                  self.responses.extend(json["records"])
 334                  return mock.Mock(status_code=200)
 335  
 336      tracker = MockPostTracker()
 337  
 338      with (
 339          mock.patch("requests.post", side_effect=tracker.mock_post),
 340          TelemetryClient() as telemetry_client,
 341      ):
 342          telemetry_client.add_record(record)
 343          start_time = time.time()
 344          telemetry_client.flush(terminate=terminate)
 345          duration = time.time() - start_time
 346          if terminate:
 347              assert duration < 1.5
 348          else:
 349              assert duration < 2.5
 350  
 351      # no retry when terminating
 352      if terminate:
 353          assert tracker.responses == []
 354      else:
 355          assert record.event_name in [resp["data"]["event_name"] for resp in tracker.responses]
 356  
 357  
 358  def test_stop_event(mock_telemetry_client: TelemetryClient, mock_requests):
 359      mock_telemetry_client._is_stopped = True
 360  
 361      record = Record(
 362          event_name="test_event",
 363          timestamp_ns=time.time_ns(),
 364          status=Status.SUCCESS,
 365      )
 366      mock_telemetry_client.add_record(record)
 367  
 368      # we need to terminate since the threads are stopped
 369      mock_telemetry_client.flush(terminate=True)
 370  
 371      # No records should be sent since the client is stopped
 372      events = {req["data"]["event_name"] for req in mock_requests}
 373      assert record.event_name not in events
 374  
 375  
 376  def test_concurrent_record_addition(mock_telemetry_client: TelemetryClient, mock_requests):
 377      def add_records(thread_id):
 378          for i in range(5):
 379              record = Record(
 380                  event_name=f"test_event_{thread_id}_{i}",
 381                  timestamp_ns=time.time_ns(),
 382                  status=Status.SUCCESS,
 383              )
 384              mock_telemetry_client.add_record(record)
 385              time.sleep(0.1)
 386  
 387      # Start multiple threads
 388      threads = []
 389      for i in range(3):
 390          thread = threading.Thread(name=f"telemetry-client-{i}", target=add_records, args=(i,))
 391          threads.append(thread)
 392          thread.start()
 393  
 394      # Wait for all threads to complete
 395      for thread in threads:
 396          thread.join()
 397  
 398      mock_telemetry_client.flush()
 399  
 400      # Should have received records from all threads
 401      events = {req["data"]["event_name"] for req in mock_requests}
 402      assert all(
 403          event_name in events
 404          for event_name in [
 405              f"test_event_{thread_id}_{i}" for thread_id in range(3) for i in range(5)
 406          ]
 407      )
 408  
 409  
 410  def test_telemetry_info_inclusion(mock_telemetry_client: TelemetryClient, mock_requests):
 411      record = Record(
 412          event_name="test_event",
 413          timestamp_ns=time.time_ns(),
 414          status=Status.SUCCESS,
 415      )
 416      mock_telemetry_client.add_record(record)
 417  
 418      mock_telemetry_client.flush()
 419  
 420      # Verify telemetry info is included
 421      data = next(req["data"] for req in mock_requests if req["data"]["event_name"] == "test_event")
 422  
 423      # Check that telemetry info fields are present
 424      assert mock_telemetry_client.info.items() <= data.items()
 425  
 426      # Check that record fields are present
 427      assert data["event_name"] == "test_event"
 428      assert data["status"] == "success"
 429  
 430  
 431  def test_partition_key(mock_telemetry_client: TelemetryClient, mock_requests):
 432      record = Record(
 433          event_name="test_event",
 434          timestamp_ns=time.time_ns(),
 435          status=Status.SUCCESS,
 436      )
 437      mock_telemetry_client.add_record(record)
 438      mock_telemetry_client.add_record(record)
 439  
 440      mock_telemetry_client.flush()
 441  
 442      # Verify partition key is random
 443      assert mock_requests[0]["partition-key"] != mock_requests[1]["partition-key"]
 444  
 445  
 446  def test_max_workers_setup(monkeypatch):
 447      monkeypatch.setattr("mlflow.telemetry.client.MAX_WORKERS", 8)
 448      with TelemetryClient() as telemetry_client:
 449          assert telemetry_client._max_workers == 8
 450          telemetry_client.activate()
 451          # Test that correct number of threads are created
 452          assert len(telemetry_client._consumer_threads) == 8
 453  
 454          # Verify thread names
 455          for i, thread in enumerate(telemetry_client._consumer_threads):
 456              assert thread.name == f"MLflowTelemetryConsumer-{i}"
 457              assert thread.daemon is True
 458  
 459  
 460  def test_log_suppression_in_consumer_thread(mock_requests, capsys, mock_telemetry_client):
 461      # Clear any existing captured output
 462      capsys.readouterr()
 463  
 464      # Log from main thread - this should be captured
 465      logger = logging.getLogger("mlflow.telemetry.client")
 466      logger.info("TEST LOG FROM MAIN THREAD")
 467  
 468      original_process = mock_telemetry_client._process_records
 469  
 470      def process_with_log(records):
 471          logger.info("TEST LOG FROM CONSUMER THREAD")
 472          original_process(records)
 473  
 474      mock_telemetry_client._process_records = process_with_log
 475  
 476      record = Record(
 477          event_name="test_event",
 478          timestamp_ns=time.time_ns(),
 479          status=Status.SUCCESS,
 480      )
 481      mock_telemetry_client.add_record(record)
 482  
 483      mock_telemetry_client.flush()
 484      events = {req["data"]["event_name"] for req in mock_requests}
 485      assert record.event_name in events
 486  
 487      captured = capsys.readouterr()
 488  
 489      assert "TEST LOG FROM MAIN THREAD" in captured.err
 490      # Verify that the consumer thread log was suppressed
 491      assert "TEST LOG FROM CONSUMER THREAD" not in captured.err
 492  
 493  
 494  def test_consumer_thread_no_stderr_output(mock_requests, capsys, mock_telemetry_client):
 495      # Clear any existing captured output
 496      capsys.readouterr()
 497  
 498      # Log from main thread - this should be captured
 499      logger = logging.getLogger("mlflow.telemetry.client")
 500      logger.info("MAIN THREAD LOG BEFORE CLIENT")
 501  
 502      # Clear output after client initialization to focus on consumer thread output
 503      capsys.readouterr()
 504  
 505      # Add multiple records to ensure consumer thread processes them
 506      for i in range(5):
 507          record = Record(
 508              event_name=f"test_event_{i}",
 509              timestamp_ns=time.time_ns(),
 510              status=Status.SUCCESS,
 511          )
 512          mock_telemetry_client.add_record(record)
 513  
 514      mock_telemetry_client.flush()
 515      # Wait for all records to be processed
 516      events = {req["data"]["event_name"] for req in mock_requests}
 517      assert all(event_name in events for event_name in [f"test_event_{i}" for i in range(5)])
 518  
 519      # Capture output after consumer thread has processed all records
 520      captured = capsys.readouterr()
 521  
 522      # Verify consumer thread produced no stderr output
 523      assert captured.err == ""
 524  
 525      # Log from main thread after processing - this should be captured
 526      logger.info("MAIN THREAD LOG AFTER PROCESSING")
 527      captured_after = capsys.readouterr()
 528      assert "MAIN THREAD LOG AFTER PROCESSING" in captured_after.err
 529  
 530  
 531  def test_batch_time_interval(mock_requests, monkeypatch):
 532      monkeypatch.setattr("mlflow.telemetry.client.BATCH_TIME_INTERVAL_SECONDS", 1)
 533      telemetry_client = TelemetryClient()
 534  
 535      assert telemetry_client._batch_time_interval == 1
 536  
 537      # Add first record
 538      record1 = Record(
 539          event_name="test_event_1",
 540          timestamp_ns=time.time_ns(),
 541          status=Status.SUCCESS,
 542      )
 543      telemetry_client.add_record(record1)
 544      assert len(telemetry_client._pending_records) == 1
 545  
 546      events = {req["data"]["event_name"] for req in mock_requests}
 547      assert "test_event_1" not in events
 548  
 549      # Add second record before time interval
 550      record2 = Record(
 551          event_name="test_event_2",
 552          timestamp_ns=time.time_ns(),
 553          status=Status.SUCCESS,
 554      )
 555      telemetry_client.add_record(record2)
 556      assert len(telemetry_client._pending_records) == 2
 557  
 558      # Wait for time interval to pass
 559      time.sleep(1.5)
 560      assert len(telemetry_client._pending_records) == 0
 561      # records are sent due to time interval
 562      events = {req["data"]["event_name"] for req in mock_requests}
 563      assert "test_event_1" in events
 564      assert "test_event_2" in events
 565  
 566      record3 = Record(
 567          event_name="test_event_3",
 568          timestamp_ns=time.time_ns(),
 569          status=Status.SUCCESS,
 570      )
 571      telemetry_client.add_record(record3)
 572      telemetry_client.flush()
 573  
 574      # Verify all records were sent
 575      event_names = {req["data"]["event_name"] for req in mock_requests}
 576      assert all(env in event_names for env in ["test_event_1", "test_event_2", "test_event_3"])
 577  
 578  
 579  def test_set_telemetry_client_non_blocking():
 580      start_time = time.time()
 581      with TelemetryClient() as telemetry_client:
 582          assert time.time() - start_time < 1
 583          assert telemetry_client is not None
 584          time.sleep(1.1)
 585          assert not any(
 586              thread.name.startswith("GetTelemetryConfig") for thread in threading.enumerate()
 587          )
 588  
 589  
 590  @pytest.mark.parametrize(
 591      "mock_requests_return_value",
 592      [
 593          mock.Mock(status_code=403),
 594          mock.Mock(
 595              status_code=200,
 596              json=mock.Mock(
 597                  return_value={
 598                      "mlflow_version": VERSION,
 599                      "disable_telemetry": True,
 600                  }
 601              ),
 602          ),
 603          mock.Mock(
 604              status_code=200,
 605              json=mock.Mock(
 606                  return_value={
 607                      "mlflow_version": "1.0.0",
 608                      "disable_telemetry": False,
 609                      "ingestion_url": "http://localhost:9999",
 610                  }
 611              ),
 612          ),
 613          mock.Mock(
 614              status_code=200,
 615              json=mock.Mock(
 616                  return_value={
 617                      "mlflow_version": VERSION,
 618                      "disable_telemetry": False,
 619                      "ingestion_url": "http://localhost:9999",
 620                      "rollout_percentage": 0,
 621                  }
 622              ),
 623          ),
 624          mock.Mock(
 625              status_code=200,
 626              json=mock.Mock(
 627                  return_value={
 628                      "mlflow_version": VERSION,
 629                      "disable_telemetry": False,
 630                      "ingestion_url": "http://localhost:9999",
 631                      "rollout_percentage": 70,
 632                  }
 633              ),
 634          ),
 635      ],
 636  )
 637  @pytest.mark.no_mock_requests_get
 638  def test_client_get_config_none(mock_requests_return_value):
 639      with (
 640          mock.patch("mlflow.telemetry.client.requests.get") as mock_requests,
 641          mock.patch("random.randint", return_value=80),
 642      ):
 643          mock_requests.return_value = mock_requests_return_value
 644          client = TelemetryClient()
 645          client._get_config()
 646          assert client.config is None
 647  
 648  
 649  @pytest.mark.no_mock_requests_get
 650  def test_client_get_config_not_none():
 651      with (
 652          mock.patch("mlflow.telemetry.client.requests.get") as mock_requests,
 653          mock.patch("random.randint", return_value=50),
 654      ):
 655          mock_requests.return_value = mock.Mock(
 656              status_code=200,
 657              json=mock.Mock(
 658                  return_value={
 659                      "mlflow_version": VERSION,
 660                      "disable_telemetry": False,
 661                      "ingestion_url": "http://localhost:9999",
 662                      "rollout_percentage": 70,
 663                  }
 664              ),
 665          )
 666          with TelemetryClient() as telemetry_client:
 667              telemetry_client._get_config()
 668              assert telemetry_client.config.ingestion_url == "http://localhost:9999"
 669              assert telemetry_client.config.disable_events == set()
 670  
 671      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests:
 672          mock_requests.return_value = mock.Mock(
 673              status_code=200,
 674              json=mock.Mock(
 675                  return_value={
 676                      "mlflow_version": VERSION,
 677                      "disable_telemetry": False,
 678                      "ingestion_url": "http://localhost:9999",
 679                      "rollout_percentage": 100,
 680                  }
 681              ),
 682          )
 683          with TelemetryClient() as telemetry_client:
 684              telemetry_client._get_config()
 685              assert telemetry_client.config.ingestion_url == "http://localhost:9999"
 686              assert telemetry_client.config.disable_events == set()
 687  
 688      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests:
 689          mock_requests.return_value = mock.Mock(
 690              status_code=200,
 691              json=mock.Mock(
 692                  return_value={
 693                      "mlflow_version": VERSION,
 694                      "disable_telemetry": False,
 695                      "ingestion_url": "http://localhost:9999",
 696                      "rollout_percentage": 100,
 697                      "disable_events": [],
 698                      "disable_sdks": ["mlflow-tracing"],
 699                  }
 700              ),
 701          )
 702          with (
 703              mock.patch(
 704                  "mlflow.telemetry.client.get_source_sdk", return_value=SourceSDK.MLFLOW_TRACING
 705              ),
 706              TelemetryClient() as telemetry_client,
 707          ):
 708              telemetry_client._get_config()
 709              assert telemetry_client.config is None
 710  
 711          with (
 712              mock.patch(
 713                  "mlflow.telemetry.client.get_source_sdk", return_value=SourceSDK.MLFLOW_SKINNY
 714              ),
 715              TelemetryClient() as telemetry_client,
 716          ):
 717              telemetry_client._get_config()
 718              assert telemetry_client.config.ingestion_url == "http://localhost:9999"
 719              assert telemetry_client.config.disable_events == set()
 720  
 721          with (
 722              mock.patch("mlflow.telemetry.client.get_source_sdk", return_value=SourceSDK.MLFLOW),
 723              TelemetryClient() as telemetry_client,
 724          ):
 725              telemetry_client._get_config()
 726              assert telemetry_client.config.ingestion_url == "http://localhost:9999"
 727              assert telemetry_client.config.disable_events == set()
 728  
 729  
 730  @pytest.mark.no_mock_requests_get
 731  @pytest.mark.skipif(is_windows(), reason="This test only passes on non-Windows")
 732  def test_get_config_disable_non_windows():
 733      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests_get:
 734          mock_requests_get.return_value = mock.Mock(
 735              status_code=200,
 736              json=mock.Mock(
 737                  return_value={
 738                      "mlflow_version": VERSION,
 739                      "disable_telemetry": False,
 740                      "ingestion_url": "http://localhost:9999",
 741                      "rollout_percentage": 100,
 742                      "disable_os": ["linux", "darwin"],
 743                  }
 744              ),
 745          )
 746          with TelemetryClient() as telemetry_client:
 747              telemetry_client._get_config()
 748              assert telemetry_client.config is None
 749  
 750      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests:
 751          mock_requests.return_value = mock.Mock(
 752              status_code=200,
 753              json=mock.Mock(
 754                  return_value={
 755                      "mlflow_version": VERSION,
 756                      "disable_telemetry": False,
 757                      "ingestion_url": "http://localhost:9999",
 758                      "rollout_percentage": 100,
 759                      "disable_os": ["win32"],
 760                  }
 761              ),
 762          )
 763          with TelemetryClient() as telemetry_client:
 764              telemetry_client._get_config()
 765              assert telemetry_client.config.ingestion_url == "http://localhost:9999"
 766              assert telemetry_client.config.disable_events == set()
 767  
 768  
 769  @pytest.mark.no_mock_requests_get
 770  @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
 771  def test_get_config_windows():
 772      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests:
 773          mock_requests.return_value = mock.Mock(
 774              status_code=200,
 775              json=mock.Mock(
 776                  return_value={
 777                      "mlflow_version": VERSION,
 778                      "disable_telemetry": False,
 779                      "ingestion_url": "http://localhost:9999",
 780                      "rollout_percentage": 100,
 781                      "disable_os": ["win32"],
 782                  }
 783              ),
 784          )
 785          with TelemetryClient() as telemetry_client:
 786              telemetry_client._get_config()
 787              assert telemetry_client.config is None
 788  
 789      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests:
 790          mock_requests.return_value = mock.Mock(
 791              status_code=200,
 792              json=mock.Mock(
 793                  return_value={
 794                      "mlflow_version": VERSION,
 795                      "disable_telemetry": False,
 796                      "ingestion_url": "http://localhost:9999",
 797                      "rollout_percentage": 100,
 798                      "disable_os": ["linux", "darwin"],
 799                  }
 800              ),
 801          )
 802          with TelemetryClient() as telemetry_client:
 803              telemetry_client._get_config()
 804              assert telemetry_client.config.ingestion_url == "http://localhost:9999"
 805              assert telemetry_client.config.disable_events == set()
 806  
 807  
 808  @pytest.mark.no_mock_requests_get
 809  def test_client_set_to_none_if_config_none():
 810      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests:
 811          mock_requests.return_value = mock.Mock(
 812              status_code=200,
 813              json=mock.Mock(
 814                  return_value={
 815                      "mlflow_version": VERSION,
 816                      "disable_telemetry": True,
 817                  }
 818              ),
 819          )
 820          with TelemetryClient() as telemetry_client:
 821              assert telemetry_client is not None
 822              telemetry_client.activate()
 823              telemetry_client._config_thread.join(timeout=3)
 824              assert not telemetry_client._config_thread.is_alive()
 825              assert telemetry_client.config is None
 826              assert telemetry_client._is_config_fetched is True
 827              assert telemetry_client._is_stopped
 828  
 829  
 830  @pytest.mark.no_mock_requests_get
 831  def test_records_not_dropped_when_fetching_config(mock_requests):
 832      record = Record(
 833          event_name="test_event",
 834          timestamp_ns=time.time_ns(),
 835          status=Status.SUCCESS,
 836          duration_ms=0,
 837      )
 838      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests_get:
 839          mock_requests_get.return_value = mock.Mock(
 840              status_code=200,
 841              json=mock.Mock(
 842                  return_value={
 843                      "mlflow_version": VERSION,
 844                      "disable_telemetry": False,
 845                      "ingestion_url": "http://localhost:9999",
 846                      "rollout_percentage": 100,
 847                  }
 848              ),
 849          )
 850          with TelemetryClient() as telemetry_client:
 851              telemetry_client.activate()
 852              # wait for config to be fetched
 853              telemetry_client._config_thread.join(timeout=3)
 854              telemetry_client.add_record(record)
 855              telemetry_client.flush()
 856              validate_telemetry_record(
 857                  telemetry_client, mock_requests, record.event_name, check_params=False
 858              )
 859  
 860  
 861  @pytest.mark.no_mock_requests_get
 862  @pytest.mark.parametrize("error_code", [400, 401, 403, 404, 412, 500, 502, 503, 504])
 863  def test_config_fetch_no_retry(mock_requests, error_code):
 864      record = Record(
 865          event_name="test_event",
 866          timestamp_ns=time.time_ns(),
 867          status=Status.SUCCESS,
 868      )
 869  
 870      def mock_requests_get(*args, **kwargs):
 871          time.sleep(1)
 872          return mock.Mock(status_code=error_code)
 873  
 874      with (
 875          mock.patch("mlflow.telemetry.client.requests.get", side_effect=mock_requests_get),
 876          TelemetryClient() as telemetry_client,
 877      ):
 878          telemetry_client.add_record(record)
 879          telemetry_client.flush()
 880          events = [req["data"]["event_name"] for req in mock_requests]
 881          assert record.event_name not in events
 882          assert get_telemetry_client() is None
 883  
 884  
 885  def test_warning_suppression_in_shutdown(recwarn, mock_telemetry_client: TelemetryClient):
 886      def flush_mock(*args, **kwargs):
 887          warnings.warn("test warning")
 888  
 889      with mock.patch.object(mock_telemetry_client, "flush", flush_mock):
 890          mock_telemetry_client._at_exit_callback()
 891          assert len(recwarn) == 0
 892  
 893  
 894  @pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason="Requires full tracking SDK")
 895  @pytest.mark.parametrize("tracking_uri_scheme", ["databricks", "databricks-uc", "uc"])
 896  def test_databricks_tracking_uri_scheme_does_not_use_oss_path(mock_requests, tracking_uri_scheme):
 897      record = Record(
 898          event_name="test_event",
 899          timestamp_ns=time.time_ns(),
 900          status=Status.SUCCESS,
 901      )
 902  
 903      with (
 904          _use_tracking_uri(f"{tracking_uri_scheme}://profile_name"),
 905          mock.patch(
 906              "mlflow.telemetry.client.http_request",
 907              return_value=mock.Mock(status_code=200),
 908          ),
 909          mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"),
 910          mock.patch("mlflow.telemetry.client._IS_MLFLOW_DEV_VERSION", False),
 911          TelemetryClient() as telemetry_client,
 912      ):
 913          telemetry_client.add_record(record)
 914          telemetry_client.flush(terminate=True)
 915          # OSS ingestion path should not receive records
 916          assert len(mock_requests) == 0
 917  
 918  
 919  @pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason="Requires full tracking SDK")
 920  @pytest.mark.parametrize("tracking_uri_scheme", ["databricks", "databricks-uc", "uc"])
 921  def test_databricks_end_to_end_forwarding(tracking_uri_scheme):
 922      record = Record(
 923          event_name="test_event",
 924          timestamp_ns=time.time_ns(),
 925          status=Status.SUCCESS,
 926          duration_ms=42,
 927          params={"key": "value"},
 928      )
 929  
 930      with (
 931          _use_tracking_uri(f"{tracking_uri_scheme}://profile_name"),
 932          mock.patch(
 933              "mlflow.telemetry.client.http_request",
 934              return_value=mock.Mock(status_code=200),
 935          ) as mock_http,
 936          mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"),
 937          mock.patch("mlflow.telemetry.client._IS_MLFLOW_DEV_VERSION", False),
 938          TelemetryClient() as telemetry_client,
 939      ):
 940          telemetry_client.add_record(record)
 941          telemetry_client.flush()
 942  
 943          mock_http.assert_called_once()
 944          payload = mock_http.call_args.kwargs["json"]
 945          assert len(payload["events"]) == 1
 946          event = payload["events"][0]
 947          assert event["event_name"] == "test_event"
 948          assert event["tracking_uri_scheme"] == tracking_uri_scheme
 949          assert "params_json" in event
 950          assert "params" not in event
 951  
 952  
 953  def test_databricks_forwarding_disabled_for_dev_versions():
 954      record = Record(
 955          event_name="test_event",
 956          timestamp_ns=time.time_ns(),
 957          status=Status.SUCCESS,
 958      )
 959  
 960      with TelemetryClient() as client:
 961          client.info["tracking_uri_scheme"] = "databricks"
 962  
 963          with (
 964              mock.patch(
 965                  "mlflow.telemetry.client.http_request",
 966                  return_value=mock.Mock(status_code=200),
 967              ) as mock_http,
 968              mock.patch("mlflow.telemetry.client._IS_MLFLOW_DEV_VERSION", True),
 969          ):
 970              client._process_records([record])
 971  
 972          mock_http.assert_not_called()
 973  
 974  
 975  def test_forward_to_databricks_params_json_serialization():
 976      with TelemetryClient() as client:
 977          client.info["tracking_uri_scheme"] = "databricks"
 978          record = Record(
 979              event_name="genai_evaluate",
 980              timestamp_ns=1700000000000000000,
 981              status=Status.SUCCESS,
 982              params={"predict_fn_provided": True},
 983          )
 984  
 985          with (
 986              mock.patch(
 987                  "mlflow.telemetry.client.http_request",
 988                  return_value=mock.Mock(status_code=200),
 989              ) as mock_http,
 990              mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"),
 991              mock.patch(
 992                  "mlflow.tracking._tracking_service.utils.get_tracking_uri",
 993                  return_value="databricks",
 994              ),
 995          ):
 996              client._forward_to_databricks([record])
 997  
 998          event = mock_http.call_args.kwargs["json"]["events"][0]
 999          assert "params" not in event
1000          assert "params_json" in event
1001          assert json.loads(event["params_json"]) == {"predict_fn_provided": True}
1002  
1003  
1004  def test_forward_to_databricks_no_params_json_when_params_none():
1005      with TelemetryClient() as client:
1006          client.info["tracking_uri_scheme"] = "databricks"
1007          record = Record(
1008              event_name="test_event",
1009              timestamp_ns=time.time_ns(),
1010              status=Status.SUCCESS,
1011          )
1012  
1013          with (
1014              mock.patch(
1015                  "mlflow.telemetry.client.http_request",
1016                  return_value=mock.Mock(status_code=200),
1017              ) as mock_http,
1018              mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"),
1019              mock.patch(
1020                  "mlflow.tracking._tracking_service.utils.get_tracking_uri",
1021                  return_value="databricks",
1022              ),
1023          ):
1024              client._forward_to_databricks([record])
1025  
1026          event = mock_http.call_args.kwargs["json"]["events"][0]
1027          assert "params" not in event
1028          assert "params_json" not in event
1029  
1030  
1031  @pytest.mark.parametrize("status_code", list(UNRECOVERABLE_ERRORS))
1032  def test_forward_to_databricks_stops_on_unrecoverable_error(status_code):
1033      with TelemetryClient() as client:
1034          client.info["tracking_uri_scheme"] = "databricks"
1035          record = Record(
1036              event_name="test_event",
1037              timestamp_ns=time.time_ns(),
1038              status=Status.SUCCESS,
1039          )
1040  
1041          with (
1042              mock.patch(
1043                  "mlflow.telemetry.client.http_request",
1044                  return_value=mock.Mock(status_code=status_code),
1045              ),
1046              mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"),
1047              mock.patch(
1048                  "mlflow.tracking._tracking_service.utils.get_tracking_uri",
1049                  return_value="databricks",
1050              ),
1051          ):
1052              client._forward_to_databricks([record])
1053  
1054          assert client._is_stopped
1055          assert not client.is_active
1056  
1057  
1058  def test_forward_to_databricks_credential_failure_non_fatal():
1059      with TelemetryClient() as client:
1060          client.info["tracking_uri_scheme"] = "databricks"
1061          record = Record(
1062              event_name="test_event",
1063              timestamp_ns=time.time_ns(),
1064              status=Status.SUCCESS,
1065          )
1066  
1067          with (
1068              mock.patch(
1069                  "mlflow.utils.databricks_utils.get_databricks_host_creds",
1070                  side_effect=Exception("no creds"),
1071              ),
1072              mock.patch(
1073                  "mlflow.tracking._tracking_service.utils.get_tracking_uri",
1074                  return_value="databricks",
1075              ),
1076          ):
1077              client._forward_to_databricks([record])
1078  
1079          assert not client._is_stopped
1080  
1081  
1082  @pytest.mark.parametrize("error_code", RETRYABLE_ERRORS)
1083  def test_forward_to_databricks_retries_on_retryable_error(error_code):
1084      with TelemetryClient() as client:
1085          client.info["tracking_uri_scheme"] = "databricks"
1086          record = Record(
1087              event_name="test_event",
1088              timestamp_ns=time.time_ns(),
1089              status=Status.SUCCESS,
1090          )
1091  
1092          with (
1093              mock.patch(
1094                  "mlflow.telemetry.client.http_request",
1095                  side_effect=[
1096                      mock.Mock(status_code=error_code),
1097                      mock.Mock(status_code=error_code),
1098                      mock.Mock(status_code=200),
1099                  ],
1100              ) as mock_http,
1101              mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"),
1102              mock.patch(
1103                  "mlflow.tracking._tracking_service.utils.get_tracking_uri",
1104                  return_value="databricks",
1105              ),
1106              mock.patch("mlflow.telemetry.client.time.sleep"),
1107          ):
1108              client._forward_to_databricks([record])
1109  
1110          assert mock_http.call_count == 3
1111  
1112  
1113  @pytest.mark.no_mock_requests_get
1114  def test_disable_events(mock_requests):
1115      with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests_get:
1116          mock_requests_get.return_value = mock.Mock(
1117              status_code=200,
1118              json=mock.Mock(
1119                  return_value={
1120                      "mlflow_version": VERSION,
1121                      "disable_telemetry": False,
1122                      "ingestion_url": "http://localhost:9999",
1123                      "rollout_percentage": 100,
1124                      "disable_events": [CreateLoggedModelEvent.name],
1125                      "disable_sdks": [],
1126                  }
1127              ),
1128          )
1129          with (
1130              TelemetryClient() as telemetry_client,
1131              mock.patch(
1132                  "mlflow.telemetry.track.get_telemetry_client", return_value=telemetry_client
1133              ),
1134          ):
1135              telemetry_client.activate()
1136              telemetry_client._config_thread.join(timeout=1)
1137              mlflow.initialize_logged_model(name="model", tags={"key": "value"})
1138              telemetry_client.flush()
1139              assert len(mock_requests) == 0
1140  
1141              with mlflow.start_run():
1142                  pass
1143              validate_telemetry_record(
1144                  telemetry_client, mock_requests, CreateRunEvent.name, check_params=False
1145              )
1146  
1147  
1148  @pytest.mark.no_mock_requests_get
1149  def test_fetch_config_after_first_record():
1150      record = Record(
1151          event_name="test_event",
1152          timestamp_ns=time.time_ns(),
1153          status=Status.SUCCESS,
1154          duration_ms=0,
1155      )
1156  
1157      mock_response = mock.Mock(
1158          status_code=200,
1159          json=mock.Mock(
1160              return_value={
1161                  "mlflow_version": VERSION,
1162                  "disable_telemetry": False,
1163                  "ingestion_url": "http://localhost:9999",
1164                  "rollout_percentage": 70,
1165              }
1166          ),
1167      )
1168      with mock.patch(
1169          "mlflow.telemetry.client.requests.get", return_value=mock_response
1170      ) as mock_requests_get:
1171          with TelemetryClient() as telemetry_client:
1172              assert telemetry_client._is_config_fetched is False
1173              telemetry_client.add_record(record)
1174              telemetry_client._config_thread.join(timeout=1)
1175              assert telemetry_client._is_config_fetched is True
1176          mock_requests_get.assert_called_once()
1177  
1178  
1179  @pytest.mark.parametrize(
1180      "uri",
1181      [
1182          "http://localhost",
1183          "http://localhost:5000",
1184          "http://127.0.0.1",
1185          "http://127.0.0.1:5000/api/2.0/mlflow",
1186          "http://[::1]",
1187      ],
1188  )
1189  def test_is_localhost_uri_returns_true_for_localhost(uri):
1190      assert _is_localhost_uri(uri)
1191  
1192  
1193  @pytest.mark.parametrize(
1194      "uri",
1195      [
1196          "http://example.com",
1197          "http://example.com:5000",
1198          "https://mlflow.example.com",
1199          "http://192.168.1.1",
1200          "http://192.168.1.1:5000",
1201          "http://10.0.0.1:5000",
1202          "https://my-tracking-server.com/api/2.0/mlflow",
1203      ],
1204  )
1205  def test_is_localhost_uri_returns_false_for_remote(uri):
1206      assert _is_localhost_uri(uri) is False
1207  
1208  
1209  def test_is_localhost_uri_returns_none_for_empty_hostname():
1210      assert _is_localhost_uri("file:///tmp/mlruns") is None
1211  
1212  
1213  def test_is_localhost_uri_returns_none_on_parse_error():
1214      # urlparse doesn't raise on most inputs, but we test the fallback behavior
1215      # by mocking urlparse to raise
1216      with mock.patch("urllib.parse.urlparse", side_effect=ValueError("Invalid URI")):
1217          assert _is_localhost_uri("http://localhost") is None
1218  
1219  
1220  def test_is_workspace_enabled_included_in_telemetry_info(
1221      mock_telemetry_client: TelemetryClient, mock_requests, monkeypatch
1222  ):
1223      monkeypatch.setenv("MLFLOW_WORKSPACE", "my-workspace")
1224      record = Record(
1225          event_name="test_event",
1226          timestamp_ns=time.time_ns(),
1227          status=Status.SUCCESS,
1228      )
1229      mock_telemetry_client.add_record(record)
1230      mock_telemetry_client.flush()
1231      data = next(req["data"] for req in mock_requests if req["data"]["event_name"] == "test_event")
1232      assert data["ws_enabled"] is True
1233  
1234  
1235  def test_is_workspace_disabled_included_in_telemetry_info(
1236      mock_telemetry_client: TelemetryClient, mock_requests, monkeypatch
1237  ):
1238      monkeypatch.delenv("MLFLOW_WORKSPACE", raising=False)
1239      record = Record(
1240          event_name="test_event",
1241          timestamp_ns=time.time_ns(),
1242          status=Status.SUCCESS,
1243      )
1244      mock_telemetry_client.add_record(record)
1245      mock_telemetry_client.flush()
1246      data = next(req["data"] for req in mock_requests if req["data"]["event_name"] == "test_event")
1247      assert data["ws_enabled"] is False