_shap_patch.py
1 import pickle 2 3 import shap 4 from shap._serializable import Deserializer, Serializable, Serializer 5 6 7 class _PatchedKernelExplainer(shap.KernelExplainer): 8 def save(self, out_file, model_saver=None, masker_saver=None): 9 """ 10 This patched `save` method fix `KernelExplainer.save`. 11 Issues in original `KernelExplainer.save`: 12 - It saves model by calling model.save, but shap.utils._legacy.Model has no save method 13 - It tries to save "masker", but there's no "masker" in KernelExplainer 14 - It does not save "KernelExplainer.data" attribute, the attribute is required when 15 loading back 16 Note: `model_saver` and `masker_saver` are meaningless argument for `KernelExplainer.save`, 17 the model in "KernelExplainer" is an instance of `shap.utils._legacy.Model` 18 (it wraps the predict function), we can only use pickle to dump it. 19 and no `masker` for KernelExplainer so `masker_saver` is meaningless. 20 but I preserve the 2 argument for overridden API compatibility. 21 """ 22 pickle.dump(type(self), out_file) 23 with Serializer(out_file, "shap.Explainer", version=0) as s: 24 s.save("model", self.model) 25 s.save("link", self.link) 26 s.save("data", self.data) 27 28 @classmethod 29 def load(cls, in_file, model_loader=None, masker_loader=None, instantiate=True): 30 """ 31 This patched `load` method fix `KernelExplainer.load`. 32 Issues in original KernelExplainer.load: 33 - Use mismatched model loader to load model 34 - Try to load non-existent "masker" attribute 35 - Does not load "data" attribute and then cause calling " KernelExplainer" 36 constructor lack of "data" argument. 37 Note: `model_loader` and `masker_loader` are meaningless argument for 38 `KernelExplainer.save`, because the `model` object is saved by pickle dump, 39 we must use pickle load to load it. 40 and no `masker` for KernelExplainer so `masker_loader` is meaningless. 41 but I preserve the 2 argument for overridden API compatibility. 42 """ 43 if instantiate: 44 return cls._instantiated_load(in_file, model_loader=None, masker_loader=None) 45 46 kwargs = Serializable.load(in_file, instantiate=False) 47 with Deserializer(in_file, "shap.Explainer", min_version=0, max_version=0) as s: 48 kwargs["model"] = s.load("model") 49 kwargs["link"] = s.load("link") 50 kwargs["data"] = s.load("data") 51 return kwargs