300字范文,内容丰富有趣,生活中的好帮手!
300字范文 > pytorch保存和加载文件的方法 从断点处继续训练

pytorch保存和加载文件的方法 从断点处继续训练

时间:2020-11-02 13:18:56

相关推荐

pytorch保存和加载文件的方法 从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法'''import torch as torchimport torchvision as tvimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Fimport torchvision.transforms as transformsimport os# 参数声明batch_size = 32epochs = 10WORKERS = 0 # dataloder线程数test_flag = False # 测试标志,True时加载保存好的模型进行测试ROOT = '/home/pxt/pytorch/cifar' # MNIST数据集保存路径log_dir = '/home/pxt/pytorch/logs/cifar_model.pth' # 模型保存路径# 加载MNIST数据集transform = pose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)# 构造模型class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3, padding=1)self.conv2 = nn.Conv2d(64, 128, 3, padding=1)self.conv3 = nn.Conv2d(128, 256, 3, padding=1)self.conv4 = nn.Conv2d(256, 256, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(256 * 8 * 8, 1024)self.fc2 = nn.Linear(1024, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool(F.relu(self.conv2(x)))x = F.relu(self.conv3(x))x = self.pool(F.relu(self.conv4(x)))x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = Net().cpu()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01)# 模型训练def train(model, train_loader, epoch):model.train()train_loss = 0for i, data in enumerate(train_loader, 0):x, y = datax = x.cpu()y = y.cpu()optimizer.zero_grad()y_hat = model(x)loss = criterion(y_hat, y)loss.backward()optimizer.step()train_loss += lossprint('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))loss_mean = train_loss / (i + 1)print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))# 模型测试def test(model, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for i, data in enumerate(test_loader, 0):x, y = datax = x.cpu()y = y.cpu()optimizer.zero_grad()y_hat = model(x)test_loss += criterion(y_hat, y).item()pred = y_hat.max(1, keepdim=True)[1]correct += pred.eq(y.view_as(pred)).sum().item()test_loss /= (i + 1)print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_data), 100. * correct / len(test_data)))def main():# 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤if test_flag:# 加载保存的模型直接进行测试机验证checkpoint = torch.load(log_dir)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']test(model, test_load)return# 如果有保存的模型,则加载模型,并在其基础上继续训练if os.path.exists(log_dir):checkpoint = torch.load(log_dir)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']print('加载 epoch {} 成功!'.format(start_epoch))else:start_epoch = 0print('无保存了的模型,将从头开始训练!')for epoch in range(start_epoch+1, epochs):train(model, train_load, epoch)test(model, test_load)# 保存模型state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}torch.save(state, log_dir)if __name__ == '__main__':main()

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。