iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch实现常用乘法算子TensorRT的示例代码
  • 589
分享到

Pytorch实现常用乘法算子TensorRT的示例代码

2024-04-02 19:04:59 589人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

目录1.乘法运算总览2.乘法算子实现2.1矩阵乘算子实现2.2点乘算子实现本文介绍一下 PyTorch 中常用乘法的 TensorRT 实现。 pytorch 用于训练,Tensor

本文介绍一下 PyTorch 中常用乘法的 TensorRT 实现。

pytorch 用于训练,TensorRT 用于推理是很多 ai 应用开发的标配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识。

1.乘法运算总览

先把 pytorch 中的一些常用的乘法运算进行一个总览:

  • torch.mm:用于两个矩阵 (不包括向量) 的乘法,如维度 (m, n) 的矩阵乘以维度 (n, p) 的矩阵;
  • torch.bmm:用于带 batch 的三维向量的乘法,如维度 (b, m, n) 的矩阵乘以维度 (b, n, p) 的矩阵;
  • torch.mul:用于同维度矩阵的逐像素点相乘,也即点乘,如维度 (m, n) 的矩阵点乘维度 (m, n) 的矩阵。该方法支持广播,也即支持矩阵和元素点乘;
  • torch.mv:用于矩阵和向量的乘法,矩阵在前,向量在后,如维度 (m, n) 的矩阵乘以维度为 (n) 的向量,输出维度为 (m);
  • torch.matmul:用于两个张量相乘,或矩阵与向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
  • @:作用相当于 torch.matmul;
  • *:作用相当于 torch.mul;

如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。

2.乘法算子实现

2.1矩阵乘算子实现

先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):

>>> import torch
>>> # torch.mm
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99, 88)
>>> c = torch.mm(a, b)
>>> c.shape
torch.size([66, 88])
>>>
>>> # torch.bmm
>>> a = torch.randn(3, 66, 99)
>>> b = torch.randn(3, 99, 77)
>>> c = torch.bmm(a, b)
>>> c.shape
torch.size([3, 66, 77])
>>>
>>> # torch.mv
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99)
>>> c = torch.mv(a, b)
>>> c.shape
torch.size([66])
>>>
>>> # torch.matmul
>>> a = torch.randn(32, 3, 66, 99)
>>> b = torch.randn(32, 3, 99, 55)
>>> c = torch.matmul(a, b)
>>> c.shape
torch.size([32, 3, 66, 55])
>>>
>>> # @
>>> d = a @ b
>>> d.shape
torch.size([32, 3, 66, 55])

来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply 方法覆盖,对应 torch.matmul,先来看该方法的定义:

//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
  ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
  return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}

可以看到这个方法有四个传参,对应两个张量和其 operation。来看这个算子在 TensorRT 中怎么添加:

// 构造张量 Tensor0
nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);
// 构造张量 Tensor1
nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);

// 添加矩阵乘法
nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);

// 获取输出
matmulOutput = Matmul_layer->getOputput(0);

2.2点乘算子实现

再来看看点乘的 pytorch 的实现 (以下实现在终端):

>>> import torch
>>> # torch.mul
>>> a = torch.randn(66, 99)
>>> b = torch.randn(66, 99)
>>> c = torch.mul(a, b)
>>> c.shape
torch.size([66, 99])
>>> d = 0.125
>>> e = torch.mul(a, d)
>>> e.shape
torch.size([66, 99])
>>> # *
>>> f = a * b
>>> f.shape
torch.size([66, 99])

来看 TensorRT 的实现,以上乘法都可使用 addScale 方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:

//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//!              This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//!              and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
  return mImpl->addScale(input, mode, shift, scale, power);
}

 可以看到有三个模式:

  • kUNIFORM:weights 为一个值,对应张量乘一个元素;
  • kCHANNEL:weights 维度和输入张量通道的 c 维度对应,可以做一些以通道为基准的预处理;
  • kELEMENTWISE:weights 维度和输入张量的 c、h、w 对应,不考虑 batch,所以是输入的后三维;

再来看这个算子在 TensorRT 中怎么添加:

// 构造张量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);

// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;

// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };

// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行

// 添加张量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);

// 获取输出
scaleOutput = Scale_layer->getOputput(0);

有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。

到此这篇关于Pytorch实现常用乘法算子TensorRT的示例代码的文章就介绍到这了,更多相关Pytorch乘法算子TensorRT内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: Pytorch实现常用乘法算子TensorRT的示例代码

