/ mlflow / models / evaluation / _shap_patch.py
_shap_patch.py
 1  import pickle
 2  
 3  import shap
 4  from shap._serializable import Deserializer, Serializable, Serializer
 5  
 6  
 7  class _PatchedKernelExplainer(shap.KernelExplainer):
 8      def save(self, out_file, model_saver=None, masker_saver=None):
 9          """
10          This patched `save` method fix `KernelExplainer.save`.
11          Issues in original `KernelExplainer.save`:
12           - It saves model by calling model.save, but shap.utils._legacy.Model has no save method
13           - It tries to save "masker", but there's no "masker" in KernelExplainer
14           - It does not save "KernelExplainer.data" attribute, the attribute is required when
15             loading back
16          Note: `model_saver` and `masker_saver` are meaningless argument for `KernelExplainer.save`,
17          the model in "KernelExplainer" is an instance of `shap.utils._legacy.Model`
18          (it wraps the predict function), we can only use pickle to dump it.
19          and no `masker` for KernelExplainer so `masker_saver` is meaningless.
20          but I preserve the 2 argument for overridden API compatibility.
21          """
22          pickle.dump(type(self), out_file)
23          with Serializer(out_file, "shap.Explainer", version=0) as s:
24              s.save("model", self.model)
25              s.save("link", self.link)
26              s.save("data", self.data)
27  
28      @classmethod
29      def load(cls, in_file, model_loader=None, masker_loader=None, instantiate=True):
30          """
31          This patched `load` method fix `KernelExplainer.load`.
32          Issues in original KernelExplainer.load:
33           - Use mismatched model loader to load model
34           - Try to load non-existent "masker" attribute
35           - Does not load "data" attribute and then cause calling " KernelExplainer"
36             constructor lack of "data" argument.
37          Note: `model_loader` and `masker_loader` are meaningless argument for
38          `KernelExplainer.save`, because the `model` object is saved by pickle dump,
39          we must use pickle load to load it.
40          and no `masker` for KernelExplainer so `masker_loader` is meaningless.
41          but I preserve the 2 argument for overridden API compatibility.
42          """
43          if instantiate:
44              return cls._instantiated_load(in_file, model_loader=None, masker_loader=None)
45  
46          kwargs = Serializable.load(in_file, instantiate=False)
47          with Deserializer(in_file, "shap.Explainer", min_version=0, max_version=0) as s:
48              kwargs["model"] = s.load("model")
49              kwargs["link"] = s.load("link")
50              kwargs["data"] = s.load("data")
51          return kwargs