/ tests / webhooks / test_delivery.py
test_delivery.py
  1  from pathlib import Path
  2  from unittest.mock import patch
  3  
  4  import pytest
  5  
  6  from mlflow.entities.webhook import Webhook, WebhookAction, WebhookEntity, WebhookEvent
  7  from mlflow.store.model_registry.file_store import FileStore
  8  from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore
  9  from mlflow.webhooks.delivery import deliver_webhook
 10  from mlflow.webhooks.delivery import test_webhook as send_test_webhook
 11  
 12  
 13  @pytest.fixture
 14  def file_store(tmp_path: Path) -> FileStore:
 15      pytest.skip("FileStore is no longer supported.")
 16      return FileStore(str(tmp_path))
 17  
 18  
 19  @pytest.fixture
 20  def sql_store(tmp_path: Path) -> SqlAlchemyStore:
 21      db_file = tmp_path / "test.db"
 22      db_uri = f"sqlite:///{db_file}"
 23      return SqlAlchemyStore(db_uri)
 24  
 25  
 26  @pytest.fixture
 27  def webhook_event() -> WebhookEvent:
 28      return WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)
 29  
 30  
 31  @pytest.fixture
 32  def webhook_payload() -> dict[str, str]:
 33      return {"name": "test_model", "description": "Test model"}
 34  
 35  
 36  def test_deliver_webhook_exits_early_for_file_store(
 37      file_store: FileStore, webhook_event: WebhookEvent, webhook_payload: dict[str, str]
 38  ) -> None:
 39      pytest.skip("FileStore is no longer supported.")
 40      with patch("mlflow.webhooks.delivery._deliver_webhook_impl") as mock_impl:
 41          deliver_webhook(
 42              event=webhook_event,
 43              payload=webhook_payload,
 44              store=file_store,
 45          )
 46  
 47          # _deliver_webhook_impl should not be called for FileStore
 48          mock_impl.assert_not_called()
 49  
 50  
 51  def test_deliver_webhook_calls_impl_for_sql_store(
 52      sql_store: SqlAlchemyStore, webhook_event: WebhookEvent, webhook_payload: dict[str, str]
 53  ) -> None:
 54      with patch("mlflow.webhooks.delivery._deliver_webhook_impl") as mock_impl:
 55          deliver_webhook(
 56              event=webhook_event,
 57              payload=webhook_payload,
 58              store=sql_store,
 59          )
 60  
 61          # _deliver_webhook_impl should be called for SqlAlchemyStore
 62          mock_impl.assert_called_once_with(
 63              event=webhook_event,
 64              payload=webhook_payload,
 65              store=sql_store,
 66          )
 67  
 68  
 69  def test_deliver_webhook_handles_exception_for_sql_store(
 70      sql_store: SqlAlchemyStore, webhook_event: WebhookEvent, webhook_payload: dict[str, str]
 71  ) -> None:
 72      with (
 73          patch("mlflow.webhooks.delivery._deliver_webhook_impl", side_effect=Exception("Test")),
 74          patch("mlflow.webhooks.delivery._logger") as mock_logger,
 75      ):
 76          # This should not raise an exception
 77          deliver_webhook(
 78              event=webhook_event,
 79              payload=webhook_payload,
 80              store=sql_store,
 81          )
 82  
 83          # Verify that the error was logged
 84          mock_logger.error.assert_called_once()
 85          assert "Failed to deliver webhook for event" in str(mock_logger.error.call_args)
 86  
 87  
 88  def test_deliver_webhook_no_exception_for_file_store(
 89      file_store: FileStore, webhook_event: WebhookEvent, webhook_payload: dict[str, str]
 90  ) -> None:
 91      pytest.skip("FileStore is no longer supported.")
 92      with (
 93          patch(
 94              "mlflow.webhooks.delivery._deliver_webhook_impl", side_effect=Exception("Test")
 95          ) as mock_impl,
 96          patch("mlflow.webhooks.delivery._logger") as mock_logger,
 97      ):
 98          # This should not raise an exception and should return early
 99          deliver_webhook(
100              event=webhook_event,
101              payload=webhook_payload,
102              store=file_store,
103          )
104  
105          # _deliver_webhook_impl should not be called, so no error should be logged
106          mock_impl.assert_not_called()
107          mock_logger.error.assert_not_called()
108  
109  
110  def test_test_webhook_rejects_private_ip():
111      event = WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)
112      webhook = Webhook(
113          webhook_id="wh-1",
114          name="test",
115          url="https://localhost/hook",
116          events=[event],
117          creation_timestamp=0,
118          last_updated_timestamp=0,
119      )
120  
121      result = send_test_webhook(webhook)
122  
123      assert result.success is False
124      assert "must not resolve to a non-public" in result.error_message