广告
返回顶部
首页 > 资讯 > 后端开发 > Python >Python实战之MNIST手写数字识别详解
  • 671
分享到

Python实战之MNIST手写数字识别详解

2024-04-02 19:04:59 671人浏览 薄情痞子

Python 官方文档:入门教程 => 点击学习

摘要

目录数据集介绍1.数据预处理2.网络搭建3.网络配置关于优化器关于损失函数关于指标4.网络训练与测试5.绘制loss和accuracy随着epochs的变化图6.完整代码数据集介绍

数据集介绍

MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片,且内置于keras。本文采用Tensorflow下Keras(Keras中文文档)神经网络api进行网络搭建。

开始之前,先回忆下机器学习的通用工作流程( √表示本文用到,×表示本文没有用到 )

1.定义问题,收集数据集(√)

2.选择衡量成功的指标(√)

3.确定评估的方法(√)

4.准备数据(√)

5.开发比基准更好的模型(×)

6.扩大模型规模(×)

7.模型正则化与调节参数(×)

关于最后一层激活函数与损失函数的选择

下面开始正文~

1.数据预处理

首先导入数据,要使用mnist.load()函数

我们来看看它的源码声明:

def load_data(path='mnist.npz'):
  """Loads the [MNIST dataset](Http://yann.lecun.com/exdb/mnist/).

  This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
  along with a test set of 10,000 images.
  More info can be found at the
  [MNIST homepage](http://yann.lecun.com/exdb/mnist/).


  Arguments:
      path: path where to cache the dataset locally
          (relative to `~/.keras/datasets`).

  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
      **x_train, x_test**: uint8 arrays of grayscale image data with shapes
        (num_samples, 28, 28).

      **y_train, y_test**: uint8 arrays of digit labels (integers in range 0-9)
        with shapes (num_samples,).
  """

可以看到,里面包含了数据集的下载链接,以及数据集规模、尺寸以及数据类型的声明,且函数返回的是四个numpy array组成的两个元组。

导入数据集并reshape至想要形状,再标准化处理。

其中内置于keras的to_cateGorical()就是one-hot编码——将每个标签表示为全零向量,只有标签索引对应的元素为1.

eg: col=10

[0,1,9]-------->[ [1,0,0,0,0,0,0,0,0,0],
                  [0,1,0,0,0,0,0,0,0,0],
                  [0,0,0,0,0,0,0,0,0,1] ]        

我们可以手动实现它:

def one_hot(sequences,col):
        resuts=np.zeros((len(sequences),col))
        # for i,sequence in enumerate(sequences):
        #         resuts[i,sequence]=1
        for i in range(len(sequences)):
                for j in range(len(sequences[i])):
                        resuts[i,sequences[i][j]]=1
        return resuts

下面是预处理过程

def data_preprocess():
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    train_images = train_images.reshape((60000, 28, 28, 1))
    train_images = train_images.astype('float32') / 255
    #print(train_images[0])
    test_images = test_images.reshape((10000, 28, 28, 1))
    test_images = test_images.astype('float32') / 255

    train_labels = to_categorical(train_labels)
    test_labels = to_categorical(test_labels)
    return train_images,train_labels,test_images,test_labels

2.网络搭建

这里我们搭建的是卷积神经网络,就是包含一些卷积、池化、全连接的简单线性堆积。我们知道多个线性层堆叠实现的仍然是线性运算,添加层数并不会扩展假设空间(从输入数据到输出数据的所有可能的线性变换集合),因此需要添加非线性或激活函数。relu是最常用的激活函数,也可以用prelu、elu

def build_module():
    model = models.Sequential()
    #第一层卷积层,首层需要指出input_shape形状
    model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)))
    #第二层最大池化层
    model.add(layers.MaxPooling2D((2,2)))
    #第三层卷积层
    model.add(layers.Conv2D(64, (3,3), activation='relu'))
    #第四层最大池化层
    model.add(layers.MaxPooling2D((2,2)))
    #第五层卷积层
    model.add(layers.Conv2D(64, (3,3), activation='relu'))
    #第六层Flatten层,将3D张量平铺为向量
    model.add(layers.Flatten())
    #第七层全连接层
    model.add(layers.Dense(64, activation='relu'))
    #第八层softmax层,进行分类
    model.add(layers.Dense(10, activation='softmax'))
    return model

