Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

- 🔴 Improved the performance of the `TimeSeries.map()` method for functions that take two arguments. The mapping is now applied on the entire time index and values array which requires users to reshape the time index explicitly within the function. See more information in the `TimeSeries.map()` method documentation. [#2911](https://github.com/unit8co/darts/pull/2911) by [Jakub Chłapek](https://github.com/jakubchlapek)

- Improved the `save`/`load` methods for `TorchForecastingModel` to support `os.PathLike` objects as file path. [#2947](https://github.com/unit8co/darts/pull/2947) by [Timon Erhart](https://github.com/timonerhart)

**Fixed**

**Dependencies**
Expand Down
19 changes: 14 additions & 5 deletions darts/models/forecasting/torch_forecasting_model.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to add this to all other save / load methods:

  • model construction param: work_dir
  • to_onnx()
  • load_weights()
  • work_dir in load_from_checkpoint()
  • work_dir in load_weights_from_checkpoint()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that makes sense. Will try to do so as soon as i find time

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from glob import glob
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -1933,7 +1934,7 @@ def _clean(self) -> Self:

def save(
self,
path: Optional[str] = None,
path: Optional[Union[str, os.PathLike]] = None,
clean: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -1974,14 +1975,17 @@ def save(
"""
if path is None:
# default path
path = self._default_save_path() + ".pt"
path = Path(self._default_save_path() + ".pt")
else:
# ensures that all os.PathLike are accepted
path = Path(path)

# save the TorchForecastingModel (does not save the PyTorch LightningModule, and Trainer)
with open(path, "wb") as f_out:
torch.save(self if not clean else self._clean(), f_out)

# save the LightningModule checkpoint (weights only with `clean=True`)
path_ptl_ckpt = path + ".ckpt"
path_ptl_ckpt = path.with_name(path.name + ".ckpt")
if self.trainer is not None:
self.trainer.save_checkpoint(path_ptl_ckpt, weights_only=clean)

Expand All @@ -2000,7 +2004,9 @@ def save(

@staticmethod
def load(
path: str, pl_trainer_kwargs: Optional[dict] = None, **kwargs
path: Union[str, os.PathLike],
pl_trainer_kwargs: Optional[dict] = None,
**kwargs,
) -> "TorchForecastingModel":
"""
Loads a model from a given file path.
Expand Down Expand Up @@ -2052,14 +2058,17 @@ def load(
For more information, read the `official documentation <https://pytorch-lightning.readthedocs.io/en/stable/
common/lightning_module.html#load-from-checkpoint>`__.
"""
# Ensures all os.PathLike are accepted
path = Path(path)

# load the base TorchForecastingModel (does not contain the actual PyTorch LightningModule)
with open(path, "rb") as fin:
model: TorchForecastingModel = torch.load(
fin, weights_only=False, map_location=kwargs.get("map_location", None)
)

# if a checkpoint was saved, we also load the PyTorch LightningModule from checkpoint
path_ptl_ckpt = path + ".ckpt"
path_ptl_ckpt = path.with_name(path.name + ".ckpt")
if os.path.exists(path_ptl_ckpt):
model.model = model._load_from_checkpoint(path_ptl_ckpt, **kwargs)
else:
Expand Down