binary_classification.py
1 import os 2 3 import numpy as np 4 import shap 5 from sklearn.datasets import load_breast_cancer 6 from sklearn.ensemble import RandomForestClassifier 7 8 import mlflow 9 from mlflow.artifacts import download_artifacts 10 from mlflow.tracking import MlflowClient 11 12 # prepare training data 13 X, y = load_breast_cancer(return_X_y=True, as_frame=True) 14 X = X.iloc[:50, :8] 15 y = y.iloc[:50] 16 17 # train a model 18 model = RandomForestClassifier() 19 model.fit(X, y) 20 21 # log an explanation 22 with mlflow.start_run() as run: 23 mlflow.shap.log_explanation(lambda X: model.predict_proba(X)[:, 1], X) 24 25 # list artifacts 26 client = MlflowClient() 27 artifact_path = "model_explanations_shap" 28 artifacts = [x.path for x in client.list_artifacts(run.info.run_id, artifact_path)] 29 print("# artifacts:") 30 print(artifacts) 31 32 # load back the logged explanation 33 dst_path = download_artifacts(run_id=run.info.run_id, artifact_path=artifact_path) 34 base_values = np.load(os.path.join(dst_path, "base_values.npy")) 35 shap_values = np.load(os.path.join(dst_path, "shap_values.npy")) 36 37 # show a force plot 38 shap.force_plot(float(base_values), shap_values[0, :], X.iloc[0, :], matplotlib=True)