Skip to content

MLflow for Darts implementation#3022

Draft
jakubchlapek wants to merge 56 commits intounit8co:masterfrom
jakubchlapek:feat/mlflow-base
Draft

MLflow for Darts implementation#3022
jakubchlapek wants to merge 56 commits intounit8co:masterfrom
jakubchlapek:feat/mlflow-base

Conversation

@jakubchlapek
Copy link
Copy Markdown
Collaborator

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Addresses #2092 .

Summary

Provides a custom MLflow flavor for Darts on Darts' side. Supports autologging, logging, saving and loading of the models.
This PR focuses on the base MLflow integration, leaving serving of the models to be discussed in the future.

Included an example quickstart for the integration, however consider all of this a draft :)
Find example code in the .ipynb, however also providing a code snippet here as a quick reproducible example:

import mlflow
import tempfile
import os
from darts.metrics.metrics import smape
from darts.utils.mlflow import load_model, autolog
from darts.models import NBEATSModel, LinearRegressionModel
from darts.datasets import AirPassengersDataset
from torchmetrics import MeanAbsoluteError

# temp file setup
tmpdir = tempfile.mkdtemp()
mlflow_db = os.path.join(tmpdir, "mlflow.db")
mlflow.set_tracking_uri(f"sqlite:///{mlflow_db}")
mlflow.set_experiment("darts-forecasting")

train, val = AirPassengersDataset().load().astype("float32").split_before(0.7)

# autologging - patches .fit() on all ForecastingModel subclasses.
# for PyTorch-based models, inject_per_epoch_callbacks injects a Lightning callback
# that logs train/val loss or/and  user-specified torch metrics at the end of each epoch automatically.
autolog(
    log_models=True,
    log_params=True,
    log_training_metrics=True,
    log_validation_metrics=True,   # requires val_series in .fit()
    inject_per_epoch_callbacks=True, 
    extra_metrics=[smape],         # optional extra darts metric functions
)

with mlflow.start_run(run_name="nbeats") as run:
    model = NBEATSModel(
        input_chunk_length=24, 
        output_chunk_length=12,
        torch_metrics=MeanAbsoluteError())
    # val_series is forwarded to Lightning's val_dataloaders;
    # autolog captures per-epoch val metrics via the injected callback
    model.fit(train, val_series=val, epochs=10)
    run_id = run.info.run_id


# regression/sklearn models work identically
with mlflow.start_run(run_name="linreg"):
    model = LinearRegressionModel(lags=12)
    model.fit(train)  # logs params + in-sample metrics

# load back from MLflow
loaded = load_model(f"runs:/{run_id}/model")
preds = loaded.predict(12, series=train) # need to specify series as we save with clean=True in save_model

# import shutil
# shutil.rmtree(tmpdir)

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jakubchlapek
Copy link
Copy Markdown
Collaborator Author

Hey @daidahao, adding this draft PR in the meantime so you and @dennisbader can have a look at what I have currently regarding the integration. There are still some decisions I am not too thrilled about and decisions to be made about the overall direction, but I'm happy to talk more about it during the meeting. Thanks for being so active for the library, really nice to be working together :)

daidahao added 3 commits March 5, 2026 09:34
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
try:
import pytorch_lightning as pl # noqa: F401

PL_AVAILABLE = True

This comment was marked as resolved.

pyproject.toml Outdated
]
notorch = [
"catboost>=1.0.6",
"catboost>=1.0.6,<=1.2.9",

This comment was marked as resolved.

"statsforecast>=1.4",
"xgboost>=2.1.4",
]
mlflow = ["mlflow>=2.0"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like a discussion on the new option here. My understanding is that users who would need the Darts-MLflow integration probably have MLFlow installed already and set up properly. For users who have not, MLflow itself has options for databricks, which some users might find useful. Could we instead direct users to MLflow official guide for installation?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, @MichaelVerdegaal raised a suggestion for mlflow>=3.0. I have not used MLflow 2.x before but I think the minimum version should be deliberated as well.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think directing the user to the official guide sounds good, and for devs it makes sense including the mlflow in the [optional] requirements. I'm also inclined to agree with @MichaelVerdegaal regarding the versioning. MLflow 3.0 introduced some breaking changes (i.e. the models no longer being saved with other artifacts but in the models/ directory and versioned), so I think it can solve a lot of future headaches for us to cap the version at 3.0 minimum. So if MLflow is not installed we would raise an error, directing to the official installation and notifying that we support from 3.0 onwards. lmk what you think

does not apply the `with_managed_run` wrapper to the specified
`patch_function`.
"""
# Enable/disable mlflow.pytorch.autolog for per-epoch metrics on torch models.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand why the decorator would short-circuit here if we call mlflow.pytorch.autolog() with disable=True. Looking at XGBoost flavour, it seems they are able to call mlflow.sklearn._autolog() within mlflow.xgboost.autolog(). Is it because mlflow.sklearn._autolog() is not wrapped but mlflow.pytorch.autolog() is?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, was a while since I tested this, so I might not have the clearest answer, but as far as I remember during testing when pytorch.autolog was wrapped the disable=True would not reach it, leading to overwriting the old run with new run's logs. This new approach definitely worked properly for me, but if I missed something it could prove nice to simplify with calling them within each other normally.


classes_to_patch = [ForecastingModel]

for subclass in get_all_subclasses(ForecastingModel):

This comment was marked as resolved.


def log_model(
model,
artifact_path: str | None = None,

This comment was marked as resolved.

log_models: bool = True,
log_params: bool = True,
log_metrics: bool = True,
inject_per_epoch_callbacks: bool = True,

This comment was marked as resolved.

A list of pip requirement strings.
"""
reqs = [_get_pinned_requirement("darts")]
if is_torch:

This comment was marked as resolved.

if code_dir_subpath is not None:
darts_flavor_conf["code"] = code_dir_subpath

default_reqs = None if pip_requirements else get_default_pip_requirements(is_torch)

This comment was marked as resolved.

bool
True if the model is a TorchForecastingModel, False otherwise.
"""
try:

This comment was marked as resolved.

daidahao added 7 commits March 5, 2026 09:47
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
daidahao and others added 13 commits March 5, 2026 14:52
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@daidahao
Copy link
Copy Markdown
Contributor

daidahao commented Mar 7, 2026

@mizeller @jakubchlapek @dennisbader

Greetings! I've addressed most of the comments here, except for a few discussion points.

I've left a TODO note on post-fitting metrics which, IMHO, are HARD to implement at this point due to how MLflow manages active runs in autolog context. In short, we would need to keep a mapping between MLflow run ids, fitted models, model predictions, and metrics, to ensure the metrics are logged under the right run id (see mlflow.sklearn).

Sincere apologies for suggesting post-fitting metrics in the first place! I didn't realise the complexity involved.

My suggestion is to skip post-fitting metrics for now or settle for compromises such as non-terminated active runs (at the risk of cross-logging).

Other than that, I am truly proud of what we have achieved here and will hand this over to @mizeller for runups and more great work.

@daidahao
Copy link
Copy Markdown
Contributor

@mizeller @jakubchlapek @dennisbader

Just thought of an easy way to track TimeSeries provenance without a mlflow.sklearn-style _AutologgingMetricsManager for metric auto-logging:

Since TimeSeries here is a custom container rather than generic numpy arrays in sklearn, we could write run_id into its metadata dict whenever they are generated from forecasts. There are different ways of doing that, either via patching predict(), historical_forecasts() etc., OR changing _build_forecast_series(), etc.

Logging metrics then becomes a lot easier. When patching darts.metrics, we only need to identify run_id in TimeSeries.metadata and log metrics to those runs accordingly.

What do you think?

@jakubchlapek
Copy link
Copy Markdown
Collaborator Author

Hi @daidahao, it's been a minute haha, Michel will be continuing on this, but I've found some time to checkout the changes so far. Thanks a lot for the review and your updates, they look good and make sense to me :).

Regarding the post-fitting metrics, I agree that the issue is more complicated than originally envisioned. I think the TimeSeries.metadata patching is a smart idea that would work, but I'm not a fan of mutating the TimeSeries. In my mind I'd like the objects to not be impacted by the logging (for example, that later we don't export the series with the run_id still attached). I would second leaving the solution as is i.e. logging to the current non-terminated active run, with the behavior documented well. I think the autolog even in the context block where it works correctly
e.g.

with mlflow.start_run(run_name="whatever"):
    # metrics should log here nicely

could bring nice value and be useful. Important that we document the risk of crosslogging for multiple active runs though.

Another issue that I'd like to focus on first to have in the merged PR would be the support for historical_forecasts(retrain=True # or int) on autologging.
In the current implementation, since we patch fit the hfc will start a run for each iteration, along with saving models and any artifacts. The two ideas that come to mind are to detect if the .fit was called from hfc, ignore any fits called from within the method, and get only the result, OR patch historical_forecasts directly to supress the autologging for the iterations however I haven't yet investigated deeper, so just brainstorming :).
Secondly, also stemming from the retrain issue, is that the current covariate saving methods are based on the model's past_covariate_series attribute. Since in historical_forecasts we currently train new internal models for each iteration and fit them, we also don't pass the correct flags to the final model, since it's not modfied.
code snippet from forecasting_model.py L1180-1190

if apply_retrain:
    # fit a new instance of the model
    model = model.untrained_model()
    model._fit_wrapper(
        series=train_series_tf,
        past_covariates=past_covariates_tf,
        future_covariates=future_covariates_tf,
        sample_weight=sample_weight_tf,
        val_series=val_series_tf,
        **fit_kwargs,
    )

This leads to the covariates not saving to covariates.json later. Not sure what the best course of action here is, but I think this could be easily solved by patching historical_forecasts and logging directly. I would prefer to avoid patching so many functions, but if it makes sense then it's fine.
Thanks a lot for all the work so far, I think we can soon have this done properly :)

@daidahao
Copy link
Copy Markdown
Contributor

@jakubchlapek

Thank you for the response here. For metric logging, I agree with your suggestion on leaving it as is at some risks of cross-logging models as long as those risks are clearly documented. In the long term, I would prefer a more robust solution using either a _AutologgingMetricsManager or metadata, while we continue exploring other options. As far as I could see, the key concerns of mutating TimeSeries would be the overheads, which can be kept to minimal by disabling copying, etc., while the run_id metadata is not overtly intrusive given the MLflow context. Is there anything else that I missed?

PS. Speaking of exporting, AFAIK, there is no official import/export functionality and file format for TimeSeries. I do wish there is a safe/easy way to do so without relying on pickle--Darts v0.34.0 last year broke the old TimeSeries in pickle format and we had to rerun all experiments at one point. :(

For the historical_forecasts() support, I also agree with your solutions here. However, if time is a concern here for @mizeller, I would not mind implementing the support in another PR and targeting a future release. So long as the current limitations are clearly documented and integration is marked as beta. It would be helpful to gather early feedback from Darts users, while basic integration (fit, predict, save, load, etc.) is covered.

@mizeller
Copy link
Copy Markdown
Contributor

mizeller commented Apr 2, 2026

@daidahao I planned to work on the historical_forecast support next week.

If you're available before, feel free:) I unfortunately didnt have any capacity the last few weeks - sorry about that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants