Skip to content

Lightning 2.x support #149

@maxwelltsai

Description

@maxwelltsai

Hi @cweniger et al.,

I notice that currently swyft is based on pytorch-lightning==1.9.5. Do you have any plans to upgrade this legacy component to a new version, e.g.,lightning==2.4.x?

A bit of background: In a collaboration with DAMTP (Cambridge) I am currently porting the swyft library to Intel GPU, because their supercomputer "Dawn" is powered by Intel GPUs. Together with my colleagues we have made a version of lightning that supports Intel GPUs, but it is based on lightning 2.x. Therefore, I made some changes in the swyft code to bump up lightning to 2.4, but it seems that swyft relies on an API that is no longer available in 2.x. The following error occurs when I try to do trainer.infer(network, obs, prior_samples):

Traceback (most recent call last):
  File "/nfs/site/home/xucai/Works/swyft/tests/truncation.py", line 78, in <module>
    predictions, bounds, samples = round(obs, bounds = bounds)
  File "/nfs/site/home/xucai/Works/swyft/tests/truncation.py", line 68, in round
    predictions = trainer.infer(network, obs, prior_samples)
  File "/nfs/site/home/xucai/Works/swyft/swyft/lightning/core.py", line 318, in infer
    ratio_batches = self.predict(model, dl)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 858, in predict
    return call._call_and_handle_interrupt(
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 897, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1020, in _run_stage
    return self.predict_loop.run()
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/prediction_loop.py", line 107, in run
    self.reset()
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/prediction_loop.py", line 176, in reset
    raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.')
ValueError: `trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.

How do we get rid of this issue? Do we need to use CombinedLoader in non-sequential mode?

If it helps, I can submit a PR with the modification that bumps up the lightning version to 2.4.

Thanks,
Maxwell

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions