[Feature] SHAP-based explainer for torch models#3049
[Feature] SHAP-based explainer for torch models#3049daidahao wants to merge 120 commits intounit8co:masterfrom
Conversation
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
This ensure that the last possible index is always explained when `add_encoders` is used. Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
Looking forward to your reviews. Cheers! |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #3049 +/- ##
==========================================
+ Coverage 95.76% 95.87% +0.11%
==========================================
Files 158 159 +1
Lines 17241 17632 +391
==========================================
+ Hits 16510 16904 +394
+ Misses 731 728 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
Thanks a lot for the PR @daidahao 🥳 It will take me some time to review this one, so please bear with us :) From a quick glance and out of curiosity: Is the DeepExplainer not yet implemented? From an initial discovery of @dumjax, he mentioned that DeepExplainer would be much more efficient than the regular model-agnostic explainers. |
No worries. I expect this would to be a lonnnnnng review, given how big the PR is.
Yes, DeepExplainer is not implemented here for the reasons listed above:
Also, DeepExplainer and GradientExplainer expect In terms of efficiency, KernelExplainer is without doubt the slowest among itself, PermutationExplainer (default), DeepExplainer, GradientExplainer.
From my usage, all other threes are a lot faster than Kernel. I have not done direct comparison between Permutation and Deep/Gradient. But I found Permutation runtime to be quite reasonable running the unit tests locally (I could even explain Chronos-2 locally though I have disabled it for tests). That said, I think DeepExplainer and GradientExplainer can be added in a future release if they prove to more efficient than Permutation AND provided that we can address the compatibility issues here (mainly for DeepExplainer). We also need to clarify that some explainers provide no base values while others do and what the limitations might be with those explainers. |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Checklist before merging this PR:
Fixes #871. Fixes #2788. Fixes #2296. Fixes #2571. Fixes #1262. Fixes #2566. Fixes #1332.
Summary
This PR:
TorchExplainerfor explaining torch models with SHAP,SKLearnExplainer,explain_single()method for explaining a single prediction instance in both explainers,(NEW) Torch Explainer
TorchExplaineris introduced forTorchForecastingModelinstances, with a feature set aligned with the SKLearn explainer:explain().explain_single().summary_plot()andforce_plot().It supports target, past covariates, future covariates, and static covariates (including component-specific/global covariates), and returns SHAP values in
SHAPExplainabilityResult/SHAPSingleExplainabilityResultobjects.Motivation
An increasing number of models in Darts are torch-based (recently #3002, #2980, #2944) and users need a consistent way to explain their forecasts.
For scikit-learn models, the existing
ShapExplainer(nowSKLearnExplainer) provides SHAP-based explanations with method selection based on model type.For torch models, we need a new explainer that can handle the different model architectures, while conforming to existing explainability API patterns.
permutationprovides general applicability and faster explanations thankernelorsampling. Users can choose other SHAP methods if desired.PLForecastingModulein a genericnn.Modulethat can be explained by these methods, in addition to the current numpy-based function wrapper.Design
TorchExplainermirrors theSKLearnExplainerAPI for consistency, withexplain(),summary_plot(), andforce_plot()methods.Implementation Details
PLForecastingModulein a numpy-compatible function which:forward()method to get predictions,SKLearnExplainerfor consistency in querying and visualization.API Reference
TorchExplainer. Other SHAP explainers have similar class signatures.PLForecastingModule._get_batch_prediction()is incorporated intoTorchExplainer._func_wrapper()for SHAP, which handles the conversion between flat numpy arrays and the torch tensors expected by the module.create_lagged_component_names()is the ground-truth for feature naming conventions in Darts.Differences to SKLearnExplainer
TorchForecastingModelvsSKLearnModel.kernel,sampling,partition,permutation; sklearn additionally supports tree/linear/additive where applicable).TorchExplainercan explain likelihood parameters of probabilistic forecasts, whileSKLearnExplainercan only explain the median (quantile) or mean (poisson) predictions.TorchExplaineruses batched tensor to prevent OOM errors, whileSKLearnExplaineruses full-size numpy arrays.Methods
explain()for horizon/component-level explanations over forecastable timestamps.explain_single()for one forecast instance (equivalent prediction context topredict(n=output_chunk_length)).summary_plot()shows distributions of feature contributions.force_plot()shows feature contributions for a specific horizon/component.Use Cases
Summary Plot
Feature-importance distribution analysis per horizon/component for torch models.
Force Plot
Local additive contribution view for a selected horizon and target component.
Explaining Multiple Instances
Batch explanations from foreground data with optional sampling controls for performance.
Explaining Single Instance
Per-instance explanation API (
explain_single()) for local interpretability.Explaining Probabilistic Forecasts
Probabilistic torch models are supported by explaining each likelihood parameter component, treating them as separate targets. This is useful for understanding how features contribute to uncertainty estimates.
(CHANGE) SKLearn Explainer
The previous
ShapExplaineris renamed and aligned with the new naming/API style.Renaming
ShapExplainer->SKLearnExplainer.ShapExplainabilityResult->SHAPExplainabilityResult.SHAPSingleExplainabilityResultforexplain_single()outputs.darts.explainabilitynow exposeSKLearnExplainer,TorchExplainer, and SHAP result classes.Bug Fixes
generate_fit_predict_encodings), improving consistency with forecasting behavior.(NEW) Explaining Single Instance
SKLearnExplainer.explain_single()is added, returning SHAP and feature values for a single prediction instance in the same style as the torch explainer.(NEW) Explainability Notebook
Added
examples/28-Explainability-examples.ipynbcovering:summary_plot()and scatter dependence plots for both explainers (same below).explain()andforce_plot()and common SHAP visualizations.explain_single()and corresponding visualizations.TorchExplainerand visualizing component-specific explanations.ShapExplainertoSKLearnExplainer.Notebook is wired into docs examples (
docs/source/examples.rst) and referenced in docs indexing.Miscellaneous
SHAPcapitalization.darts/tests/explainability/test_sklearn_explainer.pydarts/tests/explainability/test_torch_explainer.pyOther Information
ShapExplainershould migrate toSKLearnExplainer.