Skip to content
8 changes: 2 additions & 6 deletions src/autointent/context/data_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from ._data_handler import DataHandler
from ._stratification import (
SplitReadinessResult,
StratifiedSplitter,
check_split_readiness,
split_dataset,
)
from ._readiness_util import SplitReadinessResult, check_split_readiness
from ._stratification import StratifiedSplitter, split_dataset

__all__ = [
"DataHandler",
Expand Down
195 changes: 195 additions & 0 deletions src/autointent/context/data_handler/_readiness_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

from collections import Counter
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple

import numpy as np

if TYPE_CHECKING:
from datasets import Dataset as HFDataset

from autointent import Dataset
from autointent.configs import DataConfig

from ._safe_multilabel_stratification import _validate_multilabel_matrix
from ._stratification import StratifiedSplitter


class ClassCount(NamedTuple):
id: int
"""Class (intent) index."""

n_samples: int
"""Number of samples from the class (intent)."""


@dataclass(frozen=True)
class SplitReadinessResult:
"""Result of checking whether a dataset can be fed to autointent pipeline.

Attributes:
ready: True if stratification can be performed (enough samples per class).
underpopulated_classes: List of (label, n_samples) for classes below the minimum.
min_samples_per_class_required: Minimum samples per class used for the check.
reason: Human-readable reason when not ready (e.g. OOS not configured).
"""

ready: bool
underpopulated_classes: list[ClassCount]
min_samples_per_class_required: int
reason: str | None


def check_split_readiness(
dataset: Dataset,
split: str,
config: DataConfig,
allow_oos_in_train: bool | None = None,
) -> SplitReadinessResult:
"""Check whether the dataset has enough samples per class for autointent pipeline.

Args:
dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`).
split: The split name to check (e.g. ``Split.TRAIN``).
config: data config
allow_oos_in_train: Same as in :func:`split_dataset`. If the split contains OOS samples
and this is ``None``, this function raises ``ValueError`` (mirrors splitting behavior).
"""
min_samples_per_class = _min_samples_per_class_for_config(config=config)
if split not in dataset:
return SplitReadinessResult(
ready=False,
underpopulated_classes=[],
min_samples_per_class_required=min_samples_per_class,
reason=f"Dataset has no split '{split}'.",
)
hf_split = dataset[split]
splitter = StratifiedSplitter(
test_size=config.validation_size,
label_feature=dataset.label_feature,
random_seed=None,
)
inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train)
expected_n_classes = _expected_n_classes(dataset, inputs.dataset, splitter.label_feature)

if inputs.multilabel:
underpopulated = _find_underpopulated_multilabel(inputs.dataset, splitter.label_feature, min_samples_per_class)
else:
underpopulated = _find_underpopulated_multiclass(
inputs.dataset,
splitter.label_feature,
min_samples_per_class,
expected_n_classes=expected_n_classes,
)
ready = len(underpopulated) == 0
reason: str | None = None

if ready and (not inputs.multilabel):
split_ok, split_reason = _check_multiclass_split_size_feasibility(
dataset=inputs.dataset,
label_feature=splitter.label_feature,
test_size=inputs.test_size,
expected_n_classes=expected_n_classes,
)
if not split_ok:
ready = False
reason = split_reason

if not ready and reason is None:
parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated]
reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format(
min_samples_per_class, "; ".join(parts)
)
return SplitReadinessResult(
ready=ready,
underpopulated_classes=underpopulated,
min_samples_per_class_required=min_samples_per_class,
reason=reason,
)


def _min_samples_per_class_for_config(config: DataConfig) -> int:
"""Return a recommended minimum samples-per-class for a given data config."""
# Base requirement for a single stratified split.
# For CV, the canonical lower bound is one example per fold.
base = 2 if config.scheme == "ho" else int(config.n_folds)

# separation_ratio triggers an extra stratified split of the effective train
# pool (e.g. decision vs scoring), so we double the requirement.
factor = 1 if config.separation_ratio is None else 2
return base * factor


