entrypoint.py
1 import argparse 2 import os 3 import sys 4 5 import numpy as np 6 import sklearn 7 from sklearn.pipeline import make_pipeline 8 from sklearn.preprocessing import StandardScaler 9 from sklearn.svm import SVC 10 11 import mlflow 12 13 parser = argparse.ArgumentParser() 14 parser.add_argument( 15 "--test", 16 action="store_true", 17 help="If specified, check this script is running in a virtual environment created by mlflow " 18 "and python and sickit-learn versions are correct.", 19 ) 20 args = parser.parse_args() 21 if args.test: 22 assert "VIRTUAL_ENV" in os.environ 23 assert sys.version_info[:3] == (3, 8, 18), sys.version_info 24 assert sklearn.__version__ == "1.0.2", sklearn.__version__ 25 26 X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) 27 y = np.array([1, 1, 2, 2]) 28 29 clf = make_pipeline(StandardScaler(), SVC(gamma="auto")) 30 clf.fit(X, y) 31 32 with mlflow.start_run(): 33 mlflow.sklearn.log_model(clf, name="model")