/ examples / shap / binary_classification.py
binary_classification.py
 1  import os
 2  
 3  import numpy as np
 4  import shap
 5  from sklearn.datasets import load_breast_cancer
 6  from sklearn.ensemble import RandomForestClassifier
 7  
 8  import mlflow
 9  from mlflow.artifacts import download_artifacts
10  from mlflow.tracking import MlflowClient
11  
12  # prepare training data
13  X, y = load_breast_cancer(return_X_y=True, as_frame=True)
14  X = X.iloc[:50, :8]
15  y = y.iloc[:50]
16  
17  # train a model
18  model = RandomForestClassifier()
19  model.fit(X, y)
20  
21  # log an explanation
22  with mlflow.start_run() as run:
23      mlflow.shap.log_explanation(lambda X: model.predict_proba(X)[:, 1], X)
24  
25  # list artifacts
26  client = MlflowClient()
27  artifact_path = "model_explanations_shap"
28  artifacts = [x.path for x in client.list_artifacts(run.info.run_id, artifact_path)]
29  print("# artifacts:")
30  print(artifacts)
31  
32  # load back the logged explanation
33  dst_path = download_artifacts(run_id=run.info.run_id, artifact_path=artifact_path)
34  base_values = np.load(os.path.join(dst_path, "base_values.npy"))
35  shap_values = np.load(os.path.join(dst_path, "shap_values.npy"))
36  
37  # show a force plot
38  shap.force_plot(float(base_values), shap_values[0, :], X.iloc[0, :], matplotlib=True)