iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > GO >pytorch中的transforms.ToTensor和transforms.Normalize的实现
  • 305
分享到

pytorch中的transforms.ToTensor和transforms.Normalize的实现

2024-04-02 19:04:59 305人浏览 安东尼
摘要

目录transfORMs.ToTensortransforms.Normalize?transforms.ToTensor 最近看PyTorch时,遇到了对图像数据的归一化,如下图所

transforms.ToTensor

最近看PyTorch时,遇到了对图像数据的归一化,如下图所示:

image-20220416115017669

该怎么理解这串代码呢?我们一句一句的来看,先看transforms.ToTensor(),我们可以先转到官方给的定义,如下图所示:

image-20220416115331930

大概的意思就是说,transforms.ToTensor()可以将PIL和numpy格式的数据从[0,255]范围转换到[0,1] ,具体做法其实就是将原始数据除以255。另外原始数据的shape是(H x W x C),通过transforms.ToTensor()后shape会变为(C x H x W)。这样说我觉得大家应该也是能理解的,这部分并不难,但想着还是用一些例子来加深大家的映像???

先导入一些包

import cv2
import numpy as np
import torch
from torchvision import transforms

定义一个数组模型图片,注意数组数据类型需要时np.uint8【官方图示中给出】

data = np.array([
                [[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1]],
                [[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2]],
                [[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3]],
                [[4,4,4],[4,4,4],[4,4,4],[4,4,4],[4,4,4]],
                [[5,5,5],[5,5,5],[5,5,5],[5,5,5],[5,5,5]]
        ],dtype='uint8')

这是可以看看data的shape,注意现在为(W H C)。

image-20220416120518895

使用transforms.ToTensor()将data进行转换

data = transforms.ToTensor()(data)

这时候我们来看看data中的数据及shape。

image-20220416120811156

​ 很明显,数据现在都映射到了[0, 1]之间,并且data的shape发生了变换。

**注意:不知道大家是如何理解三维数组的,这里提供我的一个方法。**???

?原始的data的shape为(5,5,3),则其表示有5个(5 , 3)的二维数组,即我们把最外层的[]去掉就得到了5个五行三列的数据。

?同样的,变换后data的shape为(3,5,5),则其表示有3个(5 , 5)的二维数组,即我们把最外层的[]去掉就得到了3个五行五列的数据。

transforms.Normalize?

相信通过前面的叙述大家应该对transforms.ToTensor有了一定的了解,下面将来说说这个transforms.Normalize???同样的,我们先给出官方的定义,如下图所示:

image-20220416195418909

可以看到这个函数的输出output[channel] = (input[channel] - mean[channel]) / std[channel]。这里[channel]的意思是指对特征图的每个通道都进行这样的操作。【mean为均值,std为标准差】接下来我们看第一张图片中的代码,即

image-20220416200305159

这里的第一个参数(0.5,0.5,0.5)表示每个通道的均值都是0.5,第二个参数(0.5,0.5,0.5)表示每个通道的方差都为0.5。【因为图像一般是三个通道,所以这里的向量都是1x3的???】有了这两个参数后,当我们传入一个图像时,就会按照上面的公式对图像进行变换。【注意:这里说图像其实也不够准确,因为这个函数传入的格式不能为PIL Image,我们应该先将其转换为Tensor格式

说了这么多,那么这个函数到底有什么用呢?我们通过前面的ToTensor已经将数据归一化到了0-1之间,现在又接上了一个Normalize函数有什么用呢?其实Normalize函数做的是将数据变换到了[-1,1]之间。之前的数据为0-1,当取0时,output =(0 - 0.5)/ 0.5 = -1;当取1时,output =(1 - 0.5)/ 0.5 = 1。这样就把数据统一到了[-1,1]之间了???那么问题又来了,数据统一到[-1,1]有什么好处呢?数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度。【这句话是再网络上找到最多的解释,自己也不确定其正确性】

