300字范文,内容丰富有趣,生活中的好帮手!
300字范文 > pytorch 保存 加载模型

pytorch 保存 加载模型

时间:2023-12-20 12:13:37

相关推荐

pytorch 保存 加载模型

一般保存为.pt格式,保存模型使用:

torch.save(model, '保存位置')

加载模型使用:

model_load = torch.load('加载模型的位置')

完整代码

import torchimport torch.nn as nnclass LinearRegressionModel(nn.Module):def __init__(self, input_shape, output_shape):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_shape, output_shape)def forward(self, x):out = self.linear(x)return outif __name__ == '__main__':model = LinearRegressionModel(10, 1)torch.save(model, 'my_linear_model.pt')model_load = torch.load('my_linear_model.pt')

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。