iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >深入理解Pytorch中的torch.matmul()
  • 817
分享到

深入理解Pytorch中的torch.matmul()

Pytorchtorch.matmul()torch.matmul() 2023-05-15 17:05:45 817人浏览 薄情痞子

Python 官方文档:入门教程 => 点击学习

摘要

目录torch.matmul()语法作用举例情形1: 一维 * 一维情形2: 二维 * 二维情形3: 一维 * 二维情形4: 二维 * 一维情形5:两个参数至少为一维且至少一个参数为

torch.matmul()

语法

torch.matmul(input, other, *, out=None) → Tensor

作用

两个张量的矩阵乘积

行为取决于张量的维度,如下所示:

  • 如果两个张量都是一维的,则返回点积(标量)。
  • 如果两个参数都是二维的,则返回矩阵-矩阵乘积。
  • 如果第一个参数是一维的,第二个参数是二维的,为了矩阵乘法的目的,在它的维数前面加上一个 1。在矩阵相乘之后,前置维度被移除。
  • 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量积。
  • 如果两个参数至少为一维且至少一个参数为 N 维(其中 N > 2),则返回批处理矩阵乘法
    • 如果第一个参数是一维的,则将 1 添加到其维度,以便批量矩阵相乘并在之后删除。如果第二个参数是一维的,则将 1 附加到其维度以用于批量矩阵倍数并在之后删除
    • 非矩阵(即批次)维度是广播的(因此必须是可广播的)
    • 例如,如果输入是( j × 1 × n × n ) (j \times 1 \times n \times n)(j×1×n×n) 张量
    • 另一个是 ( k × n × n ) (k \times n \times n)(k×n×n)张量,
    • out 将是一个 ( j × k × n × n ) (j \times k \times n \times n)(j×k×n×n) 张量

请注意,广播逻辑在确定输入是否可广播时仅查看批处理维度,而不是矩阵维度

例如

  • 如果输入是 ( j × 1 × n × m ) (j \times 1 \times n \times m)(j×1×n×m) 张量
  • 另一个是 ( k × m × p ) (k \times m \times p)(k×m×p) 张量
  • 即使最后两个维度(即矩阵维度)不同,这些输入对于广播也是有效的
  • out 将是一个 ( j × k × n × p ) (j \times k \times n \times p)(j×k×n×p) 张量

该运算符支持 TensorFloat32。

在某些 ROCm 设备上,当使用 float16 输入时,此模块将使用不同的向后精度

举例

情形1: 一维 * 一维

如果两个张量都是一维的,则返回点积(标量)

tensor1 = torch.Tensor([1,2,3])
tensor2 =torch.Tensor([4,5,6])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

ans = 1 * 4 + 2 * 5 + 3 * 6 = 32

情形2: 二维 * 二维

如果两个参数都是二维的,则返回矩阵-矩阵乘积
也就是 正常的矩阵乘法 (m * n) * (n * k) = (m * k)

tensor1 = torch.Tensor([[1,2,3],[1,2,3]])
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

情形3: 一维 * 二维

如果第一个参数是一维的,第二个参数是二维的,为了矩阵乘法的目的,在它的维数前面加上一个 1
在矩阵相乘之后,前置维度被移除

tensor1 = torch.Tensor([1,2,3]) # 注意这里是一维
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

tensor1 = torch.Tensor([1,2,3]) 修改为 tensor1 = torch.Tensor([[1,2,3]])

发现一个结果是[24., 30.] 一个是[[24., 30.]]

所以,当一维 * 二维时, 开始变成 1 * m(一维的维度),也就是一个二维, 再进行正常的矩阵运算,得到[[24., 30.]], 然后再去掉开始增加的一个维度,得到[24., 30.]

想象为二维 * 二维(前置维度为1),最后结果去掉一个维度即可

情形4: 二维 * 一维

如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量积

tensor1 =torch.Tensor([[4,5,6],[7,8,9]])
tensor2 = torch.Tensor([1,2,3])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

