300字范文,内容丰富有趣,生活中的好帮手!
300字范文 > 加载dict_PyTorch 7.保存和加载pytorch模型的两种方法

加载dict_PyTorch 7.保存和加载pytorch模型的两种方法

时间:2019-05-05 06:33:37

相关推荐

加载dict_PyTorch 7.保存和加载pytorch模型的两种方法

众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如:

x1 = {"d":"ddf","dd":'fdsf'}torch.save(x1, 'a1.pt')x2 = ["ddf",'fdsf']torch.save(x2, 'a2.pt')x3 = 1torch.save(x3, 'a3.pt')x4 = torch.ones(3)torch.save(x4, 'a4.pt')

读取的时候也是一样:

x5 = torch.load('a1.pt')x6 = torch.load('a2.pt')x7 = torch.load('a3.pt')x8 = torch.load('a4.pt')

这种非常简单粗暴,直接把整个对象扔进磁盘文件里保存,所以对于我们训练好的模型来说,因为训练好的模型也是一个对象,所以我们也可以使用这个方法把训练好的模型对象直接扔进去。但是这样有一个问题,就是模型对象开销比较大,比如最近包含1350亿个参数的那个有名的神经网络模型,如果把它保存到磁盘里面没有百八十T是保存不下的。所以我们是不是可以仅仅保存模型里面的关键数据呢?

答案是,可以!

因为决定一个模型是什么样有两方面的因素,一个是模型的结构是什么,另一个是模型的参数是什么,这两个定了,这个模型也就确定了。模型的结构在我们初始化模型对象的时候就定了,比如对于任意一个模型类,我们初始化它的两个对象,这两个对象代表的模型的结构肯定是一样的,区别就在于它们的参数不一样。所以我们保存模型的关键就是保存模型的参数,而模型的结构每次用的时候新建一个对象就好了,然后从磁盘里把模型的参数读取出来赋给这个对象。是不是超级简单?

那我们怎么拿到模型的参数呢?巧了!

模型的state_dict()函数就是返回模型的所有参数的(这个函数是nn.Module的,所以所有继承了nn.Module的模型类都有这个函数),比如:

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()net.state_dict()

输出:

OrderedDict([('hidden.weight',tensor([[-0.4195, 0.2609, 0.4325],[-0.4031, 0.2078, 0.2077]])),('hidden.bias', tensor([ 0.0755, -0.1408])),('output.weight', tensor([[0.2473, 0.6614]])),('output.bias', tensor([0.6191]))])

有的同学可能注意到了,self.act层的参数没有包含进来!

大哥,self.act层没有参数好吗(捂脸)

还有的同学可能想问,那有的层有参数、有的层没有参数,那万一加载的时候把某个参数给错了层怎么办?

完全不会!注意看,state_dict()返回的是一个字典,每一个张量都对应的有层的名字,清清楚楚,绝对没有问题。

那这样就简单了,举个例子看一下:

X = torch.randn(2, 3)Y = net(X) # 这个net就是上面创建的那个对象,我们把它的参数保存起来,然后新建一个net2,然后把保存的这些参数加载进net2,这样我们把X输入net2得到的Y2应该与Y是相等的PATH = "./net.pt"torch.save(net.state_dict(), PATH)net2 = MLP()net2.load_state_dict(torch.load(PATH))Y2 = net2(X)Y2 == Y

输出:

tensor([[1],[1]], dtype=torch.uint8)

输出的张量,代表Y2 == Y比较结果为true,也就是说是一样的,验证了我们的猜想(上面代码注释中的那个猜想)。

好了,以上就是pytorch保存和加载模型的两种方法,是不是非常简单?

阳阳:保存和加载pytorch模型的两种方法,选哪个好?​

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。