When the n_center variable has a high value, some patterns may not belonging to any of the n_center cluser,.
In this case the code
for c in range(n_centers): cost += torch.norm(x[l2_cls == c] - tmp_center[c], p=2, dim=1).mean()
could generate nan as result