MLflow for Darts implementation#3022
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
Hey @daidahao, adding this draft PR in the meantime so you and @dennisbader can have a look at what I have currently regarding the integration. There are still some decisions I am not too thrilled about and decisions to be made about the overall direction, but I'm happy to talk more about it during the meeting. Thanks for being so active for the library, really nice to be working together :) |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
darts/utils/utils.py
Outdated
| try: | ||
| import pytorch_lightning as pl # noqa: F401 | ||
|
|
||
| PL_AVAILABLE = True |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
pyproject.toml
Outdated
| ] | ||
| notorch = [ | ||
| "catboost>=1.0.6", | ||
| "catboost>=1.0.6,<=1.2.9", |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
| "statsforecast>=1.4", | ||
| "xgboost>=2.1.4", | ||
| ] | ||
| mlflow = ["mlflow>=2.0"] |
There was a problem hiding this comment.
I would like a discussion on the new option here. My understanding is that users who would need the Darts-MLflow integration probably have MLFlow installed already and set up properly. For users who have not, MLflow itself has options for databricks, which some users might find useful. Could we instead direct users to MLflow official guide for installation?
There was a problem hiding this comment.
Also, @MichaelVerdegaal raised a suggestion for mlflow>=3.0. I have not used MLflow 2.x before but I think the minimum version should be deliberated as well.
There was a problem hiding this comment.
I think directing the user to the official guide sounds good, and for devs it makes sense including the mlflow in the [optional] requirements. I'm also inclined to agree with @MichaelVerdegaal regarding the versioning. MLflow 3.0 introduced some breaking changes (i.e. the models no longer being saved with other artifacts but in the models/ directory and versioned), so I think it can solve a lot of future headaches for us to cap the version at 3.0 minimum. So if MLflow is not installed we would raise an error, directing to the official installation and notifying that we support from 3.0 onwards. lmk what you think
| does not apply the `with_managed_run` wrapper to the specified | ||
| `patch_function`. | ||
| """ | ||
| # Enable/disable mlflow.pytorch.autolog for per-epoch metrics on torch models. |
There was a problem hiding this comment.
Sorry, I don't understand why the decorator would short-circuit here if we call mlflow.pytorch.autolog() with disable=True. Looking at XGBoost flavour, it seems they are able to call mlflow.sklearn._autolog() within mlflow.xgboost.autolog(). Is it because mlflow.sklearn._autolog() is not wrapped but mlflow.pytorch.autolog() is?
There was a problem hiding this comment.
Sorry, was a while since I tested this, so I might not have the clearest answer, but as far as I remember during testing when pytorch.autolog was wrapped the disable=True would not reach it, leading to overwriting the old run with new run's logs. This new approach definitely worked properly for me, but if I missed something it could prove nice to simplify with calling them within each other normally.
darts/utils/mlflow.py
Outdated
|
|
||
| classes_to_patch = [ForecastingModel] | ||
|
|
||
| for subclass in get_all_subclasses(ForecastingModel): |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| def log_model( | ||
| model, | ||
| artifact_path: str | None = None, |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
darts/utils/mlflow.py
Outdated
| log_models: bool = True, | ||
| log_params: bool = True, | ||
| log_metrics: bool = True, | ||
| inject_per_epoch_callbacks: bool = True, |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
darts/utils/mlflow.py
Outdated
| A list of pip requirement strings. | ||
| """ | ||
| reqs = [_get_pinned_requirement("darts")] | ||
| if is_torch: |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
darts/utils/mlflow.py
Outdated
| if code_dir_subpath is not None: | ||
| darts_flavor_conf["code"] = code_dir_subpath | ||
|
|
||
| default_reqs = None if pip_requirements else get_default_pip_requirements(is_torch) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
darts/utils/mlflow.py
Outdated
| bool | ||
| True if the model is a TorchForecastingModel, False otherwise. | ||
| """ | ||
| try: |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
@mizeller @jakubchlapek @dennisbader Greetings! I've addressed most of the comments here, except for a few discussion points. I've left a Sincere apologies for suggesting post-fitting metrics in the first place! I didn't realise the complexity involved. My suggestion is to skip post-fitting metrics for now or settle for compromises such as non-terminated active runs (at the risk of cross-logging). Other than that, I am truly proud of what we have achieved here and will hand this over to @mizeller for runups and more great work. |
|
@mizeller @jakubchlapek @dennisbader Just thought of an easy way to track Since Logging metrics then becomes a lot easier. When patching What do you think? |
|
Hi @daidahao, it's been a minute haha, Michel will be continuing on this, but I've found some time to checkout the changes so far. Thanks a lot for the review and your updates, they look good and make sense to me :). Regarding the post-fitting metrics, I agree that the issue is more complicated than originally envisioned. I think the with mlflow.start_run(run_name="whatever"):
# metrics should log here nicelycould bring nice value and be useful. Important that we document the risk of crosslogging for multiple active runs though. Another issue that I'd like to focus on first to have in the merged PR would be the support for if apply_retrain:
# fit a new instance of the model
model = model.untrained_model()
model._fit_wrapper(
series=train_series_tf,
past_covariates=past_covariates_tf,
future_covariates=future_covariates_tf,
sample_weight=sample_weight_tf,
val_series=val_series_tf,
**fit_kwargs,
)This leads to the covariates not saving to |
|
Thank you for the response here. For metric logging, I agree with your suggestion on leaving it as is at some risks of cross-logging models as long as those risks are clearly documented. In the long term, I would prefer a more robust solution using either a
For the |
|
@daidahao I planned to work on the If you're available before, feel free:) I unfortunately didnt have any capacity the last few weeks - sorry about that! |
Checklist before merging this PR:
Addresses #2092 .
Summary
Provides a custom MLflow flavor for Darts on Darts' side. Supports autologging, logging, saving and loading of the models.
This PR focuses on the base MLflow integration, leaving serving of the models to be discussed in the future.
Included an example quickstart for the integration, however consider all of this a draft :)
Find example code in the .ipynb, however also providing a code snippet here as a quick reproducible example: