diff --git a/src/climpred/bootstrap.py b/src/climpred/bootstrap.py index f209db74a..b3e8242de 100644 --- a/src/climpred/bootstrap.py +++ b/src/climpred/bootstrap.py @@ -18,6 +18,14 @@ resample_iterations_idx as _resample_iterations_idx, ) +try: + import xbootstrap as xb + + XBOOTSTRAP_AVAILABLE = True +except ImportError: + xb = None + XBOOTSTRAP_AVAILABLE = False + from .checks import ( has_dims, has_valid_lead_units, @@ -69,6 +77,71 @@ def _resample(initialized, resample_dim): return smp_initialized +def _resample_multiple_dims(initialized, resample_dims): + """Resample with replacement across multiple dimensions simultaneously. + + Args: + initialized (xr.Dataset): input xr.Dataset to be resampled. + resample_dims (list of str): dimensions to resample along. + + Returns: + xr.Dataset: resampled along all ``resample_dims``. + + """ + result = initialized + # Resample each dimension sequentially + for dim in resample_dims: + result = _resample(result, dim) + return result + + +def _resample_multiple_dims_xbootstrap( + initialized, resample_dims, n_iteration, block_sizes=None +): + """Resample with replacement across multiple dimensions using xbootstrap. + + Args: + initialized (xr.Dataset): input xr.Dataset to be resampled. + resample_dims (list of str): dimensions to resample along. + n_iteration (int): number of bootstrap iterations. + block_sizes (dict, optional): block sizes for each dimension. + If None, uses size 1 for each dimension (equivalent to standard bootstrap). + + Returns: + xr.Dataset: resampled dataset with 'iteration' dimension added. + + """ + if not XBOOTSTRAP_AVAILABLE: + raise ImportError( + "xbootstrap is required for efficient multi-dimensional resampling. " + "Install with: pip install xbootstrap" + ) + + # Default to block size 1 for standard bootstrap behavior + if block_sizes is None: + block_sizes = {dim: 1 for dim in resample_dims} + + # Only include dimensions that exist in the dataset + available_dims = set(initialized.dims.keys()) + blocks = { + dim: block_sizes.get(dim, 1) for dim in resample_dims if dim in available_dims + } + + if not blocks: + # No dimensions to resample + return initialized + + # Use xbootstrap for multi-dimensional resampling + resampled = xb.block_bootstrap( + initialized, + blocks=blocks, + n_iteration=n_iteration, + circular=True, # Use circular bootstrap + ) + + return resampled + + def _distribution_to_ci(ds, ci_low, ci_high, dim="iteration"): """Get confidence intervals from bootstrapped distribution. @@ -465,6 +538,66 @@ def _chunk_before_resample_iterations_idx( return ds +def resample_skill_xbootstrap(self, iterations, resample_dim, verify_kwargs): + """Bootstrap skill using xbootstrap for efficient multi-dimensional resampling. + + This function uses xbootstrap to do all bootstrap iterations at once, + which is more efficient than the loop-based approach. + """ + logging.info("use resample_skill_xbootstrap") + + if not XBOOTSTRAP_AVAILABLE: + # Fallback to loop method if xbootstrap not available + logging.warning("xbootstrap not available, falling back to loop method") + return resample_skill_loop(self, iterations, resample_dim, verify_kwargs) + + # Handle both single dimension and list of dimensions + if isinstance(resample_dim, str): + resample_dims = [resample_dim] + else: + resample_dims = resample_dim + + # Get initialized dataset + initialized = self.get_initialized() + + # Use xbootstrap to generate all resampled versions at once + resampled_data = _resample_multiple_dims_xbootstrap( + initialized, resample_dims, iterations + ) + + # The result has an 'iteration' dimension, so we need to compute skill + # for each iteration + resampled_skills = [] + + for i in range(iterations): + # Extract data for this iteration + iter_data = resampled_data.isel(iteration=i) + + # Create a temporary copy of self with this iteration's data + self_temp = self.copy() + self_temp._datasets["initialized"] = iter_data + + # Handle uninitialized data if needed + if "uninitialized" in verify_kwargs["reference"]: + if not self.get_uninitialized(): + self_temp._datasets["uninitialized"] = ( + self.generate_uninitialized().get_uninitialized() + ) + else: + # For uninitialized, we still use the original member resampling + self_temp._datasets["uninitialized"] = _resample( + self.get_uninitialized(), "member" + ) + + # Compute skill for this iteration + skill = self_temp.verify(**verify_kwargs) + resampled_skills.append(skill) + + # Concatenate all skills along iteration dimension + resampled_skills = xr.concat(resampled_skills, "iteration") + return resampled_skills + + def resample_skill_loop(self, iterations, resample_dim, verify_kwargs): # slow: loop and verify each time # used for HindcastEnsemble.bootstrap(metric='acc') and if @@ -483,9 +616,14 @@ def resample_skill_loop(self, iterations, resample_dim, verify_kwargs): loop = tqdm(loop) for i in loop: # resample initialized - self_for_loop._datasets["initialized"] = _resample( - self.get_initialized(), resample_dim - ) + if isinstance(resample_dim, list): + self_for_loop._datasets["initialized"] = _resample_multiple_dims( + self.get_initialized(), resample_dim + ) + else: + self_for_loop._datasets["initialized"] = _resample( + self.get_initialized(), resample_dim + ) if "uninitialized" in verify_kwargs["reference"]: # resample uninitialized if not self.get_uninitialized(): diff --git a/src/climpred/classes.py b/src/climpred/classes.py index e78275c61..ecdc9ce4a 100644 --- a/src/climpred/classes.py +++ b/src/climpred/classes.py @@ -32,6 +32,7 @@ from .alignment import return_inits_and_verif_dates from .bias_removal import bias_correction, gaussian_bias_removal, xclim_sdba from .bootstrap import ( + XBOOTSTRAP_AVAILABLE, _distribution_to_ci, _p_ci_from_sig, _pvalue_from_distributions, @@ -39,6 +40,7 @@ resample_skill_exclude_resample_dim_from_dim, resample_skill_loop, resample_skill_resample_before, + resample_skill_xbootstrap, resample_uninitialized_from_initialized, warn_if_chunking_would_increase_performance, ) @@ -1013,7 +1015,7 @@ def _bootstrap( groupby: Optional[groupbyType] = None, iterations: Optional[int] = None, sig: int = 95, - resample_dim: Optional[str] = None, + resample_dim: Optional[Union[str, List[str]]] = None, **metric_kwargs: metric_kwargsType, ) -> xr.Dataset: """PredictionEnsemble.bootstrap() parent method. @@ -1068,93 +1070,111 @@ def _bootstrap( self2 = self skill = self2.verify(**verify_kwargs) - # different ways to compute resample_skill - if ( - _metric.name in PEARSON_R_CONTAINING_METRICS - and self.kind == "hindcast" - and alignment - ): - if ("same_verif" in alignment) & (resample_dim == "init"): - raise KeywordError( - "Cannot have alignment='same_verifs' and resample_dim='init' and " - "metric='pearson_r'. Change `resample_dim` to 'member' to keep " - "common verification alignment or `alignment` to 'same_inits' to " - "resample over initializations." + # Handle multi-dimensional resampling + if isinstance(resample_dim, list): + # For multi-dimensional resampling, prefer xbootstrap if available + if XBOOTSTRAP_AVAILABLE: + resampled_skills = resample_skill_xbootstrap( + self2, iterations, resample_dim, verify_kwargs ) + else: + # Fallback to loop method if xbootstrap not available + resampled_skills = resample_skill_loop( + self2, iterations, resample_dim, verify_kwargs + ) + # Skip the complex conditional logic below for multi-dim case + multi_dim_resample = True + else: + multi_dim_resample = False - if OPTIONS["bootstrap_resample_skill_func"] == "default": + # different ways to compute resample_skill + if not multi_dim_resample: if ( _metric.name in PEARSON_R_CONTAINING_METRICS and self.kind == "hindcast" - and resample_dim == "init" + and alignment ): - # slow: loop and verify each time - # used for HindcastEnsemble.bootstrap(metric='acc') - resampled_skills = resample_skill_loop( - self, iterations, resample_dim, verify_kwargs - ) - elif ( - alignment in ["same_verifs", "same_verif"] - and self.kind == "hindcast" - and resample_dim == "init" - ): - # allow https://github.com/pangeo-data/climpred/issues/582 - resampled_skills = resample_skill_exclude_resample_dim_from_dim( - self, iterations, resample_dim, verify_kwargs - ) + if ("same_verif" in alignment) & (resample_dim == "init"): + raise KeywordError( + "Cannot have alignment='same_verifs' and resample_dim='init' and " + "metric='pearson_r'. Change `resample_dim` to 'member' to keep " + "common verification alignment or `alignment` to 'same_inits' to " + "resample over initializations." + ) - elif ( - resample_dim == "init" - and self.kind == "hindcast" - and not _metric.probabilistic - and _metric.name != "rmse" - ): - # fast way by verify(dim=[]) and then resampling init - # used for HindcastEnsemble.bootstrap(resample_dim='init') - resampled_skills = resample_skill_exclude_resample_dim_from_dim( - self, iterations, resample_dim, verify_kwargs - ) + if OPTIONS["bootstrap_resample_skill_func"] == "default": + if ( + _metric.name in PEARSON_R_CONTAINING_METRICS + and self.kind == "hindcast" + and resample_dim == "init" + ): + # slow: loop and verify each time + # used for HindcastEnsemble.bootstrap(metric='acc') + resampled_skills = resample_skill_loop( + self2, iterations, resample_dim, verify_kwargs + ) + elif ( + alignment in ["same_verifs", "same_verif"] + and self.kind == "hindcast" + and resample_dim == "init" + ): + # allow https://github.com/pangeo-data/climpred/issues/582 + resampled_skills = resample_skill_exclude_resample_dim_from_dim( + self2, iterations, resample_dim, verify_kwargs + ) - elif ( - resample_dim == "init" - and self.kind == "hindcast" - and _metric.probabilistic - and "member" in dim - and _metric.name - not in ["rank_histogram", "discrimination", "reliability"] - ): - # fast way by verify(dim=[]) and then resampling init - # used for HindcastEnsemble.bootstrap(resample_dim='init') - resampled_skills = resample_skill_exclude_resample_dim_from_dim( - self, iterations, resample_dim, verify_kwargs - ) - elif ( - resample_dim == "member" - and self.kind == "hindcast" - and _metric.name - in ["threshold_brier_score", "reliability", "rank_histogram"] - ): - resampled_skills = resample_skill_loop( - self, iterations, resample_dim, verify_kwargs - ) + elif ( + resample_dim == "init" + and self.kind == "hindcast" + and not _metric.probabilistic + and _metric.name != "rmse" + ): + # fast way by verify(dim=[]) and then resampling init + # used for HindcastEnsemble.bootstrap(resample_dim='init') + resampled_skills = resample_skill_exclude_resample_dim_from_dim( + self2, iterations, resample_dim, verify_kwargs + ) - elif resample_dim == "member" or self.kind == "perfect": - resampled_skills = resample_skill_resample_before( - self, iterations, resample_dim, verify_kwargs - ) + elif ( + resample_dim == "init" + and self.kind == "hindcast" + and _metric.probabilistic + and "member" in dim + and _metric.name + not in ["rank_histogram", "discrimination", "reliability"] + ): + # fast way by verify(dim=[]) and then resampling init + # used for HindcastEnsemble.bootstrap(resample_dim='init') + resampled_skills = resample_skill_exclude_resample_dim_from_dim( + self2, iterations, resample_dim, verify_kwargs + ) + elif ( + resample_dim == "member" + and self.kind == "hindcast" + and _metric.name + in ["threshold_brier_score", "reliability", "rank_histogram"] + ): + resampled_skills = resample_skill_loop( + self2, iterations, resample_dim, verify_kwargs + ) + + elif resample_dim == "member" or self.kind == "perfect": + resampled_skills = resample_skill_resample_before( + self2, iterations, resample_dim, verify_kwargs + ) + else: + # slow: loop and verify each time, but always works + resampled_skills = resample_skill_loop( + self2, iterations, resample_dim, verify_kwargs + ) else: - # slow: loop and verify each time, but always works - resampled_skills = resample_skill_loop( - self, iterations, resample_dim, verify_kwargs + resample_skill_func = eval( + f"resample_skill_{OPTIONS['bootstrap_resample_skill_func']}" + ) + resampled_skills = resample_skill_func( + self2, iterations, resample_dim, verify_kwargs ) - else: - resample_skill_func = eval( - f"resample_skill_{OPTIONS['bootstrap_resample_skill_func']}" - ) - resampled_skills = resample_skill_func( - self, iterations, resample_dim, verify_kwargs - ) # continue with skill and resampled_skills @@ -1716,7 +1736,7 @@ def bootstrap( groupby: Optional[groupbyType] = None, iterations: Optional[int] = None, sig: int = 95, - resample_dim: str = "member", + resample_dim: Union[str, List[str]] = "member", **metric_kwargs: metric_kwargsType, ) -> xr.Dataset: """Bootstrap with replacement according to :cite:t:`Goddard2013`. @@ -2461,7 +2481,7 @@ def bootstrap( groupby: Optional[groupbyType] = None, iterations: Optional[int] = None, sig: int = 95, - resample_dim: str = "member", + resample_dim: Union[str, List[str]] = "member", **metric_kwargs: metric_kwargsType, ) -> xr.Dataset: """Bootstrap with replacement according to :cite:t:`Goddard2013`. @@ -2498,6 +2518,8 @@ def bootstrap( - ``"member"``: select a different set of members from hind - ``"init"``: select a different set of initializations from hind + - ``["member", "init"]``: resample both member and init dimensions simultaneously + (uses xbootstrap if available for efficient block bootstrap resampling) groupby: group ``init`` before passing ``initialized`` to ``bootstrap``. **metric_kwargs: arguments passed to ``metric``. @@ -2569,6 +2591,18 @@ def bootstrap( iterations: 50 confidence_interval_levels: 0.975-0.025 + Simultaneous resampling across member and init dimensions: + + >>> HindcastEnsemble.bootstrap( + ... metric="crps", + ... comparison="m2o", + ... dim="member", + ... iterations=50, + ... resample_dim=["member", "init"], # NEW: simultaneous resampling + ... alignment="same_inits", + ... reference=["persistence", "climatology", "uninitialized"], + ... ) + """ return self._bootstrap( metric=metric, diff --git a/src/climpred/tests/test_bootstrap.py b/src/climpred/tests/test_bootstrap.py index 64c138e33..5f41a883c 100644 --- a/src/climpred/tests/test_bootstrap.py +++ b/src/climpred/tests/test_bootstrap.py @@ -12,9 +12,12 @@ from climpred import HindcastEnsemble from climpred.bootstrap import ( + XBOOTSTRAP_AVAILABLE, _bootstrap_by_stacking, _chunk_before_resample_iterations_idx, _resample, + _resample_multiple_dims, + _resample_multiple_dims_xbootstrap, bootstrap_uninit_pm_ensemble_from_control_cftime, ) from climpred.constants import CONCAT_KWARGS @@ -544,3 +547,122 @@ def test_generate_uninitialized(hindcast_hist_obs_1d): assert not hindcast_hist_obs_1d_new.verify(**kw).equals( hindcast_hist_obs_1d.verify(**kw) ) + + +def test_resample_multiple_dims_xbootstrap(): + """Test _resample_multiple_dims_xbootstrap function if xbootstrap is available.""" + if not XBOOTSTRAP_AVAILABLE: + pytest.skip("xbootstrap not available") + + # Create a simple test dataset + data = np.random.random((3, 4, 5)) # init, member, lead + ds = xr.Dataset( + {"var": (["init", "member", "lead"], data)}, + coords={"init": range(3), "member": range(4), "lead": range(5)}, + ) + + # Test resampling both init and member dimensions + result = _resample_multiple_dims_xbootstrap(ds, ["init", "member"], n_iteration=3) + + # Check that iteration dimension was added + assert "iteration" in result.dims + assert result.iteration.size == 3 + + # Check that original dimensions are still there + assert "init" in result.dims + assert "member" in result.dims + assert "lead" in result.dims + + +@pytest.mark.skipif(not XBOOTSTRAP_AVAILABLE, reason="xbootstrap not available") +def test_bootstrap_multi_dim_resample_xbootstrap(hindcast_hist_obs_1d): + """Test HindcastEnsemble.bootstrap with xbootstrap multi-dimensional resampling.""" + # Test with both member and init resampling using xbootstrap + result = hindcast_hist_obs_1d.bootstrap( + metric="rmse", + comparison="e2o", + dim=[], + iterations=2, # Keep low for test speed + resample_dim=["member", "init"], # Multi-dimensional resampling + alignment="same_inits", + ) + + # Check that the result has the expected structure + assert "results" in result.dims + assert "skill" in result.dims + assert "verify skill" in result.results.values + assert "low_ci" in result.results.values + assert "high_ci" in result.results.values + + +def test_resample_multiple_dims(): + """Test _resample_multiple_dims function.""" + # Create a simple test dataset + data = np.random.random((3, 4, 5)) # init, member, lead + ds = xr.Dataset( + {"var": (["init", "member", "lead"], data)}, + coords={"init": range(3), "member": range(4), "lead": range(5)}, + ) + + # Test resampling both init and member dimensions + result = _resample_multiple_dims(ds, ["init", "member"]) + + # Check that dimensions are preserved + assert result.dims == ds.dims + assert set(result.coords) == set(ds.coords) + + # Check that init labels are changed but member labels stay the same + assert list(result.member.values) == list(ds.member.values) + # init should be resampled so some values might be different + + +def test_bootstrap_multi_dim_resample(hindcast_hist_obs_1d): + """Test HindcastEnsemble.bootstrap with multi-dimensional resampling.""" + # Test with both member and init resampling + result = hindcast_hist_obs_1d.bootstrap( + metric="rmse", + comparison="e2o", + dim=[], + iterations=2, # Keep low for test speed + resample_dim=["member", "init"], # NEW: simultaneous resampling + alignment="same_inits", + ) + + # Check that the result has the expected structure + assert "results" in result.dims + assert "skill" in result.dims + assert "verify skill" in result.results.values + assert "low_ci" in result.results.values + assert "high_ci" in result.results.values + + +def test_bootstrap_single_vs_multi_dim_different_results(hindcast_hist_obs_1d): + """Test that single and multi-dimensional resampling give different results.""" + np.random.seed(42) # Set seed for reproducibility + + # Single dimension resampling (member only) + result_single = hindcast_hist_obs_1d.bootstrap( + metric="rmse", + comparison="e2o", + dim=[], + iterations=2, + resample_dim="member", + alignment="same_inits", + ) + + np.random.seed(42) # Reset seed for fair comparison + + # Multi-dimensional resampling (member and init) + result_multi = hindcast_hist_obs_1d.bootstrap( + metric="rmse", + comparison="e2o", + dim=[], + iterations=2, + resample_dim=["member", "init"], + alignment="same_inits", + ) + + # Results should have same structure but potentially different values + assert result_single.dims == result_multi.dims + # Due to different resampling, results should typically be different + # (though this isn't guaranteed due to randomness)