diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c168ca157..3622d7810e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - 🔴 Improved the performance of the `TimeSeries.map()` method for functions that take two arguments. The mapping is now applied on the entire time index and values array which requires users to reshape the time index explicitly within the function. See more information in the `TimeSeries.map()` method documentation. [#2911](https://github.com/unit8co/darts/pull/2911) by [Jakub Chłapek](https://github.com/jakubchlapek) +- Improved the `save`/`load` methods for `TorchForecastingModel` to support `os.PathLike` objects as file path. [#2947](https://github.com/unit8co/darts/pull/2947) by [Timon Erhart](https://github.com/timonerhart) + **Fixed** **Dependencies** diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 74c28d4162..bcc5d3d0f2 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -26,6 +26,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from glob import glob +from pathlib import Path from typing import Any, Callable, Literal, Optional, Union if sys.version_info >= (3, 11): @@ -1933,7 +1934,7 @@ def _clean(self) -> Self: def save( self, - path: Optional[str] = None, + path: Optional[Union[str, os.PathLike]] = None, clean: bool = False, ) -> None: """ @@ -1974,14 +1975,17 @@ def save( """ if path is None: # default path - path = self._default_save_path() + ".pt" + path = Path(self._default_save_path() + ".pt") + else: + # ensures that all os.PathLike are accepted + path = Path(path) # save the TorchForecastingModel (does not save the PyTorch LightningModule, and Trainer) with open(path, "wb") as f_out: torch.save(self if not clean else self._clean(), f_out) # save the LightningModule checkpoint (weights only with `clean=True`) - path_ptl_ckpt = path + ".ckpt" + path_ptl_ckpt = path.with_name(path.name + ".ckpt") if self.trainer is not None: self.trainer.save_checkpoint(path_ptl_ckpt, weights_only=clean) @@ -2000,7 +2004,9 @@ def save( @staticmethod def load( - path: str, pl_trainer_kwargs: Optional[dict] = None, **kwargs + path: Union[str, os.PathLike], + pl_trainer_kwargs: Optional[dict] = None, + **kwargs, ) -> "TorchForecastingModel": """ Loads a model from a given file path. @@ -2052,6 +2058,9 @@ def load( For more information, read the `official documentation `__. """ + # Ensures all os.PathLike are accepted + path = Path(path) + # load the base TorchForecastingModel (does not contain the actual PyTorch LightningModule) with open(path, "rb") as fin: model: TorchForecastingModel = torch.load( @@ -2059,7 +2068,7 @@ def load( ) # if a checkpoint was saved, we also load the PyTorch LightningModule from checkpoint - path_ptl_ckpt = path + ".ckpt" + path_ptl_ckpt = path.with_name(path.name + ".ckpt") if os.path.exists(path_ptl_ckpt): model.model = model._load_from_checkpoint(path_ptl_ckpt, **kwargs) else: