iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >使用Pytorch如何完成多分类问题
  • 182
分享到

使用Pytorch如何完成多分类问题

Pytorch多分类Pytorch完成多分类Pytorch多分类问题 2023-02-02 12:02:18 182人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

目录PyTorch如何完成多分类为什么要用transfORM归一化模型总结Pytorch如何完成多分类 多分类问题在最后的输出层采用的Softmax Layer,其具有两个特点:1.

Pytorch如何完成多分类

多分类问题在最后的输出层采用的Softmax Layer,其具有两个特点:1.每个输出的值都是在(0,1);2.所有值加起来和为1.

假设是最后线性层的输出,则对应的Softmax function为:

输出经过sigmoid运算即可是西安输出的分类概率都大于0且总和为1。

上图的交叉熵损失就包含了softmax计算和右边的标签输入计算(即框起来的部分)

所以在使用交叉熵损失的时候,神经网络的最后一层是不要做激活的,因为把它做成分布的激活是包含在交叉熵损失里面的,最后一层不要做非线性变换,直接交给交叉熵损失。

如上图,做交叉熵损失时要求y是一个长整型的张量,构造时直接用

criterion = torch.nn.CrossEntropyLoss()

3个类别,分别是2,0,1

Y_pred1 ,Y_pred2还是线性输出,没经过softmax,还不是概率分布,比如Y_pred1,0.9最大,表示对应为第3个的概率最大,和2吻合,1.1最大,表示对应为第1个的概率最大,和0吻合,2.1最大,表示对应为第2个的概率最大,和1吻合,那么Y_pred1 的损失会比较小

对于Y_pred2,0.8最大,表示对应为第1个的概率最大,和0不吻合,0.5最大,表示对应为第3个的概率最大,和2不吻合,0.5最大,表示对应为第3个的概率最大,和2不吻合,那么Y_pred2 的损失会比较大

Exercise 9-1: CrossEntropyLoss vs NLLLoss

What are the differences?

• Reading the document:

https://pytorch.org/docs/stable/nn.html#crossentropyloss

Https://pytorch.org/docs/stable/nn.html#nllloss

• Try to know why:

• CrossEntropyLoss <==> LogSoftmax + NLLLoss

为什么要用transform

transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])

PyTorch读图像用的是python的imageLibrary,就是PIL,现在用的都是pillow,pillow读进来的图像用神经网络处理的时候,神经网络有一个特点就是希望输入的数值比较小,最好是在-1到+1之间,最好是输入遵从正态分布,这样的输入对神经网络训练是最有帮助的

原始图像是28*28的像素值在0到255之间,我们把它转变成图像张量,像素值是0到1

在视觉里面,灰度图就是一个矩阵,但实际上并不是一个矩阵,我们把它叫做单通道图像,彩色图像是3通道,通道有宽度和高度,一般我们读进来的图像张量是WHC(宽高通道)

在PyTorch里面我们需要转化成CWH,把通道放在前面是为了在PyTorch里面进行更高效的图像处理,卷积运算。所以拿到图像之后,我们就把它先转化成pytorch里面的一个Tensor,把0到255的值变成0到1的浮点数,然后把维度由2828变成128*28的张量,由单通道变成多通道,

这个过程可以用transforms的ToTensor这个函数实现

归一化

transforms.Normalize((0.1307, ), (0.3081, ))

这里的0.1307,0.3081是对Mnist数据集所有的像素求均值方差得到的

也就是说,将来拿到了图像,先变成张量,然后Normalize,切换到0,1分布,然后供神经网络训练

如上图,定义好transform变换之后,直接把它放到数据集里面,为什么要放在数据集里面呢,是为了在读取第i个数据的时候,直接用transform处理

 

模型

输入是一组图像,激活层改用Relu

全连接神经网络要求输入是一个矩阵

所以需要把输入的张量变成一阶的,这里的N表示有N个图片

view函数可以改变张量的形状,-1表示将来自动去算它的值是多少,比如输入是n128*28

将来会自动把n算出来,输入了张量就知道形状,就知道有多少个数值

最后输出是(N,10)因为是有0-9这10个标签嘛,10表示该图像属于某一个标签的概率,现在还是线性值,我们再用softmax把它变成概率

 #沿着第一个维度找最大值的下标,返回值有两个,因为是10列嘛,返回值一个是每一行的最大值,另一个是最大值的下标(每一个样本就是一行,每一行有10个量)(行是第0个维度,列是第1个维度)

MNIST数据集训练代码

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
 
# prepare dataset
 
batch_size = 64
 
transform = transforms.Compose([
    transforms.ToTensor(), #先将图像变换成一个张量tensor。
    transforms.Normalize((0.1307,), (0.3081,))
    #其中的0.1307是MNIST数据集的均值,0.3081是MNIST数据集的标准差。
])  # 归一化,均值和方差
 
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True,
                               download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
 
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False,
                               download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
 
# design model using class
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)
 
    def forward(self, x):
        # 28 * 28 = 784
        # 784 = 28 * 28,即将N *1*28*28转化成 N *1*784
        x = x.view(-1, 784)  # -1其实就是自动获取mini_batch
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)  # 最后一层不做激活,不进行非线性变换
 
model = Net()
 
#CrossEntropyLoss <==> LogSoftmax + NLLLoss。
#也就是说使用CrossEntropyLoss最后一层(线性层)是不需要做其他变化的;
#使用NLLLoss之前,需要对最后一层(线性层)先进行SoftMax处理,再进行log操作。
 
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
#momentum 是带有优化的一个训练过程参数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
# training cycle forward, backward, update
 
def train(epoch):
    running_loss = 0.0
    #enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,
    #同时列出数据和数据下标,一般用在 for 循环当中。
    #enumerate(sequence, [start=0])
    for batch_idx, data in enumerate(train_loader, 0):
        # 获得一个批次的数据和标签
        inputs, target = data
        optimizer.zero_grad()
 
        #forward + backward + update
        # 获得模型预测结果(64, 10)
        outputs = model(inputs)
        # 交叉熵代价函数outputs(64,10),target(64)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0
 
def test():
    correct = 0
    total = 0
    with torch.no_grad():#不需要计算梯度。
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            #orch.max的返回值有两个,第一个是每一行的最大值是多少,第二个是每一行最大值的下标(索引)是多少。
            _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度
            total += labels.size(0)
            correct += (predicted == labels).sum().item()  # 张量之间的比较运算
    print('accuracy on test set: %d %% ' % (100 * correct / total))
 
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: 使用Pytorch如何完成多分类问题

本文链接: https://www.lsjlt.com/news/193977.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • 使用Pytorch如何完成多分类问题
    目录Pytorch如何完成多分类为什么要用transform归一化模型总结Pytorch如何完成多分类 多分类问题在最后的输出层采用的Softmax Layer,其具有两个特点:1....
    99+
    2023-02-02
    Pytorch多分类 Pytorch完成多分类 Pytorch多分类问题
  • 如何使用Pytorch完成图像分类任务详解
    目录概述:一. 数据准备二.定义一个卷积神经网络三.完整代码如下:总结概述: 本文将通过组织自己的训练数据,使用Pytorch深度学习框架来训练自己的模型,最终实现自己的图像分类!本...
    99+
    2022-11-11
  • Pytorch如何继承Subset类完成自定义数据拆分
    这篇文章主要介绍“Pytorch如何继承Subset类完成自定义数据拆分”,在日常操作中,相信很多人在Pytorch如何继承Subset类完成自定义数据拆分问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Pyt...
    99+
    2023-06-29
  • 如何使用Pytorch训练分类器
    如何使用Pytorch训练分类器,相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。一、 数据通常来说,当你处理图像,文本,语音或者视频数据时,你可以使用标准python包将数据加载...
    99+
    2023-06-02
  • 如何用conda安装PyTorch(windows、GPU)最全安装教程(cudatoolkit、python、PyTorch、Anaconda版本对应问题)(完美解决安装CPU而不是GPU的问题)
    一、开发环境         安装PyTorch的开发环境:Anaconda+CUDA+cuDNN+PyCharm Community 二、安装过程 1、Anaconda的安装  1.1 版本选择 第一步就是最关键的版本对应问题(这决定你能...
    99+
    2023-10-21
    conda pytorch 人工智能 python 深度学习 pycharm windows
  • lombok 子类中如何使用@Builder问题
    目录lombok子类中如何使用@Builder子类使用lombok的@Builder注解正确姿势分析一下lombok子类中如何使用@Builder lombok大家都知道,在使用PO...
    99+
    2022-11-13
  • Excel如何使用Ctrl + e完成数据分裂操作
    这篇文章主要为大家展示了“Excel如何使用Ctrl + e完成数据分裂操作”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Excel如何使用Ctrl + e完成数据分裂操作”这篇文章吧。Ctrl...
    99+
    2023-06-27
  • 如何使用四象限法分析问题
    本篇内容介绍了“如何使用四象限法分析问题”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!对时间管理有了解的同...
    99+
    2022-10-19
  • 如何使用Python程序完成描述性统计分析需求
    目录一、前言1.1 关于描述性统计分析1.2 本篇目的1.3 提示二、程序内容的编写2.1 导入数据与前期处理 2.2 描述性统计分析所要计算的数据2.3 数据可视化2.3...
    99+
    2023-03-22
    python python统计分析 python分析需求
  • 如何在Windows环境下使用PHP完成LeetCode高难度算法题?
    在Windows环境下使用PHP完成LeetCode高难度算法题可能是许多程序员们的一个挑战。但是,这并不是不可能的。在本文中,我们将向您展示如何使用PHP编写高质量的算法代码,以及如何在Windows环境下运行和测试它们。 首先,我们需要...
    99+
    2023-10-15
    windows leetcode 大数据
  • linux中如何使用awk完成更多结构化的复杂任务
    这篇文章将为大家详细讲解有关linux中如何使用awk完成更多结构化的复杂任务,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。awk 的程序结构awk 脚本是由 {}(大括号)包围的功能...
    99+
    2023-06-15
  • LeetCode算法题中,如何使用容器完成复杂的计算任务?
    随着计算机技术的发展,越来越多的计算任务需要通过算法来完成。其中,LeetCode算法题是程序员们不断挑战自我的一个平台。在这些题目中,容器成为了一个非常重要的工具。本文将介绍在LeetCode算法题中如何使用容器完成复杂的计算任务,并通...
    99+
    2023-06-01
    leetcode 编程算法 容器
  • 如何使用GO语言自然语言处理API解决文本分类问题?
    当今互联网时代,数据量已经达到了惊人的数量,而这些数据中包含着各种各样的信息,其中文本信息占据了很大一部分。因此,如何对文本信息进行分类和分析,已经成为了一项十分重要的工作。而自然语言处理技术则是实现这一目标的关键技术之一。在本文中,我们将...
    99+
    2023-09-22
    自然语言处理 api http
  • 如何使用Spring解决ibatis多数据源的问题
    本篇内容介绍了“如何使用Spring解决ibatis多数据源的问题”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!iBatis多数据源的苦恼在...
    99+
    2023-06-18
  • C#如何使用Task类解决线程的等待问题
    目录使用Task类解决线程的等待问题Task类用法示例小结C#代码执行中等待10秒使用Task类解决线程的等待问题 在任何的编程语言中,面对耗时任务时,我们都会有这样的需求:让任务执...
    99+
    2022-11-13
  • C++分析讲解类的静态成员函数如何使用
    目录一、未完成的需求二、问题分析三、静态成员函数四、小结一、未完成的需求 统计在程序运行期间某个类的对象数目保证程序的安全性(不能使用全局变量)随时可以获取当前对象的数目 在C++分...
    99+
    2022-11-13
  • jpa使用注解生成表时无外键问题如何解决
    这篇“jpa使用注解生成表时无外键问题如何解决”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“jpa使用注解生成表时无外键问题...
    99+
    2023-07-02
  • SpringBoot如何使用 Redis 分布式锁解决并发问题
    这期内容当中小编将会给大家带来有关SpringBoot如何使用 Redis 分布式锁解决并发问题,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。问题背景现在的应用程序架构中,很多服务都是多副本运行,从而保证...
    99+
    2023-06-25
  • 如何解决Https页面使用百度分享的问题
    小编给大家分享一下如何解决Https页面使用百度分享的问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!解决方法解决方法来源于细语呢喃,注意网站可能被墙,需要代理...
    99+
    2023-06-08
  • 如何解决使用feign传递参数类型为MultipartFile的问题
    这篇文章主要介绍如何解决使用feign传递参数类型为MultipartFile的问题,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!feign传递参数类型为MultipartFilefeign默认是不支持多媒体文件类型...
    99+
    2023-06-29
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作