Conversation
- Implement TiRexModel based on Darts FoundationModel API - Add optional tirex-ts integration via load_model() + forecast() - Require accept_license=True to acknowledge NXAI Community License - Support deterministic + quantile probabilistic forecasting - Add unit tests with stub TiRex pipeline (no external dependency required) - updated .gitignore (added .python-version)
added detailed docstrings to tirex_model.py set fixed chunck lenghts
added examples notebook there is still something wrong with the forecast. need to investigate!
added example notebook
- align TiRexModel.fit with FoundationModel API (incl. val/trainer kwargs) - expose correct covariate capability flags (no past/future covariates) - keep strict univariate/covariate validation for train + validation inputs - decouple Chronos2/TimesFM2p5/TiRex imports in darts.models.__init__ - extend TiRex tests for API compatibility and capability flags - add TiRex notebook to docs examples and TiRex row to model support table - clean TiRex example notebook wording and TiRex docstring examples - fix TiRex docstrings to be Sphinx/numpydoc compatible
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
Thank you for the PR! The integration is in a very good shape already. I scanned the code and noticed some small issues like duplicated validation logic and class signature. Because I don’t know how busy you might be, I wonder if it would easier for you to let me modify the code directly (I wrote the initial |
|
Hi @daidahao! |
|
@lukfischer If you could do that, I could then apply necessary edits to your code in the coming days. There might be a few optional edits you might consider like adding fidelity tests, I will leave some comments for those and we can discuss about them. |
|
@daidahao |
- Update `example.rst` as well Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
- Use Darts default docstring where necessary - Format notes and warnings Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
- Add back TorchForecastingModel argument docstring - Remove any that are training-specific Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
- Remove `device` argument to model. - Use `PLForecastingModule.device` as the device for initiating TiRex model. This prevents moving tensor across devices when inference. Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
- Fold multivariate target components into batch dimension so multivariate forecasting can be supported. - Remove duplicated covariate validation logic. Note that all torch forecasting models in Darts support multivariate forecasting. Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
According to docstring, when `enable_finetuning=None`, foundation model is not fine-tuned. However, the `model_params` was not being set correctly. Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
I think the fidelity tests would still fail on GitHub CI, but not locally. I investigated the tirex codebase and could not find randomness or sampling that could explain the numerical differences. Any ideas why that happened? @lukfischer @martinloretzzz Other than that, I think this PR is review-ready. @dennisbader |
|
We observe the same thing in our repositories, there's a small dependence on the actual hardware it is run on (TiRex is a recurrent model, so floating point rouding errors can get amplified over the sequence length; we also use BF16). The only thing we did about it was to increase the error tolerances so that the tests still pass on all hardware we test on. |
dennisbader
left a comment
There was a problem hiding this comment.
Thanks a lot for this great PR and the contribution @lukfischer, @martinloretzzz and @daidahao 🚀 TiRex will be a really nice addition to Darts :)
The implementation is solid! I added a couple of suggestions in the comments which we should address before merging. The main one being that we should enable fine-tuning for TiRex to be aligned with all general torch (and foundation) model support.
pyproject.toml
Outdated
| "ray>=2.53.0", | ||
| "plotly>=6.5.2", | ||
| "neuralforecast>=3.0.0", | ||
| "tirex-ts>=1.4.0; python_version >= '3.11'", |
There was a problem hiding this comment.
As @daidahao mentioned, it would be great if we could have python 3.10 supported as well
| logger, | ||
| ) | ||
|
|
||
| if kwargs.get("enable_finetuning", False) not in (None, False): |
|
@dennisbader Thanks for the review here. Besides the fine-tuning support which TiRex team are perhaps best positioned to address, most of them are minor so I will try to address them in the coming days. |
|
@dennisbader thx a lot for the through review! |
|
Thanks a lot for the answer @lukfischer :) I think we're talking about two different things. I'm referring to our (Darts) own existing ability to fine-tune (or train) the foundation models on your own data using the This fine-tuning logic is already in place for all our torch-based models including the existing foundation models. There is no magic behind our default training scenario, so it would still be up to the user / your services to perform robust / quality training. The same thing would apply if a user downloads the model from hugging face / the tirex library. I do hope that we can enable this, as our idea was that all foundation models in Darts would support it. Let me know what you think. |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
Hi everyone, all the minor issues should be addressed with the new commits. There might still be line coverage gaps but I will try to address them once I have the codecov report. |
|
We've relaxed the python requirements to include 3.10, package is published to pypi as v1.4.1: https://github.com/NX-AI/tirex/releases/tag/v1.4.1 |
Oh ok, sorry, my bad! I get you point and after initial discussion with the team, we tend to agree that we indeed should enable this. Give me some more time to double check this with our management. But I'm positive, that we will be able to enable this. |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@martinloretzzz That's great to hear! I've also removed the py bound for
@lukfischer Great to hear that TiRex team is supportive of the initiative! @dennisbader Could you please start the CI workflow for testing? |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
@martinloretzzz and @lukfischer, thanks a lot for adding support for Python 3.10 and the re-consideration of the "fine-tuning" 🚀 |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #3038 +/- ##
==========================================
- Coverage 96.28% 96.21% -0.08%
==========================================
Files 160 161 +1
Lines 17207 17271 +64
==========================================
+ Hits 16568 16617 +49
- Misses 639 654 +15 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Checklist before merging this PR:
Fixes #3016.
Addresses NX-AI/tirex#29.
Summary
This PR adds and hardens the
TiRexModelintegration in Darts.It includes the TiRex forecasting wrapper, model registration in
darts.models, targeted tests, and example notebook/docs integration. It also fixes a few integration issues discovered during review:TiRexModel.fit()with the standardFoundationModel/Darts APIThe integration also requires explicit license acknowledgement. Users must pass
accept_license=Truewhen constructingTiRexModelto confirm acceptance of the NXAI Community License.Other Information
Validation performed:
pytest -q darts/tests/models/forecasting/test_tirex.pypre-commit run --all-filesmake build-all-docs SPHINXOPTS="-W"The changelog entry under
Unreleasedstill needs to be added before merge.