Skip to content

[Feature] SHAP-based explainer for torch models#3049

Open
daidahao wants to merge 120 commits intounit8co:masterfrom
daidahao:feature/shap-torch
Open

[Feature] SHAP-based explainer for torch models#3049
daidahao wants to merge 120 commits intounit8co:masterfrom
daidahao:feature/shap-torch

Conversation

@daidahao
Copy link
Copy Markdown
Contributor

@daidahao daidahao commented Mar 27, 2026

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #871. Fixes #2788. Fixes #2296. Fixes #2571. Fixes #1262. Fixes #2566. Fixes #1332.

Summary

This PR:

  • adds TorchExplainer for explaining torch models with SHAP,
  • renames the existing SKLearn explainer to SKLearnExplainer,
  • adds a new explain_single() method for explaining a single prediction instance in both explainers,
  • adds a new explainability notebook with examples for both explainers,
  • includes various bug fixes and improvements to the explainability module,
  • misc updates to docs and tests.

(NEW) Torch Explainer

TorchExplainer is introduced for TorchForecastingModel instances, with a feature set aligned with the SKLearn explainer:

  • Batched explanations with explain().
  • Single-instance explanations with explain_single().
  • Visualization helpers with summary_plot() and force_plot().

It supports target, past covariates, future covariates, and static covariates (including component-specific/global covariates), and returns SHAP values in SHAPExplainabilityResult / SHAPSingleExplainabilityResult objects.

Motivation

An increasing number of models in Darts are torch-based (recently #3002, #2980, #2944) and users need a consistent way to explain their forecasts.

For scikit-learn models, the existing ShapExplainer (now SKLearnExplainer) provides SHAP-based explanations with method selection based on model type.
For torch models, we need a new explainer that can handle the different model architectures, while conforming to existing explainability API patterns.

  • Why SHAP? SHAP gives additive, model-agnostic feature attributions that are consistent across explainers.
  • Why Permutation Explainer? For torch models, defaulting to permutation provides general applicability and faster explanations than kernel or sampling. Users can choose other SHAP methods if desired.
  • Why not DeepExplainer or GradientExplainer? Both are designed for deep learning models and are faster than KernelSHAP. However, they have limitations (from my experiments):
    • DeepExplainer is incompatible with many torch models due to reused layers.
    • Both do not output base values, which are needed for consistent SHAP result objects and visualizations (e.g., waterfall, force plots).
  • Why not captum? Meta's PyTorch native library supports various attribution methods (Integrated Gradients, DeepLIFT, etc.) and is efficient for torch models. However, as of now, it does not support multi-target explanations. Forecasts in Darts are multi-target in nature (multiple horizons x components x likelihood parameters), so using captum would incur for-loop overhead.
  • Future: We can consider supporting DeepExplainer/GradientExplainer as additional SHAP methods in the future if they yield better efficiency for some torch models. This would require wrapping PLForecastingModule in a generic nn.Module that can be explained by these methods, in addition to the current numpy-based function wrapper.

Design

  • TorchExplainer mirrors the SKLearnExplainer API for consistency, with explain(), summary_plot(), and force_plot() methods.
  • It builds SHAP inputs from torch inference datasets to stay consistent with Darts prediction semantics.
  • It handles deterministic and probabilistic models (for probabilistic models, explanations are produced for likelihood parameter components).

Implementation Details

  • Internally, it flattens the SHAP inputs into a 2D numpy array expected by SHAP, while keeping track of the original feature structure (horizon/component/likelihood parameter) for constructing SHAP results.
  • It wraps PLForecastingModule in a numpy-compatible function which:
    • recovers the spaghetti inputs (targets, past/future/static covariates) from the flattened 2D numpy array,
    • calls the module's forward() method to get predictions,
    • returns predictions also in a flattened 2D format, which is then passed to SHAP explainer.
  • It constructs SHAP result objects with the same structure as SKLearnExplainer for consistency in querying and visualization.

API Reference

Differences to SKLearnExplainer

  • Scope: TorchForecastingModel vs SKLearnModel.
  • Supported SHAP methods differ (torch: kernel, sampling, partition, permutation; sklearn additionally supports tree/linear/additive where applicable).
  • TorchExplainer can explain likelihood parameters of probabilistic forecasts, while SKLearnExplainer can only explain the median (quantile) or mean (poisson) predictions.
  • TorchExplainer uses batched tensor to prevent OOM errors, while SKLearnExplainer uses full-size numpy arrays.

Methods

  • explain() for horizon/component-level explanations over forecastable timestamps.
  • explain_single() for one forecast instance (equivalent prediction context to predict(n=output_chunk_length)).
  • summary_plot() shows distributions of feature contributions.
  • force_plot() shows feature contributions for a specific horizon/component.

Use Cases

Summary Plot

Feature-importance distribution analysis per horizon/component for torch models.

import shap
shap.initjs()

from darts.datasets import WineDataset
from darts.explainability import TorchExplainer
from darts.models import TiDEModel

series = WineDataset().load().astype("float32")
model = TiDEModel(12, 12).fit(series[:36])
explainer = TorchExplainer(model)
explainer.summary_plot(horizons=[1])
summary

Force Plot

Local additive contribution view for a selected horizon and target component.

explainer.force_plot(horizon=1)
Screenshot 2026-03-27 at 10 41 52

Explaining Multiple Instances

Batch explanations from foreground data with optional sampling controls for performance.

result = explainer.explain(series[:36])
# return a `TimeSeries` of SHAP values where time index
# corresponds to the instance timestamps
result.get_explanation(horizon=1)
# return the raw SHAP explanation object for custom visualizations
shap_object = result.get_shap_explanation_object(horizon=1)
# plot waterfall for the first forecast instance
shap.plots.waterfall(shap_object[0])
waterfall

Explaining Single Instance

Per-instance explanation API (explain_single()) for local interpretability.

single_result = explainer.explain_single(series[:36])
# return a `TimeSeries` of SHAP values where time index corresponds to the **prediction** timestamp
single_result.get_explanation()
# return the raw SHAP explanation object for custom visualizations
single_shap_object = single_result.get_shap_explanation_object()
# plot heatmap for the single instance explanation along the horizon
shap.plots.heatmap(single_shap_object, instance_order=np.arange(12))
heatmap

Explaining Probabilistic Forecasts

Probabilistic torch models are supported by explaining each likelihood parameter component, treating them as separate targets. This is useful for understanding how features contribute to uncertainty estimates.

from darts.utils.likelihood_models import QuantileRegression
# fit a probabilistic model with quantile regression likelihood
prob_model = TiDEModel(12, 12, likelihood=QuantileRegression(quantiles=[0.1, 0.5, 0.9]))
prob_model.fit(series[:36])
# create an explainer for the probabilistic model
prob_explainer = TorchExplainer(prob_model)
# explain the probabilistic forecasts
# this will produce explanations for each likelihood parameter component
# (e.g., Y_q0.100, Y_q0.500, Y_q0.900)
prob_result = prob_explainer.explain(series[:36])
# get SHAP values as a `TimeSeries` for the 0.1 quantile at horizon 1
prob_result.get_explanation(horizon=1, component="Y_q0.100")
            Y_target_lag-12  Y_target_lag-11  Y_target_lag-10  Y_target_lag-9  Y_target_lag-8  ...  Y_target_lag-5  Y_target_lag-4  Y_target_lag-3  Y_target_lag-2  Y_target_lag-1
1981-01-01     -3697.863974      -252.308866       -41.762030        0.572893    -1353.563396  ...      -91.867447     -128.090894      -39.738832      208.761212      -53.789530
1981-02-01     -2648.507187       -80.287658       -53.070808       45.709788     -821.725195  ...       30.775013      -14.196725     -861.172957      392.385305      139.613268
1981-03-01      -477.149089      -195.594982       -51.709828       14.808723      553.345521  ...      -35.324536      273.863595    -1509.285057     -393.562763       31.391378
1981-04-01     -1998.012530      -171.417969       -66.743827      -61.221131      904.461867  ...     -232.607326      590.834413     1727.039878     -267.061452        7.137201
1981-05-01     -1777.624966      -124.021548        -9.517231       -8.676134     -132.987919  ...     -384.271071     -501.348593      927.618172     -129.376857       22.660183
...                     ...              ...              ...             ...             ...  ...             ...             ...             ...             ...             ...
1982-09-01       169.768241        50.163736       159.011448       63.019742    -1706.948496  ...      -76.503501      -20.519832       74.718658      225.857574      -72.105523
1982-10-01      1054.076957       454.064094       260.788809       60.042956    -1343.217095  ...        6.109713      -43.923037     -872.911818      249.960171      -23.256607
1982-11-01      4563.048261       555.714791       -94.331751       99.717854     -361.191972  ...       15.797448      297.566318    -1125.641917       62.627629      -13.438255
1982-12-01      6351.228711      -202.557555       -23.213595      -24.084026      838.380352  ...     -266.064357      417.857843     -372.773744        8.018883      -16.682221
1983-01-01     -2495.383209      -168.530121       -40.911461      -65.382957      382.564556  ...     -297.340957      112.898120     -220.364707      327.032420      -58.788773

shape: (25, 12, 1), freq: MS, size: 2.34 KB

(CHANGE) SKLearn Explainer

The previous ShapExplainer is renamed and aligned with the new naming/API style.

Renaming

  • ShapExplainer -> SKLearnExplainer.
  • ShapExplainabilityResult -> SHAPExplainabilityResult.
  • New SHAPSingleExplainabilityResult for explain_single() outputs.
  • Public imports in darts.explainability now expose SKLearnExplainer, TorchExplainer, and SHAP result classes.

Bug Fixes

  • Improved input processing for explainers by using prediction-aware encoder generation for foreground data (generate_fit_predict_encodings), improving consistency with forecasting behavior.
  • Better validation and clearer errors in explainability result querying (component/horizon checks).
  • Improved stationarity warnings to indicate the specific component and series index.

(NEW) Explaining Single Instance

SKLearnExplainer.explain_single() is added, returning SHAP and feature values for a single prediction instance in the same style as the torch explainer.

(NEW) Explainability Notebook

Added examples/28-Explainability-examples.ipynb covering:

  • Introduction to SHAP and explainability in Darts.
  • Data and model setup for both sklearn and torch examples.
  • Global explanations with summary_plot() and scatter dependence plots for both explainers (same below).
  • Local batched explanations with explain() and force_plot() and common SHAP visualizations.
  • Local single-instance explanations with explain_single() and corresponding visualizations.
  • Explaining probabilistic forecasts with TorchExplainer and visualizing component-specific explanations.
  • Migration note from ShapExplainer to SKLearnExplainer.
  • Conclusion and references.

Notebook is wired into docs examples (docs/source/examples.rst) and referenced in docs indexing.

Miscellaneous

  • Reworked explainability module exports and docs text to consistently use SHAP capitalization.
  • Added/expanded tests for both explainers:
    • darts/tests/explainability/test_sklearn_explainer.py
    • darts/tests/explainability/test_torch_explainer.py
  • Added torch-side robustness fixes around dataset indexing and future-covariate length handling while creating SHAP arrays.

Other Information

  • This PR includes API renaming in explainability. Existing code using ShapExplainer should migrate to SKLearnExplainer.

daidahao added 30 commits March 27, 2026 10:26
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
This ensure that the last possible index is always explained when
`add_encoders` is used.

Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
daidahao added 17 commits March 27, 2026 10:26
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@daidahao daidahao requested a review from dennisbader as a code owner March 27, 2026 10:28
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@daidahao
Copy link
Copy Markdown
Contributor Author

@dennisbader @dumjax

Looking forward to your reviews. Cheers!

Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 27, 2026

Codecov Report

❌ Patch coverage is 99.18963% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.87%. Comparing base (5d9404e) to head (c148909).

Files with missing lines Patch % Lines
darts/explainability/explainability_result.py 93.33% 2 Missing ⚠️
darts/models/forecasting/rnn_model.py 83.33% 2 Missing ⚠️
darts/explainability/__init__.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #3049      +/-   ##
==========================================
+ Coverage   95.76%   95.87%   +0.11%     
==========================================
  Files         158      159       +1     
  Lines       17241    17632     +391     
==========================================
+ Hits        16510    16904     +394     
+ Misses        731      728       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@dennisbader
Copy link
Copy Markdown
Collaborator

Thanks a lot for the PR @daidahao 🥳 It will take me some time to review this one, so please bear with us :)

From a quick glance and out of curiosity: Is the DeepExplainer not yet implemented? From an initial discovery of @dumjax, he mentioned that DeepExplainer would be much more efficient than the regular model-agnostic explainers.

@daidahao
Copy link
Copy Markdown
Contributor Author

Thanks a lot for the PR @daidahao 🥳 It will take me some time to review this one, so please bear with us :)

No worries. I expect this would to be a lonnnnnng review, given how big the PR is.

From a quick glance and out of curiosity: Is the DeepExplainer not yet implemented? From an initial discovery of @dumjax, he mentioned that DeepExplainer would be much more efficient than the regular model-agnostic explainers.

Yes, DeepExplainer is not implemented here for the reasons listed above:

  • It does not allow reused layers. Even trivial things like calling a relu twice inside a network break the explainer.
    Deep Explainer fails on Resnet50 pretrained model. shap/shap#1479
    You could actually see I experimented DeepExplainer in previous commits but removed it because so many models in Darts ran into issues.
  • Both DeepExplainer and GradientExplainer do not output base values, even though DeepExplainer does have additivity guarantee internally (see SHAP code deep_pytorch.py). Many SHAP plots like waterfall would require base values.

Also, DeepExplainer and GradientExplainer expect nn.Module rather than an agnostic numpy function like in other explainers, so we need a ModuleWrapper in addition to FuncWrapper that we have now. If you are interested, I have a local branch on GradientExplainer with a ModuleWrapper for prototyping, but the correctness is not guaranteed.

In terms of efficiency, KernelExplainer is without doubt the slowest among itself, PermutationExplainer (default), DeepExplainer, GradientExplainer.

*There is some confusion online about SHAP=KernelSHAP which influenced my initial commits, but in fact permutation is more efficient and common these days, hence the switch to permutation here.

From my usage, all other threes are a lot faster than Kernel. I have not done direct comparison between Permutation and Deep/Gradient. But I found Permutation runtime to be quite reasonable running the unit tests locally (I could even explain Chronos-2 locally though I have disabled it for tests).

That said, I think DeepExplainer and GradientExplainer can be added in a future release if they prove to more efficient than Permutation AND provided that we can address the compatibility issues here (mainly for DeepExplainer). We also need to clarify that some explainers provide no base values while others do and what the limitations might be with those explainers.

Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants