iis服务器助手广告广告
返回顶部
首页 > 资讯 > 精选 >TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集
  • 340
分享到

TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集

2023-06-25 12:06:16 340人浏览 泡泡鱼
摘要

今天就跟大家聊聊有关Tensorflow中Softmax逻辑回归如何识别手写数字MNIST数据集,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。基于MNIST数据集的逻辑回归模型做十分

今天就跟大家聊聊有关Tensorflow中Softmax逻辑回归如何识别手写数字MNIST数据集,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。

基于MNIST数据集的逻辑回归模型做十分类任务

没有隐含层的Softmax Regression只能直接从图像的像素点推断是哪个数字,而没有特征抽象的过程。多层神经网络依靠隐含层,则可以组合出高阶特征,比如横线、竖线、圆圈等,之后可以将这些高阶特征或者说组件再组合成数字,就能实现精准的匹配和分类。

import tensorflow as tfimport numpy as npimport input_dataprint('Download and Extract MNIST dataset')mnist = input_data.read_data_sets('data/', one_hot=True) # one_hot=True意思是编码格式为01编码print("tpye of 'mnist' is %s" % (type(mnist)))print("number of train data is %d" % (mnist.train.num_examples))print("number of test data is %d" % (mnist.test.num_examples))trainimg = mnist.train.imagestrainlabel = mnist.train.labelstestimg = mnist.test.imagestestlabel = mnist.test.labelsprint("MNIST loaded")"""print("type of 'trainimg' is %s"    % (type(trainimg)))print("type of 'trainlabel' is %s"  % (type(trainlabel)))print("type of 'testimg' is %s"     % (type(testimg)))print("type of 'testlabel' is %s"   % (type(testlabel)))print("------------------------------------------------")print("shape of 'trainimg' is %s"   % (trainimg.shape,))print("shape of 'trainlabel' is %s" % (trainlabel.shape,))print("shape of 'testimg' is %s"    % (testimg.shape,))print("shape of 'testlabel' is %s"  % (testlabel.shape,))"""x = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10]) # None is for infinitew = tf.Variable(tf.zeros([784, 10])) # 为了方便直接用0初始化,可以高斯初始化b = tf.Variable(tf.zeros([10])) # 10分类的任务,10种label,所以只需要初始化10个bpred = tf.nn.softmax(tf.matmul(x, w) + b) # 前向传播的预测值cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=[1])) # 交叉熵损失函数optm = tf.train.GradientDescentOptimizer(0.01).minimize(cost)corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) # tf.equal()对比预测值的索引和真实label的索引是否一样,一样返回True,不一样返回Falseaccr = tf.reduce_mean(tf.cast(corr, tf.float32))init = tf.global_variables_initializer() # 全局参数初始化器training_epochs = 100 # 所有样本迭代100次batch_size = 100 # 每进行一次迭代选择100个样本display_step = 5# SESSIONsess = tf.Session() # 定义一个Sessionsess.run(init) # 在sess里run一下初始化操作# MINI-BATCH LEARNINGfor epoch in range(training_epochs): # 每一个epoch进行循环    avg_cost = 0. # 刚开始损失值定义为0    num_batch = int(mnist.train.num_examples/batch_size)    for i in range(num_batch): # 每一个batch进行选择        batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 通过next_batch()就可以一个一个batch的拿数据,        sess.run(optm, feed_dict={x: batch_xs, y: batch_ys}) # run一下用梯度下降进行求解,通过placeholder把x,y传进来        avg_cost += sess.run(cost, feed_dict={x: batch_xs, y:batch_ys})/num_batch    # DISPLAY    if epoch % display_step == 0: # display_step之前定义为5,这里每5个epoch打印一下        train_acc = sess.run(accr, feed_dict={x: batch_xs, y:batch_ys})        test_acc = sess.run(accr, feed_dict={x: mnist.test.images, y: mnist.test.labels})        print("Epoch: %03D/%03d cost: %.9f TRAIN ACCURACY: %.3f TEST ACCURACY: %.3f"              % (epoch, training_epochs, avg_cost, train_acc, test_acc))print("DONE")

迭代100次跑一下模型,最终,在测试集上可以达到92.2%的准确率,虽然还不错,但是还达不到实用的程度。手写数字的识别的主要应用场景是识别银行支票,如果准确率不够高,可能会引起严重的后果。

