当保存和加载模型时,需要熟悉三个核心功能:
torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。
torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。
Python中对于模型数据的保存和加载操作都是引用Python内置的pickle包,使用pickle.dump()和pickle.load()方法。在Pytorch中也有同样功能的方法提供。
>>>torch.save(model,'model.pkl') #保存整个模型>>>model = torch.load('model.pkl') #加载整个模型>>>torch.save(alexnet.state_dict(),'params.pkl') #保存网格中的参数>>>alexnet.load_state_dict(torch.load('params.pkl')) #加载网格中的参数
在torchvision.models模块里,PyTorch提供了一些常用的模型:
可以使用torch.util.model_zoo来预加载它们,具体设置通过参数pretrained=True来实现。
>>>import torchvision.models as models>>>ResNet18 = models.ResNet18(pretrained=True)>>>alexnet = models.alexnet(pretrained=True)>>>squeezenet = models.squeezenet1_0(pretrained=True)>>>vgg16 = models.vgg16(pretrained=True)>>>densenet = models.densenet161(pretrained=True)>>>inception = models.inception_v3(pretrained=True)
加载这类预训练模型的过程中,还可以进行微处理。
>>>pretrained_dict = model_zoo.load_url(model_urls['resnet134'])>>>model_dict = model.state_dict()>>>pretrained_dict = {k:v for k,v in pretrained_dict.items()if k in model_dict}#将pretrained_dict里不属于model_dict的键剔除掉>>>model_dict.update(pretrained_dict) #更新现在有的model_dict>>>model.load_state_dict(model_dict)
参考
《PyTorch机器学习从入门到实战》