iis服务器助手广告广告
返回顶部
首页 > 资讯 > 精选 >pytorch中torch.topk()函数怎么用
  • 513
分享到

pytorch中torch.topk()函数怎么用

2023-06-29 07:06:51 513人浏览 安东尼
摘要

这篇文章主要介绍“PyTorch中torch.topk()函数怎么用”,在日常操作中,相信很多人在pytorch中torch.topk()函数怎么用问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”pytorch

这篇文章主要介绍“PyTorch中torch.topk()函数怎么用”,在日常操作中,相信很多人在pytorch中torch.topk()函数怎么用问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”pytorch中torch.topk()函数怎么用”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

函数作用:

pytorch中torch.topk()函数怎么用

pytorch中torch.topk()函数怎么用

该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。

通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号。

举个栗子:

import numpy as npimport torchimport torch.utils.data.dataset as Datasetfrom torch.utils.data import Dataset,DataLoader####################准备一个数组#########################tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],             [3,4,5,1,1,1,1,1,1,1,1],             [7,8,9,1,1,1,1,1,1,1,1],             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)####################打印这个原数组#########################print('tensor1:')print(tensor1)#################使用torch.topk()这个函数##################print('使用torch.topk()这个函数得到:')'''k=3代表从原数组中取得3个元素,dim=1表示从原数组中的第一维获取元素(在本例中是分别从[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、  [7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]这四个数组中获取3个元素)其中largest=True表示从大到小取元素'''print(torch.topk(tensor1, k=3, dim=1, largest=True))#################打印这个函数第一个返回值####################print('函数第一个返回值topk[0]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])#################打印这个函数第二个返回值####################print('函数第二个返回值topk[1]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])'''#######################运行结果##########################tensor1:tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])使用torch.topk()这个函数得到:'得到的values是原数组dim=1的四组从大到小的三个元素值;得到的indices是获取到的元素值在原数组dim=1中的位置。'torch.return_types.topk(values=tensor([[10., 10.,  2.],        [ 5.,  4.,  3.],        [ 9.,  8.,  7.],        [ 7.,  4.,  1.]]),indices=tensor([[ 0, 10,  2],        [ 2,  1,  0],        [ 2,  1,  0],        [ 2,  1,  0]]))函数第一个返回值topk[0]如下tensor([[10., 10.,  2.],        [ 5.,  4.,  3.],        [ 9.,  8.,  7.],        [ 7.,  4.,  1.]])        函数第二个返回值topk[1]如下tensor([[ 0, 10,  2],        [ 2,  1,  0],        [ 2,  1,  0],        [ 2,  1,  0]])'''

该函数功能经常用来获取张量或者数组中最大或者最小的元素以及索引位置,是一个经常用到的基本函数。

实例演示

任务一:

取top1(最大值):

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])print(pred)values, indices = pred.topk(1, dim=0, largest=True, sorted=True)print(indices)print(values)# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。_, indices_max = pred.max(dim=0, keepdim=True)print(indices_max)print(indices_max == indices)输出:tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])tensor([[1, 1, 1, 1, 1]])tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])tensor([[1, 1, 1, 1, 1]])tensor([[True, True, True, True, True]])

任务二:

按行取出topk,将小于topk的置为inf:

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])print(pred)top_k = 2  # 按行求出每一行的最大的前两个值filter_value=-float('Inf')indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]print(indices_to_remove)pred[indices_to_remove] = filter_value  # 对于topk之外的其他元素的logits值设为负无穷print(pred) 输出:tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])tensor([[4],        [4],        [4],        [3]])tensor([[0.4053],        [1.8823],        [1.7255],        [0.3849]])tensor([[ True, False,  True,  True, False],        [ True, False,  True,  True, False],        [ True,  True, False,  True, False],        [ True, False,  True, False,  True]])tensor([[   -inf, -0.3873,    -inf,    -inf,  0.4053],        [   -inf,  1.4164,    -inf,    -inf,  1.8823],        [   -inf,    -inf,  1.2590,    -inf,  1.7255],        [   -inf,  0.3041,    -inf,  0.3849,    -inf]])

任务三:

import numpy as npimport torchimport torch.utils.data.dataset as Datasetfrom torch.utils.data import Dataset,DataLoadertensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],             [3,4,5,1,1,1,1,1,1,1,1],             [7,8,9,1,1,1,1,1,1,1,1],             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)# tensor2=torch.tensor([[3,2,1],#                       [6,5,4],#                       [1,4,7],#                       [9,8,7]],dtype=torch.float32)#print('tensor1:')print(tensor1)print('直接输出topk,会得到两个东西,我们需要的是第二个indices')print(torch.topk(tensor1, k=3, dim=1, largest=True))print('topk[0]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])print('topk[1]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])'''tensor1:tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])直接输出topk,会得到两个东西,我们需要的是第二个indicestorch.return_types.topk(values=tensor([[10., 10.,  2.],        [ 5.,  4.,  3.],        [ 9.,  8.,  7.],        [ 7.,  4.,  1.]]),indices=tensor([[ 0, 10,  2],        [ 2,  1,  0],        [ 2,  1,  0],        [ 2,  1,  0]]))topk[0]如下tensor([[10., 10.,  2.],        [ 5.,  4.,  3.],        [ 9.,  8.,  7.],        [ 7.,  4.,  1.]])topk[1]如下tensor([[ 0, 10,  2],        [ 2,  1,  0],        [ 2,  1,  0],        [ 2,  1,  0]])'''

到此,关于“pytorch中torch.topk()函数怎么用”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注编程网网站,小编会继续努力为大家带来更多实用的文章!

--结束END--

本文标题: pytorch中torch.topk()函数怎么用

本文链接: https://www.lsjlt.com/news/323426.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • pytorch中torch.topk()函数怎么用
    这篇文章主要介绍“pytorch中torch.topk()函数怎么用”,在日常操作中,相信很多人在pytorch中torch.topk()函数怎么用问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”pytorch...
    99+
    2023-06-29
  • pytorch中torch.topk()函数的快速理解
    目录函数作用:举个栗子:实例演示总结函数作用: 该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。 通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组...
    99+
    2024-04-02
  • PyTorch中torch.matmul()函数怎么使用
    这篇文章主要介绍了PyTorch中torch.matmul()函数怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇PyTorch中torch.matmul()函数怎么使用文章都会有所收获,下面我们一起来看...
    99+
    2023-07-06
  • pytorch中Parameter函数怎么使用
    这篇文章主要介绍了pytorch中Parameter函数怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch中Parameter函数怎么使用文章都会有所收获,下面我们一起来看看吧。用法介绍pyt...
    99+
    2023-06-29
  • Pytorch中怎么调用forward()函数
    这篇文章主要讲解了“Pytorch中怎么调用forward()函数”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Pytorch中怎么调用forward()函数”吧!Pytorch调用forw...
    99+
    2023-07-05
  • Pytorch中的torch.gather()函数怎么用
    这篇文章将为大家详细讲解有关Pytorch中的torch.gather()函数怎么用,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。参数说明以官方说明为例,gather()函数需要三个参数,输入input,...
    99+
    2023-06-25
  • pytorch中的torch.nn.Conv2d()函数怎么用
    这篇文章主要为大家展示了“pytorch中的torch.nn.Conv2d()函数怎么用”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“pytorch中的torch.nn.Conv2d()函数怎么...
    99+
    2023-06-29
  • pytorch中的view()函数怎么使用
    这篇文章主要介绍了pytorch中的view()函数怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch中的view()函数怎么使用文章都会有所收获,下面我们一起来看看吧。一、普通用法 (手动调...
    99+
    2023-06-29
  • pytorch中BatchNorm2d函数的参数怎么使用
    本篇内容主要讲解“pytorch中BatchNorm2d函数的参数怎么使用”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“pytorch中BatchNorm2d函数的参数怎么使用”吧!BN原理、作...
    99+
    2023-07-04
  • Pytorch中backward()多个loss函数怎么用
    这篇文章主要介绍Pytorch中backward()多个loss函数怎么用,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!Pytorch的backward()函数假若有多个loss函数,如何进行反向传播和更新呢?&nb...
    99+
    2023-06-15
  • Pytorch中的backward()多个loss函数怎么用
    这篇文章主要介绍了Pytorch中的backward()多个loss函数怎么用,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。Pytorch的backward()函数假若有多个...
    99+
    2023-06-15
  • pytorch中的squeeze函数、cat函数使用
    1 squeeze(): 去除size为1的维度,包括行和列。 至于维度大于等于2时,squeeze()不起作用。 行、例: >>> torch.rand(4,...
    99+
    2024-04-02
  • PyTorch中怎么定义损失函数
    在PyTorch中,我们可以使用torch.nn模块中的各种损失函数来定义损失函数。以下是一些常用的损失函数及其定义方法: 均方误...
    99+
    2024-04-02
  • pytorch中nn.Flatten()函数如何使用
    这篇文章主要介绍了pytorch中nn.Flatten()函数如何使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch中nn.Flatten()函数如何使用文章都会有所收获,下面我们一起来看看吧。t...
    99+
    2023-07-04
  • Pytorch中如何调用forward()函数
    目录Pytorch调用forward()函数Pytorch函数调用的问题和源码解读总结Pytorch调用forward()函数 Module类是nn模块里提供的一个模型构造类,是所有...
    99+
    2023-02-17
    Pytorch调用forward函数 Pytorch forward函数 Pytorch forward()函数
  • pytorch中Parameter函数用法示例
    目录用法介绍代码介绍用法介绍 pytorch中的Parameter函数可以对某个张量进行参数化。它可以将不可训练的张量转化为可训练的参数类型,同时将转化后的张量绑定到模型可训练参数的...
    99+
    2024-04-02
  • PyTorch中Torch.arange函数详解
    目录torch.arange函数详解函数原型用法参数说明关键字参数代码示例pyTorch中torch.range()和torch.arange()的区别总结torch.arange函...
    99+
    2023-02-03
    pytorch torch.arange函数 torch.arange() torch.arange函数
  • Pytorch中torch.cat()函数解析
    一. torch.cat()函数解析 1. 函数说明 1.1 官网:torch.cat(),函数定义及参数说明如下图所示: 1.2 函数功能 函数将两个张量(tensor)按指定维度拼接在一起,注意...
    99+
    2023-10-20
    pytorch 深度学习 python 神经网络
  • pytorch中的numel函数用法说明
    获取tensor中一共包含多少个元素 import torch x = torch.randn(3,3) print("number elements of x is ",x.n...
    99+
    2024-04-02
  • 如何在pytorch中使用numel函数
    本篇文章给大家分享的是有关如何在pytorch中使用numel函数,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。获取tensor中一共包含多少个元素import to...
    99+
    2023-06-15
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作