Epoch: 095/100 loss: 0.283259882 train_acc: 0.940 test_acc: 0.922

插一些知识点,关于tensorflow中一些函数的用法

sess = tf.InteractiveSession()arr = np.array([[31, 23,  4, 24, 27, 34],                [18,  3, 25,  0,  6, 35],                [28, 14, 33, 22, 30,  8],                [13, 30, 21, 19,  7,  9],                [16,  1, 26, 32,  2, 29],                [17, 12,  5, 11, 10, 15]])
在tensorflow中打印要用.eval()tf.rank(arr).eval() # 打印矩阵arr的维度tf.shape(arr).eval() # 打印矩阵arr的大小tf.argmax(arr, 0).eval() # 打印最大值的索引,参数0为按列求索引,1为按行求索引

看完上述内容,你们对TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集有进一步的了解吗?如果还想了解更多知识或者相关内容,请关注编程网精选频道,感谢大家的支持。

--结束END--

本文标题: TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集

本文链接: https://www.lsjlt.com/news/304765.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • c++中函数返回值的类型是由什么决定的
    在 c++ 中,函数返回值类型由其函数原型的类型决定,包括:函数原型指定返回值类型:在函数名称后跟冒号,再跟返回值类型。默认返回值类型为 int:如果不指定返回值类型,默认类型为 int...
    99+
    2024-05-14
    c++
  • 在c++中,什么叫函数的返回值
    在 c++ 中,函数只能返回一个值。解决方法:引用传递、结构体或类、out 参数。没有返回值的函数可以使用 void 类型,表示不返回任何值。 什么是 C++ 中函数的返回值? 在 C...
    99+
    2024-05-14
    c++
  • c++中static的作用和用法
    c++ 中的 static 关键字用于声明静态变量、函数或类成员,使其在程序生命周期内存在或与类的每个实例关联。具体用法如下:静态变量:在函数外声明,仅创建一份副本,在程序启动时初始化且...
    99+
    2024-05-14
    c++
  • static在c和c++中的区别
    static关键字在c和c++中用于控制变量的生命周期和作用域。在c中,它延长局部变量和限制全局变量的作用域。在c++中,它还用于定义类成员变量和函数、命名空间中的变量和函数,以及函数内...
    99+
    2024-05-14
    c语言 c++ 作用域
  • c++中a++与++a的区别
    c++ 中 a++ 和 ++a 区别:后缀递增 a++ 先返回原始值,再递增;前缀递增 ++a 先递增,再返回递增后的值。 C++ 中 a++ 与 ++a 的区别 在 C++ 中,a+...
    99+
    2024-05-14
    c++
  • if else在c++中的用法
    在 c++ 中,if else 语句根据条件执行不同代码块的语法为:if (condition) { } else { }。它可用于:检查数字是否为正数根据条件执行嵌套 if els...
    99+
    2024-05-14
    c++
  • struct在c和c++中的区别
    c和c++中struct的区别包括:c中成员默认公开访问,c++中默认私有访问。c++可以在struct定义中初始化成员,c中不允许。c++支持成员函数,c不支持。c++不支持匿名str...
    99+
    2024-05-14
    c++
  • c++中的所有函数都是传值调用吗
    函数调用类型可分为传值调用和引用调用,默认采用传值调用,传值调用中形参接收实参副本,引用调用中形参接收实参引用,对形参进行的修改也会影响实参。 C++中的函数调用类型 C++中,函数调...
    99+
    2024-05-14
    c++
  • c++中ifdef的用法
    c++ 中的 #ifdef 预处理器指令用于根据预定义宏是否存在来编译或不编译代码块。它的语法是 #ifdef ,其作用包括:检查宏是否存在,如果宏已定义,则编译其后的代码块;实现条件编...
    99+
    2024-05-14
    c++
  • c++中的函数调用有哪几种方式?它们有什么区别
    c++ 中的函数调用方式有 4 种:值传递(复制实参值,不影响实参)、引用传递(传递实参地址,修改形参值会修改实参)、指针传递(传递实参指向的内存地址,修改指向的值会影响实参)、rval...
    99+
    2024-05-14
    c++
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作