读到这里大家是不是以为就完了呢?这里还想和大家唠上一唠???上面的两个参数(0.5,0.5,0.5)是怎么得来的呢?这是根据数据集中的数据计算出的均值和标准差,所以往往不同的数据集这两个值是不同的???这里再举一个例子帮助大家理解其计算过程。同样采用上文例子中提到的数据。

上文已经得到了经ToTensor转换后的数据,现需要求出该数据每个通道的mean和std。【这一部分建议大家自己运行看看每一步的结果???】

# 需要对数据进行扩维,增加batch维度
data = torch.unsqueeze(data,0)    #在pytorch中一般都是(batch,C,H,W)
nb_samples = 0.
#创建3维的空列表
channel_mean = torch.zeros(3)
channel_std = torch.zeros(3)
N, C, H, W = data.shape[:4]
data = data.view(N, C, -1)  #将数据的H,W合并
#展平后,w,h属于第2维度,对他们求平均,sum(0)为将同一纬度的数据累加
channel_mean += data.mean(2).sum(0)  
#展平后,w,h属于第2维度,对他们求标准差,sum(0)为将同一纬度的数据累加
channel_std += data.std(2).sum(0)
#获取所有batch的数据,这里为1
nb_samples += N
#获取同一batch的均值和标准差
channel_mean /= nb_samples
channel_std /= nb_samples
print(channel_mean, channel_std)   #结果为tensor([0.0118, 0.0118, 0.0118]) tensor([0.0057, 0.0057, 0.0057])

将上述得到的mean和std带入公式,计算输出。

for i in range(3):
    data[i] = (data[i] - channel_mean[i]) / channel_std[i]
print(data)

输出结果:

image-20220416205341050

​ 从结果可以看出,我们计算的mean和std并不是0.5,且最后的结果也没有在[-1,1]之间。

最后我们再来看一个有意思的例子,我们得到了最终的结果,要是我们想要变回去怎么办,其实很简单啦,就是一个逆运算,即input = std*output + mean,然后再乘上255就可以得到原始的结果了。很多人获取吐槽了,这也叫有趣!!??哈哈哈这里我想说的是另外的一个事,如果我们对一张图像进行了归一化,这时候你用归一化后的数据显示这张图像的时候,会发现同样会是原图。

参考链接1:https://zhuanlan.zhihu.com/p/414242338

参考链接2:Https://blog.csdn.net/peacefairy/article/details/108020179

到此这篇关于pytorch中的transforms.ToTensor和transforms.Normalize的实现的文章就介绍到这了,更多相关pytorch transforms.ToTensor和transforms.Normalize内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

您可能感兴趣的文档:

--结束END--

本文标题: pytorch中的transforms.ToTensor和transforms.Normalize的实现

