众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如:
x1 = {"d":"ddf","dd":'fdsf'}torch.save(x1, 'a1.pt')x2 = ["ddf",'fdsf']torch.save(x2, 'a2.pt')x3 = 1torch.save(x3, 'a3.pt')x4 = torch.ones(3)torch.save(x4, 'a4.pt')
读取的时候也是一样:
x5 = torch.load('a1.pt')x6 = torch.load('a2.pt')x7 = torch.load('a3.pt')x8 = torch.load('a4.pt')
这种非常简单粗暴,直接把整个对象扔进磁盘文件里保存,所以对于我们训练好的模型来说,因为训练好的模型也是一个对象,所以我们也可以使用这个方法把训练好的模型对象直接扔进去。但是这样有一个问题,就是模型对象开销比较大,比如最近包含1350亿个参数的那个有名的神经网络模型,如果把它保存到磁盘里面没有百八十T是保存不下的。所以我们是不是可以仅仅保存模型里面的关键数据呢?
答案是,可以!
因为决定一个模型是什么样有两方面的因素,一个是模型的结构是什么,另一个是模型的参数是什么,这两个定了,这个模型也就确定了。模型的结构在我们初始化模型对象的时候就定了,比如对于任意一个模型类,我们初始化它的两个对象,这两个对象代表的模型的结构肯定是一样的,区别就在于它们的参数不一样。所以我们保存模型的关键就是保存模型的参数,而模型的结构每次用的时候新建一个对象就好了,然后从磁盘里把模型的参数读取出来赋给这个对象。是不是超级简单?
那我们怎么拿到模型的参数呢?巧了!
模型的state_dict()函数就是返回模型的所有参数的(这个函数是nn.Module的,所以所有继承了nn.Module的模型类都有这个函数),比如:
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()net.state_dict()
输出:
OrderedDict([('hidden.weight',tensor([[-0.4195, 0.2609, 0.4325],[-0.4031, 0.2078, 0.2077]])),('hidden.bias', tensor([ 0.0755, -0.1408])),('output.weight', tensor([[0.2473, 0.6614]])),('output.bias', tensor([0.6191]))])
有的同学可能注意到了,self.act层的参数没有包含进来!
大哥,self.act层没有参数好吗(捂脸)
还有的同学可能想问,那有的层有参数、有的层没有参数,那万一加载的时候把某个参数给错了层怎么办?
完全不会!注意看,state_dict()返回的是一个字典,每一个张量都对应的有层的名字,清清楚楚,绝对没有问题。
那这样就简单了,举个例子看一下:
X = torch.randn(2, 3)Y = net(X) # 这个net就是上面创建的那个对象,我们把它的参数保存起来,然后新建一个net2,然后把保存的这些参数加载进net2,这样我们把X输入net2得到的Y2应该与Y是相等的PATH = "./net.pt"torch.save(net.state_dict(), PATH)net2 = MLP()net2.load_state_dict(torch.load(PATH))Y2 = net2(X)Y2 == Y
输出:
tensor([[1],[1]], dtype=torch.uint8)
输出的张量,代表Y2 == Y比较结果为true,也就是说是一样的,验证了我们的猜想(上面代码注释中的那个猜想)。
好了,以上就是pytorch保存和加载模型的两种方法,是不是非常简单?
阳阳:保存和加载pytorch模型的两种方法,选哪个好?