-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
最近用了一下DBN在CUDA上面跑代码,发现一个运行的小问题
dbn.py 里 52行
s = (torch.rand(p.size())< p).float().to(self.dvc)
会报错
Traceback (most recent call last):
File "mnist_cls.py", line 55, in
model.run(e=3, pre_e=3)
File "../core/epoch.py", line 98, in run
self.pre_batch_training(pre_e, b)
File "../core/pre_module.py", line 60, in pre_batch_training
module.batch_training(i)
File "../model/dbn.py", line 99, in batch_training
v0,h0,vk,hk = self.forward(data)
File "../model/dbn.py", line 65, in forward
ph0, h0 = self.transfrom(v0,'v2h')
File "../model/dbn.py", line 52, in transfrom
s = (torch.rand(p.size())< p).float().to(self.dvc)
RuntimeError: expected device cpu but got device cuda:0
修改为s = (torch.rand(p.size())< p.cpu()).float().to(self.dvc),正常运行
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels