300字范文,内容丰富有趣,生活中的好帮手!
300字范文 > PyTorch 预训练模型 保存 读取和更新模型参数以及多 GPU 训练模型

PyTorch 预训练模型 保存 读取和更新模型参数以及多 GPU 训练模型

时间:2023-02-01 14:06:29

相关推荐

PyTorch 预训练模型 保存 读取和更新模型参数以及多 GPU 训练模型

加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动!

同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注极市平台公众号回复加群,立刻申请入群~

作者:没头脑

/p/75563856

来源:知乎,已获作者授权转载,禁止二次转载。

目录

PyTorch 预训练模型

保存模型参数

读取模型参数

冻结部分模型参数,进行 fine-tuning

模型训练与测试的设置

利用 torch.nn.DataParallel 进行多 GPU 训练

1. PyTorch 预训练模型

Pytorch 提供了许多 Pre-Trained Model on ImageNet,仅需调用 torchvision.models 即可,具体细节可查看官方文档。

往往我们需要对 Pre-Trained Model 进行相应的修改,以适应我们的任务。这种情况下,我们可以先输出 Pre-Trained Model 的结构,确定好对哪些层修改,或者添加哪些层,接着,再将其修改即可。

比如,我需要将 ResNet-50 的 Layer 3 后的所有层去掉,在分别连接十个分类器,分类器由 ResNet-50.layer4 和 AvgPool Layer 和 FC Layer 构成。这里就需要用到 torch.nn.ModuleList 了,比如::self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

代码中的 [nn.Linear(10, 10) for i in range(10)] 是一个python列表,必须要把它转换成一个Module Llist列表才可以被 PyTorch 使用,否则在运行的时候会报错:RuntimeError: Input type (CUDAFloatTensor) and weight type (CPUFloatTensor) should be the same

2. 保存模型参数

PyTorch 中保存模型的方式有许多种:

# 保存整个网络torch.save(model, PATH)# 保存网络中的参数, 速度快,占空间少torch.save(model.state_dict(),PATH)# 选择保存网络中的一部分参数或者额外保存其余的参数torch.save({"state_dict": model.state_dict(), "fc_dict":model.fc.state_dict(),"optimizer": optimizer.state_dict(),"alpha": loss.alpha, "gamma": loss.gamma},PATH)

3. 读取模型参数

同样的,PyTorch 中读取模型参数的方式也有许多种:

# 读取整个网络model = torch.load(PATH)# 读取网络中的参数model.load_state_dict(torch.load(PATH))# 读取网络中的部分参数(本质其实就是更新字典)pretrained_dict = torch.load(pretrained_model_weight)model_dict = model.state_dict()pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}model_dict.update(pretrained_dict)

4. 冻结部分模型参数,进行 fine-tuning

加载完 Pre-Trained Model 后,我们需要对其进行 Finetune。但是在此之前,我们往往需要冻结一部分的模型参数:

# 第一种方式for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 Falsep.requires_grad = Falsefor p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 Truep.requires_grad = True# 将需要 fine-tuning 的参数放入optimizer 中optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)# 第二种方式optim_param = []for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 Falsep.requires_grad = Falsefor p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 Truep.requires_grad = Trueoptim_param.append(p)optimizer.SGD(optim_param, lr=1e-3) # 将需要 fine-tuning 的参数放入optimizer 中

5. 模型训练与测试的设置

训练时,应调用 model.train() ;测试时,应调用 model.eval(),以及 with torch.no_grad():

model.train():使 model 变成训练模式,此时 dropout 和 batch normalization 的操作在训练起到防止网络过拟合的问题。

model.eval():PyTorch会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。不然的话,一旦测试集的 Batch Size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。

with torch.no_grad():PyTorch 将不再计算梯度,这将使得模型 forward 的时候,显存的需求大幅减少,速度大幅提高。

注意:若模型中具有 Batch Normalization 操作,想固定该操作进行训练时,需调用对应的 module 的 eval() 函数。这是因为 BN Module 除了参数以外,还会对输入的数据进行统计,若不调用 eval(),统计量将发生改变!具体代码可以这样写:

for module in model.modules():module.eval()

在其他地方看到的解释:

model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval model instead of training mode.

torch.no_grad() impacts the autograd engine and deactivate it. It will reduce memory usage and speed up computations but you won’t be able to backprop (which you don’t want in an eval script).

6. 利用 torch.nn.DataParallel 进行多 GPU 训练

import torchimport torch.nn as nnimport torchvision.models as models# 生成模型# 利用 torch.nn.DataParallel 进行载入模型,默认使用所有GPU(可以用 CUDA_VISIBLE_DEVICES 设置所使用的 GPU)model = nn.DataParallel(models.resnet18())# 冻结参数for param in model.module.layer4.parameters():param.requires_grad = Falseparam_optim = filter(lambda p:p.requires_grad, model.parameters())# 设置测试模式model.module.layer4.eval()# 保存模型参数(读取所保存模型参数后,再进行并行化操作,否则无法利用之前的代码进行读取)torch.save(model.module.state_dict(),"./CheckPoint.pkl")

-完-

*延伸阅读

PyTorch语义分割开源库semsegPyTorch

福利,PyTorch中文版官方教程来了

添加极市小助手微信(ID : cv-mart),备注:研究方向-姓名-学校/公司-城市(如:目标检测-小极-北大-深圳),即可申请加入目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群,更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~

△长按添加极市小助手

△长按关注极市平台

觉得有用麻烦给个在看啦~

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。