test_e2e.py
1 import contextlib 2 import os 3 import subprocess 4 import sys 5 import time 6 from dataclasses import dataclass 7 from pathlib import Path 8 from typing import Any, Generator 9 10 import psutil 11 import pytest 12 import requests 13 from cryptography.fernet import Fernet 14 15 from mlflow import MlflowClient 16 from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent 17 18 from tests.helper_functions import get_safe_port 19 from tests.webhooks.app import WEBHOOK_SECRET 20 21 22 @dataclass 23 class WebhookLogEntry: 24 endpoint: str 25 headers: dict[str, str] 26 status_code: int 27 payload: dict[str, Any] 28 error: str | None = None 29 attempt: int | None = None 30 31 32 def wait_until_ready(health_endpoint: str, max_attempts: int = 10) -> None: 33 for _ in range(max_attempts): 34 try: 35 resp = requests.get(health_endpoint, timeout=2) 36 if resp.status_code == 200: 37 return 38 except requests.RequestException: 39 time.sleep(1) 40 raise RuntimeError(f"Failed to start server at {health_endpoint}") 41 42 43 @contextlib.contextmanager 44 def _run_mlflow_server(tmp_path: Path) -> Generator[str, None, None]: 45 port = get_safe_port() 46 backend_store_uri = f"sqlite:///{tmp_path / 'mlflow.db'}" 47 artifact_root = (tmp_path / "artifacts").as_uri() 48 with subprocess.Popen( 49 [ 50 sys.executable, 51 "-m", 52 "mlflow", 53 "server", 54 f"--port={port}", 55 f"--backend-store-uri={backend_store_uri}", 56 f"--default-artifact-root={artifact_root}", 57 ], 58 cwd=tmp_path, 59 env=( 60 os.environ.copy() 61 | { 62 "MLFLOW_WEBHOOK_ALLOWED_SCHEMES": "http", 63 "MLFLOW_WEBHOOK_SECRET_ENCRYPTION_KEY": Fernet.generate_key().decode(), 64 "MLFLOW_WEBHOOK_REQUEST_MAX_RETRIES": "3", 65 "MLFLOW_WEBHOOK_REQUEST_TIMEOUT": "10", 66 "MLFLOW_WEBHOOK_CACHE_TTL": "0", # Disable caching for tests 67 "MLFLOW_WEBHOOK_ALLOW_PRIVATE_IPS": "true", # Allow localhost in e2e tests 68 "MLFLOW_SERVER_ENABLE_JOB_EXECUTION": "false", # Not needed for webhook tests 69 } 70 ), 71 ) as prc: 72 try: 73 url = f"http://localhost:{port}" 74 wait_until_ready(f"{url}/health") 75 yield url 76 finally: 77 # Kill the gunicorn processes spawned by mlflow server 78 try: 79 proc = psutil.Process(prc.pid) 80 except psutil.NoSuchProcess: 81 # Handle case where the process did not start correctly 82 pass 83 else: 84 for child in proc.children(recursive=True): 85 child.terminate() 86 87 # Kill the mlflow server process 88 prc.terminate() 89 90 91 class AppClient: 92 def __init__(self, base: str) -> None: 93 self._base = base 94 95 def get_url(self, endpoint: str) -> str: 96 return f"{self._base}{endpoint}" 97 98 def reset(self) -> None: 99 """Reset both logs and counters""" 100 resp = requests.post(self.get_url("/reset")) 101 resp.raise_for_status() 102 103 def get_logs(self) -> list[WebhookLogEntry]: 104 response = requests.get(self.get_url("/logs")) 105 response.raise_for_status() 106 logs_data = response.json().get("logs", []) 107 return [WebhookLogEntry(**log_data) for log_data in logs_data] 108 109 def wait_for_logs(self, expected_count: int, timeout: float = 5.0) -> list[WebhookLogEntry]: 110 """Wait for webhooks to be delivered with a timeout.""" 111 start_time = time.time() 112 while time.time() - start_time < timeout: 113 logs = self.get_logs() 114 if len(logs) >= expected_count: 115 return logs 116 time.sleep(0.1) 117 # Raise timeout error if expected count not reached 118 logs = self.get_logs() 119 raise TimeoutError( 120 f"Timeout waiting for {expected_count} webhook logs. " 121 f"Got {len(logs)} logs after {timeout}s timeout." 122 ) 123 124 125 @contextlib.contextmanager 126 def _run_app(tmp_path: Path) -> Generator[AppClient, None, None]: 127 port = get_safe_port() 128 app_path = Path(__file__).parent / "app.py" 129 with subprocess.Popen( 130 [ 131 sys.executable, 132 app_path, 133 str(port), 134 ], 135 cwd=tmp_path, 136 ) as prc: 137 try: 138 url = f"http://localhost:{port}" 139 wait_until_ready(f"{url}/health") 140 yield AppClient(url) 141 finally: 142 prc.terminate() 143 144 145 @pytest.fixture(scope="module") 146 def app_client(tmp_path_factory: pytest.TempPathFactory) -> Generator[AppClient, None, None]: 147 tmp_path = tmp_path_factory.mktemp("app") 148 with _run_app(tmp_path) as client: 149 yield client 150 151 152 @pytest.fixture(scope="module") 153 def mlflow_server( 154 app_client: AppClient, tmp_path_factory: pytest.TempPathFactory 155 ) -> Generator[str, None, None]: 156 tmp_path = tmp_path_factory.mktemp("mlflow_server") 157 with _run_mlflow_server(tmp_path) as url: 158 yield url 159 160 161 @pytest.fixture(scope="module") 162 def mlflow_client(mlflow_server: str) -> MlflowClient: 163 with pytest.MonkeyPatch.context() as mp: 164 # Disable retries to fail fast 165 mp.setenv("MLFLOW_HTTP_REQUEST_MAX_RETRIES", "0") 166 return MlflowClient(tracking_uri=mlflow_server, registry_uri=mlflow_server) 167 168 169 @pytest.fixture(autouse=True) 170 def cleanup(mlflow_client: MlflowClient, app_client: AppClient) -> Generator[None, None, None]: 171 yield 172 173 for webhook in mlflow_client.list_webhooks(): 174 mlflow_client.delete_webhook(webhook.webhook_id) 175 176 app_client.reset() 177 178 179 def test_registered_model_created(mlflow_client: MlflowClient, app_client: AppClient) -> None: 180 mlflow_client.create_webhook( 181 name="registered_model_created", 182 url=app_client.get_url("/insecure-webhook"), 183 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 184 ) 185 registered_model = mlflow_client.create_registered_model( 186 name="test_name", 187 description="test_description", 188 tags={"test_tag_key": "test_tag_value"}, 189 ) 190 logs = app_client.wait_for_logs(expected_count=1) 191 assert len(logs) == 1 192 assert logs[0].endpoint == "/insecure-webhook" 193 assert logs[0].payload == { 194 "name": registered_model.name, 195 "description": registered_model.description, 196 "tags": registered_model.tags, 197 } 198 199 200 def test_model_version_created(mlflow_client: MlflowClient, app_client: AppClient) -> None: 201 mlflow_client.create_webhook( 202 name="model_version_created", 203 url=app_client.get_url("/insecure-webhook"), 204 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 205 ) 206 registered_model = mlflow_client.create_registered_model(name="model_version_created") 207 model_version = mlflow_client.create_model_version( 208 name=registered_model.name, 209 source="s3://bucket/path/to/model", 210 run_id="1234567890abcdef", 211 tags={"test_tag_key": "test_tag_value"}, 212 description="test_description", 213 ) 214 logs = app_client.wait_for_logs(expected_count=1) 215 assert len(logs) == 1 216 assert logs[0].endpoint == "/insecure-webhook" 217 assert logs[0].payload == { 218 "name": registered_model.name, 219 "version": model_version.version, 220 "source": "s3://bucket/path/to/model", 221 "run_id": "1234567890abcdef", 222 "description": "test_description", 223 "tags": {"test_tag_key": "test_tag_value"}, 224 } 225 226 227 def test_model_version_tag_set(mlflow_client: MlflowClient, app_client: AppClient) -> None: 228 mlflow_client.create_webhook( 229 name="model_version_tag_set", 230 url=app_client.get_url("/insecure-webhook"), 231 events=[WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.SET)], 232 ) 233 registered_model = mlflow_client.create_registered_model(name="model_version_tag_set") 234 model_version = mlflow_client.create_model_version( 235 name=registered_model.name, 236 source="s3://bucket/path/to/model", 237 run_id="1234567890abcdef", 238 ) 239 mlflow_client.set_model_version_tag( 240 name=model_version.name, 241 version=model_version.version, 242 key="test_tag_key", 243 value="new_value", 244 ) 245 logs = app_client.wait_for_logs(expected_count=1) 246 assert len(logs) == 1 247 assert logs[0].endpoint == "/insecure-webhook" 248 assert logs[0].payload == { 249 "name": "model_version_tag_set", 250 "version": model_version.version, 251 "key": "test_tag_key", 252 "value": "new_value", 253 } 254 255 256 def test_model_version_tag_deleted(mlflow_client: MlflowClient, app_client: AppClient) -> None: 257 mlflow_client.create_webhook( 258 name="model_version_tag_deleted", 259 url=app_client.get_url("/insecure-webhook"), 260 events=[WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.DELETED)], 261 ) 262 registered_model = mlflow_client.create_registered_model(name="model_version_tag_deleted") 263 model_version = mlflow_client.create_model_version( 264 name=registered_model.name, 265 source="s3://bucket/path/to/model", 266 run_id="1234567890abcdef", 267 tags={"test_tag_key": "test_tag_value"}, 268 ) 269 mlflow_client.set_model_version_tag( 270 name=model_version.name, 271 version=model_version.version, 272 key="test_tag_key", 273 value="new_value", 274 ) 275 mlflow_client.delete_model_version_tag( 276 name=model_version.name, version=model_version.version, key="test_tag_key" 277 ) 278 logs = app_client.wait_for_logs(expected_count=1) 279 assert len(logs) == 1 280 assert logs[0].endpoint == "/insecure-webhook" 281 assert logs[0].payload == { 282 "name": registered_model.name, 283 "version": model_version.version, 284 "key": "test_tag_key", 285 } 286 287 288 def test_model_version_alias_created(mlflow_client: MlflowClient, app_client: AppClient) -> None: 289 mlflow_client.create_webhook( 290 name="model_version_alias_created", 291 url=app_client.get_url("/insecure-webhook"), 292 events=[WebhookEvent(WebhookEntity.MODEL_VERSION_ALIAS, WebhookAction.CREATED)], 293 ) 294 registered_model = mlflow_client.create_registered_model(name="model_version_alias_created") 295 model_version = mlflow_client.create_model_version( 296 name=registered_model.name, 297 source="s3://bucket/path/to/model", 298 run_id="1234567890abcdef", 299 tags={"test_tag_key": "test_tag_value"}, 300 description="test_description", 301 ) 302 mlflow_client.set_registered_model_alias( 303 name=model_version.name, version=model_version.version, alias="test_alias" 304 ) 305 logs = app_client.wait_for_logs(expected_count=1) 306 assert len(logs) == 1 307 assert logs[0].endpoint == "/insecure-webhook" 308 assert logs[0].payload == { 309 "name": registered_model.name, 310 "version": model_version.version, 311 "alias": "test_alias", 312 } 313 314 315 def test_model_version_alias_deleted(mlflow_client: MlflowClient, app_client: AppClient) -> None: 316 mlflow_client.create_webhook( 317 name="model_version_alias_deleted", 318 url=app_client.get_url("/insecure-webhook"), 319 events=[WebhookEvent(WebhookEntity.MODEL_VERSION_ALIAS, WebhookAction.DELETED)], 320 ) 321 registered_model = mlflow_client.create_registered_model(name="model_version_alias_deleted") 322 model_version = mlflow_client.create_model_version( 323 name=registered_model.name, 324 source="s3://bucket/path/to/model", 325 run_id="1234567890abcdef", 326 tags={"test_tag_key": "test_tag_value"}, 327 description="test_description", 328 ) 329 mlflow_client.set_registered_model_alias( 330 name=model_version.name, version=model_version.version, alias="test_alias" 331 ) 332 mlflow_client.delete_registered_model_alias(name=model_version.name, alias="test_alias") 333 logs = app_client.wait_for_logs(expected_count=1) 334 assert len(logs) == 1 335 assert logs[0].endpoint == "/insecure-webhook" 336 assert logs[0].payload == { 337 "name": registered_model.name, 338 "alias": "test_alias", 339 } 340 341 342 def test_webhook_with_secret(mlflow_client: MlflowClient, app_client: AppClient) -> None: 343 # Create webhook with secret that matches the one in app.py 344 mlflow_client.create_webhook( 345 name="secure_webhook", 346 url=app_client.get_url("/secure-webhook"), 347 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 348 secret=WEBHOOK_SECRET, 349 ) 350 351 registered_model = mlflow_client.create_registered_model( 352 name="test_hmac_model", 353 description="Testing HMAC signature", 354 tags={"env": "test"}, 355 ) 356 357 logs = app_client.wait_for_logs(expected_count=1) 358 assert len(logs) == 1 359 assert logs[0].endpoint == "/secure-webhook" 360 assert logs[0].payload == { 361 "name": registered_model.name, 362 "description": registered_model.description, 363 "tags": registered_model.tags, 364 } 365 assert logs[0].status_code == 200 366 # HTTP headers are case-insensitive and FastAPI normalizes them to lowercase 367 assert "x-mlflow-signature" in logs[0].headers 368 assert logs[0].headers["x-mlflow-signature"].startswith("v1,") 369 370 371 def test_webhook_with_wrong_secret(mlflow_client: MlflowClient, app_client: AppClient) -> None: 372 # Create webhook with wrong secret that doesn't match the one in app.py 373 mlflow_client.create_webhook( 374 name="wrong_secret_webhook", 375 url=app_client.get_url("/secure-webhook"), 376 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 377 secret="wrong-secret", # This doesn't match WEBHOOK_SECRET in app.py 378 ) 379 380 # This should fail at the webhook endpoint due to signature mismatch 381 # But MLflow will still create the registered model 382 mlflow_client.create_registered_model( 383 name="test_wrong_hmac", 384 description="Testing wrong HMAC signature", 385 ) 386 387 # The webhook request should have failed, but error should be logged 388 logs = app_client.wait_for_logs(expected_count=1) 389 assert len(logs) == 1 390 assert logs[0].endpoint == "/secure-webhook" 391 assert logs[0].error == "Invalid signature" 392 assert logs[0].status_code == 401 393 394 395 def test_webhook_without_secret_to_secure_endpoint( 396 mlflow_client: MlflowClient, app_client: AppClient 397 ) -> None: 398 # Create webhook without secret pointing to secure endpoint 399 mlflow_client.create_webhook( 400 name="no_secret_to_secure", 401 url=app_client.get_url("/secure-webhook"), 402 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 403 # No secret provided 404 ) 405 406 mlflow_client.create_registered_model( 407 name="test_no_secret_to_secure", 408 description="Testing no secret to secure endpoint", 409 ) 410 411 # The webhook request should fail due to missing signature, but error should be logged 412 logs = app_client.wait_for_logs(expected_count=1) 413 assert len(logs) == 1 414 assert logs[0].endpoint == "/secure-webhook" 415 assert logs[0].error == "Missing signature header" 416 assert logs[0].status_code == 400 417 418 419 def test_webhook_test_insecure_endpoint(mlflow_client: MlflowClient, app_client: AppClient) -> None: 420 # Create webhook for testing 421 webhook = mlflow_client.create_webhook( 422 name="test_webhook", 423 url=app_client.get_url("/insecure-webhook"), 424 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 425 ) 426 427 # Test the webhook 428 result = mlflow_client.test_webhook(webhook.webhook_id) 429 430 # Check that the test was successful 431 assert result.success is True 432 assert result.response_status == 200 433 assert result.error_message is None 434 435 # Check that the test payload was received 436 logs = app_client.wait_for_logs(expected_count=1) 437 assert len(logs) == 1 438 assert logs[0].endpoint == "/insecure-webhook" 439 assert logs[0].payload == { 440 "name": "example_model", 441 "version": "1", 442 "source": "models:/123", 443 "run_id": "abcd1234abcd5678", 444 "tags": {"example_key": "example_value"}, 445 "description": "An example model version", 446 } 447 448 449 def test_webhook_test_secure_endpoint(mlflow_client: MlflowClient, app_client: AppClient) -> None: 450 # Create webhook with secret for testing 451 webhook = mlflow_client.create_webhook( 452 name="test_secure_webhook", 453 url=app_client.get_url("/secure-webhook"), 454 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 455 secret=WEBHOOK_SECRET, 456 ) 457 458 # Test the webhook 459 result = mlflow_client.test_webhook(webhook.webhook_id) 460 461 # Check that the test was successful 462 assert result.success is True 463 assert result.response_status == 200 464 assert result.error_message is None 465 466 # Check that the test payload was received with proper signature 467 logs = app_client.wait_for_logs(expected_count=1) 468 assert len(logs) == 1 469 assert logs[0].endpoint == "/secure-webhook" 470 assert logs[0].payload == { 471 "name": "example_model", 472 "tags": {"example_key": "example_value"}, 473 "description": "An example registered model", 474 } 475 476 assert logs[0].status_code == 200 477 assert "x-mlflow-signature" in logs[0].headers 478 assert logs[0].headers["x-mlflow-signature"].startswith("v1,") 479 480 481 def test_webhook_test_with_specific_event( 482 mlflow_client: MlflowClient, app_client: AppClient 483 ) -> None: 484 # Create webhook that supports multiple events 485 webhook = mlflow_client.create_webhook( 486 name="multi_event_webhook", 487 url=app_client.get_url("/insecure-webhook"), 488 events=[ 489 WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED), 490 WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED), 491 WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.SET), 492 ], 493 ) 494 495 # Test with a specific event (not the first one) 496 result = mlflow_client.test_webhook( 497 webhook.webhook_id, event=WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.SET) 498 ) 499 500 # Check that the test was successful 501 assert result.success is True 502 assert result.response_status == 200 503 assert result.error_message is None 504 505 # Check that the correct payload was sent 506 logs = app_client.wait_for_logs(expected_count=1) 507 assert len(logs) == 1 508 assert logs[0].endpoint == "/insecure-webhook" 509 assert logs[0].payload == { 510 "name": "example_model", 511 "version": "1", 512 "key": "example_key", 513 "value": "example_value", 514 } 515 516 517 def test_webhook_test_failed_endpoint(mlflow_client: MlflowClient, app_client: AppClient) -> None: 518 # Create webhook pointing to non-existent endpoint 519 webhook = mlflow_client.create_webhook( 520 name="failed_webhook", 521 url=app_client.get_url("/nonexistent-endpoint"), 522 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 523 ) 524 525 # Test the webhook 526 result = mlflow_client.test_webhook(webhook.webhook_id) 527 528 # Check that the test failed 529 assert result.success is False 530 assert result.response_status == 404 531 assert result.error_message is None # No error message for HTTP errors 532 assert result.response_body is not None # Should contain error response 533 534 535 def test_webhook_test_with_wrong_secret(mlflow_client: MlflowClient, app_client: AppClient) -> None: 536 # Create webhook with wrong secret 537 webhook = mlflow_client.create_webhook( 538 name="wrong_secret_test_webhook", 539 url=app_client.get_url("/secure-webhook"), 540 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 541 secret="wrong-secret", 542 ) 543 544 # Test the webhook 545 result = mlflow_client.test_webhook(webhook.webhook_id) 546 547 # Check that the test failed due to wrong signature 548 assert result.success is False 549 assert result.response_status == 401 550 assert result.error_message is None 551 552 # Check that error was logged 553 logs = app_client.wait_for_logs(expected_count=1) 554 assert len(logs) == 1 555 assert logs[0].endpoint == "/secure-webhook" 556 assert logs[0].error == "Invalid signature" 557 assert logs[0].status_code == 401 558 559 560 def test_webhook_retry_on_5xx_error(mlflow_client: MlflowClient, app_client: AppClient) -> None: 561 # Create webhook pointing to flaky endpoint 562 mlflow_client.create_webhook( 563 name="retry_test_webhook", 564 url=app_client.get_url("/flaky-webhook"), 565 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 566 ) 567 568 # Create a registered model to trigger the webhook 569 registered_model = mlflow_client.create_registered_model( 570 name="test_retry_model", 571 description="Testing retry logic", 572 ) 573 574 logs = app_client.wait_for_logs(expected_count=3, timeout=15) 575 576 # First two attempts should fail with 500 577 assert logs[0].endpoint == "/flaky-webhook" 578 assert logs[0].status_code == 500 579 assert logs[0].error == "Server error (will retry)" 580 assert logs[0].payload["name"] == registered_model.name 581 582 assert logs[1].endpoint == "/flaky-webhook" 583 assert logs[1].status_code == 500 584 assert logs[1].error == "Server error (will retry)" 585 586 # Third attempt should succeed 587 assert logs[2].endpoint == "/flaky-webhook" 588 assert logs[2].status_code == 200 589 assert logs[2].error is None 590 assert logs[2].payload["name"] == registered_model.name 591 592 593 def test_webhook_retry_on_429_rate_limit( 594 mlflow_client: MlflowClient, app_client: AppClient 595 ) -> None: 596 # Create webhook pointing to rate-limited endpoint 597 mlflow_client.create_webhook( 598 name="rate_limit_test_webhook", 599 url=app_client.get_url("/rate-limited-webhook"), 600 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 601 ) 602 603 # Create a registered model to trigger the webhook 604 registered_model = mlflow_client.create_registered_model( 605 name="test_rate_limit_model", 606 description="Testing 429 retry logic", 607 ) 608 609 logs = app_client.wait_for_logs(expected_count=2, timeout=10) 610 611 # First attempt should fail with 429 612 assert logs[0].endpoint == "/rate-limited-webhook" 613 assert logs[0].status_code == 429 614 assert logs[0].error == "Rate limited" 615 assert logs[0].payload["name"] == registered_model.name 616 assert logs[0].attempt == 1 617 618 # Second attempt should succeed 619 assert logs[1].endpoint == "/rate-limited-webhook" 620 assert logs[1].status_code == 200 621 assert logs[1].error is None 622 assert logs[1].payload["name"] == registered_model.name 623 assert logs[1].attempt == 2 624 625 626 # Prompt Registry Webhook Tests 627 628 629 def test_prompt_created(mlflow_client: MlflowClient, app_client: AppClient) -> None: 630 mlflow_client.create_webhook( 631 name="prompt_created", 632 url=app_client.get_url("/insecure-webhook"), 633 events=[WebhookEvent(WebhookEntity.PROMPT, WebhookAction.CREATED)], 634 ) 635 636 prompt = mlflow_client.create_prompt( 637 name="test_prompt", 638 description="test_prompt_description", 639 tags={"custom_tag": "custom_value"}, 640 ) 641 642 logs = app_client.wait_for_logs(expected_count=1) 643 assert len(logs) == 1 644 assert logs[0].endpoint == "/insecure-webhook" 645 assert logs[0].payload == { 646 "name": prompt.name, 647 "description": prompt.description, 648 "tags": {"custom_tag": "custom_value"}, 649 } 650 651 652 def test_prompt_version_created(mlflow_client: MlflowClient, app_client: AppClient) -> None: 653 mlflow_client.create_webhook( 654 name="prompt_version_created", 655 url=app_client.get_url("/insecure-webhook"), 656 events=[WebhookEvent(WebhookEntity.PROMPT_VERSION, WebhookAction.CREATED)], 657 ) 658 659 prompt = mlflow_client.create_prompt( 660 name="test_prompt_version", 661 description="A test prompt", 662 ) 663 664 mlflow_client.create_prompt_version( 665 name=prompt.name, 666 template="Hello {{name}}! How are you today?", 667 description="test_prompt_version_description", 668 tags={"version_tag": "v1"}, 669 ) 670 671 logs = app_client.wait_for_logs(expected_count=1) 672 assert len(logs) == 1 673 assert logs[0].endpoint == "/insecure-webhook" 674 assert logs[0].payload == { 675 "name": prompt.name, 676 "version": "1", # Version comes as string 677 "template": "Hello {{name}}! How are you today?", 678 "description": "test_prompt_version_description", 679 "tags": { 680 "version_tag": "v1", 681 }, 682 } 683 684 685 def test_prompt_tag_set(mlflow_client: MlflowClient, app_client: AppClient) -> None: 686 mlflow_client.create_webhook( 687 name="prompt_tag_set", 688 url=app_client.get_url("/insecure-webhook"), 689 events=[WebhookEvent(WebhookEntity.PROMPT_TAG, WebhookAction.SET)], 690 ) 691 692 prompt = mlflow_client.create_prompt( 693 name="test_prompt_tag_set", 694 description="A test prompt", 695 ) 696 697 mlflow_client.set_prompt_tag( 698 name=prompt.name, 699 key="environment", 700 value="production", 701 ) 702 703 logs = app_client.wait_for_logs(expected_count=1) 704 assert len(logs) == 1 705 assert logs[0].endpoint == "/insecure-webhook" 706 assert logs[0].payload == { 707 "name": prompt.name, 708 "key": "environment", 709 "value": "production", 710 } 711 712 713 def test_prompt_tag_deleted(mlflow_client: MlflowClient, app_client: AppClient) -> None: 714 mlflow_client.create_webhook( 715 name="prompt_tag_deleted", 716 url=app_client.get_url("/insecure-webhook"), 717 events=[WebhookEvent(WebhookEntity.PROMPT_TAG, WebhookAction.DELETED)], 718 ) 719 720 prompt = mlflow_client.create_prompt( 721 name="test_prompt_tag_deleted", 722 tags={"environment": "staging"}, 723 ) 724 725 mlflow_client.delete_prompt_tag( 726 name=prompt.name, 727 key="environment", 728 ) 729 730 logs = app_client.wait_for_logs(expected_count=1) 731 assert len(logs) == 1 732 assert logs[0].endpoint == "/insecure-webhook" 733 assert logs[0].payload == { 734 "name": prompt.name, 735 "key": "environment", 736 } 737 738 739 def test_prompt_version_tag_set(mlflow_client: MlflowClient, app_client: AppClient) -> None: 740 mlflow_client.create_webhook( 741 name="prompt_version_tag_set", 742 url=app_client.get_url("/insecure-webhook"), 743 events=[WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.SET)], 744 ) 745 746 prompt = mlflow_client.create_prompt(name="test_prompt_version_tag_set") 747 prompt_version = mlflow_client.create_prompt_version( 748 name=prompt.name, 749 template="Hello {{name}}!", 750 ) 751 752 mlflow_client.set_prompt_version_tag( 753 name=prompt.name, 754 version=str(prompt_version.version), 755 key="quality_score", 756 value="excellent", 757 ) 758 759 logs = app_client.wait_for_logs(expected_count=1) 760 assert len(logs) == 1 761 assert logs[0].endpoint == "/insecure-webhook" 762 assert logs[0].payload == { 763 "name": prompt.name, 764 "version": "1", 765 "key": "quality_score", 766 "value": "excellent", 767 } 768 769 770 def test_prompt_version_tag_deleted(mlflow_client: MlflowClient, app_client: AppClient) -> None: 771 mlflow_client.create_webhook( 772 name="prompt_version_tag_deleted", 773 url=app_client.get_url("/insecure-webhook"), 774 events=[WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.DELETED)], 775 ) 776 777 prompt = mlflow_client.create_prompt(name="test_prompt_version_tag_deleted") 778 prompt_version = mlflow_client.create_prompt_version( 779 name=prompt.name, 780 template="Hello {{name}}!", 781 tags={"quality_score": "good"}, 782 ) 783 784 mlflow_client.delete_prompt_version_tag( 785 name=prompt.name, 786 version=str(prompt_version.version), 787 key="quality_score", 788 ) 789 790 logs = app_client.wait_for_logs(expected_count=1) 791 assert len(logs) == 1 792 assert logs[0].endpoint == "/insecure-webhook" 793 assert logs[0].payload == { 794 "name": prompt.name, 795 "version": "1", 796 "key": "quality_score", 797 } 798 799 800 def test_prompt_alias_created(mlflow_client: MlflowClient, app_client: AppClient) -> None: 801 mlflow_client.create_webhook( 802 name="prompt_alias_created", 803 url=app_client.get_url("/insecure-webhook"), 804 events=[WebhookEvent(WebhookEntity.PROMPT_ALIAS, WebhookAction.CREATED)], 805 ) 806 807 prompt = mlflow_client.create_prompt(name="test_prompt_alias_created") 808 prompt_version = mlflow_client.create_prompt_version( 809 name=prompt.name, 810 template="Hello {{name}}!", 811 ) 812 813 mlflow_client.set_prompt_alias( 814 name=prompt.name, 815 version=int(prompt_version.version), 816 alias="production", 817 ) 818 819 logs = app_client.wait_for_logs(expected_count=1) 820 assert len(logs) == 1 821 assert logs[0].endpoint == "/insecure-webhook" 822 assert logs[0].payload == { 823 "name": prompt.name, 824 "alias": "production", 825 "version": "1", 826 } 827 828 829 def test_prompt_alias_deleted(mlflow_client: MlflowClient, app_client: AppClient) -> None: 830 mlflow_client.create_webhook( 831 name="prompt_alias_deleted", 832 url=app_client.get_url("/insecure-webhook"), 833 events=[WebhookEvent(WebhookEntity.PROMPT_ALIAS, WebhookAction.DELETED)], 834 ) 835 836 prompt = mlflow_client.create_prompt(name="test_prompt_alias_deleted") 837 prompt_version = mlflow_client.create_prompt_version( 838 name=prompt.name, 839 template="Hello {{name}}!", 840 ) 841 mlflow_client.set_prompt_alias( 842 name=prompt.name, 843 version=int(prompt_version.version), 844 alias="staging", 845 ) 846 847 mlflow_client.delete_prompt_alias( 848 name=prompt.name, 849 alias="staging", 850 ) 851 852 logs = app_client.wait_for_logs(expected_count=1) 853 assert len(logs) == 1 854 assert logs[0].endpoint == "/insecure-webhook" 855 assert logs[0].payload == { 856 "name": prompt.name, 857 "alias": "staging", 858 } 859 860 861 def test_prompt_webhook_with_mixed_events( 862 mlflow_client: MlflowClient, app_client: AppClient 863 ) -> None: 864 mlflow_client.create_webhook( 865 name="mixed_events_webhook", 866 url=app_client.get_url("/insecure-webhook"), 867 events=[ 868 WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED), 869 WebhookEvent(WebhookEntity.PROMPT, WebhookAction.CREATED), 870 WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED), 871 WebhookEvent(WebhookEntity.PROMPT_VERSION, WebhookAction.CREATED), 872 ], 873 ) 874 875 model = mlflow_client.create_registered_model( 876 name="regular_model", 877 description="Regular model description", 878 ) 879 880 prompt = mlflow_client.create_prompt( 881 name="test_prompt_mixed", 882 description="Prompt description", 883 ) 884 885 mlflow_client.create_model_version( 886 name=model.name, 887 source="s3://bucket/model", 888 run_id="1234567890abcdef", 889 ) 890 891 mlflow_client.create_prompt_version( 892 name=prompt.name, 893 template="Hello {{name}}!", 894 ) 895 896 logs = app_client.wait_for_logs(expected_count=4, timeout=10) 897 assert len(logs) == 4 898 899 # Webhooks are processed asynchronously and may arrive out of order 900 expected_payloads = [ 901 { 902 "name": "regular_model", 903 "description": "Regular model description", 904 "tags": {}, 905 }, 906 { 907 "name": "test_prompt_mixed", 908 "description": "Prompt description", 909 "tags": {}, 910 }, 911 { 912 "name": "regular_model", 913 "source": "s3://bucket/model", 914 "run_id": "1234567890abcdef", 915 "version": "1", 916 "description": None, 917 "tags": {}, 918 }, 919 { 920 "name": "test_prompt_mixed", 921 "template": "Hello {{name}}!", 922 "version": "1", 923 "description": None, 924 "tags": {}, 925 }, 926 ] 927 actual_payloads = [log.payload for log in logs] 928 assert sorted(actual_payloads, key=str) == sorted(expected_payloads, key=str) 929 930 931 def test_prompt_webhook_test_endpoint(mlflow_client: MlflowClient, app_client: AppClient) -> None: 932 webhook = mlflow_client.create_webhook( 933 name="prompt_test_webhook", 934 url=app_client.get_url("/insecure-webhook"), 935 events=[ 936 WebhookEvent(WebhookEntity.PROMPT, WebhookAction.CREATED), 937 WebhookEvent(WebhookEntity.PROMPT_VERSION, WebhookAction.CREATED), 938 WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.SET), 939 ], 940 ) 941 942 result = mlflow_client.test_webhook( 943 webhook.webhook_id, 944 event=WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.SET), 945 ) 946 947 assert result.success is True 948 assert result.response_status == 200 949 assert result.error_message is None 950 951 logs = app_client.wait_for_logs(expected_count=1) 952 assert len(logs) == 1 953 assert logs[0].endpoint == "/insecure-webhook" 954 assert logs[0].payload == { 955 "name": "example_prompt", 956 "version": "1", 957 "key": "example_key", 958 "value": "example_value", 959 }