使用model.summary()查看搭建的网路结构:

3.网络配置

网络搭建好之后还需要关键的一步设置配置。比如:优化器——网络梯度下降进行参数更新的具体方法、损失函数——衡量生成值与目标值之间的距离、评估指标等。配置这些可以通过 model.compile() 参数传递做到。

我们来看看model.compile()的源码分析下:

  def compile(self,
              optimizer='rmsprop',
              loss=None,
              metrics=None,
              loss_weights=None,
              weighted_metrics=None,
              run_eagerly=None,
              steps_per_execution=None,
              **kwargs):
    """Configures the model for training.

关于优化器

优化器:字符串(优化器名称)或优化器实例。

字符串格式:比如使用优化器的默认参数

实例优化器进行参数传入:

keras.optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0)
model.compile(optimizer='rmsprop',loss='mean_squared_error')

建议使用优化器的默认参数 (除了学习率 lr,它可以被自由调节)

参数:

lr: float >= 0. 学习率。
rho: float >= 0. RMSProp梯度平方的移动均值的衰减率.
epsilon: float >= 0. 模糊因子. 若为 None, 默认为 K.epsilon()。
decay: float >= 0. 每次参数更新后学习率衰减值。

类似还有好多优化器,比如SGD、Adagrad、Adadelta、Adam、Adamax、Nadam等

关于损失函数

取决于具体任务,一般来说损失函数要能够很好的刻画任务。比如

1.回归问题

希望神经网络输出的值与ground-truth的距离更近,选取能刻画距离的loss应该会更合适,比如L1 Loss、MSE Loss等

2.分类问题

希望神经网络输出的类别与ground-truth的类别一致,选取能刻画类别分布的loss应该会更合适,比如cross_entropy

具体常见选择可查看文章开始处关于损失函数的选择

关于指标

常规使用查看上述列表即可。下面说说自定义评价函数:它应该在编译的时候(compile)传递进去。该函数需要以 (y_true, y_pred) 作为输入参数,并返回一个张量作为输出结果。

import keras.backend as K
def mean_pred(y_true, y_pred):
    return K.mean(y_pred)

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy', mean_pred])

4.网络训练与测试

1.训练(拟合)

使用model.fit(),它可以接受的参数列表

def fit(self,
          x=None,
          y=None,
          batch_size=None,
          epochs=1,
          verbose=1,
          callbacks=None,
          validation_split=0.,
          validation_data=None,
          shuffle=True,
          class_weight=None,
          sample_weight=None,
          initial_epoch=0,
          steps_per_epoch=None,
          validation_steps=None,
          validation_batch_size=None,
          validation_freq=1,
          max_queue_size=10,
          workers=1,
          use_multiprocessing=False):

这个源码有300多行长,具体的解读放在下次。

我们对训练数据进行划分,以64个样本为小批量进行网络传递,对所有数据迭代5次

model.fit(train_images, train_labels, epochs = 5, batch_size=64)

2.测试

 

使用model.evaluates()函数
 

test_loss, test_acc = model.evaluate(test_images, test_labels)

关于测试函数的返回声明:

Returns:
        Scalar test loss (if the model has a single output and no metrics)
        or list of scalars (if the model has multiple outputs
        and/or metrics). The attribute `model.metrics_names` will give you
        the display labels for the scalar outputs.

5.绘制loss和accuracy随着epochs的变化图

model.fit()返回一个History对象,它包含一个history成员,记录了训练过程的所有数据。

我们采用matplotlib.pyplot进行绘图,具体见后面完整代码。

Returns:
        A `History` object. Its `History.history` attribute is
        a record of training loss values and metrics values
        at successive epochs, as well as validation loss values
        and validation metrics values (if applicable).
def draw_loss(history):
    loss=history.history['loss']
    epochs=range(1,len(loss)+1)
    plt.subplot(1,2,1)#第一张图
    plt.plot(epochs,loss,'bo',label='Training loss')
    plt.title("Training loss")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1,2,2)#第二张图
    accuracy=history.history['accuracy']
    plt.plot(epochs,accuracy,'bo',label='Training accuracy')
    plt.title("Training accuracy")
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.suptitle("Train data")
    plt.legend()
    plt.show()

6.完整代码

from tensorflow.keras.datasets import mnist
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np
def data_preprocess():
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    train_images = train_images.reshape((60000, 28, 28, 1))
    train_images = train_images.astype('float32') / 255
    #print(train_images[0])
    test_images = test_images.reshape((10000, 28, 28, 1))
    test_images = test_images.astype('float32') / 255

    train_labels = to_categorical(train_labels)
    test_labels = to_categorical(test_labels)
    return train_images,train_labels,test_images,test_labels

#搭建网络
def build_module():
    model = models.Sequential()
    #第一层卷积层
    model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)))
    #第二层最大池化层
    model.add(layers.MaxPooling2D((2,2)))
    #第三层卷积层
    model.add(layers.Conv2D(64, (3,3), activation='relu'))
    #第四层最大池化层
    model.add(layers.MaxPooling2D((2,2)))
    #第五层卷积层
    model.add(layers.Conv2D(64, (3,3), activation='relu'))
    #第六层Flatten层,将3D张量平铺为向量
    model.add(layers.Flatten())
    #第七层全连接层
    model.add(layers.Dense(64, activation='relu'))
    #第八层softmax层,进行分类
    model.add(layers.Dense(10, activation='softmax'))
    return model
def draw_loss(history):
    loss=history.history['loss']
    epochs=range(1,len(loss)+1)
    plt.subplot(1,2,1)#第一张图
    plt.plot(epochs,loss,'bo',label='Training loss')
    plt.title("Training loss")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1,2,2)#第二张图
    accuracy=history.history['accuracy']
    plt.plot(epochs,accuracy,'bo',label='Training accuracy')
    plt.title("Training accuracy")
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.suptitle("Train data")
    plt.legend()
    plt.show()
if __name__=='__main__':
    train_images,train_labels,test_images,test_labels=data_preprocess()
    model=build_module()
    print(model.summary())
    model.compile(optimizer='rmsprop', loss = 'categorical_crossentropy', metrics=['accuracy'])
    history=model.fit(train_images, train_labels, epochs = 5, batch_size=64)
    draw_loss(history)
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print('test_loss=',test_loss,'  test_acc = ', test_acc)


迭代训练过程中loss和accuracy的变化

由于数据集比较简单,随便的神经网络设计在测试集的准确率可达到99.2%

以上就是python实战之MNIST手写数字识别详解的详细内容,更多关于Python MNIST手写数字识别的资料请关注编程网其它相关文章!

--结束END--

本文标题: Python实战之MNIST手写数字识别详解

本文链接: https://www.lsjlt.com/news/161670.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Python实战之MNIST手写数字识别详解
    目录数据集介绍1.数据预处理2.网络搭建3.网络配置关于优化器关于损失函数关于指标4.网络训练与测试5.绘制loss和accuracy随着epochs的变化图6.完整代码数据集介绍 ...
    99+
    2022-11-12
  • Python实战小项目之Mnist手写数字识别
    目录程序流程分析图:传播过程:代码展示:创建环境准备数据集下载数据集下载测试集绘制图像搭建神经网络训练模型测试模型保存训练模型运行结果展示:程序流程分析图: 传播过程: 代码展...
    99+
    2022-11-12
  • caffe的python接口之手写数字识别mnist实例
    目录引言一、数据准备二、导入caffe库,并设定文件路径二、生成配置文件三、生成参数文件solver四、开始训练模型五、完成的python文件引言 深度学习的第一个实例一般都是mni...
    99+
    2022-11-11
  • PyTorch实现MNIST数据集手写数字识别详情
    目录一、PyTorch是什么?二、程序示例1.引入必要库2.下载数据集3.加载数据集4.搭建CNN模型并实例化5.交叉熵损失函数损失函数及SGD算法优化器6.训练函数7.测试函数8....
    99+
    2022-11-11
  • Python MNIST手写体识别详解与试练
    【人工智能项目】MNIST手写体识别实验及分析 1.实验内容简述 1.1 实验环境 本实验采用的软硬件实验环境如表所示: 在Windows操作系统下,采用基于Tensorflow...
    99+
    2022-11-12
  • C++OpenCV实战之手写数字识别
    目录前言一、准备数据集二、KNN训练三、模型预测及结果显示四、源码总结前言 本案例通过使用machine learning机器学习模块进行手写数字识别。源码注释也写得比较清楚啦,大家...
    99+
    2022-11-13
    C++ OpenCV手写数字识别 C++ OpenCV数字识别 OpenCV 数字识别
  • pytorch实现mnist手写彩色数字识别
    目录前言一 前期工作1.设置GPU或者cpu2.导入数据二 数据预处理1.加载数据2.可视化数据3.再次检查数据三 搭建网络四 训练模型1.设置学习率2.模型训练五 模型评估1.Lo...
    99+
    2022-11-11
  • 机器学习python实战之手写数字识别
    看了上一篇内容之后,相信对K近邻算法有了一个清晰的认识,今天的内容——手写数字识别是对上一篇内容的延续,这里也是为了自己能更熟练的掌握k-NN算法。 我们有大约2000个训练样本和1000个左右测试样本,训...
    99+
    2022-06-04
    实战 机器 数字
  • Python中如何实现MNIST手写数字识别功能
    这篇文章主要为大家展示了“Python中如何实现MNIST手写数字识别功能”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Python中如何实现MNIST手写数字识别功能”这篇文章吧。数据集介绍M...
    99+
    2023-06-22
  • Java实现BP神经网络MNIST手写数字识别的示例详解
    目录一、神经网络的构建二、系统架构服务器客户端采用MVC架构一、神经网络的构建 (1):构建神经网络层次结构 由训练集数据可知,手写输入的数据维数为784维,而对应的输出结果为分别为...
    99+
    2023-01-31
    Java实现手写数字识别 Java手写数字识别 Java数字识别
  • Python中如何实现MNIST手写体识别
    这篇文章主要介绍Python中如何实现MNIST手写体识别,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!1.实验内容简述1.1 实验环境本实验采用的软硬件实验环境如表所示:在Windows操作系统下,采用基于Tens...
    99+
    2023-06-25
  • pytorch教程实现mnist手写数字识别代码示例
    目录1.构建网络2.编写训练代码3.编写测试代码4.指导程序train和test5.完整代码 1.构建网络 nn.Moudle是pytorch官方指定的编写Net模块,在init函数...
    99+
    2022-11-12
  • Python-OpenCV实战:利用KNN算法识别手写数字
    目录前言手写数字数据集 MNIST 介绍基准模型——利用 KNN 算法识别手写数字改进模型1——参数 K 对识别手写数字精确度的影响改进模型2——训练数据量对识别手写数字精确度的影响...
    99+
    2022-11-12
  • TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集
    基于MNIST数据集的逻辑回归模型做十分类任务 没有隐含层的Softmax Regression只能直接从图像的像素点推断是哪个数字,而没有特征抽象的过程。多层神经网络依靠隐含层,则...
    99+
    2022-11-12
  • TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集
    今天就跟大家聊聊有关TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。基于MNIST数据集的逻辑回归模型做十分...
    99+
    2023-06-25
  • python神经网络编程之手写数字识别
    目录写在之前一、代码框架二、准备工作三、框架的开始四、训练模型构建五、手写数字的识别六、源码七、思考写在之前 首先是写在之前的一些建议: 首先是关于这本书,我真的认为他是将神经网络里...
    99+
    2022-11-12
  • Python实战之实现截图识别文字
    目录前言一、获取百度智能云token二、百度借口调用三、搭建窗口化的程序以便于使用四、实现截图的自动保存五、将识别到的文字输出显示在窗口文本框中并将文字发送到剪切板六、提取识别后文字...
    99+
    2022-11-12
  • 详解Python手写数字识别模型的构建与使用
    目录一:手写数字模型构建与保存1 加载数据集2 特征数据 标签数据3 训练集 测试集4 数据流图 输入层5 隐藏层6 损失函数7 梯度下降算法8 输出损失值 9 模型 保存...
    99+
    2022-12-22
    Python手写数字识别 Python手写数字识别模型 Python 数字 识别
  • Python实现带GUI界面的手写数字识别
    目录1.效果图2.数据集3.关于模型4.关于GUI设计5.缺点6.遗留问题1.效果图 有点low,轻喷 点击选择图片会优先从当前目录查找 2.数据集 这部分我是对MNIST数据...
    99+
    2022-11-12
  • Python利用SVM算法实现识别手写数字
    目录前言使用 SVM 进行手写数字识别参数 C 和 γ 对识别手写数字精确度的影响完整代码前言 支持向量机 (Support Vector Machine, SVM) 是一种监督学习...
    99+
    2022-11-12
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作