iis服务器助手广告广告
返回顶部
首页 > 资讯 > 精选 >Pytorch如何实现常用乘法算子TensorRT
  • 860
分享到

Pytorch如何实现常用乘法算子TensorRT

2023-06-30 18:06:03 860人浏览 安东尼
摘要

这篇文章主要介绍了PyTorch如何实现常用乘法算子TensorRT的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch如何实现常用乘法算子TensorRT文章都会有所收获,下面我们一起来看看吧。1.乘

这篇文章主要介绍了PyTorch如何实现常用乘法算子TensorRT的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch如何实现常用乘法算子TensorRT文章都会有所收获,下面我们一起来看看吧。

1.乘法运算总览

先把 pytorch 中的一些常用的乘法运算进行一个总览:

  • torch.mm:用于两个矩阵 (不包括向量) 的乘法,如维度 (m, n) 的矩阵乘以维度 (n, p) 的矩阵;

  • torch.bmm:用于带 batch 的三维向量的乘法,如维度 (b, m, n) 的矩阵乘以维度 (b, n, p) 的矩阵;

  • torch.mul:用于同维度矩阵的逐像素点相乘,也即点乘,如维度 (m, n) 的矩阵点乘维度 (m, n) 的矩阵。该方法支持广播,也即支持矩阵和元素点乘;

  • torch.mv:用于矩阵和向量的乘法,矩阵在前,向量在后,如维度 (m, n) 的矩阵乘以维度为 (n) 的向量,输出维度为 (m);

  • torch.matmul:用于两个张量相乘,或矩阵与向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;

  • @:作用相当于 torch.matmul;

  • *:作用相当于 torch.mul;

如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。

2.乘法算子实现

2.1矩阵乘算子实现

先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):

>>> import torch>>> # torch.mm>>> a = torch.randn(66, 99)>>> b = torch.randn(99, 88)>>> c = torch.mm(a, b)>>> c.shapetorch.size([66, 88])>>>>>> # torch.bmm>>> a = torch.randn(3, 66, 99)>>> b = torch.randn(3, 99, 77)>>> c = torch.bmm(a, b)>>> c.shapetorch.size([3, 66, 77])>>>>>> # torch.mv>>> a = torch.randn(66, 99)>>> b = torch.randn(99)>>> c = torch.mv(a, b)>>> c.shapetorch.size([66])>>>>>> # torch.matmul>>> a = torch.randn(32, 3, 66, 99)>>> b = torch.randn(32, 3, 99, 55)>>> c = torch.matmul(a, b)>>> c.shapetorch.size([32, 3, 66, 55])>>>>>> # @>>> d = a @ b>>> d.shapetorch.size([32, 3, 66, 55])

来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply 方法覆盖,对应 torch.matmul,先来看该方法的定义:

//!//! \brief Add a MatrixMultiply layer to the network.//!//! \param input0 The first input tensor (commonly A).//! \param op0 The operation to apply to input0.//! \param input1 The second input tensor (commonly B).//! \param op1 The operation to apply to input1.//!//! \see IMatrixMultiplyLayer//!//! \warning Int32 tensors are not valid input tensors.//!//! \return The new matrix multiply layer, or nullptr if it could not be created.//!IMatrixMultiplyLayer* addMatrixMultiply(  ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept{  return mImpl->addMatrixMultiply(input0, op0, input1, op1);}

可以看到这个方法有四个传参,对应两个张量和其 operation。来看这个算子在 TensorRT 中怎么添加:

// 构造张量 Tensor0nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);// 构造张量 Tensor1nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);// 添加矩阵乘法nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);// 获取输出matmulOutput = Matmul_layer->getOputput(0);

2.2点乘算子实现

再来看看点乘的 pytorch 的实现 (以下实现在终端):

>>> import torch>>> # torch.mul>>> a = torch.randn(66, 99)>>> b = torch.randn(66, 99)>>> c = torch.mul(a, b)>>> c.shapetorch.size([66, 99])>>> d = 0.125>>> e = torch.mul(a, d)>>> e.shapetorch.size([66, 99])>>> # *>>> f = a * b>>> f.shapetorch.size([66, 99])

来看 TensorRT 的实现,以上乘法都可使用 addScale 方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:

//!//! \brief Add a Scale layer to the network.//!//! \param input The input tensor to the layer.//!              This tensor is required to have a minimum of 3 dimensions in implicit batch mode//!              and a minimum of 4 dimensions in explicit batch mode.//! \param mode The scaling mode.//! \param shift The shift value.//! \param scale The scale value.//! \param power The power value.//!//! If the weights are available, then the size of weights are dependent on the ScaleMode.//! For ::kUNIFORM, the number of weights equals 1.//! For ::kCHANNEL, the number of weights equals the channel dimension.//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.//!//! \see addScaleNd//! \see IScaleLayer//! \warning Int32 tensors are not valid input tensors.//!//! \return The new Scale layer, or nullptr if it could not be created.//!IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept{  return mImpl->addScale(input, mode, shift, scale, power);}

 可以看到有三个模式:

  • kUNIFORM:weights 为一个值,对应张量乘一个元素;

  • kCHANNEL:weights 维度和输入张量通道的 c 维度对应,可以做一些以通道为基准的预处理;

  • kELEMENTWISE:weights 维度和输入张量的 c、h、w 对应,不考虑 batch,所以是输入的后三维;

再来看这个算子在 TensorRT 中怎么添加:

// 构造张量 inputnvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISEscalemode = kUNIFORM;// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行// 添加张量乘法nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);// 获取输出scaleOutput = Scale_layer->getOputput(0);

有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。

