300字范文,内容丰富有趣,生活中的好帮手!
300字范文 > Pytorch——保存训练好的模型参数

Pytorch——保存训练好的模型参数

时间:2021-06-20 21:32:04

相关推荐

Pytorch——保存训练好的模型参数

文章目录

1.前言2.torch.save(保存模型)3.torch.load整个网络4.torch.load网络参数(只提取参数)5.调用三个函数

1.前言

训练好了一个模型, 我们当然想要保存它, 留到下次要用的时候直接提取直接用,下面我将来讲如何存储训练好的模型参数

2.torch.save(保存模型)

首先,先搭建一个神经网络

import torchfrom torch import nnimport matplotlib.pyplot as plttorch.manual_seed(11) # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # [100] -> [100,1]y = x.pow(2) + 0.5*torch.rand(x.size()) # y的形状与x一样def make_and_save_model():network = torch.nn.Sequential(torch.nn.Linear(1, 8),torch.nn.ReLU(),torch.nn.Linear(8, 1))optimizer = torch.optim.SGD(network.parameters(), lr=0.3) #优化器criterion = torch.nn.MSELoss()#损失函数# 训练for i in range(200):prediction = network(x)#数据放入模型后得到预测值loss = criterion(prediction, y) #计算预测值与真实值之间的误差optimizer.zero_grad() #清空梯度loss.backward() #误差反向传播optimizer.step()#更新参数torch.save(network, 'network.pth') # 保存整个网络torch.save(network.state_dict(), 'network_params.pth') # 只保存网络中的参数plt.figure(1, figsize = (10,3))plt.subplot(131)plt.title('network')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)plt.pause(1)

3.torch.load整个网络

这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.

def load_whole_model():network_whole = torch.load('network.pth')prediction = network_whole(x)plt.figure(1, figsize = (10,3))plt.subplot(132)plt.title('network_whole')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)plt.pause(1)

4.torch.load网络参数(只提取参数)

这种方式将会提取所有的参数, 然后再放到你的新建网络中

def load_only_params():network_params = torch.nn.Sequential(torch.nn.Linear(1, 8),torch.nn.ReLU(),torch.nn.Linear(8, 1))network_params.load_state_dict(torch.load('network_params.pth'))prediction = network_params(x)plt.figure(1, figsize = (10,3))plt.subplot(133)plt.title('network_params')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)

5.调用三个函数

会看到加载后的模型画出的图是一样的,说明模型的参数正确加载了。

make_and_save_model()load_whole_model()load_only_params()

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。