-
Notifications
You must be signed in to change notification settings - Fork 13
RuntimeError (type mismatch) when double-precision GMM training #59
Description
Hi, thank you for this wonderful repo.
Abst
I'm going to train GMM with double-precision by passing precision=64 to trainer_params, and I got a RuntimeError on initialization.
An error reports type-mismatch, and I found I could resolve this error by adding a single line.
It seems to be a bug, so I'm reporting it.
Environment
My environments are below.
- pycave: 3.2.1
- pytorch-lightning: 1.9.5
- torch: 1.11.0+cu113
- torchmetrics: 0.11.4
Error message
Traceback (most recent call last):
...
File "****.py", line 97, in train_gmm
gmm = gmm.fit(X)
File "/usr/local/lib/python3.8/dist-packages/pycave/bayes/gmm/estimator.py", line 153, in fit
estimator = KMeans(
File "/usr/local/lib/python3.8/dist-packages/pycave/clustering/kmeans/estimator.py", line 129, in fit
self.trainer(max_epochs=num_epochs).fit(module, loader)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
call._call_and_handle_interrupt(
...
self.distance_sampler.update(batch, shortest_distances)
File "/usr/local/lib/python3.8/dist-packages/torchmetrics/metric.py", line 400, in wrapped_func
raise err
File "/usr/local/lib/python3.8/dist-packages/torchmetrics/metric.py", line 390, in wrapped_func
update(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pycave/clustering/kmeans/metrics.py", line 151, in update
self.choices.masked_scatter_(
RuntimeError: masked_scatter: expected self and source to have same dtypes but gotFloat and Double
In short, it says mismatch of type, between self.choices(maybe float?) and data(double).
As I saw an implemention, DistanceSampler.choices on pycave/clustering/kmeans/metrics.py is just a torch.Tensor.
So the type of elements needs to be converted to data.dtype.
Resolving the error
I added self.choices = self.choices.to(data.dtype) on line 151, and it became working!
# Then, we sample from the data `num_choices` times and replace if needed
choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True)
+ self.choices = self.choices.to(data.dtype)
self.choices.masked_scatter_(
use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]]
)I would appreciate if you could deal with this issue.
Best regards,
tenk-9