Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

- Added native multi-quantile support for `CatBoostModel` by using CatBoost’s `MultiQuantile` loss for faster training and inference. Set `likelihood="multiquantile"` to enable this feature. [#3032](https://github.com/unit8co/darts/pull/3032) by [Zhihao Dai](https://github.com/daidahao)
- Added native multi-quantile support for `XGBModel`. Similar to the regular quantile support, it still fits dedicated models per quantile, but it is more efficient due to fewer tabularization operations. Set `likelihood="multiquantile"` to enable this feature. [#3056](https://github.com/unit8co/darts/pull/3056) by [Oswald Zink](https://github.com/ozink-u8)
- Improvements to `RegressionEnsembleModel` : [#2773](https://github.com/unit8co/darts/issues/2773) by [Gabriel Margaria](https://github.com/Jaco-Pastorius).
- `StatsForecastModel` now accepts `model` as a StatsForecast model name, class, or instance; `model_kwargs` supplies constructor arguments when `model` is a name or class. This simplifies config-driven setups. [#3058](https://github.com/unit8co/darts/pull/3058) by [Trevin Chow](https://github.com/tmchow).
- Improvements to `RegressionEnsembleModel` : [#3041](https://github.com/unit8co/darts/pull/3041) by [Gabriel Margaria](https://github.com/Jaco-Pastorius).
- Base forecasting models using `output_chunk_shift>0` are now fully supported. If you're using a custom `regression_model`, simply set its output shift to be the same as that of the base models.
- Added support for `output_chunk_length>1` for the ensemble (regression) model. This means that the ensemble model can now consume information from base model forecasts over the entire horizon.
- Scaled metrics (`ase`, `sse`, `mase`, `msse`, `rmsse`) no longer raise a hard `ValueError` when the `insample` series has zero error scale (constant or perfectly seasonal signals). A new `zero_division` parameter controls the behavior: [#3059](https://github.com/unit8co/darts/pull/3059) by [Mahima Sharma](https://github.com/mahi-ma)
Expand Down
76 changes: 55 additions & 21 deletions darts/models/forecasting/sf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
------------------
"""

from typing import Protocol, runtime_checkable

import numpy as np
from statsforecast.models import _TS
import statsforecast.models as sf_models

from darts import TimeSeries, concatenate
from darts.logging import get_logger
from darts.logging import get_logger, raise_log
from darts.models import LinearRegressionModel
from darts.models.forecasting.forecasting_model import (
TransferableFutureCovariatesLocalForecastingModel,
Expand All @@ -19,11 +21,29 @@
logger = get_logger(__name__)


@runtime_checkable
class _SFModel(Protocol):
"""This serves as a protocol for expected StatsForecast model API."""

uses_exog: bool

def __init__(*args, **kwargs): ...

def fit(self, *args, **kwargs): ...

def predict(self, *args, **kwargs) -> dict: ...

def forecast(self, *args, **kwargs) -> dict: ...

def forward(self, *args, **kwargs) -> dict: ...


class StatsForecastModel(TransferableFutureCovariatesLocalForecastingModel):
@random_method
def __init__(
self,
model: _TS,
model: str | type[sf_models._TS] | sf_models._TS = "AutoARIMA",
model_kwargs: dict | None = None,
add_encoders: dict | None = None,
quantiles: list[float] | None = None,
random_state: int | None = None,
Expand Down Expand Up @@ -70,7 +90,13 @@ def __init__(
Parameters
----------
model
Any StatsForecast model.
Name, class, or instance of the StatsForecast base model to be used from ``statsforecast.models``, e.g.,
``"AutoARIMA"``, ``AutoARIMA``, or ``AutoARIMA()``. See all `StatsForecast models
<https://nixtlaverse.nixtla.io/statsforecast/src/core/models_intro.html>`__ here.
model_kwargs
A dictionary of model parameters to initialize the StatsForecast base model. The expected
parameters depend on the base model used. Only effective when `model` is a string or class.
Default: ``None``.
add_encoders
A large number of future covariates can be automatically generated with `add_encoders`.
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
Expand Down Expand Up @@ -106,12 +132,11 @@ def encode_year(idx):
>>> from darts.datasets import AirPassengersDataset
>>> from darts.models import StatsForecastModel
>>> from darts.utils.timeseries_generation import datetime_attribute_timeseries
>>> from statsforecast.models import AutoARIMA
>>> series = AirPassengersDataset().load()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the AutoARIMA import is no longer necessary

>>> # optionally, use some future covariates; e.g. the value of the month encoded as a sine and cosine series
>>> future_cov = datetime_attribute_timeseries(series, "month", cyclic=True, add_length=6)
>>> # define AutoARIMA parameters
>>> model = StatsForecastModel(model=AutoARIMA(season_length=12))
>>> model = StatsForecastModel(model="AutoARIMA", model_kwargs={"season_length": 12})
>>> model.fit(series, future_covariates=future_cov)
>>> pred = model.predict(6, future_covariates=future_cov)
>>> print(pred.values())
Expand All @@ -122,10 +147,31 @@ def encode_year(idx):
[502.67834069]
[566.04774778]]
"""
if not isinstance(model, _TS):
raise ValueError(
"`model` must be a StatsForecast model imported from `statsforecast.models`."
model_kwargs = model_kwargs or {}
if isinstance(model, sf_models._TS):
pass
elif isinstance(model, str):
try:
model_class = getattr(sf_models, model)
except AttributeError:
raise_log(
ValueError(
f"Could not find a StatsForecast model class named `{model}` "
f"in `statsforecast.models`."
),
logger,
)
model = model_class(**model_kwargs)
elif isinstance(model, type) and issubclass(model, sf_models._TS):
model = model(**model_kwargs)
else:
raise_log(
ValueError(
"`model` must be a valid StatsForecast model name (str), class or instance."
),
logger,
)

self.model: _SFModel = model
self._likelihood = QuantilePrediction(quantiles=quantiles)

Expand Down Expand Up @@ -377,18 +423,6 @@ def _supports_non_retrainable_historical_forecasts(self) -> bool:
return self._supports_native_transferable_series


class _SFModel(_TS):
"""This serves as a protocol for expected StatsForecast model API."""

def fit(self, *args, **kwargs): ...

def predict(self, *args, **kwargs) -> dict: ...

def forecast(self, *args, **kwargs) -> dict: ...

def forward(self, *args, **kwargs) -> dict: ...


def _unpack_sf_dict(
forecast_dict: dict,
levels: list[float] | None,
Expand Down
53 changes: 46 additions & 7 deletions darts/tests/models/forecasting/test_sf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
allow_module_level=True,
)

from statsforecast.models import AutoETS as SF_AutoETS
from statsforecast.models import SimpleExponentialSmoothing as SF_ETS
import statsforecast.models as sf_models
from statsforecast.utils import ConformalIntervals

import darts.utils.timeseries_generation as tg
Expand Down Expand Up @@ -51,7 +50,10 @@ class TestSFModels:
(AutoARIMA, {"season_length": 12}), # (native, native)
(AutoMFLES, {"season_length": 12, "test_size": 12}), # (custom, native)
(AutoETS, {"season_length": 12}), # (native, custom)
(StatsForecastModel, {"model": SF_ETS(alpha=0.1)}), # (custom, custom)
(
StatsForecastModel,
{"model": sf_models.SimpleExponentialSmoothing(alpha=0.1)},
), # (custom, custom)
],
)
def test_transferable_series_forecast(self, config):
Expand Down Expand Up @@ -151,7 +153,7 @@ def test_transferable_series_forecast(self, config):
(AutoETS, {"season_length": 12}, False), # (native, custom, native)
(
StatsForecastModel,
{"model": SF_ETS(alpha=0.1)},
{"model": sf_models.SimpleExponentialSmoothing(alpha=0.1)},
True,
), # (custom, custom, conformal)
],
Expand Down Expand Up @@ -187,7 +189,9 @@ def test_probabilistic_support(self, config):
if isinstance(model, AutoMFLES):
kwargs["prediction_intervals"] = ci
else:
kwargs["model"] = SF_ETS(alpha=0.1, prediction_intervals=ci)
kwargs["model"] = sf_models.SimpleExponentialSmoothing(
alpha=0.1, prediction_intervals=ci
)
model = model_cls(**kwargs).fit(series)

with pytest.raises(ValueError) as exc:
Expand Down Expand Up @@ -249,7 +253,7 @@ def test_probabilistic_support(self, config):
"model",
[
AutoETS(season_length=12, model="ZZZ"),
StatsForecastModel(SF_AutoETS(season_length=12, model="ZZZ")),
StatsForecastModel(sf_models.AutoETS(season_length=12, model="ZZZ")),
],
)
def test_custom_fc_support_fit_on_residuals(self, model):
Expand Down Expand Up @@ -280,7 +284,7 @@ def test_custom_fc_support_fit_on_residuals(self, model):
"model",
[
AutoETS(season_length=12, model="ZZZ"),
StatsForecastModel(SF_AutoETS(season_length=12, model="ZZZ")),
StatsForecastModel(sf_models.AutoETS(season_length=12, model="ZZZ")),
],
)
def test_custom_fc_support_fit_a_linreg(self, model):
Expand Down Expand Up @@ -361,3 +365,38 @@ def test_wrong_covariates(self):
with pytest.raises(ValueError) as exc:
_ = model.predict(n=n, future_covariates=fc[:-1])
assert exc_expected in str(exc.value)

@pytest.mark.parametrize(
"model, model_kwargs",
[
("AutoETS", {}),
("AutoETS", {"season_length": 12}),
(sf_models.AutoETS, {}),
(sf_models.AutoETS, {"season_length": 12}),
(sf_models.AutoETS(), {}),
(sf_models.AutoETS(season_length=12), {}),
],
)
def test_model_creation(self, model, model_kwargs):
model = StatsForecastModel(
model=model,
model_kwargs=model_kwargs,
)
model.fit(series=self.series[:24])
preds = model.predict(n=12)
assert len(preds) == 12

class InvalidModel:
pass

@pytest.mark.parametrize("model", ["InvalidModel", InvalidModel, InvalidModel()])
def test_invalid_model(self, model):
if not isinstance(model, str):
with pytest.raises(ValueError, match="must be a valid StatsForecast model"):
_ = StatsForecastModel(model=model)
else:
with pytest.raises(
ValueError,
match="Could not find a StatsForecast model class named `InvalidModel`",
):
_ = StatsForecastModel(model=model)
4 changes: 1 addition & 3 deletions darts/utils/likelihood_models/statsforecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
----------------------------------
"""

from abc import ABC

import numpy as np

from darts.logging import get_logger
Expand All @@ -18,7 +16,7 @@
logger = get_logger(__name__)


class QuantilePrediction(Likelihood, ABC):
class QuantilePrediction(Likelihood):
def __init__(self, quantiles: list[float]):
"""Quantile Prediction Likelihood

Expand Down
Loading