iis服务器助手广告
返回顶部
首页 > 资讯 > 精选 >CNN如何解决Flowers图像分类任务
  • 676
分享到

CNN如何解决Flowers图像分类任务

2023-07-05 10:07:35 676人浏览 安东尼
摘要

本篇内容介绍了“CNN如何解决Flowers图像分类任务”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!加载并展示数据(1)该数据需要从网上下

本篇内容介绍了“CNN如何解决Flowers图像分类任务”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!

加载并展示数据

(1)该数据需要从网上下载,需要耐心等待片刻,下载下来自动会存放在“你的主目录.keras\datasets\flower_photos”。

(2)数据中总共有 5 种类,分别是 daisy、 dandelion、roses、sunflowers、tulips,总共包含了 3670 张图片。

(3) 随机展示了一张花朵的图片。

import matplotlib.pyplot as pltimport numpy as npimport PILimport Tensorflow as tfimport pathlibfrom tensorflow import kerasfrom tensorflow.keras import layersfrom tensorflow.keras.models import Sequentialimport randomdataset_url = "https://storage.Googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)data_dir = pathlib.Path(data_dir)image_count = len(list(data_dir.glob('*/*.jpg')))print("总共包含%d张图片,下面随便展示一张玫瑰的图片样例:"%image_count)roses = list(data_dir.glob('roses/*'))PIL.Image.open(str(random.choice(roses)))

构件处理图像的 pipeline

(1)使用 tf.keras.utils.image_dataset_from_directory 可以将我们的花朵图片数据,从磁盘加载到内存中,并形成 tensorflow 高效的 tf.data.Dataset 类型。

(2)我们将数据集 shuffle 之后,进行二八比例的随机抽取分配,80% 的数据作为我们的训练集,共 2936 张图片, 20% 的数据集作为我们的测试集,共 734 张图片。

(3)我们使用 Dataset.cache 和 Dataset.prefetch 来提升数据的处理速度,使用 cache 在将数据从磁盘加载到 cache 之后,就可以将数据一直放 cache 中便于我们的后续访问,这可以保证在训练过程中数据的处理不会成为计算的瓶颈。另外使用 prefetch 可以在 GPU 训练模型的时候,CPU 将之后需要的数据提前进行处理放入 cache 中,也是为了提高数据的处理性能,加快整个训练过程,不至于训练模型时浪费时间等待数据。

(4)我们随便选取了 6 张图像进行展示,可以看到它们的图片以及对应的标签。

batch_size = 32img_height = 180img_width = 180train_ds = tf.keras.utils.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=1, image_size=(img_height, img_width), batch_size=batch_size)val_ds = tf.keras.utils.image_dataset_from_directory( data_dir,  validation_split=0.2, subset="validation", seed=1, image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_namesnum_classes = len(class_names)AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)plt.figure(figsize=(5, 5))for images, labels in train_ds.take(1):    for i in range(6):        ax = plt.subplot(2, 3, i + 1)        plt.imshow(images[i].numpy().astype("uint8"))        plt.title(class_names[labels[i]])        plt.axis("off")

结果打印:

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.

搭建深度学习分类模型

(1)因为最初的图片都是 RGB 三通道图片,像素点的值在 [0,255] 之间,为了加速模型的收敛,我们要将所有的数据进行归一化操作。所以在模型的第一层加入了 layers.Rescaling 对图片进行处理。

(2)使用了三个卷积块,每个卷积块中包含了卷积层和池化层,并且每一个卷积层中都添加了 relu 激活函数,卷积层不断提取图片的特征,池化层可以有效的所见特征矩阵的尺寸,同时也可以减少最后连接层的中的参数数量,权重参数少的同时也起到了加快计算速度和防止过拟合的作用。

(3)最后加入了两层全连接层,输出对图片的分类预测 logit

(4)使用 Adam 作为我们的模型优化器,使用 SparseCategoricalCrossentropy 计算我们的损失值,在训练过程中观察 accuracy 指标。

model = Sequential([  layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),  layers.Conv2D(16, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(32, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(64, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Flatten(),  layers.Dense(128, activation='relu'),  layers.Dense(num_classes)])model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

训练模型并观察结果

(1)我们使用训练集进行模型的训练,使用验证集进行模型的验证,总共训练 5 个 epoch 。

(2)我们通过对训练过程中产生的准确率和损失值,与验证过程中产生的准确率和损失值进行绘图对比,训练时的准确率高出验证时的准确率很多,训练时的损失值远远低于验证时的损失值,这说明模型存在过拟合风险。正常的情况这两个指标应该是大体呈现同一个发展趋势。

epochs = 5history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()

结果打印:

Epoch 1/5
92/92 [==============================] - 45s 494ms/step - loss: 0.2932 - accuracy: 0.8992 - val_loss: 1.2603 - val_accuracy: 0.6417
Epoch 2/5
92/92 [==============================] - 40s 436ms/step - loss: 0.1814 - accuracy: 0.9414 - val_loss: 1.5241 - val_accuracy: 0.6267
Epoch 3/5
92/92 [==============================] - 36s 394ms/step - loss: 0.0949 - accuracy: 0.9745 - val_loss: 1.6629 - val_accuracy: 0.6499
Epoch 4/5
92/92 [==============================] - 48s 518ms/step - loss: 0.0554 - accuracy: 0.9860 - val_loss: 1.7566 - val_accuracy: 0.6621
Epoch 5/5
92/92 [==============================] - 39s 419ms/step - loss: 0.0341 - accuracy: 0.9918 - val_loss: 2.1150 - val_accuracy: 0.6335

CNN如何解决Flowers图像分类任务

加入了抑制过拟合措施并重新进行模型的训练和测试

(1)当训练样本数量较少时,通常会发生过拟合现象。我们可以操作数据增强技术,通过随机翻转、旋转等方式来增加样本的丰富程度。常见的数据增强处理方式有:tf.keras.layers.RandomFlip、tf.keras.layers.RandomRotation和 tf.keras.layers.RandomZoom。这些方法可以像其他层一样包含在模型中,并在 GPU 上运行。

(2)这里挑选了一张图片,对其进行 6 次执行数据增强,可以看到得到了经过一定程度缩放、旋转、反转的数据集。

data_augmentation = keras.Sequential([    layers.RandomFlip("horizontal", input_shape=(img_height, img_width, 3)),    layers.RandomRotation(0.1),    layers.RandomZoom(0.5)])plt.figure(figsize=(5, 5))for images, _ in train_ds.take(1):    for i in range(6):        augmented_images = data_augmentation(images)        ax = plt.subplot(2, 3, i + 1)        plt.imshow(augmented_images[0].numpy().astype("uint8"))        plt.axis("off")

CNN如何解决Flowers图像分类任务

(3)在模型架构的开始加入数据增强层,同时在全连接层的地方加入 Dropout ,进行神经元的随机失活,这两个方法的加入可以有效抑制模型过拟合的风险。其他的模型结构、优化器、损失函数、观测值和之前相同。通过绘制数据图我们发现,使用这些措施很明显减少了过拟合的风险。

model = Sequential([  data_augmentation,  layers.Rescaling(1./255),  layers.Conv2D(16, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(32, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(64, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Dropout(0.2),  layers.Flatten(),  layers.Dense(128, activation='relu'),  layers.Dense(num_classes, name="outputs")])model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])epochs = 15history = model.fit( train_ds, validation_data=val_ds, epochs=epochs)acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()

结果打印:

92/92 [==============================] - 57s 584ms/step - loss: 1.3080 - accuracy: 0.4373 - val_loss: 1.0929 - val_accuracy: 0.5749
Epoch 2/15
92/92 [==============================] - 41s 445ms/step - loss: 1.0763 - accuracy: 0.5596 - val_loss: 1.3068 - val_accuracy: 0.5204
...
Epoch 14/15
92/92 [==============================] - 59s 643ms/step - loss: 0.6306 - accuracy: 0.7585 - val_loss: 0.7963 - val_accuracy: 0.7044
Epoch 15/15
92/92 [==============================] - 42s 452ms/step - loss: 0.6155 - accuracy: 0.7691 - val_loss: 0.8513 - val_accuracy: 0.6975

CNN如何解决Flowers图像分类任务

(4)最后我们使用一张随机下载的图片,用模型进行类别的预测,发现可以识别出来。

sunflower_url = "Https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)img = tf.keras.utils.load_img(  sunflower_path, target_size=(img_height, img_width) )img_array = tf.keras.utils.img_to_array(img)img_array = tf.expand_dims(img_array, 0) predictions = model.predict(img_array)score = tf.nn.softmax(predictions[0])print(  "这张图片最有可能属于 {} ,有 {:.2f} 的置信度。".fORMat(class_names[np.argmax(score)], 100 * np.max(score)))

结果打印:

这张图片最有可能属于 sunflowers ,有 97.39 的置信度。

“CNN如何解决Flowers图像分类任务”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注编程网网站,小编将为大家输出更多高质量的实用文章!

--结束END--

本文标题: CNN如何解决Flowers图像分类任务

本文链接: https://www.lsjlt.com/news/351332.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • CNN如何解决Flowers图像分类任务
    本篇内容介绍了“CNN如何解决Flowers图像分类任务”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!加载并展示数据(1)该数据需要从网上下...
    99+
    2023-07-05
  • 一文详解CNN解决Flowers图像分类任务
    目录前言加载并展示数据构件处理图像的 pipeline搭建深度学习分类模型训练模型并观察结果加入了抑制过拟合措施并重新进行模型的训练和测试前言 本文主要任务是使用通过 tf.ker...
    99+
    2023-03-10
    CNN Flowers图像分类 CNN Flowers
  • Keras如何实现图像分类任务
    在Keras中实现图像分类任务通常需要遵循以下步骤: 准备数据集:首先需要准备包含图像和对应标签的数据集。可以使用Keras中的...
    99+
    2024-04-02
  • 如何使用Pytorch完成图像分类任务详解
    目录概述:一. 数据准备二.定义一个卷积神经网络三.完整代码如下:总结概述: 本文将通过组织自己的训练数据,使用Pytorch深度学习框架来训练自己的模型,最终实现自己的图像分类!本...
    99+
    2024-04-02
  • PaddlePaddle中的图像分类任务如何实现
    在PaddlePaddle中实现图像分类任务通常使用卷积神经网络(CNN)。以下是一个简单的图像分类示例: 导入必要的库和模块: ...
    99+
    2024-04-02
  • PyTorch如何实现一个简单的CNN图像分类器
    这篇文章给大家分享的是有关PyTorch如何实现一个简单的CNN图像分类器的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。一. 加载数据Pytorch的数据加载一般是用torch.utils.data.Datase...
    99+
    2023-06-15
  • 如何在Torch中进行图像分类任务
    在Torch中进行图像分类任务通常涉及以下步骤: 准备数据集:首先,需要准备包含训练和测试图像的数据集。可以使用Torch的im...
    99+
    2024-04-02
  • DeepLearning4j仔面处理图像分类任务
    DeepLearning4j是一个用于深度学习的开源软件库,可以用于处理各种机器学习任务,包括图像分类。在DeepLearning4...
    99+
    2024-04-02
  • Tensorflow2.10实现图像分割任务示例详解
    目录前言准备大纲实现1. 获取数据2. 处理数据3. 搭建模型4. 编译、训练模型5. 预测前言 图像分割在医学成像、自动驾驶汽车和卫星成像等方面有很多应用,本质其实就是图像像素分...
    99+
    2023-01-05
    Tensorflow 图像分割 Tensorflow 分割
  • Python如何给图像分类(图像识别模型构建)
    在日常生活中总是有给图像分类的场景,比如垃圾分类、不同场景的图像分类等;今天的文章主要是基于图像识别场景进行模型构建。图像识别是通过 Python深度学习来进行模型训练,再使用模型对...
    99+
    2024-04-02
  • win11任务栏图标重叠如何解决
    这篇文章主要介绍“win11任务栏图标重叠如何解决”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“win11任务栏图标重叠如何解决”文章能帮助大家解决问题。 win11任务栏图标重叠怎么办首先我们可以...
    99+
    2023-07-01
  • Torch中如何处理多类分类任务
    在Torch中处理多类分类任务通常使用交叉熵损失函数和softmax函数。首先,定义一个包含所有可能类别的输出层,并使用softma...
    99+
    2024-04-02
  • win11显示任务栏图标黑如何解决
    这篇文章主要讲解了“win11显示任务栏图标黑如何解决”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“win11显示任务栏图标黑如何解决”吧!首先点击桌面空白处,然后在右键菜单中选择“个性化”...
    99+
    2023-07-01
  • win10任务栏图标不对齐如何解决
    如果在Windows 10中,任务栏的图标不对齐,您可以尝试以下方法来解决问题:1. 重新启动Windows资源管理器:- 按下Ct...
    99+
    2023-08-22
    win10
  • Pytorch中如何实现病虫害图像分类
    本篇文章给大家分享的是有关Pytorch中如何实现病虫害图像分类,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、pytorch框架1.1、概念PyTorch是一个开源的Pyt...
    99+
    2023-06-22
  • win7任务栏无输入法图标如何解决
    如果你的Windows 7任务栏上没有输入法图标,可以尝试以下解决方法:1. 检查输入法设置:右键点击任务栏空白处,选择“工具栏”>...
    99+
    2023-09-05
    win7
  • Keras如何实现文本分类任务
    Keras是一个高级神经网络库,可以用来构建和训练深度学习模型。在Keras中实现文本分类任务通常需要以下步骤: 数据预处理:首...
    99+
    2024-04-02
  • win7任务栏预览缩略图没了如何解决
    如果Windows 7任务栏预览缩略图消失了,你可以尝试以下解决方法:1. 重新启用任务栏预览功能:右键点击任务栏空白处,选择“属性...
    99+
    2023-09-06
    win7
  • windows任务栏图标重叠在一起如何解决
    本文小编为大家详细介绍“windows任务栏图标重叠在一起如何解决”,内容详细,步骤清晰,细节处理妥当,希望这篇“windows任务栏图标重叠在一起如何解决”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。任务栏图标...
    99+
    2023-07-02
  • win10任务栏右下角图标空白如何解决
    如果Windows 10任务栏右下角的图标显示为空白,可能是由于系统错误或者某些应用程序的冲突导致的。以下是一些解决此问题的方法:1...
    99+
    2023-09-02
    win10
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作