Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 141 additions & 3 deletions src/climpred/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
resample_iterations_idx as _resample_iterations_idx,
)

try:
import xbootstrap as xb

XBOOTSTRAP_AVAILABLE = True
except ImportError:
xb = None
XBOOTSTRAP_AVAILABLE = False

from .checks import (
has_dims,
has_valid_lead_units,
Expand Down Expand Up @@ -69,6 +77,71 @@ def _resample(initialized, resample_dim):
return smp_initialized


def _resample_multiple_dims(initialized, resample_dims):
"""Resample with replacement across multiple dimensions simultaneously.

Args:
initialized (xr.Dataset): input xr.Dataset to be resampled.
resample_dims (list of str): dimensions to resample along.

Returns:
xr.Dataset: resampled along all ``resample_dims``.

"""
result = initialized
# Resample each dimension sequentially
for dim in resample_dims:
result = _resample(result, dim)
return result


def _resample_multiple_dims_xbootstrap(
initialized, resample_dims, n_iteration, block_sizes=None
):
"""Resample with replacement across multiple dimensions using xbootstrap.

Args:
initialized (xr.Dataset): input xr.Dataset to be resampled.
resample_dims (list of str): dimensions to resample along.
n_iteration (int): number of bootstrap iterations.
block_sizes (dict, optional): block sizes for each dimension.
If None, uses size 1 for each dimension (equivalent to standard bootstrap).

Returns:
xr.Dataset: resampled dataset with 'iteration' dimension added.

"""
if not XBOOTSTRAP_AVAILABLE:
raise ImportError(
"xbootstrap is required for efficient multi-dimensional resampling. "
"Install with: pip install xbootstrap"
)

# Default to block size 1 for standard bootstrap behavior
if block_sizes is None:
block_sizes = {dim: 1 for dim in resample_dims}

# Only include dimensions that exist in the dataset
available_dims = set(initialized.dims.keys())
blocks = {
dim: block_sizes.get(dim, 1) for dim in resample_dims if dim in available_dims
}

if not blocks:
# No dimensions to resample
return initialized

# Use xbootstrap for multi-dimensional resampling
resampled = xb.block_bootstrap(
initialized,
blocks=blocks,
n_iteration=n_iteration,
circular=True, # Use circular bootstrap
)

return resampled


def _distribution_to_ci(ds, ci_low, ci_high, dim="iteration"):
"""Get confidence intervals from bootstrapped distribution.

Expand Down Expand Up @@ -465,6 +538,66 @@ def _chunk_before_resample_iterations_idx(
return ds


def resample_skill_xbootstrap(self, iterations, resample_dim, verify_kwargs):
"""Bootstrap skill using xbootstrap for efficient multi-dimensional resampling.

This function uses xbootstrap to do all bootstrap iterations at once,
which is more efficient than the loop-based approach.
"""
logging.info("use resample_skill_xbootstrap")

if not XBOOTSTRAP_AVAILABLE:
# Fallback to loop method if xbootstrap not available
logging.warning("xbootstrap not available, falling back to loop method")
return resample_skill_loop(self, iterations, resample_dim, verify_kwargs)

# Handle both single dimension and list of dimensions
if isinstance(resample_dim, str):
resample_dims = [resample_dim]
else:
resample_dims = resample_dim

# Get initialized dataset
initialized = self.get_initialized()

# Use xbootstrap to generate all resampled versions at once
resampled_data = _resample_multiple_dims_xbootstrap(
initialized, resample_dims, iterations
)

# The result has an 'iteration' dimension, so we need to compute skill
# for each iteration
resampled_skills = []

for i in range(iterations):
# Extract data for this iteration
iter_data = resampled_data.isel(iteration=i)

# Create a temporary copy of self with this iteration's data
self_temp = self.copy()
self_temp._datasets["initialized"] = iter_data

# Handle uninitialized data if needed
if "uninitialized" in verify_kwargs["reference"]:
if not self.get_uninitialized():
self_temp._datasets["uninitialized"] = (
self.generate_uninitialized().get_uninitialized()
)
else:
# For uninitialized, we still use the original member resampling
self_temp._datasets["uninitialized"] = _resample(
self.get_uninitialized(), "member"
)

# Compute skill for this iteration
skill = self_temp.verify(**verify_kwargs)
resampled_skills.append(skill)

# Concatenate all skills along iteration dimension
resampled_skills = xr.concat(resampled_skills, "iteration")
return resampled_skills


def resample_skill_loop(self, iterations, resample_dim, verify_kwargs):
# slow: loop and verify each time
# used for HindcastEnsemble.bootstrap(metric='acc') and if
Expand All @@ -483,9 +616,14 @@ def resample_skill_loop(self, iterations, resample_dim, verify_kwargs):
loop = tqdm(loop)
for i in loop:
# resample initialized
self_for_loop._datasets["initialized"] = _resample(
self.get_initialized(), resample_dim
)
if isinstance(resample_dim, list):
self_for_loop._datasets["initialized"] = _resample_multiple_dims(
self.get_initialized(), resample_dim
)
else:
self_for_loop._datasets["initialized"] = _resample(
self.get_initialized(), resample_dim
)
if "uninitialized" in verify_kwargs["reference"]:
# resample uninitialized
if not self.get_uninitialized():
Expand Down
Loading
Loading