iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >python中的Pytorch建模流程汇总
  • 302
分享到

python中的Pytorch建模流程汇总

2024-04-02 19:04:59 302人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录1导入库2设置初始值3导入并制作数据集4定义神经网络架构5定义训练流程6训练模型本节内容学习帮助大家梳理神经网络训练的架构。 一般我们训练神经网络有以下步骤: 导入库设置训练参数

本节内容学习帮助大家梳理神经网络训练的架构。

一般我们训练神经网络有以下步骤:

  • 导入库
  • 设置训练参数的初始值
  • 导入数据集并制作数据集
  • 定义神经网络架构
  • 定义训练流程
  • 训练模型

推荐文章:

python实现可视化大屏

分享4款 Python 自动数据分析神器

以下,我就将上述步骤使用代码进行注释讲解:

1 导入库

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader, DataLoader
import torchvision
import torchvision.transfORMs as transforms

2 设置初始值

# 学习率
lr = 0.15
# 优化算法参数
gamma = 0.8
# 每次小批次训练个数
bs = 128
# 整体数据循环次数
epochs = 10

3 导入并制作数据集

本次我们使用FashionMNIST图像数据集,每个图像是一个28*28的像素数组,共有10个衣物类别,比如连衣裙、运动鞋、包等。

注:初次运行下载需要等待较长时间。

# 导入数据集
mnist = torchvision.datasets.FashionMNIST(
    root = './Datastes'
    , train = True
    , download = True
    , transform = transforms.ToTensor())
    
# 制作数据集
batchdata = DataLoader(mnist
                       , batch_size = bs
                       , shuffle = True
                       , drop_last = False)

我们可以对数据进行检查:

for x, y in batchdata:
    print(x.shape)
    print(y.shape)
    break

# torch.Size([128, 1, 28, 28])
# torch.Size([128])

可以看到一个batch中有128个样本,每个样本的维度是1*28*28。

之后我们确定模型的输入维度与输出维度:

# 输入的维度
input_ = mnist.data[0].numel()
# 784

# 输出的维度
output_ = len(mnist.targets.unique())
# 10

4 定义神经网络架构

先使用一个128个神经元的全连接层,然后用relu激活函数,再将其结果映射到标签的维度,并使用softmax进行激活。

