Skip to content

一个代码中的小问题 #2

@lyq998

Description

@lyq998

最近用了一下DBN在CUDA上面跑代码,发现一个运行的小问题
dbn.py 里 52行
s = (torch.rand(p.size())< p).float().to(self.dvc)
会报错
Traceback (most recent call last):
File "mnist_cls.py", line 55, in
model.run(e=3, pre_e=3)
File "../core/epoch.py", line 98, in run
self.pre_batch_training(pre_e, b)
File "../core/pre_module.py", line 60, in pre_batch_training
module.batch_training(i)
File "../model/dbn.py", line 99, in batch_training
v0,h0,vk,hk = self.forward(data)
File "../model/dbn.py", line 65, in forward
ph0, h0 = self.transfrom(v0,'v2h')
File "../model/dbn.py", line 52, in transfrom
s = (torch.rand(p.size())< p).float().to(self.dvc)
RuntimeError: expected device cpu but got device cuda:0

修改为s = (torch.rand(p.size())< p.cpu()).float().to(self.dvc),正常运行

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions