test_client.py
1 import json 2 import logging 3 import threading 4 import time 5 import warnings 6 from unittest import mock 7 8 import pytest 9 10 import mlflow 11 from mlflow.environment_variables import _MLFLOW_TELEMETRY_SESSION_ID 12 from mlflow.telemetry.client import ( 13 BATCH_SIZE, 14 BATCH_TIME_INTERVAL_SECONDS, 15 MAX_QUEUE_SIZE, 16 MAX_WORKERS, 17 RETRYABLE_ERRORS, 18 UNRECOVERABLE_ERRORS, 19 TelemetryClient, 20 _is_localhost_uri, 21 get_telemetry_client, 22 ) 23 from mlflow.telemetry.events import CreateLoggedModelEvent, CreateRunEvent 24 from mlflow.telemetry.schemas import Record, SourceSDK, Status 25 from mlflow.utils.os import is_windows 26 from mlflow.version import IS_TRACING_SDK_ONLY, VERSION 27 28 from tests.telemetry.helper_functions import validate_telemetry_record 29 30 if not IS_TRACING_SDK_ONLY: 31 from mlflow.tracking._tracking_service.utils import _use_tracking_uri 32 33 34 def test_telemetry_client_initialization(mock_telemetry_client: TelemetryClient, mock_requests): 35 assert mock_telemetry_client.info is not None 36 assert mock_telemetry_client._queue.maxsize == MAX_QUEUE_SIZE 37 assert mock_telemetry_client._max_workers == MAX_WORKERS 38 assert mock_telemetry_client._batch_size == BATCH_SIZE 39 assert mock_telemetry_client._batch_time_interval == BATCH_TIME_INTERVAL_SECONDS 40 41 42 def test_telemetry_client_session_id( 43 mock_telemetry_client: TelemetryClient, mock_requests, monkeypatch 44 ): 45 monkeypatch.setenv(_MLFLOW_TELEMETRY_SESSION_ID.name, "test_session_id") 46 with TelemetryClient() as telemetry_client: 47 assert telemetry_client.info["session_id"] == "test_session_id" 48 monkeypatch.delenv(_MLFLOW_TELEMETRY_SESSION_ID.name, raising=False) 49 with TelemetryClient() as telemetry_client: 50 assert telemetry_client.info["session_id"] != "test_session_id" 51 52 53 def test_add_record_and_send(mock_telemetry_client: TelemetryClient, mock_requests): 54 # Create a test record 55 record = Record( 56 event_name="test_event", 57 timestamp_ns=time.time_ns(), 58 status=Status.SUCCESS, 59 ) 60 61 # Add record and wait for processing 62 mock_telemetry_client.add_record(record) 63 mock_telemetry_client.flush() 64 received_record = next( 65 req for req in mock_requests if req["data"]["event_name"] == "test_event" 66 ) 67 68 assert "data" in received_record 69 assert "partition-key" in received_record 70 71 data = received_record["data"] 72 assert data["event_name"] == "test_event" 73 assert data["status"] == "success" 74 75 76 def test_add_records_and_send(mock_telemetry_client: TelemetryClient, mock_requests): 77 # Pre-populate pending_records with 200 records 78 initial_records = [ 79 Record( 80 event_name=f"initial_{i}", 81 timestamp_ns=time.time_ns(), 82 status=Status.SUCCESS, 83 ) 84 for i in range(200) 85 ] 86 mock_telemetry_client.add_records(initial_records) 87 88 # We haven't hit the batch size limit yet, so expect no records to be sent 89 assert len(mock_telemetry_client._pending_records) == 200 90 assert len(mock_requests) == 0 91 92 # Add 1000 more records 93 # Expected behavior: 94 # - First 300 records fill to 500 -> send batch (200 + 300) to queue 95 # - Next 500 records -> send batch to queue 96 # - Last 200 records remain in pending (200 < 500) 97 additional_records = [ 98 Record( 99 event_name=f"additional_{i}", 100 timestamp_ns=time.time_ns(), 101 status=Status.SUCCESS, 102 ) 103 for i in range(1000) 104 ] 105 mock_telemetry_client.add_records(additional_records) 106 107 # Verify batching logic: 108 # - 2 batches should be in the queue 109 # - 200 records should remain in pending 110 assert mock_telemetry_client._queue.qsize() == 2 111 assert len(mock_telemetry_client._pending_records) == 200 112 113 # Flush to process queue and send the remaining partial batch 114 mock_telemetry_client.flush() 115 116 # Verify all 1200 records were sent 117 assert len(mock_requests) == 1200 118 event_names = {req["data"]["event_name"] for req in mock_requests} 119 assert all(f"initial_{i}" in event_names for i in range(200)) 120 assert all(f"additional_{i}" in event_names for i in range(1000)) 121 122 123 def test_record_with_session_and_installation_id( 124 mock_telemetry_client: TelemetryClient, mock_requests 125 ): 126 record = Record( 127 event_name="test_event", 128 timestamp_ns=time.time_ns(), 129 status=Status.SUCCESS, 130 session_id="session_id_override", 131 installation_id="installation_id_override", 132 ) 133 mock_telemetry_client.add_record(record) 134 mock_telemetry_client.flush() 135 assert mock_requests[0]["data"]["session_id"] == "session_id_override" 136 assert mock_requests[0]["data"]["installation_id"] == "installation_id_override" 137 138 record = Record( 139 event_name="test_event", 140 timestamp_ns=time.time_ns(), 141 status=Status.SUCCESS, 142 ) 143 mock_telemetry_client.add_record(record) 144 mock_telemetry_client.flush() 145 assert mock_requests[1]["data"]["session_id"] == mock_telemetry_client.info["session_id"] 146 assert ( 147 mock_requests[1]["data"]["installation_id"] == mock_telemetry_client.info["installation_id"] 148 ) 149 150 151 def test_batch_processing(mock_telemetry_client: TelemetryClient, mock_requests): 152 mock_telemetry_client._batch_size = 3 # Set small batch size for testing 153 154 # Add multiple records 155 for i in range(5): 156 record = Record( 157 event_name=f"test_event_{i}", 158 timestamp_ns=time.time_ns(), 159 status=Status.SUCCESS, 160 ) 161 mock_telemetry_client.add_record(record) 162 163 mock_telemetry_client.flush() 164 events = {req["data"]["event_name"] for req in mock_requests} 165 assert all(event_name in events for event_name in [f"test_event_{i}" for i in range(5)]) 166 167 168 def test_flush_functionality(mock_telemetry_client: TelemetryClient, mock_requests): 169 record = Record( 170 event_name="test_event", 171 timestamp_ns=time.time_ns(), 172 status=Status.SUCCESS, 173 ) 174 mock_telemetry_client.add_record(record) 175 176 mock_telemetry_client.flush() 177 events = {req["data"]["event_name"] for req in mock_requests} 178 assert record.event_name in events 179 180 181 def test_record_sent(mock_telemetry_client: TelemetryClient, mock_requests): 182 record_1 = Record( 183 event_name="test_event_1", 184 timestamp_ns=time.time_ns(), 185 status=Status.SUCCESS, 186 ) 187 mock_telemetry_client.add_record(record_1) 188 mock_telemetry_client.flush() 189 190 assert len(mock_requests) == 1 191 data = mock_requests[0]["data"] 192 assert data["event_name"] == record_1.event_name 193 assert data["status"] == "success" 194 195 session_id = data.get("session_id") 196 installation_id = data.get("installation_id") 197 assert session_id is not None 198 assert installation_id is not None 199 200 record_2 = Record( 201 event_name="test_event_2", 202 timestamp_ns=time.time_ns(), 203 status=Status.FAILURE, 204 ) 205 record_3 = Record( 206 event_name="test_event_3", 207 timestamp_ns=time.time_ns(), 208 status=Status.SUCCESS, 209 ) 210 mock_telemetry_client.add_record(record_2) 211 mock_telemetry_client.add_record(record_3) 212 mock_telemetry_client.flush() 213 assert len(mock_requests) == 3 214 215 # all record should have the same session id and installation id 216 assert {req["data"].get("session_id") for req in mock_requests} == {session_id} 217 assert {req["data"].get("installation_id") for req in mock_requests} == {installation_id} 218 219 220 def test_client_shutdown(mock_telemetry_client: TelemetryClient, mock_requests): 221 for _ in range(100): 222 record = Record( 223 event_name="test_event", 224 timestamp_ns=time.time_ns(), 225 status=Status.SUCCESS, 226 ) 227 mock_telemetry_client.add_record(record) 228 229 start_time = time.time() 230 mock_telemetry_client.flush(terminate=True) 231 end_time = time.time() 232 assert end_time - start_time < 0.1 233 events = {req["data"]["event_name"] for req in mock_requests} 234 assert "test_event" not in events 235 assert not mock_telemetry_client.is_active 236 237 238 @pytest.mark.parametrize( 239 "url", 240 [ 241 "http://127.0.0.1:9999/nonexistent", 242 "http://127.0.0.1:9999/unauthorized", 243 "http://127.0.0.1:9999/forbidden", 244 "http://127.0.0.1:9999/bad_request", 245 ], 246 ) 247 def test_telemetry_collection_stopped_on_error(mock_requests, mock_telemetry_client, url): 248 mock_telemetry_client.config.ingestion_url = url 249 250 # Add a record - should not crash 251 record = Record( 252 event_name="test_event", 253 timestamp_ns=time.time_ns(), 254 status=Status.SUCCESS, 255 ) 256 mock_telemetry_client.add_record(record) 257 258 mock_telemetry_client.flush(terminate=True) 259 260 assert mock_telemetry_client._is_stopped is True 261 assert mock_telemetry_client.is_active is False 262 requests_count = len(mock_requests) 263 assert requests_count <= 1 264 265 # add record after stopping should be no-op 266 mock_telemetry_client.add_record(record) 267 mock_telemetry_client.flush(terminate=True) 268 assert len(mock_requests) == requests_count 269 270 271 @pytest.mark.parametrize("error_code", [429, 500]) 272 @pytest.mark.parametrize("terminate", [True, False]) 273 def test_telemetry_retry_on_error(error_code, terminate): 274 record = Record( 275 event_name="test_event", 276 timestamp_ns=time.time_ns(), 277 status=Status.SUCCESS, 278 ) 279 280 class MockPostTracker: 281 def __init__(self): 282 self.count = 0 283 self.responses = [] 284 285 def mock_post(self, url, json=None, **kwargs): 286 self.count += 1 287 if self.count < 3: 288 return mock.Mock(status_code=error_code) 289 else: 290 self.responses.extend(json["records"]) 291 return mock.Mock(status_code=200) 292 293 tracker = MockPostTracker() 294 295 with ( 296 mock.patch("requests.post", side_effect=tracker.mock_post), 297 TelemetryClient() as telemetry_client, 298 ): 299 telemetry_client.add_record(record) 300 start_time = time.time() 301 telemetry_client.flush(terminate=terminate) 302 duration = time.time() - start_time 303 if terminate: 304 assert duration < 1.5 305 else: 306 assert duration < 2.5 307 308 if terminate: 309 assert tracker.responses == [] 310 else: 311 assert record.event_name in [resp["data"]["event_name"] for resp in tracker.responses] 312 313 314 @pytest.mark.parametrize("error_type", [ConnectionError, TimeoutError]) 315 @pytest.mark.parametrize("terminate", [True, False]) 316 def test_telemetry_retry_on_request_error(error_type, terminate): 317 record = Record( 318 event_name="test_event", 319 timestamp_ns=time.time_ns(), 320 status=Status.SUCCESS, 321 ) 322 323 class MockPostTracker: 324 def __init__(self): 325 self.count = 0 326 self.responses = [] 327 328 def mock_post(self, url, json=None, **kwargs): 329 self.count += 1 330 if self.count < 3: 331 raise error_type() 332 else: 333 self.responses.extend(json["records"]) 334 return mock.Mock(status_code=200) 335 336 tracker = MockPostTracker() 337 338 with ( 339 mock.patch("requests.post", side_effect=tracker.mock_post), 340 TelemetryClient() as telemetry_client, 341 ): 342 telemetry_client.add_record(record) 343 start_time = time.time() 344 telemetry_client.flush(terminate=terminate) 345 duration = time.time() - start_time 346 if terminate: 347 assert duration < 1.5 348 else: 349 assert duration < 2.5 350 351 # no retry when terminating 352 if terminate: 353 assert tracker.responses == [] 354 else: 355 assert record.event_name in [resp["data"]["event_name"] for resp in tracker.responses] 356 357 358 def test_stop_event(mock_telemetry_client: TelemetryClient, mock_requests): 359 mock_telemetry_client._is_stopped = True 360 361 record = Record( 362 event_name="test_event", 363 timestamp_ns=time.time_ns(), 364 status=Status.SUCCESS, 365 ) 366 mock_telemetry_client.add_record(record) 367 368 # we need to terminate since the threads are stopped 369 mock_telemetry_client.flush(terminate=True) 370 371 # No records should be sent since the client is stopped 372 events = {req["data"]["event_name"] for req in mock_requests} 373 assert record.event_name not in events 374 375 376 def test_concurrent_record_addition(mock_telemetry_client: TelemetryClient, mock_requests): 377 def add_records(thread_id): 378 for i in range(5): 379 record = Record( 380 event_name=f"test_event_{thread_id}_{i}", 381 timestamp_ns=time.time_ns(), 382 status=Status.SUCCESS, 383 ) 384 mock_telemetry_client.add_record(record) 385 time.sleep(0.1) 386 387 # Start multiple threads 388 threads = [] 389 for i in range(3): 390 thread = threading.Thread(name=f"telemetry-client-{i}", target=add_records, args=(i,)) 391 threads.append(thread) 392 thread.start() 393 394 # Wait for all threads to complete 395 for thread in threads: 396 thread.join() 397 398 mock_telemetry_client.flush() 399 400 # Should have received records from all threads 401 events = {req["data"]["event_name"] for req in mock_requests} 402 assert all( 403 event_name in events 404 for event_name in [ 405 f"test_event_{thread_id}_{i}" for thread_id in range(3) for i in range(5) 406 ] 407 ) 408 409 410 def test_telemetry_info_inclusion(mock_telemetry_client: TelemetryClient, mock_requests): 411 record = Record( 412 event_name="test_event", 413 timestamp_ns=time.time_ns(), 414 status=Status.SUCCESS, 415 ) 416 mock_telemetry_client.add_record(record) 417 418 mock_telemetry_client.flush() 419 420 # Verify telemetry info is included 421 data = next(req["data"] for req in mock_requests if req["data"]["event_name"] == "test_event") 422 423 # Check that telemetry info fields are present 424 assert mock_telemetry_client.info.items() <= data.items() 425 426 # Check that record fields are present 427 assert data["event_name"] == "test_event" 428 assert data["status"] == "success" 429 430 431 def test_partition_key(mock_telemetry_client: TelemetryClient, mock_requests): 432 record = Record( 433 event_name="test_event", 434 timestamp_ns=time.time_ns(), 435 status=Status.SUCCESS, 436 ) 437 mock_telemetry_client.add_record(record) 438 mock_telemetry_client.add_record(record) 439 440 mock_telemetry_client.flush() 441 442 # Verify partition key is random 443 assert mock_requests[0]["partition-key"] != mock_requests[1]["partition-key"] 444 445 446 def test_max_workers_setup(monkeypatch): 447 monkeypatch.setattr("mlflow.telemetry.client.MAX_WORKERS", 8) 448 with TelemetryClient() as telemetry_client: 449 assert telemetry_client._max_workers == 8 450 telemetry_client.activate() 451 # Test that correct number of threads are created 452 assert len(telemetry_client._consumer_threads) == 8 453 454 # Verify thread names 455 for i, thread in enumerate(telemetry_client._consumer_threads): 456 assert thread.name == f"MLflowTelemetryConsumer-{i}" 457 assert thread.daemon is True 458 459 460 def test_log_suppression_in_consumer_thread(mock_requests, capsys, mock_telemetry_client): 461 # Clear any existing captured output 462 capsys.readouterr() 463 464 # Log from main thread - this should be captured 465 logger = logging.getLogger("mlflow.telemetry.client") 466 logger.info("TEST LOG FROM MAIN THREAD") 467 468 original_process = mock_telemetry_client._process_records 469 470 def process_with_log(records): 471 logger.info("TEST LOG FROM CONSUMER THREAD") 472 original_process(records) 473 474 mock_telemetry_client._process_records = process_with_log 475 476 record = Record( 477 event_name="test_event", 478 timestamp_ns=time.time_ns(), 479 status=Status.SUCCESS, 480 ) 481 mock_telemetry_client.add_record(record) 482 483 mock_telemetry_client.flush() 484 events = {req["data"]["event_name"] for req in mock_requests} 485 assert record.event_name in events 486 487 captured = capsys.readouterr() 488 489 assert "TEST LOG FROM MAIN THREAD" in captured.err 490 # Verify that the consumer thread log was suppressed 491 assert "TEST LOG FROM CONSUMER THREAD" not in captured.err 492 493 494 def test_consumer_thread_no_stderr_output(mock_requests, capsys, mock_telemetry_client): 495 # Clear any existing captured output 496 capsys.readouterr() 497 498 # Log from main thread - this should be captured 499 logger = logging.getLogger("mlflow.telemetry.client") 500 logger.info("MAIN THREAD LOG BEFORE CLIENT") 501 502 # Clear output after client initialization to focus on consumer thread output 503 capsys.readouterr() 504 505 # Add multiple records to ensure consumer thread processes them 506 for i in range(5): 507 record = Record( 508 event_name=f"test_event_{i}", 509 timestamp_ns=time.time_ns(), 510 status=Status.SUCCESS, 511 ) 512 mock_telemetry_client.add_record(record) 513 514 mock_telemetry_client.flush() 515 # Wait for all records to be processed 516 events = {req["data"]["event_name"] for req in mock_requests} 517 assert all(event_name in events for event_name in [f"test_event_{i}" for i in range(5)]) 518 519 # Capture output after consumer thread has processed all records 520 captured = capsys.readouterr() 521 522 # Verify consumer thread produced no stderr output 523 assert captured.err == "" 524 525 # Log from main thread after processing - this should be captured 526 logger.info("MAIN THREAD LOG AFTER PROCESSING") 527 captured_after = capsys.readouterr() 528 assert "MAIN THREAD LOG AFTER PROCESSING" in captured_after.err 529 530 531 def test_batch_time_interval(mock_requests, monkeypatch): 532 monkeypatch.setattr("mlflow.telemetry.client.BATCH_TIME_INTERVAL_SECONDS", 1) 533 telemetry_client = TelemetryClient() 534 535 assert telemetry_client._batch_time_interval == 1 536 537 # Add first record 538 record1 = Record( 539 event_name="test_event_1", 540 timestamp_ns=time.time_ns(), 541 status=Status.SUCCESS, 542 ) 543 telemetry_client.add_record(record1) 544 assert len(telemetry_client._pending_records) == 1 545 546 events = {req["data"]["event_name"] for req in mock_requests} 547 assert "test_event_1" not in events 548 549 # Add second record before time interval 550 record2 = Record( 551 event_name="test_event_2", 552 timestamp_ns=time.time_ns(), 553 status=Status.SUCCESS, 554 ) 555 telemetry_client.add_record(record2) 556 assert len(telemetry_client._pending_records) == 2 557 558 # Wait for time interval to pass 559 time.sleep(1.5) 560 assert len(telemetry_client._pending_records) == 0 561 # records are sent due to time interval 562 events = {req["data"]["event_name"] for req in mock_requests} 563 assert "test_event_1" in events 564 assert "test_event_2" in events 565 566 record3 = Record( 567 event_name="test_event_3", 568 timestamp_ns=time.time_ns(), 569 status=Status.SUCCESS, 570 ) 571 telemetry_client.add_record(record3) 572 telemetry_client.flush() 573 574 # Verify all records were sent 575 event_names = {req["data"]["event_name"] for req in mock_requests} 576 assert all(env in event_names for env in ["test_event_1", "test_event_2", "test_event_3"]) 577 578 579 def test_set_telemetry_client_non_blocking(): 580 start_time = time.time() 581 with TelemetryClient() as telemetry_client: 582 assert time.time() - start_time < 1 583 assert telemetry_client is not None 584 time.sleep(1.1) 585 assert not any( 586 thread.name.startswith("GetTelemetryConfig") for thread in threading.enumerate() 587 ) 588 589 590 @pytest.mark.parametrize( 591 "mock_requests_return_value", 592 [ 593 mock.Mock(status_code=403), 594 mock.Mock( 595 status_code=200, 596 json=mock.Mock( 597 return_value={ 598 "mlflow_version": VERSION, 599 "disable_telemetry": True, 600 } 601 ), 602 ), 603 mock.Mock( 604 status_code=200, 605 json=mock.Mock( 606 return_value={ 607 "mlflow_version": "1.0.0", 608 "disable_telemetry": False, 609 "ingestion_url": "http://localhost:9999", 610 } 611 ), 612 ), 613 mock.Mock( 614 status_code=200, 615 json=mock.Mock( 616 return_value={ 617 "mlflow_version": VERSION, 618 "disable_telemetry": False, 619 "ingestion_url": "http://localhost:9999", 620 "rollout_percentage": 0, 621 } 622 ), 623 ), 624 mock.Mock( 625 status_code=200, 626 json=mock.Mock( 627 return_value={ 628 "mlflow_version": VERSION, 629 "disable_telemetry": False, 630 "ingestion_url": "http://localhost:9999", 631 "rollout_percentage": 70, 632 } 633 ), 634 ), 635 ], 636 ) 637 @pytest.mark.no_mock_requests_get 638 def test_client_get_config_none(mock_requests_return_value): 639 with ( 640 mock.patch("mlflow.telemetry.client.requests.get") as mock_requests, 641 mock.patch("random.randint", return_value=80), 642 ): 643 mock_requests.return_value = mock_requests_return_value 644 client = TelemetryClient() 645 client._get_config() 646 assert client.config is None 647 648 649 @pytest.mark.no_mock_requests_get 650 def test_client_get_config_not_none(): 651 with ( 652 mock.patch("mlflow.telemetry.client.requests.get") as mock_requests, 653 mock.patch("random.randint", return_value=50), 654 ): 655 mock_requests.return_value = mock.Mock( 656 status_code=200, 657 json=mock.Mock( 658 return_value={ 659 "mlflow_version": VERSION, 660 "disable_telemetry": False, 661 "ingestion_url": "http://localhost:9999", 662 "rollout_percentage": 70, 663 } 664 ), 665 ) 666 with TelemetryClient() as telemetry_client: 667 telemetry_client._get_config() 668 assert telemetry_client.config.ingestion_url == "http://localhost:9999" 669 assert telemetry_client.config.disable_events == set() 670 671 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests: 672 mock_requests.return_value = mock.Mock( 673 status_code=200, 674 json=mock.Mock( 675 return_value={ 676 "mlflow_version": VERSION, 677 "disable_telemetry": False, 678 "ingestion_url": "http://localhost:9999", 679 "rollout_percentage": 100, 680 } 681 ), 682 ) 683 with TelemetryClient() as telemetry_client: 684 telemetry_client._get_config() 685 assert telemetry_client.config.ingestion_url == "http://localhost:9999" 686 assert telemetry_client.config.disable_events == set() 687 688 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests: 689 mock_requests.return_value = mock.Mock( 690 status_code=200, 691 json=mock.Mock( 692 return_value={ 693 "mlflow_version": VERSION, 694 "disable_telemetry": False, 695 "ingestion_url": "http://localhost:9999", 696 "rollout_percentage": 100, 697 "disable_events": [], 698 "disable_sdks": ["mlflow-tracing"], 699 } 700 ), 701 ) 702 with ( 703 mock.patch( 704 "mlflow.telemetry.client.get_source_sdk", return_value=SourceSDK.MLFLOW_TRACING 705 ), 706 TelemetryClient() as telemetry_client, 707 ): 708 telemetry_client._get_config() 709 assert telemetry_client.config is None 710 711 with ( 712 mock.patch( 713 "mlflow.telemetry.client.get_source_sdk", return_value=SourceSDK.MLFLOW_SKINNY 714 ), 715 TelemetryClient() as telemetry_client, 716 ): 717 telemetry_client._get_config() 718 assert telemetry_client.config.ingestion_url == "http://localhost:9999" 719 assert telemetry_client.config.disable_events == set() 720 721 with ( 722 mock.patch("mlflow.telemetry.client.get_source_sdk", return_value=SourceSDK.MLFLOW), 723 TelemetryClient() as telemetry_client, 724 ): 725 telemetry_client._get_config() 726 assert telemetry_client.config.ingestion_url == "http://localhost:9999" 727 assert telemetry_client.config.disable_events == set() 728 729 730 @pytest.mark.no_mock_requests_get 731 @pytest.mark.skipif(is_windows(), reason="This test only passes on non-Windows") 732 def test_get_config_disable_non_windows(): 733 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests_get: 734 mock_requests_get.return_value = mock.Mock( 735 status_code=200, 736 json=mock.Mock( 737 return_value={ 738 "mlflow_version": VERSION, 739 "disable_telemetry": False, 740 "ingestion_url": "http://localhost:9999", 741 "rollout_percentage": 100, 742 "disable_os": ["linux", "darwin"], 743 } 744 ), 745 ) 746 with TelemetryClient() as telemetry_client: 747 telemetry_client._get_config() 748 assert telemetry_client.config is None 749 750 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests: 751 mock_requests.return_value = mock.Mock( 752 status_code=200, 753 json=mock.Mock( 754 return_value={ 755 "mlflow_version": VERSION, 756 "disable_telemetry": False, 757 "ingestion_url": "http://localhost:9999", 758 "rollout_percentage": 100, 759 "disable_os": ["win32"], 760 } 761 ), 762 ) 763 with TelemetryClient() as telemetry_client: 764 telemetry_client._get_config() 765 assert telemetry_client.config.ingestion_url == "http://localhost:9999" 766 assert telemetry_client.config.disable_events == set() 767 768 769 @pytest.mark.no_mock_requests_get 770 @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") 771 def test_get_config_windows(): 772 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests: 773 mock_requests.return_value = mock.Mock( 774 status_code=200, 775 json=mock.Mock( 776 return_value={ 777 "mlflow_version": VERSION, 778 "disable_telemetry": False, 779 "ingestion_url": "http://localhost:9999", 780 "rollout_percentage": 100, 781 "disable_os": ["win32"], 782 } 783 ), 784 ) 785 with TelemetryClient() as telemetry_client: 786 telemetry_client._get_config() 787 assert telemetry_client.config is None 788 789 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests: 790 mock_requests.return_value = mock.Mock( 791 status_code=200, 792 json=mock.Mock( 793 return_value={ 794 "mlflow_version": VERSION, 795 "disable_telemetry": False, 796 "ingestion_url": "http://localhost:9999", 797 "rollout_percentage": 100, 798 "disable_os": ["linux", "darwin"], 799 } 800 ), 801 ) 802 with TelemetryClient() as telemetry_client: 803 telemetry_client._get_config() 804 assert telemetry_client.config.ingestion_url == "http://localhost:9999" 805 assert telemetry_client.config.disable_events == set() 806 807 808 @pytest.mark.no_mock_requests_get 809 def test_client_set_to_none_if_config_none(): 810 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests: 811 mock_requests.return_value = mock.Mock( 812 status_code=200, 813 json=mock.Mock( 814 return_value={ 815 "mlflow_version": VERSION, 816 "disable_telemetry": True, 817 } 818 ), 819 ) 820 with TelemetryClient() as telemetry_client: 821 assert telemetry_client is not None 822 telemetry_client.activate() 823 telemetry_client._config_thread.join(timeout=3) 824 assert not telemetry_client._config_thread.is_alive() 825 assert telemetry_client.config is None 826 assert telemetry_client._is_config_fetched is True 827 assert telemetry_client._is_stopped 828 829 830 @pytest.mark.no_mock_requests_get 831 def test_records_not_dropped_when_fetching_config(mock_requests): 832 record = Record( 833 event_name="test_event", 834 timestamp_ns=time.time_ns(), 835 status=Status.SUCCESS, 836 duration_ms=0, 837 ) 838 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests_get: 839 mock_requests_get.return_value = mock.Mock( 840 status_code=200, 841 json=mock.Mock( 842 return_value={ 843 "mlflow_version": VERSION, 844 "disable_telemetry": False, 845 "ingestion_url": "http://localhost:9999", 846 "rollout_percentage": 100, 847 } 848 ), 849 ) 850 with TelemetryClient() as telemetry_client: 851 telemetry_client.activate() 852 # wait for config to be fetched 853 telemetry_client._config_thread.join(timeout=3) 854 telemetry_client.add_record(record) 855 telemetry_client.flush() 856 validate_telemetry_record( 857 telemetry_client, mock_requests, record.event_name, check_params=False 858 ) 859 860 861 @pytest.mark.no_mock_requests_get 862 @pytest.mark.parametrize("error_code", [400, 401, 403, 404, 412, 500, 502, 503, 504]) 863 def test_config_fetch_no_retry(mock_requests, error_code): 864 record = Record( 865 event_name="test_event", 866 timestamp_ns=time.time_ns(), 867 status=Status.SUCCESS, 868 ) 869 870 def mock_requests_get(*args, **kwargs): 871 time.sleep(1) 872 return mock.Mock(status_code=error_code) 873 874 with ( 875 mock.patch("mlflow.telemetry.client.requests.get", side_effect=mock_requests_get), 876 TelemetryClient() as telemetry_client, 877 ): 878 telemetry_client.add_record(record) 879 telemetry_client.flush() 880 events = [req["data"]["event_name"] for req in mock_requests] 881 assert record.event_name not in events 882 assert get_telemetry_client() is None 883 884 885 def test_warning_suppression_in_shutdown(recwarn, mock_telemetry_client: TelemetryClient): 886 def flush_mock(*args, **kwargs): 887 warnings.warn("test warning") 888 889 with mock.patch.object(mock_telemetry_client, "flush", flush_mock): 890 mock_telemetry_client._at_exit_callback() 891 assert len(recwarn) == 0 892 893 894 @pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason="Requires full tracking SDK") 895 @pytest.mark.parametrize("tracking_uri_scheme", ["databricks", "databricks-uc", "uc"]) 896 def test_databricks_tracking_uri_scheme_does_not_use_oss_path(mock_requests, tracking_uri_scheme): 897 record = Record( 898 event_name="test_event", 899 timestamp_ns=time.time_ns(), 900 status=Status.SUCCESS, 901 ) 902 903 with ( 904 _use_tracking_uri(f"{tracking_uri_scheme}://profile_name"), 905 mock.patch( 906 "mlflow.telemetry.client.http_request", 907 return_value=mock.Mock(status_code=200), 908 ), 909 mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"), 910 mock.patch("mlflow.telemetry.client._IS_MLFLOW_DEV_VERSION", False), 911 TelemetryClient() as telemetry_client, 912 ): 913 telemetry_client.add_record(record) 914 telemetry_client.flush(terminate=True) 915 # OSS ingestion path should not receive records 916 assert len(mock_requests) == 0 917 918 919 @pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason="Requires full tracking SDK") 920 @pytest.mark.parametrize("tracking_uri_scheme", ["databricks", "databricks-uc", "uc"]) 921 def test_databricks_end_to_end_forwarding(tracking_uri_scheme): 922 record = Record( 923 event_name="test_event", 924 timestamp_ns=time.time_ns(), 925 status=Status.SUCCESS, 926 duration_ms=42, 927 params={"key": "value"}, 928 ) 929 930 with ( 931 _use_tracking_uri(f"{tracking_uri_scheme}://profile_name"), 932 mock.patch( 933 "mlflow.telemetry.client.http_request", 934 return_value=mock.Mock(status_code=200), 935 ) as mock_http, 936 mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"), 937 mock.patch("mlflow.telemetry.client._IS_MLFLOW_DEV_VERSION", False), 938 TelemetryClient() as telemetry_client, 939 ): 940 telemetry_client.add_record(record) 941 telemetry_client.flush() 942 943 mock_http.assert_called_once() 944 payload = mock_http.call_args.kwargs["json"] 945 assert len(payload["events"]) == 1 946 event = payload["events"][0] 947 assert event["event_name"] == "test_event" 948 assert event["tracking_uri_scheme"] == tracking_uri_scheme 949 assert "params_json" in event 950 assert "params" not in event 951 952 953 def test_databricks_forwarding_disabled_for_dev_versions(): 954 record = Record( 955 event_name="test_event", 956 timestamp_ns=time.time_ns(), 957 status=Status.SUCCESS, 958 ) 959 960 with TelemetryClient() as client: 961 client.info["tracking_uri_scheme"] = "databricks" 962 963 with ( 964 mock.patch( 965 "mlflow.telemetry.client.http_request", 966 return_value=mock.Mock(status_code=200), 967 ) as mock_http, 968 mock.patch("mlflow.telemetry.client._IS_MLFLOW_DEV_VERSION", True), 969 ): 970 client._process_records([record]) 971 972 mock_http.assert_not_called() 973 974 975 def test_forward_to_databricks_params_json_serialization(): 976 with TelemetryClient() as client: 977 client.info["tracking_uri_scheme"] = "databricks" 978 record = Record( 979 event_name="genai_evaluate", 980 timestamp_ns=1700000000000000000, 981 status=Status.SUCCESS, 982 params={"predict_fn_provided": True}, 983 ) 984 985 with ( 986 mock.patch( 987 "mlflow.telemetry.client.http_request", 988 return_value=mock.Mock(status_code=200), 989 ) as mock_http, 990 mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"), 991 mock.patch( 992 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 993 return_value="databricks", 994 ), 995 ): 996 client._forward_to_databricks([record]) 997 998 event = mock_http.call_args.kwargs["json"]["events"][0] 999 assert "params" not in event 1000 assert "params_json" in event 1001 assert json.loads(event["params_json"]) == {"predict_fn_provided": True} 1002 1003 1004 def test_forward_to_databricks_no_params_json_when_params_none(): 1005 with TelemetryClient() as client: 1006 client.info["tracking_uri_scheme"] = "databricks" 1007 record = Record( 1008 event_name="test_event", 1009 timestamp_ns=time.time_ns(), 1010 status=Status.SUCCESS, 1011 ) 1012 1013 with ( 1014 mock.patch( 1015 "mlflow.telemetry.client.http_request", 1016 return_value=mock.Mock(status_code=200), 1017 ) as mock_http, 1018 mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"), 1019 mock.patch( 1020 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 1021 return_value="databricks", 1022 ), 1023 ): 1024 client._forward_to_databricks([record]) 1025 1026 event = mock_http.call_args.kwargs["json"]["events"][0] 1027 assert "params" not in event 1028 assert "params_json" not in event 1029 1030 1031 @pytest.mark.parametrize("status_code", list(UNRECOVERABLE_ERRORS)) 1032 def test_forward_to_databricks_stops_on_unrecoverable_error(status_code): 1033 with TelemetryClient() as client: 1034 client.info["tracking_uri_scheme"] = "databricks" 1035 record = Record( 1036 event_name="test_event", 1037 timestamp_ns=time.time_ns(), 1038 status=Status.SUCCESS, 1039 ) 1040 1041 with ( 1042 mock.patch( 1043 "mlflow.telemetry.client.http_request", 1044 return_value=mock.Mock(status_code=status_code), 1045 ), 1046 mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"), 1047 mock.patch( 1048 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 1049 return_value="databricks", 1050 ), 1051 ): 1052 client._forward_to_databricks([record]) 1053 1054 assert client._is_stopped 1055 assert not client.is_active 1056 1057 1058 def test_forward_to_databricks_credential_failure_non_fatal(): 1059 with TelemetryClient() as client: 1060 client.info["tracking_uri_scheme"] = "databricks" 1061 record = Record( 1062 event_name="test_event", 1063 timestamp_ns=time.time_ns(), 1064 status=Status.SUCCESS, 1065 ) 1066 1067 with ( 1068 mock.patch( 1069 "mlflow.utils.databricks_utils.get_databricks_host_creds", 1070 side_effect=Exception("no creds"), 1071 ), 1072 mock.patch( 1073 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 1074 return_value="databricks", 1075 ), 1076 ): 1077 client._forward_to_databricks([record]) 1078 1079 assert not client._is_stopped 1080 1081 1082 @pytest.mark.parametrize("error_code", RETRYABLE_ERRORS) 1083 def test_forward_to_databricks_retries_on_retryable_error(error_code): 1084 with TelemetryClient() as client: 1085 client.info["tracking_uri_scheme"] = "databricks" 1086 record = Record( 1087 event_name="test_event", 1088 timestamp_ns=time.time_ns(), 1089 status=Status.SUCCESS, 1090 ) 1091 1092 with ( 1093 mock.patch( 1094 "mlflow.telemetry.client.http_request", 1095 side_effect=[ 1096 mock.Mock(status_code=error_code), 1097 mock.Mock(status_code=error_code), 1098 mock.Mock(status_code=200), 1099 ], 1100 ) as mock_http, 1101 mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"), 1102 mock.patch( 1103 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 1104 return_value="databricks", 1105 ), 1106 mock.patch("mlflow.telemetry.client.time.sleep"), 1107 ): 1108 client._forward_to_databricks([record]) 1109 1110 assert mock_http.call_count == 3 1111 1112 1113 @pytest.mark.no_mock_requests_get 1114 def test_disable_events(mock_requests): 1115 with mock.patch("mlflow.telemetry.client.requests.get") as mock_requests_get: 1116 mock_requests_get.return_value = mock.Mock( 1117 status_code=200, 1118 json=mock.Mock( 1119 return_value={ 1120 "mlflow_version": VERSION, 1121 "disable_telemetry": False, 1122 "ingestion_url": "http://localhost:9999", 1123 "rollout_percentage": 100, 1124 "disable_events": [CreateLoggedModelEvent.name], 1125 "disable_sdks": [], 1126 } 1127 ), 1128 ) 1129 with ( 1130 TelemetryClient() as telemetry_client, 1131 mock.patch( 1132 "mlflow.telemetry.track.get_telemetry_client", return_value=telemetry_client 1133 ), 1134 ): 1135 telemetry_client.activate() 1136 telemetry_client._config_thread.join(timeout=1) 1137 mlflow.initialize_logged_model(name="model", tags={"key": "value"}) 1138 telemetry_client.flush() 1139 assert len(mock_requests) == 0 1140 1141 with mlflow.start_run(): 1142 pass 1143 validate_telemetry_record( 1144 telemetry_client, mock_requests, CreateRunEvent.name, check_params=False 1145 ) 1146 1147 1148 @pytest.mark.no_mock_requests_get 1149 def test_fetch_config_after_first_record(): 1150 record = Record( 1151 event_name="test_event", 1152 timestamp_ns=time.time_ns(), 1153 status=Status.SUCCESS, 1154 duration_ms=0, 1155 ) 1156 1157 mock_response = mock.Mock( 1158 status_code=200, 1159 json=mock.Mock( 1160 return_value={ 1161 "mlflow_version": VERSION, 1162 "disable_telemetry": False, 1163 "ingestion_url": "http://localhost:9999", 1164 "rollout_percentage": 70, 1165 } 1166 ), 1167 ) 1168 with mock.patch( 1169 "mlflow.telemetry.client.requests.get", return_value=mock_response 1170 ) as mock_requests_get: 1171 with TelemetryClient() as telemetry_client: 1172 assert telemetry_client._is_config_fetched is False 1173 telemetry_client.add_record(record) 1174 telemetry_client._config_thread.join(timeout=1) 1175 assert telemetry_client._is_config_fetched is True 1176 mock_requests_get.assert_called_once() 1177 1178 1179 @pytest.mark.parametrize( 1180 "uri", 1181 [ 1182 "http://localhost", 1183 "http://localhost:5000", 1184 "http://127.0.0.1", 1185 "http://127.0.0.1:5000/api/2.0/mlflow", 1186 "http://[::1]", 1187 ], 1188 ) 1189 def test_is_localhost_uri_returns_true_for_localhost(uri): 1190 assert _is_localhost_uri(uri) 1191 1192 1193 @pytest.mark.parametrize( 1194 "uri", 1195 [ 1196 "http://example.com", 1197 "http://example.com:5000", 1198 "https://mlflow.example.com", 1199 "http://192.168.1.1", 1200 "http://192.168.1.1:5000", 1201 "http://10.0.0.1:5000", 1202 "https://my-tracking-server.com/api/2.0/mlflow", 1203 ], 1204 ) 1205 def test_is_localhost_uri_returns_false_for_remote(uri): 1206 assert _is_localhost_uri(uri) is False 1207 1208 1209 def test_is_localhost_uri_returns_none_for_empty_hostname(): 1210 assert _is_localhost_uri("file:///tmp/mlruns") is None 1211 1212 1213 def test_is_localhost_uri_returns_none_on_parse_error(): 1214 # urlparse doesn't raise on most inputs, but we test the fallback behavior 1215 # by mocking urlparse to raise 1216 with mock.patch("urllib.parse.urlparse", side_effect=ValueError("Invalid URI")): 1217 assert _is_localhost_uri("http://localhost") is None 1218 1219 1220 def test_is_workspace_enabled_included_in_telemetry_info( 1221 mock_telemetry_client: TelemetryClient, mock_requests, monkeypatch 1222 ): 1223 monkeypatch.setenv("MLFLOW_WORKSPACE", "my-workspace") 1224 record = Record( 1225 event_name="test_event", 1226 timestamp_ns=time.time_ns(), 1227 status=Status.SUCCESS, 1228 ) 1229 mock_telemetry_client.add_record(record) 1230 mock_telemetry_client.flush() 1231 data = next(req["data"] for req in mock_requests if req["data"]["event_name"] == "test_event") 1232 assert data["ws_enabled"] is True 1233 1234 1235 def test_is_workspace_disabled_included_in_telemetry_info( 1236 mock_telemetry_client: TelemetryClient, mock_requests, monkeypatch 1237 ): 1238 monkeypatch.delenv("MLFLOW_WORKSPACE", raising=False) 1239 record = Record( 1240 event_name="test_event", 1241 timestamp_ns=time.time_ns(), 1242 status=Status.SUCCESS, 1243 ) 1244 mock_telemetry_client.add_record(record) 1245 mock_telemetry_client.flush() 1246 data = next(req["data"] for req in mock_requests if req["data"]["event_name"] == "test_event") 1247 assert data["ws_enabled"] is False