理解为:

  • 把第一个二维中,想象为多个行向量
  • 第二个一维想象为一个列向量
  • 行向量与列向量进行矩阵乘法,得到一个标量
  • 再按照行堆叠起来即可

情形5:两个参数至少为一维且至少一个参数为 N 维(其中 N > 2),则返回批处理矩阵乘法

第一个参数为N维,第二个参数为一维时

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
print(torch.matmul(tensor1, tensor2).size())

(4) 先添加一个维度 (4 * 1)
得到(10 * 3 * 4) *( 4 * 1) = (10 * 3 * 1)
再删除最后一个维度(添加的那个)
得到结果(10 * 3)

tensor1 = torch.randn(10,2, 3, 4) # 
tensor2 = torch.randn(4)
print(torch.matmul(tensor1, tensor2).size())

(10 * 2 * 3 * 4) * (4 * 1) = (10 * 2 * 3) 【抵消4,删1】

第一个参数为一维,第二个参数为二维时

tensor1 = torch.randn(4)
tensor2 = torch.randn(10, 4, 3)
print(torch.matmul(tensor1, tensor2).size())

tensor2 中第一个10理解为批次, 10个(4 * 3)
(1 * 4)与每个(4 * 3) 相乘得到(1,3),去除1,得到(3)
批次为10,得到(10,3)

tensor1 = torch.randn(4)
tensor2 = torch.randn(10,2, 4, 3)
print(torch.matmul(tensor1, tensor2).size())

这里批次理解为[10, 2]即可

tensor1 = torch.randn(4)
tensor2 = torch.randn(10,4, 2,4,1)
print(torch.matmul(tensor1, tensor2).size())

个人理解:当一个参数为一维时,它要去匹配另一个参数的最后两个维度(二维 * 二维)

比如上面的例子就是(1 * 4) 匹配 (4,1), 批次为(10,4,2)

高维 * 高维时

注:这不太好理解 … 感觉就是要找准批次,再进行乘法(靠感觉了 哈哈 离谱)

参考 https://PyTorch.org/docs/stable/generated/torch.matmul.html#torch.matmul 

到此这篇关于深入理解Pytorch中的torch. matmul()的文章就介绍到这了,更多相关Pytorch torch. matmul()内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: 深入理解Pytorch中的torch.matmul()

