iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch写数字识别LeNet模型
  • 882
分享到

Pytorch写数字识别LeNet模型

2024-04-02 19:04:59 882人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

目录LeNet网络训练结果泛化能力测试LeNet网络 LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下 from PIL import Image imp

LeNet网络

LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下

from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torchvision
from torchvision import transfORMs
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import tqdm as tqdm

class LeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Flatten(),
                                        nn.Linear(16*25,120),nn.Sigmoid(),
                                        nn.Linear(120,84),nn.Sigmoid(),
                                        nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Flatten(),
                          nn.Linear(28*28,120),nn.Sigmoid(),
                          nn.Linear(120,84),nn.Sigmoid(),
                          nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

epochs = 15
batch = 32
lr=0.9
loss = nn.CrossEntropyLoss()
model = LeNet()
optimizer = torch.optim.SGD(model.parameters(),lr)
device = torch.device('cuda')
root = r"./"
trans_compose  = transforms.Compose([transforms.ToTensor(),
                    ])
train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)
test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)
train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)

model.to(device)
loss.to(device)
# model.apply(init_weights)
for epoch in range(epochs):
  train_loss = 0
  test_loss = 0
  correct_train = 0
  correct_test = 0
  for index,(x,y) in enumerate(train_loader):
    x = x.to(device)
    y = y.to(device)
    predict = model(x)
    L = loss(predict,y)
    optimizer.zero_grad()
    L.backward()
    optimizer.step()
    train_loss = train_loss + L
    correct_train += (predict.argmax(dim=1)==y).sum()
  acc_train = correct_train/(batch*len(train_loader))
  with torch.no_grad():
    for index,(x,y) in enumerate(test_loader):
      [x,y] = [x.to(device),y.to(device)]
      predict = model(x)
      L1 = loss(predict,y)
      test_loss = test_loss + L1
      correct_test += (predict.argmax(dim=1)==y).sum()
    acc_test = correct_test/(batch*len(test_loader))
  print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')

训练结果

epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229

泛化能力测试

找了一张图片,将其分割成只含一个数字的图片进行测试

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)
h,w = images_np.shape
images_np = np.array(255*torch.ones(h,w))-images_np#图片反色
images = Image.fromarray(images_np)
plt.figure(1)
plt.imshow(images)
test_images = []
for i in range(10):
  for j in range(16):
    test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])
sample = test_images[77]
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
print(output)
plt.figure(2)
plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))

此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。
模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。

将所有用来泛化性测试的图片进行准确率测试:

correct = 0
i = 0
cnt = 1
for sample in test_images:
  sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
  sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
  predict = model(sample_tensor)
  output = predict.argmax()
  if(output==i):
    correct+=1
  if(cnt%16==0):
    i+=1
  cnt+=1
acc_g = correct/len(test_images)
print(f'acc_g:{acc_g}')

如果不反色,acc_g=0.15

acc_g:0.50625

到此这篇关于PyTorch写数字识别LeNet模型的文章就介绍到这了,更多相关Pytorch写数字识别LeNet模型内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: Pytorch写数字识别LeNet模型

本文链接: https://www.lsjlt.com/news/163183.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Pytorch写数字识别LeNet模型
    目录LeNet网络训练结果泛化能力测试LeNet网络 LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下 from PIL import Image imp...
    99+
    2024-04-02
  • Pytorch写数字怎么识别LeNet模型
    这篇文章主要为大家分析了Pytorch写数字怎么识别LeNet模型的相关知识点,内容详细易懂,操作细节合理,具有一定参考价值。如果感兴趣的话,不妨跟着跟随小编一起来看看,下面跟着小编一起深入学习“Pytorch写数字怎么识别LeNet模型”...
    99+
    2023-06-28
  • pytorch实现手写数字图片识别
    本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下 数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备...
    99+
    2024-04-02
  • 用PyTorch构建基于卷积神经网络的手写数字识别模型
    本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 目录 一、MINST数据集介绍与分析 二、卷积神经网络 三、基于卷积神经网络的手写数字识别 一、MINST数据集...
    99+
    2023-09-06
    python 机器学习 人工智能 深度学习
  • pytorch实现mnist手写彩色数字识别
    目录前言一 前期工作1.设置GPU或者cpu2.导入数据二 数据预处理1.加载数据2.可视化数据3.再次检查数据三 搭建网络四 训练模型1.设置学习率2.模型训练五 模型评估1.Lo...
    99+
    2024-04-02
  • pytorch如何实现手写数字图片识别
    这篇文章给大家分享的是有关pytorch如何实现手写数字图片识别的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。具体内容如下数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备...
    99+
    2023-06-15
  • PyTorch实现MNIST数据集手写数字识别详情
    目录一、PyTorch是什么?二、程序示例1.引入必要库2.下载数据集3.加载数据集4.搭建CNN模型并实例化5.交叉熵损失函数损失函数及SGD算法优化器6.训练函数7.测试函数8....
    99+
    2024-04-02
  • PyTorch实现手写数字识别的示例代码
    目录加载手写数字的数据数据加载器(分批加载)建立模型模型训练测试集抽取数据,查看预测结果计算模型精度自己手写数字进行预测加载手写数字的数据 组成训练集和测试集,这里已经下载好了,所以...
    99+
    2024-04-02
  • PyTorch简单手写数字识别的实现过程
    目录一、包导入及所需数据的下载关于数据集引入的改动二、进行数据处理变换操作三、数据预览测试和数据装载四、模型搭建和参数优化关于模型搭建的改动总代码:测试总结具体流程: ① 导入相应...
    99+
    2024-04-02
  • pytorch如何利用ResNet18进行手写数字识别
    目录利用ResNet18进行手写数字识别先写resnet18.py再写绘图utils.py最后是主函数mnist_train.py总结利用ResNet18进行手写数字识别 先写res...
    99+
    2023-02-02
    pytorch ResNet18 ResNet18手写数字识别 pytorch手写数字识别
  • 详解Python手写数字识别模型的构建与使用
    目录一:手写数字模型构建与保存1 加载数据集2 特征数据 标签数据3 训练集 测试集4 数据流图 输入层5 隐藏层6 损失函数7 梯度下降算法8 输出损失值 9 模型 保存...
    99+
    2022-12-22
    Python手写数字识别 Python手写数字识别模型 Python 数字 识别
  • 用Pytorch构建一个喵咪识别模型
     本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 目录 一、前言 二、问题阐述及理论流程 2.1问题阐述 2.2猫咪图片识别原理  三、用PyTorch 实现  3...
    99+
    2023-09-03
    深度学习 人工智能 python
  • PyTorch实现手写数字的识别入门小白教程
    目录手写数字识别(小白入门)1.数据预处理2.训练模型3.测试模型,保存4.调用模型5.完整代码手写数字识别(小白入门) 今早刚刚上了节实验课,关于逻辑回归,所以手有点刺挠就想发个博...
    99+
    2024-04-02
  • pytorch教程实现mnist手写数字识别代码示例
    目录1.构建网络2.编写训练代码3.编写测试代码4.指导程序train和test5.完整代码 1.构建网络 nn.Moudle是pytorch官方指定的编写Net模块,在init函数...
    99+
    2024-04-02
  • PyTorch简单手写数字识别的实现过程是怎样的
    本篇文章给大家分享的是有关PyTorch简单手写数字识别的实现过程是怎样的,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、包导入及所需数据的下载torchvision包的主要...
    99+
    2023-06-25
  • Pytorch实现图像识别之数字识别(附详细注释)
    使用了两个卷积层加上两个全连接层实现 本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份 详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到...
    99+
    2024-04-02
  • 超详细PyTorch实现手写数字识别器的示例代码
    前言 深度学习中有很多玩具数据,mnist就是其中一个,一个人能否入门深度学习往往就是以能否玩转mnist数据来判断的,在前面很多基础介绍后我们就可以来实现一个简单的手写数字识别的网...
    99+
    2024-04-02
  • C++OpenCV实战之手写数字识别
    目录前言一、准备数据集二、KNN训练三、模型预测及结果显示四、源码总结前言 本案例通过使用machine learning机器学习模块进行手写数字识别。源码注释也写得比较清楚啦,大家...
    99+
    2022-11-13
    C++ OpenCV手写数字识别 C++ OpenCV数字识别 OpenCV 数字识别
  • DenseNet121模型实现26个英文字母识别任务
    目录一、任务概述二、DenseNet介绍三、数据集介绍四、模型实现五、实验结果与分析六、总结一、任务概述 26个英文字母识别是一个基于计算机视觉的图像分类任务,旨在从包含26个不同字...
    99+
    2023-05-17
    DenseNet121识别英文字母 DenseNet121模型识别
  • Python实战之MNIST手写数字识别详解
    目录数据集介绍1.数据预处理2.网络搭建3.网络配置关于优化器关于损失函数关于指标4.网络训练与测试5.绘制loss和accuracy随着epochs的变化图6.完整代码数据集介绍 ...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作