关于“Pytorch如何实现常用乘法算子TensorRT”这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对“Pytorch如何实现常用乘法算子TensorRT”知识都有一定的了解,大家如果还想学习更多知识,欢迎关注编程网精选频道。

--结束END--

本文标题: Pytorch如何实现常用乘法算子TensorRT

本文链接: https://www.lsjlt.com/news/330666.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Pytorch如何实现常用乘法算子TensorRT
    这篇文章主要介绍了Pytorch如何实现常用乘法算子TensorRT的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch如何实现常用乘法算子TensorRT文章都会有所收获,下面我们一起来看看吧。1.乘...
    99+
    2023-06-30
  • Pytorch实现常用乘法算子TensorRT的示例代码
    目录1.乘法运算总览2.乘法算子实现2.1矩阵乘算子实现2.2点乘算子实现本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现。 pytorch 用于训练,Tensor...
    99+
    2022-11-11
  • php如何实现乘法运算
    本篇内容主要讲解“php如何实现乘法运算”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“php如何实现乘法运算”吧!在PHP中,可以利用“*”算术运算符实现乘法运算,该运算符用于计算前后两个数的乘...
    99+
    2023-06-29
  • 如何实现大整数乘法运算与分治算法
    本篇内容主要讲解“如何实现大整数乘法运算与分治算法”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“如何实现大整数乘法运算与分治算法”吧!普通乘数运算对于乘数运算有...
    99+
    2022-10-19
  • C语言如何使用移位实现乘除法运算
    这篇文章主要为大家展示了“C语言如何使用移位实现乘除法运算”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“C语言如何使用移位实现乘除法运算”这篇文章吧。移位实现乘...
    99+
    2022-10-19
  • Pytorch 如何实现常用正则化
    Stochastic Depth 论文:Deep Networks with Stochastic Depth 本文的正则化针对于ResNet中的残差结构,类似于dropout的原理...
    99+
    2022-11-12
  • pytorch液态算法如何实现瘦脸效果
    这篇文章主要介绍了pytorch液态算法如何实现瘦脸效果,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。算法思路:假设当前点为(x,y),手动指定变形区域的中心点为C(cx,c...
    99+
    2023-06-21
  • 如何使用php递归函数实现阶乘计算
    以下是使用PHP递归函数实现阶乘计算的示例代码:```phpfunction factorial($n) {if ($n ...
    99+
    2023-09-15
    php
  • PHP如何实现常见排序算法
    本篇内容介绍了“PHP如何实现常见排序算法”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!1、冒泡排序两两相比,每循环一轮就不用再比较最后一个...
    99+
    2023-07-01
  • 如何使用批处理实现九九乘法表
    这篇文章将为大家详细讲解有关如何使用批处理实现九九乘法表,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。代码如下:@echo off :: 九九乘法表 set num=0 for /l %%i in (1,...
    99+
    2023-06-08
  • shell中如何使用awk实现九九乘法表
    小编给大家分享一下shell中如何使用awk实现九九乘法表,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!实现代码:awk ‘BEGIN{for(i=1;...
    99+
    2023-06-09
  • 如何在Go语言中实现常用的算法?
    Go语言是一种快速、简单、可靠的编程语言,它已经成为了云计算、网络编程、大数据等领域的热门语言。在软件开发过程中,算法是一个不可或缺的部分。本文将介绍如何在Go语言中实现常用的算法。 一、排序算法 1.冒泡排序 冒泡排序是一种简单的排序算法...
    99+
    2023-06-17
    教程 编程算法 numy
  • python3如何实现常见的排序算法
    小编给大家分享一下python3如何实现常见的排序算法,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!冒泡排序冒泡排序是一种简单的排序算法。它重复地走访过要排序的数...
    99+
    2023-06-20
  • 如何用shell脚本编程实现9*9乘法表
    本篇内容介绍了“如何用shell脚本编程实现9*9乘法表”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!脚本内容代码如下:#!/bin/bas...
    99+
    2023-06-09
  • PHP如何在不使用加减乘除运算符号的情况下实现加法
    这篇文章主要讲解了“PHP如何在不使用加减乘除运算符号的情况下实现加法”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“PHP如何在不使用加减乘除运算符号的情况下实现加法”吧!写一个函数,求两个...
    99+
    2023-06-20
  • PHP如何使用数组循环来实现矩阵乘法
    这篇文章主要介绍“PHP如何使用数组循环来实现矩阵乘法”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“PHP如何使用数组循环来实现矩阵乘法”文章能帮助大家解决问题。什么是矩阵乘法在数学中,一个矩阵是由...
    99+
    2023-07-06
  • python如何实现常用的五种排序算法详解
    目录一、冒泡排序 二、选择排序 三、插入排序 四、希尔排序 五、快速排序 总结一、冒泡排序 原理: 比较相邻的元素。如果第一个比第二个大就交换他们两个 每一对相邻...
    99+
    2022-11-12
  • Django编程中如何实现常用的字符串算法?
    Django是一款优秀的Python web框架,它能够帮助开发者快速构建高质量的web应用。在Django编程中,常常需要使用字符串算法来处理各种字符串相关的问题。本文将介绍Django中常用的字符串算法,并且通过演示代码来说明如何实现这...
    99+
    2023-09-25
    编程算法 django laravel
  • 如何用Nacos实现Raft算法
    这篇文章主要介绍“如何用Nacos实现Raft算法”,在日常操作中,相信很多人在如何用Nacos实现Raft算法问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”如何用Nacos实现Raft算法”的疑惑有所帮助!...
    99+
    2023-06-02
  • 如何用PHP实现递归算法
    要使用PHP实现递归算法,首先需要定义一个递归函数。递归函数是指在函数内部调用函数本身的一种方法。下面是一个使用PHP实现递归算法的...
    99+
    2023-08-24
    PHP
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作