iis服务器助手广告广告
返回顶部
首页 > 资讯 > 精选 >pytorch怎么实现加载保存查看checkpoint文件
  • 788
分享到

pytorch怎么实现加载保存查看checkpoint文件

2023-07-02 17:07:19 788人浏览 安东尼
摘要

这篇“PyTorch怎么实现加载保存查看checkpoint文件”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“pytorch

这篇“PyTorch怎么实现加载保存查看checkpoint文件”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“pytorch怎么实现加载保存查看checkpoint文件”文章吧。

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)# 保存torch.save(model.state_dict(), PATH)# 加载model.load_state_dict(torch.load(PATH))# 测试时不启用 BatchNORMalization 和 Dropoutmodel.eval()
# 方式二:保存加载整个模型# 保存torch.save(model, PATH)# 加载model = torch.load(PATH)model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型# 保存torch.save({            'epoch': epoch,            'model_state_dict': model.state_dict(),            ...            }, PATH)# 加载checkpoint = torch.load(PATH)start_epoch=checkpoint['epoch']model.load_state_dict(checkpoint['model_state_dict'])# 测试时model.eval()# 或者训练时model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载# 保存torch.save(model.state_dict(), PATH)# 加载device = torch.device('cpu')model.load_state_dict(torch.load(PATH, map_location=device))# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载# 保存torch.save(model.state_dict(), PATH)# 加载device = torch.device("cuda")model.load_state_dict(torch.load(PATH))model.to(device)
# CPU上保存,GPU上加载# 保存torch.save(model.state_dict(), PATH)# 加载device = torch.device("cuda")# 选择希望使用的GPUmodel.load_state_dict(torch.load(PATH, map_location="cuda:0"))  model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dictprint("Model's state_dict:")for param_tensor in model.state_dict():    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉# 创建一个不包含`module.`的新OrderedDictfrom collections import OrderedDictnew_state_dict = OrderedDict()for k, v in state_dict.items():    name = k[7:] # 去掉 `module.`    new_state_dict[name] = v# 加载参数model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存moduletorch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' import torch as torchimport torchvision as tvimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Fimport torchvision.transforms as transformsimport os  # 参数声明batch_size = 32epochs = 10WORKERS = 0  # dataloder线程数test_flag = False  # 测试标志,True时加载保存好的模型进行测试ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径# 加载MNIST数据集transform = tv.transforms.Compose([    transforms.ToTensor(),    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform) train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)  # 构造模型class Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)        self.pool = nn.MaxPool2d(2, 2)        self.fc1 = nn.Linear(256 * 8 * 8, 1024)        self.fc2 = nn.Linear(1024, 256)        self.fc3 = nn.Linear(256, 10)     def forward(self, x):        x = F.relu(self.conv1(x))        x = self.pool(F.relu(self.conv2(x)))        x = F.relu(self.conv3(x))        x = self.pool(F.relu(self.conv4(x)))        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])        x = F.relu(self.fc1(x))        x = F.relu(self.fc2(x))        x = self.fc3(x)        return x  model = Net().cpu() criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01)  # 模型训练def train(model, train_loader, epoch):    model.train()    train_loss = 0    for i, data in enumerate(train_loader, 0):        x, y = data        x = x.cpu()        y = y.cpu()         optimizer.zero_grad()        y_hat = model(x)        loss = criterion(y_hat, y)        loss.backward()        optimizer.step()        train_loss += loss        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))     loss_mean = train_loss / (i + 1)    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))  # 模型测试def test(model, test_loader):    model.eval()    test_loss = 0    correct = 0    with torch.no_grad():        for i, data in enumerate(test_loader, 0):            x, y = data            x = x.cpu()            y = y.cpu()             optimizer.zero_grad()            y_hat = model(x)            test_loss += criterion(y_hat, y).item()            pred = y_hat.max(1, keepdim=True)[1]            correct += pred.eq(y.view_as(pred)).sum().item()        test_loss /= (i + 1)        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(            test_loss, correct, len(test_data), 100. * correct / len(test_data)))  def main():    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤    if test_flag:        # 加载保存的模型直接进行测试机验证        checkpoint = torch.load(log_dir)        model.load_state_dict(checkpoint['model'])        optimizer.load_state_dict(checkpoint['optimizer'])        start_epoch = checkpoint['epoch']        test(model, test_load)        return     # 如果有保存的模型,则加载模型,并在其基础上继续训练    if os.path.exists(log_dir):        checkpoint = torch.load(log_dir)        model.load_state_dict(checkpoint['model'])        optimizer.load_state_dict(checkpoint['optimizer'])        start_epoch = checkpoint['epoch']        print('加载 epoch {} 成功!'.format(start_epoch))    else:        start_epoch = 0        print('无保存了的模型,将从头开始训练!')     for epoch in range(start_epoch+1, epochs):        train(model, train_load, epoch)        test(model, test_load)        # 保存模型        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}        torch.save(state, log_dir) if __name__ == '__main__':    main()

以上就是关于“pytorch怎么实现加载保存查看checkpoint文件”这篇文章的内容,相信大家都有了一定的了解,希望小编分享的内容对大家有帮助,若想了解更多相关的知识内容,请关注编程网精选频道。

--结束END--

本文标题: pytorch怎么实现加载保存查看checkpoint文件

本文链接: https://www.lsjlt.com/news/343656.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • pytorch实现加载保存查看checkpoint文件
    目录1.保存加载checkpoint文件2.跨gpu和cpu3.查看checkpoint文件内容4.常见问题pytorch保存和加载文件的方法,从断点处继续训练1.保存加载check...
    99+
    2024-04-02
  • pytorch怎么实现加载保存查看checkpoint文件
    这篇“pytorch怎么实现加载保存查看checkpoint文件”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“pytorch...
    99+
    2023-07-02
  • pytorch模型保存与加载问题怎么解决
    这篇“pytorch模型保存与加载问题怎么解决”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“pytorch模型保存与加载问题...
    99+
    2023-07-04
  • php如何实现保存下载文件
    这篇文章主要介绍“php如何实现保存下载文件”,在日常操作中,相信很多人在php如何实现保存下载文件问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”php如何实现保存下载文件”的疑惑有所帮助!接下来,请跟着小编...
    99+
    2023-06-20
  • win10下载文件怎么查看
    本篇内容主要讲解“win10下载文件怎么查看”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“win10下载文件怎么查看”吧!第一种:如果是默认位置请看 点击此电脑。 选择下载,就可以看见下载的文件...
    99+
    2023-07-01
  • pycharm怎么查找保存的文件
    您可以使用 pycharm 的文件浏览器、快捷键、项目视图、文件列表或搜索栏查找已保存的文件。 如何使用 PyCharm 查找已保存的文件 PyCharm 是一款功能强大的 Pytho...
    99+
    2024-04-18
    linux python macos pycharm
  • 微信小程序怎么实现自动保存下载文件名
    本篇内容介绍了“微信小程序怎么实现自动保存下载文件名”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!问题的提出小程序使用wx.playVoic...
    99+
    2023-06-19
  • mybatis xml文件热加载怎么实现
    这篇文章主要介绍了mybatis xml文件热加载怎么实现的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇mybatis xml文件热加载怎么实现文章都会有所收获,下面我们一起来看看吧。一、x...
    99+
    2023-07-05
  • Python怎么实现批量文件分类保存
    今天小编给大家分享一下Python怎么实现批量文件分类保存的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。我们这里以这两百多个...
    99+
    2023-06-30
  • redis存放文件路径怎么查看
    要查看Redis中存放的文件路径,可以通过以下步骤进行:1. 进入Redis的命令行界面。可以通过运行redis-cli命令来打开R...
    99+
    2023-08-24
    redis
  • Linux怎么查看文件是否存在
    小编给大家分享一下Linux怎么查看文件是否存在,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!1、使用ls命令进行判断ls -l /home/...
    99+
    2023-06-28
  • git怎么查看暂存区的文件
    要查看暂存区的文件,可以使用以下命令:1. 使用git status命令来查看暂存区的文件状态。暂存区内的文件会在"Changes ...
    99+
    2023-10-18
    git
  • sql查询结果怎么保存到文件
    在 SQL 查询中,可以使用以下方法将查询结果保存到文件: 使用 SQL 查询语句的结果导出功能。不同的数据库管理系统(DBMS)...
    99+
    2024-04-09
    sql
  • 怎么使用JavaScript实现保存文件夹功能
    这篇“怎么使用JavaScript实现保存文件夹功能”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“怎么使用JavaScrip...
    99+
    2023-07-06
  • Java实现文件上传下载以及查看功能
    目录项目的目录结构代码IOUtils.javaDownServlet.javaUploadHandleServlet.javaweb.xmlupload.jspdown.jsp运行效...
    99+
    2024-04-02
  • 怎么查看mysql文件储存在哪里
    这篇文章主要介绍怎么查看mysql文件储存在哪里,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!mysql文件储存在mysql安装目录的data文件夹中;查看文件储存路径的方法:1、打...
    99+
    2024-04-02
  • Python怎么实现从文件中加载数据
    这篇文章主要介绍“Python怎么实现从文件中加载数据”,在日常操作中,相信很多人在Python怎么实现从文件中加载数据问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Python怎么实现从文件中加载数据”的疑...
    99+
    2023-06-30
  • Java文件缓存该怎么实现?看看LeetCode题解吧!
    在现代计算机科学中,缓存是一个非常重要的概念。在我们的日常生活中,我们经常使用缓存来提高我们的应用程序的性能和响应速度。在Java编程中,文件缓存也是一个非常重要的概念,因为它可以帮助我们提高Java程序的性能。在本文中,我们将讨论Jav...
    99+
    2023-07-25
    文件 缓存 leetcode
  • java中怎么下载文件流保存到本地
    在Java中,可以使用`InputStream`和`OutputStream`来下载文件流并保存到本地。下面的代码演示了如何使用`U...
    99+
    2023-09-05
    java
  • Windows IE浏览器缓存文件怎么查看
    要查看Windows IE浏览器的缓存文件,可以按照以下步骤进行操作:1. 打开Internet Explorer浏览器。2. 点击...
    99+
    2023-10-20
    Windows
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作