conftest.py
1 import logging 2 import os 3 import subprocess 4 from functools import lru_cache 5 6 import docker 7 import pytest 8 import requests 9 from packaging.version import Version 10 11 import mlflow 12 13 TEST_IMAGE_NAME = "test_image" 14 MLFLOW_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) 15 RESOURCE_DIR = os.path.join(MLFLOW_ROOT, "tests", "resources", "dockerfile") 16 17 docker_client = docker.from_env() 18 19 _logger = logging.getLogger(__name__) 20 21 22 @pytest.fixture(autouse=True) 23 def clean_up_docker(): 24 yield 25 26 # Get all containers using the test image 27 containers = docker_client.containers.list(filters={"ancestor": TEST_IMAGE_NAME}) 28 for container in containers: 29 container.remove(force=True) 30 31 # Clean up the image 32 try: 33 docker_client.images.remove(TEST_IMAGE_NAME, force=True) 34 except docker.errors.ImageNotFound: 35 pass 36 37 # Clean up the build cache and volumes 38 try: 39 subprocess.check_call(["docker", "builder", "prune", "-a", "-f"]) 40 except subprocess.CalledProcessError as e: 41 _logger.warning("Failed to clean up docker system: %s", e) 42 43 44 @lru_cache(maxsize=1) 45 def get_released_mlflow_version(): 46 url = "https://pypi.org/pypi/mlflow/json" 47 response = requests.get(url) 48 response.raise_for_status() 49 data = response.json() 50 versions = [ 51 v for v in map(Version, data["releases"]) if not (v.is_devrelease or v.is_prerelease) 52 ] 53 return str(max(versions)) 54 55 56 def save_model_with_latest_mlflow_version(flavor, extra_pip_requirements=None, **kwargs): 57 """ 58 Save a model with overriding MLflow version from dev version to the latest released version. 59 By default a model is saved with the dev version of MLflow, which is not available on PyPI. 60 Usually we can be workaround this by adding --serve-wheel flag that starts local PyPI server, 61 however, this doesn't work when installing dependencies inside Docker container. Hence, this 62 function uses `extra_pip_requirements` to save the model with the latest released MLflow. 63 """ 64 latest_mlflow_version = get_released_mlflow_version() 65 if flavor == "langchain": 66 kwargs["pip_requirements"] = [ 67 f"mlflow[gateway]=={latest_mlflow_version}", 68 "langchain<1.1.0", 69 ] 70 else: 71 extra_pip_requirements = extra_pip_requirements or [] 72 extra_pip_requirements.append(f"mlflow=={latest_mlflow_version}") 73 if flavor == "lightgbm": 74 # Adding pyarrow < 18 to prevent pip installation resolution conflicts. 75 extra_pip_requirements.append("pyarrow<18") 76 kwargs["extra_pip_requirements"] = extra_pip_requirements 77 flavor_module = getattr(mlflow, flavor) 78 flavor_module.save_model(**kwargs)