Googolxx/STF

请问加载模型时遇到的问题

Closed this issue · 6 comments

代码如

device = "cuda" if torch.cuda.is_available() else "cpu"
net = WACNN()
model = torch.load("ckpt/cnn_0036_best.pth.tar")
net.load_state_dict(model["state_dict"])
net.eval()

会引发错误
5@~09K$J@@G_RN_BL9(T62Y

而训练指令在使用chechpoint如
CUDA_VISIBLE_DEVICES=0,1,2 python train.py -d suim_data -e 1000 --batch-size 64 --save --save_path ckpt/cnn_0036.pth.tar --checkpoint ckpt/cnn_0036.pth.tar -m cnn --cuda --lambda 0.0035
会直接无效
SBJ7V%F}2NFP59UC$F$Z 9G

已解决,抱歉是我没有注意到并行训练的问题。

已解决,抱歉是我没有注意到并行训练的问题。

您好,可以详细说一说为什么吗,我用两张显卡训练的,也遇到了这样的问题

已解决,抱歉是我没有注意到并行训练的问题。

您好,可以详细说一说为什么吗,我用两张显卡训练的,也遇到了这样的问题

model被from torch.nn import DataParallel或者from torch.nn.parallel import DistributedDataParallel包围住后,会在这个变量名后面加上module。加载权重时需要去掉变量名中的module,或者用 DataParallel | DistributedDataParallel包围model

已解决,抱歉是我没有注意到并行训练的问题。

您好,可以详细说一说为什么吗,我用两张显卡训练的,也遇到了这样的问题

或者使用作者在compressai.utils.eval_model编写好的代码,我使用作者的代码没有发现问题。

已解决,抱歉是我没有注意到并行训练的问题。

您好,可以详细说一说为什么吗,我用两张显卡训练的,也遇到了这样的问题

model被from torch.nn import DataParallel或者from torch.nn.parallel import DistributedDataParallel包围住后,会在这个变量名后面加上module。加载权重时需要去掉变量名中的module,或者用 DataParallel | DistributedDataParallel包围model

好的,谢谢您的指点!

已解决,抱歉是我没有注意到并行训练的问题。

您好,可以详细说一说为什么吗,我用两张显卡训练的,也遇到了这样的问题

model被from torch.nn import DataParallel或者from torch.nn.parallel import DistributedDataParallel包围住后,会在这个变量名后面加上module。加载权重时需要去掉变量名中的module,或者用 DataParallel | DistributedDataParallel包围model

好的,谢谢您的指点!

不必客气