-
Notifications
You must be signed in to change notification settings - Fork 373
Description
Description
When plotting a confusion matrix using plot_confusion_matrix on highly imbalanced datasets, the color scale is not adapted to the distribution of values.
If one class dominates (e.g., 100Γ more samples than others), the color scale becomes saturated by the largest cell values, making smaller but important misclassifications almost invisible.
This leads to poor readability and weak interpretability for real-world imbalanced classification problems.
π Minimal Reproducible Example
from shapash.plots.plot_evaluation_metrics import plot_confusion_matrix
# Imbalanced dataset
y_true = [1, 2, 1, 3, 2, 2, 3, 2] + [3] * 100
y_pred = [1, 2, 1, 1, 2, 2, 3, 1] + [3] * 100
# Format labels as categorical strings
y_true_labels = [f"c{int(el):02d}" for el in y_true]
y_pred_labels = [f"c{int(el):02d}" for el in y_pred]
plot_confusion_matrix(
y_true=y_true_labels,
y_pred=y_pred_labels,
width=500,
height=400
)β Current Behavior
- The color scale is computed using the raw minβmax values.
- The dominant class (here
c03) drives the color intensity. - Minority-class confusion cells appear almost white.
- Small but relevant misclassifications are visually drowned.
β Expected Behavior
The color scale should adapt to the value distribution.
A quantile-based scaling (as implemented in contribution plots in Shapash) would:
- Improve readability.
- Enhance contrast for minority classes.
- Prevent dominant cells from saturating the scale.
- Preserve interpretability in imbalanced settings.
The colors should lokk like this even if we have 100 individuals in the last cell.

π‘ Suggested Improvement
Reuse the quantile-based color scaling strategy currently used in Shapash contribution plots:
- Compute upper bounds based on a high percentile (e.g., 95th or 99th).
- Clamp extreme values to avoid scale domination.
- Maintain consistency with other Shapash visualizations.
This would ensure:
- Visual coherence across Shapash plots.
- Robust behavior for real-world imbalanced datasets.