iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >python深度学习之多标签分类器及pytorch实现源码
  • 363
分享到

python深度学习之多标签分类器及pytorch实现源码

2024-04-02 19:04:59 363人浏览 泡泡鱼

Python 官方文档:入门教程 => 点击学习

摘要

目录多标签分类器多标签分类器损失函数代码实现多标签分类器 多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分

多标签分类器

多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:

  • 类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
  • 类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云

如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。

多标签分类器损失函数

代码实现

针对图像的多标签分类器PyTorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别1的多标签可以为[1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。

from torchvision import datasets, transfORMs
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.Sq1 = nn.Sequential(         
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),   # (16, 28, 28)                           #  output: (16, 28, 28)
            nn.ReLU(),                    
            nn.MaxPool2d(kernel_size=2),    # (16, 14, 14)
        )
        self.Sq2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),  # (32, 14, 14)
            nn.ReLU(),                      
            nn.MaxPool2d(2),                # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 100)  
    def forward(self, x):
        x = self.Sq1(x)
        x = self.Sq2(x)
        x = x.view(x.size(0), -1)    
        x = self.out(x)
        ## Sigmoid activation   
        output = F.sigmoid(x)  # 1/(1+e**(-x))
        return output
def loss_fn(pred, target):
    return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()
def multilabel_generate(label):
    Y1 = F.one_hot(label, num_classes = 100)
    Y2 = F.one_hot(label+10, num_classes = 100)
    Y3 = F.one_hot(label+50, num_classes = 100) 	
    multilabel = Y1+Y2+Y3
    return multilabel
        
# def multilabel_generate(label):
# 	multilabel_dict = {}
# 	multi_list = []
# 	for i in range(label.shape[0]):
# 		multi_list.append(multilabel_dict[label[i].item()])
# 	multilabel_tensor = torch.tensor(multi_list)
#     return multilabel
def train():
    epoches = 10
    mnist_net = CNN()
    mnist_net.train()
    opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)
    for epoch in range(epoches):
    	loss = 0 
    	for batch_X, batch_Y in train_loader:
    		opitimizer.zero_grad()
    		outputs = mnist_net(batch_X)
    		loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]
    		loss.backward()
    		opitimizer.step()
    		print(loss)
if __name__ == '__main__':
	train()

以上就是python深度学习之多标签分类器及pytorch源码的详细内容,更多关于多标签分类器pytorch源码的资料请关注编程网其它相关文章!

--结束END--

本文标题: python深度学习之多标签分类器及pytorch实现源码

本文链接: https://www.lsjlt.com/news/138045.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • python深度学习之多标签分类器及pytorch实现源码
    目录多标签分类器多标签分类器损失函数代码实现多标签分类器 多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分...
    99+
    2022-11-13
  • Python Pytorch深度学习之图像分类器
    目录一、简介二、数据集三、训练一个图像分类器1、导入package吧2、归一化处理+贴标签吧3、先来康康训练集中的照片吧4、定义一个神经网络吧5、定义一个损失函数和优化器吧6、训练网...
    99+
    2022-11-12
  • Pytorch深度学习之实现病虫害图像分类
    目录一、pytorch框架1.1、概念1.2、机器学习与深度学习的区别1.3、在python中导入pytorch成功截图二、数据集三、代码复现3.1、导入第三方库3.2、CNN代码3...
    99+
    2022-11-12
  • Python深度学习pytorch实现图像分类数据集
    目录读取数据集读取小批量整合所有组件目前广泛使用的图像分类数据集之一是MNIST数据集。如今,MNIST数据集更像是一个健全的检查,而不是一个基准。 为了提高难度,我们将在接下来的章...
    99+
    2022-11-12
  • python深度学习借助多标签分类器进行对抗训练
    目录1 摘要2 方法介绍2.1 多分类任务对抗样本2.2 多标签任务对抗样本2.3 双分类器对抗训练人脸表情对抗训练1 摘要 当前深度模型抵御对抗攻击最有效的方式就是对抗训练,神经网...
    99+
    2022-11-13
  • Python深度学习之FastText实现文本分类详解
    FastText是一个三层的神经网络,输入层、隐含层和输出层。 FastText的优点: 使用浅层的神经网络实现了word2vec以及文本分类功能,效果与深层网络差不多,节约资源,...
    99+
    2022-11-11
  • Python机器学习之基于Pytorch实现猫狗分类
    目录一、环境配置二、数据集的准备三、猫狗分类的实例四、实现分类预测测试五、参考资料一、环境配置 安装Anaconda 具体安装过程,请点击本文 配置Pytorch pip install -i https://...
    99+
    2022-06-02
    Pytorch实现猫狗分类 Python Pytorch
  • Python机器学习之如何基于Pytorch实现猫狗分类
    这篇文章给大家分享的是有关Python机器学习之如何基于Pytorch实现猫狗分类的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。一、环境配置安装Anaconda具体安装过程,请点击本文配置Pytorchpip&n...
    99+
    2023-06-15
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作