# 定义神经网络架构
class Model(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear1 = nn.Linear(in_features, 128, bias = True)
        self.output = nn.Linear(128, out_features, bias = True)
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        sigma1 = torch.relu(self.linear1(x))
        sigma2 = F.log_softmax(self.output(sigma1), dim = -1)
        return sigma2

5 定义训练流程

在实际应用中,我们一般会将训练模型部分封装成一个函数,而这个函数可以继续细分为以下几步:

  • 定义损失函数与优化器
  • 完成向前传播
  • 计算损失
  • 反向传播
  • 梯度更新
  • 梯度清零

在此六步核心操作的基础上,我们通常还需要对模型的训练进度、损失值与准确度进行监视。

注释代码如下:

# 封装训练模型的函数
def fit(net, batchdata, lr, gamma, epochs):
# 参数:模型架构、数据、学习率、优化算法参数、遍历数据次数

    # 5.1 定义损失函数
    criterion = nn.NLLLoss()
    # 5.1 定义优化算法
    opt = optim.SGD(net.parameters(), lr = lr, momentum = gamma)
    
    # 监视进度:循环之前,一个样本都没有看过
    samples = 0
    # 监视准确度:循环之前,预测正确的个数为0
    corrects = 0
    
    # 全数据训练几次
    for epoch in range(epochs):
        # 对每个batch进行训练
        for batch_idx, (x, y) in enumerate(batchdata):
            # 保险起见,将标签转为1维,与样本对齐
            y = y.view(x.shape[0])
            
            # 5.2 正向传播
            sigma = net.forward(x)
            # 5.3 计算损失
            loss = criterion(sigma, y)
            # 5.4 反向传播
            loss.backward()
            # 5.5 更新梯度
            opt.step()
            # 5.6 梯度清零
            opt.zero_grad()
            
            # 监视进度:每训练一个batch,模型见过的数据就会增加x.shape[0]
            samples += x.shape[0]
            
            # 求解准确度:全部判断正确的样本量/已经看过的总样本量
            # 得到预测标签
            yhat = torch.max(sigma, -1)[1]
            # 将正确的加起来
            corrects += torch.sum(yhat == y)
            
            # 每200个batch和最后结束时,打印模型的进度
            if (batch_idx + 1) % 200 == 0 or batch_idx == (len(batchdata) - 1):
                # 监督模型进度
                print("Epoch{}:[{}/{} {: .0f}%], Loss:{:.6f}, Accuracy:{:.6f}".format(
                    epoch + 1
                    , samples
                    , epochs*len(batchdata.dataset)
                    , 100*samples/(epochs*len(batchdata.dataset))
                    , loss.data.item()
                    , float(100.0*corrects/samples)))

6 训练模型

# 设置随机种子
torch.manual_seed(51)

# 实例化模型
net = Model(input_, output_)

# 训练模型
fit(net, batchdata, lr, gamma, epochs)
# Epoch1:[25600/600000  4%], Loss:0.524430, Accuracy:69.570312
# Epoch1:[51200/600000  9%], Loss:0.363422, Accuracy:74.984375
# ......
# Epoch10:[600000/600000  100%], Loss:0.284664, Accuracy:85.771835

现在我们已经用PyTorch训练了最基础的神经网络,并且可以查看其训练成果。大家可以将代码复制进行运行!

虽然没有用到复杂的模型,但是我们在每次建模时的基本思想都是一致的

到此这篇关于python中的Pytorch建模流程汇总的文章就介绍到这了,更多相关Pytorch建模流程内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: python中的Pytorch建模流程汇总

本文链接: https://www.lsjlt.com/news/141125.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • python中的Pytorch建模流程汇总
    目录1导入库2设置初始值3导入并制作数据集4定义神经网络架构5定义训练流程6训练模型本节内容学习帮助大家梳理神经网络训练的架构。 一般我们训练神经网络有以下步骤: 导入库设置训练参数...
    99+
    2024-04-02
  • python中的Pytorch建模流程是什么
    小编给大家分享一下python中的Pytorch建模流程是什么,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!一般我们训练神经网络有以下步骤:导入库设置训练参数的初...
    99+
    2023-06-29
  • Python中psutil模块使用汇总
    简介:psutil(进程和系统实用程序)是一个跨平台库,用于检索Python中运行进程和系统利用率(CPU、内存、磁盘、网络、传感器)的信息。它主要用于系统监视、分析和限制进程资源以...
    99+
    2024-04-02
  • python中常用的内置模块汇总
    内置模块(一) Python内置的模块有很多,我们也已经接触了不少相关模块,接下来咱们就来做一些汇总和介绍。 内置模块有很多 & 模块中的功能也非常多,我们是没有办法注意全局...
    99+
    2024-04-02
  • python中的json模块常用方法汇总
    目录一、概述二、方法详解1.dump()2.dumps3.load4.loads三、代码实战1.dumps()2.dump()4.loads()一、概述 推荐使用参考网站: json...
    99+
    2024-04-02
  • Java流程控制语句最全汇总(中篇)
    目录前文Java switch case语句详解switch 语句格式switchcasedefaultbreak例 1例 2嵌套 switch 语句if 语句和 switch 语句...
    99+
    2023-01-13
    Java流程控制语句 流程控制语句 流程控制语句结构
  • 使用数学软件Matlab建模画图程序汇总
    目录1. 二维数据曲线图1.1 绘制二维曲线的基本函数1.plot()函数2.含多个输入参数的plot函数3.含选项的plot函数4.双纵坐标函数plotyy1.2 绘制图形的辅助操...
    99+
    2024-04-02
  • MySQL中创建表的三种方法汇总
    目录CREATE TABLECREATE TABLE … LIKECREATE TABLE … SELECT总结SQL 标准使用 CREATE TABLE 语句创建数据表;mysql ...
    99+
    2023-02-18
    MySQL创建表 MySQL创建表的方法 MySQL表创建
  • pytorch教程之网络的构建流程笔记
    目录构建网络定义一个网络loss FunctionBackprop更新权值参考网址 构建网络 我们可以通过torch.nn包来构建网络,现在你已经看过了autograd,nn在aut...
    99+
    2024-04-02
  • Python中pyautogui库的使用方法汇总
    目录常用操作鼠标操作键盘操作弹窗操作图像操作在使用Python做脚本的话,有两个库可以使用,一个为PyUserInput库,另一个为pyautogui库。就本人而言,我更喜欢使用py...
    99+
    2024-04-02
  • Python中列表的基本操作汇总
    目录1、列表的创建与遍历2、添加元素2.1、append()方法2.2、extend()方法2.3、insert()方法3、删除元素3.1、del命令3.2、pop()方法3.3、r...
    99+
    2024-04-02
  • Python中执行调用JS的多种方法汇总
    1. 写在前面   做爬虫的人大家都知道,现在国内Web或App普遍防护都做的很好,且越有价值的网站这方面越强 再小再弱的网站现在或多或少都要整点反爬 JS在反爬中应用非常广泛,现在做爬虫工程师基本...
    99+
    2023-08-31
    python javascript
  • Numpy库常用函数汇总:实现数据分析与建模的利器
    Numpy是Python中最常用的数学库之一,它集成了许多最佳的数学函数和操作。Numpy的使用非常广泛,包括统计、线性代数、图像处理、机器学习、神经网络等领域。在数据分析和建模方面,Numpy更是必不可少的工具之一。本文将分享...
    99+
    2024-01-19
    数据分析 Numpy 建模
  • PyTorch深度学习模型的保存和加载流程详解
    一、模型参数的保存和加载  torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经...
    99+
    2024-04-02
  • Go语言的中文文档与教程资源汇总
    Go语言的中文文档与教程资源汇总 Go语言是一种由Google开发的编程语言,它简单、高效,并具有强大的并发处理能力。随着Go语言在云计算、网络编程等领域的广泛应用,越来越多的程序员开...
    99+
    2024-04-02
  • 总结 90 条写 Python 程序的建议
    “ 阅读本文大概需要 3 分钟。 ”本文于网络整理,版权归原作者所有,如来源信息有误或侵犯权益,请联系我删除。自己写 Python 也有四五年了,一直是用自己的“强迫症”在维持自己代码的质量。都有去看 Google 的 Python 代码规...
    99+
    2023-06-01
  • 软件测试过程中常见的英文单词汇总
    一、专业名词篇 A: Automated Test 自动化测试Alpha Test a测试Acceptance Test 验收测试Agile Testing 敏捷测试Accuracy...
    99+
    2024-04-02
  • C++中新手容易犯的十种编程错误汇总
    目录前言1、有些关键字在cpp文件中多写了2、函数参数的默认值写到函数实现中了3、在编写类的时候,在类的结尾处忘记添加";"分号了4、只添加了函数声明,没有函数实现5、cpp文件忘记...
    99+
    2024-04-02
  • Python中re模块的常用方法总结
    前言 正则表达式作为计算机科学的一个概念,通常被用来检索、替换那些符合某个规则的文本。正则表达式是对字符串操作的一种逻辑公式,用事先定义好的规则字符串对字符串进行过滤逻辑处理。 re...
    99+
    2024-04-02
  • Python中os模块的12种用法总结
    目录一、先总结,再详谈二、详谈各种方法的使用1、getcwd() :返回当前工作目录2、chdir(path) :改变工作目录3、listdir(path) :列举指定目录中的文件名...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作