def _find_underpopulated_multiclass(
dataset: HFDataset, label_feature: str, min_samples_per_class: int, expected_n_classes: int
) -> list[ClassCount]:
"""Return (label, count) for each class with fewer than min_samples_per_class samples."""
labels: list[int] = dataset[label_feature]
counts = Counter(labels)

# Ensure "missing" classes are treated as 0-count (underpopulated)
result: list[ClassCount] = []
for label in range(int(expected_n_classes)):
n_samples = int(counts.get(label, 0))
if n_samples < min_samples_per_class:
result.append(ClassCount(id=int(label), n_samples=n_samples))
return result


def _find_underpopulated_multilabel(
dataset: HFDataset, label_feature: str, min_samples_per_class: int
) -> list[ClassCount]:
"""Return (label_idx, positive_count) for each label with fewer than min_samples_per_class positives."""
y = np.asarray(dataset[label_feature])
_validate_multilabel_matrix(y)
counts = y.sum(axis=0).astype(int)
return [
ClassCount(id=int(idx), n_samples=int(n_samples))
for idx, n_samples in enumerate(counts)
if n_samples < min_samples_per_class
]


def _check_multiclass_split_size_feasibility(
dataset: HFDataset, label_feature: str, test_size: float, expected_n_classes: int
) -> tuple[bool, str | None]:
"""Return whether stratified train/test sizes are feasible for multiclass splits.

Even if each class has >=2 samples, sklearn stratified splitting can fail when
the requested train/test sizes are too small to include all classes.
"""
labels = dataset[label_feature]
n_classes = expected_n_classes
n_samples = len(labels)

# Mirror sklearn's float test_size -> n_test calculation (ceil).
n_test = int(np.ceil(float(test_size) * n_samples))
n_train = n_samples - n_test

if n_test <= 0 or n_train <= 0:
return (
False,
f"Requested split sizes are invalid (n_samples={n_samples}, test_size={test_size}).",
)
if n_test < n_classes:
return (
False,
f"Stratified split would allocate too few test samples (n_test={n_test}) "
f"for the number of classes (n_classes={n_classes}).",
)
if n_train < n_classes:
return (
False,
f"Stratified split would allocate too few train samples (n_train={n_train}) "
f"for the number of classes (n_classes={n_classes}).",
)
return True, None


def _expected_n_classes(dataset: Dataset, prepared: HFDataset, label_feature: str) -> int:
if dataset.multilabel:
return len(prepared[label_feature][0])
labels: list[int] = prepared[label_feature]
max_seen = max(labels) if labels else -1
return max(dataset.n_classes, int(max_seen) + 1)
146 changes: 146 additions & 0 deletions src/autointent/context/data_handler/_safe_multilabel_stratification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
from skmultilearn.model_selection import IterativeStratification
from transformers import set_seed

if TYPE_CHECKING:
import numpy.typing as npt

_MULTILABEL_NDIMS = 2
_RARE_LABEL_COUNT_SINGLETON = 1
_RARE_LABEL_COUNT_PAIR = 2
_COIN_FLIP_P = 0.5


def safe_multilabel_split_indices(
y: npt.NDArray[Any], test_size: float, random_seed: int | None
) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
"""Split multilabel data with coverage guarantees for rare labels."""
_validate_multilabel_matrix(y)
n_samples = int(y.shape[0])
rng = np.random.default_rng(random_seed)

train_idx: set[int] = set()
test_idx: set[int] = set()
label_counts = y.sum(axis=0).astype(int)

_force_singleton_labels(y=y, label_counts=label_counts, train_idx=train_idx)
_force_pair_labels(y=y, label_counts=label_counts, train_idx=train_idx, test_idx=test_idx, rng=rng)

forced = train_idx | test_idx
remaining = np.array(sorted(set(range(n_samples)) - forced), dtype=int)
_iterative_stratify_remaining(
y=y,
remaining=remaining,
test_size=test_size,
random_seed=random_seed,
train_idx=train_idx,
test_idx=test_idx,
)
return _finalize_partition(n_samples=n_samples, train_idx=train_idx, test_idx=test_idx)


def _validate_multilabel_matrix(y: npt.NDArray[Any]) -> None:
if y.ndim != _MULTILABEL_NDIMS:
msg = (
"Expected multilabel data to be a 2D matrix-like structure "
f"(n_samples, n_labels), got shape={getattr(y, 'shape', None)!r}."
)
raise ValueError(msg)


def _assigned_split(sample_idx: int, train_idx: set[int], test_idx: set[int]) -> str | None:
if sample_idx in train_idx:
return "train"
if sample_idx in test_idx:
return "test"
return None


def _force_singleton_labels(y: npt.NDArray[Any], label_counts: npt.NDArray[Any], train_idx: set[int]) -> None:
for label, count in enumerate(label_counts):
if int(count) != _RARE_LABEL_COUNT_SINGLETON:
continue
sample = int(np.flatnonzero(y[:, label])[0])
train_idx.add(sample)


def _force_pair_samples(a: int, b: int, train_idx: set[int], test_idx: set[int], rng: np.random.Generator) -> None:
a_split = _assigned_split(a, train_idx, test_idx)
b_split = _assigned_split(b, train_idx, test_idx)

if a_split is not None and b_split is None:
(test_idx if a_split == "train" else train_idx).add(b)
return
if b_split is not None and a_split is None:
(test_idx if b_split == "train" else train_idx).add(a)
return
if a_split is None and b_split is None:
if rng.random() < _COIN_FLIP_P:
train_idx.add(a)
test_idx.add(b)
else:
train_idx.add(b)
test_idx.add(a)


def _force_pair_labels(
y: npt.NDArray[Any],
label_counts: npt.NDArray[Any],
train_idx: set[int],
test_idx: set[int],
rng: np.random.Generator,
) -> None:
for label, count in enumerate(label_counts):
if int(count) != _RARE_LABEL_COUNT_PAIR:
continue
samples = np.flatnonzero(y[:, label]).astype(int)
a, b = sorted(samples.tolist(), key=lambda i: int(y[i].sum()))
_force_pair_samples(a=a, b=b, train_idx=train_idx, test_idx=test_idx, rng=rng)


def _iterative_stratify_remaining(
y: npt.NDArray[Any],
remaining: npt.NDArray[Any],
test_size: float,
random_seed: int | None,
train_idx: set[int],
test_idx: set[int],
) -> None:
if len(remaining) == 0:
return
if random_seed is not None:
# Workaround for buggy nature of IterativeStratification from skmultilearn
set_seed(random_seed)
splitter = IterativeStratification(
n_splits=2,
order=2,
# NOTE: IterativeStratification expects fold distribution in (test, train) order,
# but returns indices as (train, test). This matches the library's behavior and
# keeps backward-compatible train/test sizes with prior implementation.
sample_distribution_per_fold=[test_size, 1.0 - test_size],
)
train_r, test_r = next(splitter.split(np.arange(len(remaining)), y[remaining]))
train_idx |= set(remaining[train_r].tolist())
test_idx |= set(remaining[test_r].tolist())


def _finalize_partition(
n_samples: int, train_idx: set[int], test_idx: set[int]
) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
train_arr = np.array(sorted(train_idx), dtype=int)
test_arr = np.array(sorted(test_idx), dtype=int)

if len(train_arr) + len(test_arr) != n_samples:
msg = (
"Multilabel split did not partition all samples: "
f"n_samples={n_samples}, train={len(train_arr)}, test={len(test_arr)}."
)
raise RuntimeError(msg)
if set(train_arr.tolist()) & set(test_arr.tolist()):
msg = "Multilabel split produced overlapping train/test indices."
raise RuntimeError(msg)
return train_arr, test_arr
Loading
Loading