本文链接: https://www.lsjlt.com/news/146272.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • pytorch中的transforms.ToTensor和transforms.Normalize的实现
    目录transforms.ToTensortransforms.Normalizetransforms.ToTensor 最近看pytorch时,遇到了对图像数据的归一化,如下图所示...
    99+
    2022-11-13
  • pytorch中的transforms.ToTensor和transforms.Normalize怎么实现
    本文小编为大家详细介绍“pytorch中的transforms.ToTensor和transforms.Normalize怎么实现”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch中的transforms.ToTensor和tr...
    99+
    2023-06-30
  • CNN的Pytorch实现(LeNet)
    目录CNN的Pytorch实现(LeNet)1. 任务目标2. 库的导入3. 模型定义4. 数据加载、处理5.模型训练整个代码CNN的Pytorch实现(LeNet)  ...
    99+
    2022-11-12
  • Pytorch中实现CPU和GPU之间的切换的两种方法
    目录方法一:.to(device)1.不知道电脑GPU可不可用时:2.指定GPU时3.指定cpu时:方法二:总结:如何在pytorch中指定CPU和GPU进行训练,以及cpu和gpu...
    99+
    2023-01-28
    Pytorch CPU和GPU切换 Pytorch CPU GPU
  • Pytorch从0实现Transformer的实践
    目录摘要一、构造数据1.1 句子长度1.2 生成句子1.3 生成字典1.4 得到向量化的句子二、位置编码2.1 计算括号内的值2.2 得到位置编码三、多头注意力3.1 self ma...
    99+
    2022-11-11
  • PyTorch中dropout设置训练和测试模式的实现示例
    这篇文章主要介绍PyTorch中dropout设置训练和测试模式的实现示例,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!看代码吧~class Net(nn.Module):…model =&nbs...
    99+
    2023-06-15
  • beam search及pytorch的实现方式
    主要记录两种不同的beam search版本 版本一 使用类似层次遍历的方式进行搜索,用队列进行维护,每次循环对当前层的所有节点进行搜索,这些节点每个分别对应topk个节点作为下一层...
    99+
    2022-11-12
  • pytorch实现textCNN的具体操作
    1. 原理 2014年的一篇文章,开创cnn用到文本分类的先河。 Convolutional Neural Networks for Sentence Classification ...
    99+
    2022-11-12
  • PyTorch中torch.tensor()和torch.to_tensor()的区别
    目录前言1、torch.as_tensor()2、torch.tensor()总结前言 在跑模型的时候,遇到如下报错 UserWarning: To copy construct f...
    99+
    2023-01-28
    torch.tensor和torch.Tensor的区别 torch.tensor()
  • PyTorch dropout设置训练和测试模式的实现
    看代码吧~ class Net(nn.Module): … model = Net() … model.train() # 把module设成训练模式,对Dropout和Batc...
    99+
    2022-11-12
  • 在 pytorch 中实现计算图和自动求导
    前言: 今天聊一聊 pytorch 的计算图和自动求导,我们先从一个简单例子来看,下面是一个简单函数建立了 yy 和 xx 之间的关系 然后...
    99+
    2022-11-11
  • pytorch实现ResNet结构的实例代码
    目录1.ResNet的创新1)亮点2)原因2.ResNet的结构1)浅层的残差结构2)深层的残差结构3)总结3.Batch Normalization4.参考代码1.ResNet的创...
    99+
    2022-11-12
  • PyTorch中的train()、eval()和no_grad()的使用
    目录什么是train()函数?什么是eval()函数?什么是no_grad()函数?train()、eval()和no_grad()函数的联系总结在PyTorch中,train()、...
    99+
    2023-05-14
    PyTorch train() eval() no_grad()
  • pytorch中的model.eval()和BN层的使用
    看代码吧~ class ConvNet(nn.module): def __init__(self, num_class=10): super(ConvN...
    99+
    2022-11-12
  • pytorch实现线性回归的方法
    这篇文章主要介绍“pytorch实现线性回归的方法”,在日常操作中,相信很多人在pytorch实现线性回归的方法问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”pytorch实现线性回归的方法”的疑惑有所帮助!...
    99+
    2023-06-14
  • PyTorch数据读取的实现示例
    前言 PyTorch作为一款深度学习框架,已经帮助我们实现了很多很多的功能了,包括数据的读取和转换了,那么这一章节就介绍一下PyTorch内置的数据读取模块吧 模块介绍 pan...
    99+
    2022-11-11
  • M1 mac安装PyTorch的实现步骤
    目录第一步 -安装和配置Miniforge第二步-创建虚拟环境第三步 -安装PyTorch第四步 -测试最后总结M1 macbook已经不是什么新产品了。TensorFlow官方已经...
    99+
    2022-11-12
  • pytorch 搭建神经网路的实现
    目录1 数据 (1)导入数据(2)数据集可视化(3)为自己制作的数据集创建类(4)数据集批处理(5)数据预处理2 神经网络(1)定义神经网络类(3)模型参数3 最优化模型参...
    99+
    2022-11-12
  • Pytorch实现全连接层的操作
    全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC。 FC的准则很简单:神经网络中除输入层之外的每个节点都和上一...
    99+
    2022-11-12
  • Anaconda配置各版本Pytorch的实现
    目录1. 前言2. 配置镜像源3. pytorch,torchvision,python 版本对应4. 创建并进入虚拟环境5. Pytorch 0.4.16. Pytorch 1.0...
    99+
    2022-11-12
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作