本文链接: https://www.lsjlt.com/news/118499.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Pytorch实现常用乘法算子TensorRT的示例代码
    目录1.乘法运算总览2.乘法算子实现2.1矩阵乘算子实现2.2点乘算子实现本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现。 pytorch 用于训练,Tensor...
    99+
    2022-11-11
  • Pytorch如何实现常用乘法算子TensorRT
    这篇文章主要介绍了Pytorch如何实现常用乘法算子TensorRT的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch如何实现常用乘法算子TensorRT文章都会有所收获,下面我们一起来看看吧。1.乘...
    99+
    2023-06-30
  • python3实现常见的排序算法(示例代码)
    冒泡排序 冒泡排序是一种简单的排序算法。它重复地走访过要排序的数列,一次比较两个元素,如果它们的顺序错误就把它们交换过来。走访数列的工作是重复地进行直到没有再需要交换,也就是说该数列...
    99+
    2022-11-12
  • C#实现常见加密算法的示例代码
    目录前言1. Base64编码1.1 原理介绍1.2 C#代码2. 凯撒密码2.1 原理介绍2.2 C#代码3. Vigenere密码3.1 原理介绍3.2 C#代码4. DES4....
    99+
    2022-11-13
  • Golang实现常见排序算法的示例代码
    目录前言五种基础排序算法对比1、冒泡排序2、选择排序3、插入排序4、快速排序前言 现在的面试真的是越来越卷了,算法已经成为了面试过程中必不可少的一个环节,你如果想进稍微好一点的公司,...
    99+
    2022-11-13
  • PHP实现常见排序算法的示例代码
    目录1、冒泡排序2、选择排序3、快速排序4、插入排序补充1、冒泡排序 两两相比,每循环一轮就不用再比较最后一个元素了,因为最后一个元素已经是最大或者最小。 function maop...
    99+
    2022-11-13
  • Java实现常见的排序算法的示例代码
    目录一、优化后的冒泡排序二、选择排序三、插入排序四、希尔排序五、快速排序六、随机化快速排序七、归并排序八、可处理负数的基数排序一、优化后的冒泡排序 package com.yzh.s...
    99+
    2022-11-13
    Java常见排序算法 Java排序算法 Java排序
  • Golang实现常见的限流算法的示例代码
    目录固定窗口滑动窗口漏桶算法令牌桶滑动日志总结限流是项目中经常需要使用到的一种工具,一般用于限制用户的请求的频率,也可以避免瞬间流量过大导致系统崩溃,或者稳定消息处理速率 这个文章主...
    99+
    2023-05-14
    Golang常见限流算法 Golang限流算法 Go 限流算法
  • Go语言实现常用排序算法的示例代码
    目录冒泡排序快速排序选择排序插入排序排序算法是在生活中随处可见,也是算法基础,因为其实现代码较短,应用较常见。所以在面试中经常会问到排序算法及其相关的问题,可以说是每个程序员都必须得...
    99+
    2022-11-11
  • Go+Redis实现常见限流算法的示例代码
    目录固定窗口滑动窗口hash实现list实现漏桶算法令牌桶滑动日志总结限流是项目中经常需要使用到的一种工具,一般用于限制用户的请求的频率,也可以避免瞬间流量过大导致系统崩溃,或者稳定消息处理速率。并且有时候我们还需要使用...
    99+
    2023-04-02
    Go Redis实现限流算法 Go Redis限流算法 Go 限流算法
  • Python/JS实现常见加密算法的示例代码
    目录前言一、编码,加密二、常见编码1.Base642. Base64 - JS实现3. Base64 - Python实现4.Unicode5.Urlencode三、线性散列算法(签...
    99+
    2022-11-11
  • Python实现异常检测LOF算法的示例代码
    目录背景LOF 算法1. k邻近距离2. k距离领域3. 可达距离4. 局部可达密度5. 局部异常因子LOF算法流程LOF优缺点Python 实现 LOFPyODSklearn大家好...
    99+
    2022-11-13
  • Java实现Kruskal算法的示例代码
    目录介绍一、构建后的图二、代码三、测试介绍 构造最小生成树还有一种算法,即 Kruskal 算法:设图 G=(V,E)是无向连通带权图,V={1,2,...n};设最小生成树 T=(...
    99+
    2022-11-13
  • PHP实现LRU算法的示例代码
    本篇文章主要给大家介绍了PHP的相关知识,LRU是Least Recently Used 近期最少使用算法, 内存管理的一种页面置换算法,下面将详解LRU算法的原理以及实现,下面一起来看一下,希望对大家有帮助。(推荐教程:PHP视频教程)原...
    99+
    2022-08-08
    php
  • Java实现Floyd算法的示例代码
    目录一 问题描述二 代码三 实现一 问题描述 求节点0到节点2的最短路径。 二 代码 package graph.floyd; ...
    99+
    2022-11-13
  • C++实现Dijkstra算法的示例代码
    目录一、算法原理二、具体代码1.graph类2.PathFinder类3. main.cpp三、示例一、算法原理 链接: Dijkstra算法及其C++实现参考这篇文章 二、具体代码...
    99+
    2022-11-13
  • Java实现Dijkstra算法的示例代码
    目录一 问题描述二 实现三 测试一 问题描述 小明为位置1,求他到其他各顶点的距离。 二 实现 package graph.dij...
    99+
    2022-11-13
  • Java实现雪花算法的示例代码
    一、介绍 SnowFlow算法是Twitter推出的分布式id生成算法,主要核心思想就是利用64bit的long类型的数字作为全局的id。在分布式系统中经常应用到,并且,在id中加入...
    99+
    2022-11-13
  • Java实现抽奖算法的示例代码
    目录一、题目描述二、解题思路三、代码详解四、优化抽奖算法解题思路代码详解一、题目描述 题目: 小虚竹为了给粉丝送福利,决定在参与学习打卡活动的粉丝中抽一位幸运粉丝,送份小礼物。为了公...
    99+
    2022-11-13
  • Java 实现LZ78压缩算法的示例代码
    LZ78 压缩算法的 Java 实现 1、压缩算法的实现 通过多路搜索树提高检索速度 package com.wretchant.lz78; import java.util....
    99+
    2022-11-12
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作