Python 官方文档:入门教程 => 点击学习
tensor.squeeze() 和 tensor.unsqueeze() 是 PyTorch 中用于改变 tensor 形状的两个函数,它们的作用如下: tensor.squeez
tensor.squeeze() 和 tensor.unsqueeze() 是 PyTorch 中用于改变 tensor 形状的两个函数,它们的作用如下:
下面给出例子来说明它们的使用。
import torch
# 创建一个形状为 (1, 3, 1, 2) 的 tensor
x = torch.randn(1, 3, 1, 2)
print(x.shape) # torch.Size([1, 3, 1, 2])
# 压缩尺寸为 1 的维度
y = x.squeeze()
print(y.shape) # torch.Size([3, 2])
# 指定要压缩的维度
y = x.squeeze(dim=0)
print(y.shape) # torch.Size([3, 1, 2])
在上面的例子中,我们创建了一个形状为 (1, 3, 1, 2) 的 tensor,然后使用 squeeze() 函数压缩了尺寸为 1 的维度。在第二个 squeeze() 调用中,我们指定了要压缩的维度为 0,也就是第一个维度,因此第一个维度的大小被压缩为 1,变成了形状为 (3, 1, 2) 的 tensor。
import torch
# 创建一个形状为 (3, 2) 的 tensor
x = torch.randn(3, 2)
print(x.shape) # torch.Size([3, 2])
# 在维度 0 上插入新维度
y = x.unsqueeze(dim=0)
print(y.shape) # torch.Size([1, 3, 2])
# 在维度 1 上插入新维度
y = x.unsqueeze(dim=1)
print(y.shape) # torch.Size([3, 1, 2])
# 在倒数第二个维度上插入新维度
y = x.unsqueeze(dim=-2)
print(y.shape) # torch.Size([3, 1, 2])
在上面的例子中,我们创建了一个形状为 (3, 2) 的 tensor,然后使用 unsqueeze() 函数在不同的位置插入了新维度。在第一个 unsqueeze() 调用中,我们在维度 0 上插入了新维度,因此新的 tensor 形状为 (1, 3, 2)。在第二个和第三个 unsqueeze() 调用中,我们分别在维度 1 和倒数第二个维度上插入了新维度,分别得到了形状为 (3, 1, 2) 和 (3, 2, 1) 的 tensor。
到此这篇关于tensor.squeeze函数和tensor.unsqueeze函数的使用详解的文章就介绍到这了,更多相关tensor.squeeze函数和tensor.unsqueeze函数内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!
--结束END--
本文标题: tensor.squeeze函数和tensor.unsqueeze函数的使用详解
本文链接: https://www.lsjlt.com/news/199069.html(转载时请注明来源链接)
有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341
下载Word文档到电脑,方便收藏和打印~
2024-03-01
2024-03-01
2024-03-01
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
回答
回答
回答
回答
回答
回答
回答
回答
回答
回答
0