本文链接: https://www.lsjlt.com/news/209323.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • 深入理解Pytorch中的torch.matmul()
    目录torch.matmul()语法作用举例情形1: 一维 * 一维情形2: 二维 * 二维情形3: 一维 * 二维情形4: 二维 * 一维情形5:两个参数至少为一维且至少一个参数为...
    99+
    2023-05-15
    Pytorch torch. matmul() torch. matmul()
  • 深入理解pytorch库的dockerfile
    目录0. dockerfile命令1. 使用指令的注意点2. dockerfile3. 参考4. 存在的问题0. dockerfile命令 FROM # 基础镜像,一切从这里开...
    99+
    2024-04-02
  • 深入理解PyTorch中的nn.Embedding的使用
    目录一、前置知识1.1 语料库(Corpus)1.2 词元(Token)1.3 词表(Vocabulary)二、nn.Embedding 基础2.1 为什么要 embedding?2...
    99+
    2024-04-02
  • Pytorch中torch.stack()函数的深入解析
    目录一. torch.stack()函数解析1. 函数说明:2. 代码举例总结一. torch.stack()函数解析 1. 函数说明: 1.1 官网:torch.stack(),函...
    99+
    2024-04-02
  • 深入理解Pytorch微调torchvision模型
    目录一、简介二、导入相关包三、数据输入四、辅助函数1、模型训练和验证2、设置模型参数的'.requires_grad属性'一、简介 在本小节,深入探讨如何对torchvision进行...
    99+
    2024-04-02
  • 如何深入理解Pytorch微调torchvision模型
    如何深入理解Pytorch微调torchvision模型,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。一、简介在本小节,深入探讨如何对torchvision进行微调和特征提...
    99+
    2023-06-25
  • Android 中ThreadLocal的深入理解
    ThreadLocal前言:    ThreadLocal很容易让人望文生义,想当然地认为是一个“本地线程”。其实,ThreadLocal并不是一个Thread,ThreadLocal是一个线程内部的数据存储类...
    99+
    2023-05-30
    android threadlocal roi
  • 深入浅析Pytorch中stack()方法
    目录1. 概念2. 参数3. 举例3.1 四个shape为[3, 3]的张量3.1.1 dim=0的情况下,直接来看结果。3.1.2 dim=1的情况下3.1.2 dim=2的情况下...
    99+
    2024-04-02
  • JavaWeb中Servlet的深入理解
    1.servlet:定义: 接口 2.配置servlet: public class HelloServlet extends HttpServlet {} HttpServlet...
    99+
    2024-04-02
  • 深入理解vue3中的reactive()
    目录开始调试复杂数据类型基本数据类型proxy对象ref类型Map类型和Set类型在vue3的开发中,reactive是提供实现响应式数据的方法。日常开发这个是使用频率很高的api。...
    99+
    2023-02-17
    vue3 reactive() vue  reactive
  • 深入理解Java中的HashMap
    目录一、HashMap的结构图示二、HashMap的成员变量以及含义2.1、hash方法说明2.2、tableSizeFor方法说明三、HashMap的构造方法四、HashMap元素...
    99+
    2024-04-02
  • 深入理解python中的ThreadLocal
    ThreadLocal在threading模块中,可以见得它是为我们的线程服务的。 它的主要作用是存储当前线程的变量,各个线程之间的变量名是可以相同的,但是线程之间的变量是隔离的,也...
    99+
    2023-03-08
    python ThreadLocal
  • 深入理解html5中的position
    在HTML5中,position属性用于控制元素在文档中的定位方式。它有四个可选值:static(默认值)、relative、fix...
    99+
    2023-09-14
    html5
  • 深入理解Python中的__builti
    0.说明        这里的说明主要是以Python 2.7为例,因为在Python 3+中,__builtin__模块被命名为builtins,下面主要是探讨Python 2.x中__builtin__模块和__builtins__模块...
    99+
    2023-01-31
    Python __builti
  • 深入理解php中unset()
    目录概述变化情况情况一:情况二:情况三:概述 unset()经常会被用到,用于销毁指定的变量,但它有自己的行为模式,如果不仔细的话可能会被中文解释给迷惑: 先来看看官方文档的说法: ...
    99+
    2024-04-02
  • 深入理解.NET中的异步
    目录一、前言二、初看异步三、多线程编程四、异步编程五、Task (ValueTask)六、Task.Run七、自己封装异步逻辑八、同步方式调用异步代码九、void async 是什么...
    99+
    2024-04-02
  • 对Vue3中reactive的深入理解
    目录Vue3 reactive的理解1.什么是reactive2.reactive注意点Vue3笔记 reactive函数Vue3 reactive的理解 1.什么是reactive...
    99+
    2024-04-02
  • 深入理解vue中的 slot-scope=“scope“
    目录理解vue的 slot-scope=“scope“vue中的slot和slot-scope使用插槽的作用具名插槽 作用域插槽总结理解vue的 s...
    99+
    2022-12-09
    vue slot-scope=scope slot-scope=scope
  • 深入理解Node.js中的Worker线程
    目录概述Node.js 中 CPU 密集型应用的历史为 CPU 密集型操作使用 worker 线程Worker 线程是如何工作的?Node.js 的 workers 是如何并行的?跨...
    99+
    2024-04-02
  • C++中对象&类的深入理解
    什么是对象 任何事物都是一个对象, 也就是传说中的万物皆为对象. 对象的组成: 数据: 描述对象的属性 函数: 描述对象的行为, 根据外界的信息进行相应操作的代码...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作