300字范文,内容丰富有趣,生活中的好帮手!
300字范文 > 保存和加载pytorch模型

保存和加载pytorch模型

时间:2019-04-13 15:09:08

相关推荐

保存和加载pytorch模型

当保存和加载模型时,需要熟悉三个核心功能:

torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。

torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。

torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。

Python中对于模型数据的保存和加载操作都是引用Python内置的pickle包,使用pickle.dump()和pickle.load()方法。在Pytorch中也有同样功能的方法提供。

>>>torch.save(model,'model.pkl') #保存整个模型>>>model = torch.load('model.pkl') #加载整个模型>>>torch.save(alexnet.state_dict(),'params.pkl') #保存网格中的参数>>>alexnet.load_state_dict(torch.load('params.pkl')) #加载网格中的参数

在torchvision.models模块里,PyTorch提供了一些常用的模型:

可以使用torch.util.model_zoo来预加载它们,具体设置通过参数pretrained=True来实现。

>>>import torchvision.models as models>>>ResNet18 = models.ResNet18(pretrained=True)>>>alexnet = models.alexnet(pretrained=True)>>>squeezenet = models.squeezenet1_0(pretrained=True)>>>vgg16 = models.vgg16(pretrained=True)>>>densenet = models.densenet161(pretrained=True)>>>inception = models.inception_v3(pretrained=True)

加载这类预训练模型的过程中,还可以进行微处理。

>>>pretrained_dict = model_zoo.load_url(model_urls['resnet134'])>>>model_dict = model.state_dict()>>>pretrained_dict = {k:v for k,v in pretrained_dict.items()if k in model_dict}#将pretrained_dict里不属于model_dict的键剔除掉>>>model_dict.update(pretrained_dict) #更新现在有的model_dict>>>model.load_state_dict(model_dict)

参考

《PyTorch机器学习从入门到实战》

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。