iis服务器助手广告广告
返回顶部
首页 > 资讯 > 精选 >Pytorch写数字怎么识别LeNet模型
  • 754
分享到

Pytorch写数字怎么识别LeNet模型

2023-06-28 17:06:11 754人浏览 独家记忆
摘要

这篇文章主要为大家分析了PyTorch写数字怎么识别LeNet模型的相关知识点,内容详细易懂,操作细节合理,具有一定参考价值。如果感兴趣的话,不妨跟着跟随小编一起来看看,下面跟着小编一起深入学习“Pytorch写数字怎么识别LeNet模型”

这篇文章主要为大家分析了PyTorch写数字怎么识别LeNet模型的相关知识点,内容详细易懂,操作细节合理,具有一定参考价值。如果感兴趣的话,不妨跟着跟随小编一起来看看,下面跟着小编一起深入学习“Pytorch写数字怎么识别LeNet模型”的知识吧。

LeNet网络

Pytorch写数字怎么识别LeNet模型

LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下

from PIL import Imageimport cv2import matplotlib.pyplot as pltimport torchvisionfrom torchvision import transfORMsimport torchfrom torch.utils.data import DataLoaderimport torch.nn as nnimport numpy as npimport tqdm as tqdmclass LeNet(nn.Module):    def __init__(self) -> None:        super().__init__()        self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),                                        nn.AvgPool2d(kernel_size=2,stride=2),                                        nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),                                        nn.AvgPool2d(kernel_size=2,stride=2),                                        nn.Flatten(),                                        nn.Linear(16*25,120),nn.Sigmoid(),                                        nn.Linear(120,84),nn.Sigmoid(),                                        nn.Linear(84,10))                def forward(self,x):        return self.sequential(x)class MLP(nn.Module):    def __init__(self) -> None:        super().__init__()        self.sequential = nn.Sequential(nn.Flatten(),                          nn.Linear(28*28,120),nn.Sigmoid(),                          nn.Linear(120,84),nn.Sigmoid(),                          nn.Linear(84,10))                def forward(self,x):        return self.sequential(x)epochs = 15batch = 32lr=0.9loss = nn.CrossEntropyLoss()model = LeNet()optimizer = torch.optim.SGD(model.parameters(),lr)device = torch.device('cuda')root = r"./"trans_compose  = transforms.Compose([transforms.ToTensor(),                    ])train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)model.to(device)loss.to(device)# model.apply(init_weights)for epoch in range(epochs):  train_loss = 0  test_loss = 0  correct_train = 0  correct_test = 0  for index,(x,y) in enumerate(train_loader):    x = x.to(device)    y = y.to(device)    predict = model(x)    L = loss(predict,y)    optimizer.zero_grad()    L.backward()    optimizer.step()    train_loss = train_loss + L    correct_train += (predict.argmax(dim=1)==y).sum()  acc_train = correct_train/(batch*len(train_loader))  with torch.no_grad():    for index,(x,y) in enumerate(test_loader):      [x,y] = [x.to(device),y.to(device)]      predict = model(x)      L1 = loss(predict,y)      test_loss = test_loss + L1      correct_test += (predict.argmax(dim=1)==y).sum()    acc_test = correct_test/(batch*len(test_loader))  print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')

训练结果

epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229

泛化能力测试

找了一张图片,将其分割成只含一个数字的图片进行测试

Pytorch写数字怎么识别LeNet模型

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)h,w = images_np.shapeimages_np = np.array(255*torch.ones(h,w))-images_np#图片反色images = Image.fromarray(images_np)plt.figure(1)plt.imshow(images)test_images = []for i in range(10):  for j in range(16):    test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])sample = test_images[77]sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))predict = model(sample_tensor)output = predict.argmax()print(output)plt.figure(2)plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))

Pytorch写数字怎么识别LeNet模型

此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。
模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。

Pytorch写数字怎么识别LeNet模型

将所有用来泛化性测试的图片进行准确率测试:

correct = 0i = 0cnt = 1for sample in test_images:  sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)  sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))  predict = model(sample_tensor)  output = predict.argmax()  if(output==i):    correct+=1  if(cnt%16==0):    i+=1  cnt+=1acc_g = correct/len(test_images)print(f'acc_g:{acc_g}')

如果不反色,acc_g=0.15

acc_g:0.50625

pytorch的优点

1.PyTorch是相当简洁且高效快速的框架;2.设计追求最少的封装;3.设计符合人类思维,它让用户尽可能地专注于实现自己的想法;4.与Google的Tensorflow类似,FAIR的支持足以确保PyTorch获得持续的开发更新;5.PyTorch作者亲自维护的论坛 供用户交流和求教问题6.入门简单

关于“Pytorch写数字怎么识别LeNet模型”就介绍到这了,更多相关内容可以搜索编程网以前的文章,希望能够帮助大家答疑解惑,请多多支持编程网网站!

--结束END--

本文标题: Pytorch写数字怎么识别LeNet模型

本文链接: https://www.lsjlt.com/news/320995.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Pytorch写数字识别LeNet模型
    目录LeNet网络训练结果泛化能力测试LeNet网络 LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下 from PIL import Image imp...
    99+
    2024-04-02
  • Pytorch写数字怎么识别LeNet模型
    这篇文章主要为大家分析了Pytorch写数字怎么识别LeNet模型的相关知识点,内容详细易懂,操作细节合理,具有一定参考价值。如果感兴趣的话,不妨跟着跟随小编一起来看看,下面跟着小编一起深入学习“Pytorch写数字怎么识别LeNet模型”...
    99+
    2023-06-28
  • pytorch实现手写数字图片识别
    本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下 数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备...
    99+
    2024-04-02
  • 用PyTorch构建基于卷积神经网络的手写数字识别模型
    本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 目录 一、MINST数据集介绍与分析 二、卷积神经网络 三、基于卷积神经网络的手写数字识别 一、MINST数据集...
    99+
    2023-09-06
    python 机器学习 人工智能 深度学习
  • pytorch实现mnist手写彩色数字识别
    目录前言一 前期工作1.设置GPU或者cpu2.导入数据二 数据预处理1.加载数据2.可视化数据3.再次检查数据三 搭建网络四 训练模型1.设置学习率2.模型训练五 模型评估1.Lo...
    99+
    2024-04-02
  • pytorch如何实现手写数字图片识别
    这篇文章给大家分享的是有关pytorch如何实现手写数字图片识别的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。具体内容如下数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备...
    99+
    2023-06-15
  • PyTorch实现MNIST数据集手写数字识别详情
    目录一、PyTorch是什么?二、程序示例1.引入必要库2.下载数据集3.加载数据集4.搭建CNN模型并实例化5.交叉熵损失函数损失函数及SGD算法优化器6.训练函数7.测试函数8....
    99+
    2024-04-02
  • PyTorch实现手写数字识别的示例代码
    目录加载手写数字的数据数据加载器(分批加载)建立模型模型训练测试集抽取数据,查看预测结果计算模型精度自己手写数字进行预测加载手写数字的数据 组成训练集和测试集,这里已经下载好了,所以...
    99+
    2024-04-02
  • PyTorch简单手写数字识别的实现过程
    目录一、包导入及所需数据的下载关于数据集引入的改动二、进行数据处理变换操作三、数据预览测试和数据装载四、模型搭建和参数优化关于模型搭建的改动总代码:测试总结具体流程: ① 导入相应...
    99+
    2024-04-02
  • pytorch如何利用ResNet18进行手写数字识别
    目录利用ResNet18进行手写数字识别先写resnet18.py再写绘图utils.py最后是主函数mnist_train.py总结利用ResNet18进行手写数字识别 先写res...
    99+
    2023-02-02
    pytorch ResNet18 ResNet18手写数字识别 pytorch手写数字识别
  • 详解Python手写数字识别模型的构建与使用
    目录一:手写数字模型构建与保存1 加载数据集2 特征数据 标签数据3 训练集 测试集4 数据流图 输入层5 隐藏层6 损失函数7 梯度下降算法8 输出损失值 9 模型 保存...
    99+
    2022-12-22
    Python手写数字识别 Python手写数字识别模型 Python 数字 识别
  • 用Pytorch构建一个喵咪识别模型
     本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 目录 一、前言 二、问题阐述及理论流程 2.1问题阐述 2.2猫咪图片识别原理  三、用PyTorch 实现  3...
    99+
    2023-09-03
    深度学习 人工智能 python
  • PyTorch简单手写数字识别的实现过程是怎样的
    本篇文章给大家分享的是有关PyTorch简单手写数字识别的实现过程是怎样的,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、包导入及所需数据的下载torchvision包的主要...
    99+
    2023-06-25
  • PyTorch实现手写数字的识别入门小白教程
    目录手写数字识别(小白入门)1.数据预处理2.训练模型3.测试模型,保存4.调用模型5.完整代码手写数字识别(小白入门) 今早刚刚上了节实验课,关于逻辑回归,所以手有点刺挠就想发个博...
    99+
    2024-04-02
  • pytorch教程实现mnist手写数字识别代码示例
    目录1.构建网络2.编写训练代码3.编写测试代码4.指导程序train和test5.完整代码 1.构建网络 nn.Moudle是pytorch官方指定的编写Net模块,在init函数...
    99+
    2024-04-02
  • Pytorch实现图像识别之数字识别(附详细注释)
    使用了两个卷积层加上两个全连接层实现 本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份 详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到...
    99+
    2024-04-02
  • 超详细PyTorch实现手写数字识别器的示例代码
    前言 深度学习中有很多玩具数据,mnist就是其中一个,一个人能否入门深度学习往往就是以能否玩转mnist数据来判断的,在前面很多基础介绍后我们就可以来实现一个简单的手写数字识别的网...
    99+
    2024-04-02
  • Python怎么构建人脸识别模型
    这篇文章主要讲解了“Python怎么构建人脸识别模型”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Python怎么构建人脸识别模型”吧!01 介绍你是否意识到,每当你上传照片到Faceboo...
    99+
    2023-06-16
  • PyTorch怎么实现图像识别
    这篇文章主要介绍“PyTorch怎么实现图像识别”,在日常操作中,相信很多人在PyTorch怎么实现图像识别问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”PyTorch怎么实现图像识别”的疑惑有所帮助!接下来...
    99+
    2023-06-29
  • C++OpenCV实战之手写数字识别
    目录前言一、准备数据集二、KNN训练三、模型预测及结果显示四、源码总结前言 本案例通过使用machine learning机器学习模块进行手写数字识别。源码注释也写得比较清楚啦,大家...
    99+
    2022-11-13
    C++ OpenCV手写数字识别 C++ OpenCV数字识别 OpenCV 数字识别
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作