广告
返回顶部
首页 > 资讯 > 后端开发 > Python >pytorch+sklearn实现数据加载的流程
  • 404
分享到

pytorch+sklearn实现数据加载的流程

pytorch数据加载pytorchsklearn数据加载pytorch加载数据 2022-11-21 22:11:48 404人浏览 泡泡鱼

Python 官方文档:入门教程 => 点击学习

摘要

目录PyTorch+sklearn实现数据加载epoch & batch_size & iteration优化算法——梯度下降Batch gr

之前在训练网络的时候加载数据都是稀里糊涂的放进去的,也没有理清楚里面的流程,今天整理一下,加深理解,也方便以后查阅。

pytorch+sklearn实现数据加载

epoch & batch_size & iteration

  • epoch:1个epoch等于使用训练集中的全部样本训练一次,通俗的讲epoch的值就是整个数据集被轮几次。
  • batch_size:批大小。在深度学习中,一般采用SGD训练,即每次训练在训练集中取batchsize个样本训练;
  • iteration:1个iteration等于使用batch_size个样本训练一次;

优化算法——梯度下降

深度学习的优化算法,说白了就是梯度下降。每次的参数更新有两种方式。

Batch gradient descent

第一种,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度,这称为批梯度下降(Batch gradient descent)

这样做至少有 2 个好处:其一,由全数据集确定的方向能够更好地代表样本总体,从而更准确地朝向极值所在的方向。其二,由于不同权重的梯度值差别巨大,因此选取一个全局的学习率很困难。 Full Batch Learning 可以使用 Rprop 只基于梯度符号并且针对性单独更新各权值。

对于更大的数据集,以上 2 个好处又变成了 2 个坏处:其一,随着数据集的海量增长和内存限制,一次性载入所有的数据进来变得越来越不可行。其二,以 Rprop 的方式迭代,会由于各个 Batch 之间的采样差异性,各次梯度修正值相互抵消,无法修正。这才有了后来 RMSProp 的妥协方案。

Stochastic gradient descent

另一种,每看一个数据就算一下损失函数,然后求梯度更新参数,这个称为随机梯度下降(Stochastic gradient descent)。这个方法速度比较快,但是收敛性能不太好,可能在最优点附近晃来晃去,达不到最优点。两次参数的更新也有可能互相抵消掉,造成目标函数震荡的比较剧烈。

Mini-batch gradient decent

为了克服两种方法的缺点,现在一般采用的是一种折中手段,mini-batch gradient decent,小批的梯度下降,这种方法把数据分为若干个批,按批来更新参数,这样,一个批中的一组数据共同决定了本次梯度的方向,下降起来就不容易跑偏,减少了随机性。另一方面因为批的样本数与整个数据集相比小了很多,计算量也不是很大。

现在用的优化器SGD是stochastic gradient descent的缩写,但不代表是一个样本就更新一回,还是基于mini-batch的。

  • 批量梯度下降:批量大小=训练集的大小
  • 随机梯度下降:批量大小= 1
  • 小批量梯度下降:1 <批量大小<训练集的大小

在小批量梯度下降的情况下,流行的批量大小包括32,64和128个样本。

再谈Batch_Size

在合理范围内,增大 Batch_Size 有何好处?

  • 内存利用率提高了,大矩阵乘法的并行化效率提高。
  • 跑完一次 epoch(全数据集)所需的迭代次数减少,对于相同数据量的处理速度进一步加快。
  • 在一定范围内,一般来说 Batch_Size 越大,其确定的下降方向越准,引起训练震荡越小。

盲目增大 Batch_Size 有何坏处?

  • 内存利用率提高了,但是内存容量可能撑不住了。
  • 跑完一次 epoch(全数据集)所需的迭代次数减少,要想达到相同的精度,其所花费的时间大大增加了,从而对参数的修正也就显得更加缓慢。
  • Batch_Size 增大到一定程度,其确定的下降方向已经基本不再变化。

深度学习的第一项任务——数据加载

数据加载流程——重要

以BCICIV_2a数据为例

import mne
import numpy as np
import torch
import torch.nn as nn
class LoadData:
    def __init__(self,eeg_file_path: str):
        self.eeg_file_path = eeg_file_path

    def load_raw_data_gdf(self,file_to_load):
        self.raw_eeg_subject = mne.io.read_raw_gdf(self.eeg_file_path + '/' + file_to_load)
        return self

    def load_raw_data_mat(self,file_to_load):
        import scipy.io as sio
        self.raw_eeg_subject = sio.loadmat(self.eeg_file_path + '/' + file_to_load)

    def get_all_files(self,file_path_extension: str =''):
        if file_path_extension:
            return glob.glob(self.eeg_file_path+'/'+file_path_extension)
        return os.listdir(self.eeg_file_path)
class LoadBCIC(LoadData):
    '''Subclass of LoadData for loading BCI Competition IV Dataset 2a'''
    def __init__(self, file_to_load, *args):
        self.stimcodes=('769','770','771','772')
        # self.epoched_data={}
        self.file_to_load = file_to_load
        self.channels_to_remove = ['EOG-left', 'EOG-central', 'EOG-right']
        super(LoadBCIC,self).__init__(*args)

    def get_epochs(self, tmin=0,tmax=1,baseline=None):
        self.load_raw_data_gdf(self.file_to_load)
        raw_data = self.raw_eeg_subject
        # raw_downsampled = raw_data.copy().resample(sfreq=128)
        self.fs = raw_data.info.get('sfreq')
        events, event_ids = mne.events_from_annotations(raw_data)
        stims =[value for key, value in event_ids.items() if key in self.stimcodes]
        epochs = mne.Epochs(raw_data, events, event_id=stims, tmin=tmin, tmax=tmax, event_repeated='drop',
                            baseline=baseline, preload=True, proj=False, reject_by_annotation=False)
        epochs = epochs.drop_channels(self.channels_to_remove)
        self.y_labels = epochs.events[:, -1] - min(epochs.events[:, -1])
        self.x_data = epochs.get_data()*1e6
        eeg_data={'x_data':self.x_data,
                  'y_labels':self.y_labels,
                  'fs':self.fs}
        return eeg_data
data_path = "/home/pytorch/LiangXiaohan/MI_Dataverse/BCICIV_2a_gdf"
file_to_load = 'A01T.gdf'
'''for BCIC Dataset'''
bcic_data = LoadBCIC(file_to_load, data_path)
eeg_data = bcic_data.get_epochs() # {'x_data':, 'y_labels':, 'fs':}

X = eeg_data.get('x_data')
Y = eeg_data.get('y_labels')
Y.shape

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=0)
X_train.shape

from sklearn.model_selection import StratifiedKFold
train_idx = {}
eval_idx = {}
skf = StratifiedKFold(n_splits=4, shuffle=True)
i = 0
for train_indices, eval_indices in skf.split(X_train, y_train):
    train_idx.update({i: train_indices})
    eval_idx.update({i: eval_indices})
    i += 1
train_idx.get(1).shape

def split_xdata(eeg_data, train_idx, eval_idx):
    x_train=np.copy(eeg_data[train_idx,:,:])
    x_eval=np.copy(eeg_data[eval_idx,:,:])
    x_train = torch.from_numpy(x_train).to(torch.float32)
    x_eval = torch.from_numpy(x_eval).to(torch.float32)
    return x_train, x_eval
def split_ydata(y_true, train_idx, eval_idx):
    y_train = np.copy(y_true[train_idx])
    y_eval = np.copy(y_true[eval_idx])
    y_train = torch.from_numpy(y_train)
    y_eval = torch.from_numpy(y_eval)
    return y_train, y_eval
x_train, x_eval = split_xdata(X_train, train_idx.get(1), eval_idx.get(1))
y_train, y_eval = split_ydata(Y_train, train_idx.get(1), eval_idx.get(1))
y_train.shape

from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import tqdm
def BCICDataLoader(x_train, y_train, batch_size=64, num_workers=2, shuffle=True):
    
    data = TensorDataset(x_train, y_train)

    train_data = DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    return train_data
train_data = BCICDataLoader(x_train, y_train, batch_size=32)
for inputs, target in tqdm(train_data):
    print(target)

到此数据就读出来了!!!

相关API解释

sklearn.model_selection.train_test_split

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html?highlight=train_test_split

sklearn.model_selection.StratifiedKFold

Https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html?highlight=stratifiedkfold#sklearn.model_selection.StratifiedKFold

torch.utils.data.TensorDataset

https://pytorch.org/docs/stable/data.html?highlight=tensordataset#torch.utils.data.TensorDataset

torch.utils.data.DataLoader

https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

参考资料

深度学习中的batch、epoch、iteration的含义

神经网络中Batch和Epoch之间的区别是什么?

谈谈深度学习中的 Batch_Size

到此这篇关于pytorch+sklearn实现数据加载的文章就介绍到这了,更多相关pytorch数据加载内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: pytorch+sklearn实现数据加载的流程

本文链接: https://www.lsjlt.com/news/173369.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • pytorch+sklearn实现数据加载的流程
    目录pytorch+sklearn实现数据加载epoch & batch_size & iteration优化算法——梯度下降Batch gr...
    99+
    2022-11-21
    pytorch数据加载 pytorch sklearn数据加载 pytorch加载数据
  • js实现动态加载数据瀑布流
    本文实例为大家分享了js实现动态加载数据瀑布流的具体代码,供大家参考,具体内容如下 实现的功能 1.每次下拉到底部会自动加载下一页的数据2.图片逐渐显示 首先html <!DO...
    99+
    2022-11-13
  • python sklearn与pandas实现缺失值数据预处理流程详解
    注:代码用 jupyter notebook跑的,分割线线上为代码,分割线下为运行结果 1.导入库生成缺失值 通过pandas生成一个6行4列的矩阵,列名分别为'col1&#...
    99+
    2022-11-11
  • VueopenLayers实现图层数据切换与加载流程详解
    目录openlayers介绍一、实现效果预览二、代码实现openlayers介绍 OpenLayers是一个用于开发WebGIS客户端的JavaScript包。OpenLayers ...
    99+
    2022-11-13
  • pytorch中怎么加载自己的数据集
    在PyTorch中,可以通过创建一个自定义的数据集类来加载自己的数据集。首先,需要导入以下必要的库和模块:```pythonimpo...
    99+
    2023-10-09
    pytorch
  • js实现瀑布流触底动态加载数据
    本文实例为大家分享了js实现瀑布流触底动态加载数据的具体代码,供大家参考,具体内容如下 // onScrollEvent 滚动条事件 <div class="box" ...
    99+
    2022-11-12
  • pytorch怎么加载自己的图片数据集
    本文小编为大家详细介绍“pytorch怎么加载自己的图片数据集”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch怎么加载自己的图片数据集”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。ImageFold...
    99+
    2023-07-02
  • pytorch加载自己的数据集源码分享
    目录一、标准的数据集流程梳理数据来源二、实现加载自己的数据集1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的)2. 在继承da...
    99+
    2022-11-11
  • PyTorch深度学习模型的保存和加载流程详解
    一、模型参数的保存和加载  torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经...
    99+
    2022-11-12
  • Oracle数据加载和卸载的实现方法
    在日常工作中;经常会遇到这样的需求: Oracle 数据表跟文本或者文件格式进行交互;即将指定文件内容导入对应的 Oracle 数据表中;或者从 Oracle 数据表导出。 其他数据库中的表跟Or...
    99+
    2022-10-18
  • Ajax实现异步加载数据
    本文实例为大家分享了Ajax实现异步加载数据的具体代码,供大家参考,具体内容如下 项目结构如下 (需要导入一个JQuery的包,配置文件web.xml和springmvc-servl...
    99+
    2022-11-12
  • vue加载视频流,实现直播功能的过程
    目录前言一、视频流是什么?二、vue加载rtmp视频流1.方法一:video.js2.方法二:ckplayer三、vue加载hls视频流1.index.html中2.video-pl...
    99+
    2022-11-13
  • 小程序实现瀑布流动态加载列表
    本文实例为大家分享了小程序实现瀑布流动态加载列表的具体代码,供大家参考,具体内容如下 最近业务需要做一个商城列表,就自己写了一个瀑布流来加载列表。 这个列表在很多地方都需要用到,就...
    99+
    2022-11-13
  • 使用pytorch加载并读取COCO数据集的详细操作
    目录环境配置基础知识:元祖、字典、数组利用PyTorch读取COCO数据集利用PyTorch读取自己制作的数据集如何使用pytorch加载并读取COCO数据集 环境配置基础知识:元祖...
    99+
    2022-11-11
  • Android实现ListView分页加载数据
    本文实例为大家分享了ListView分页加载数据的具体代码,供大家参考,具体内容如下 FenyeActivity package com.example.myapplicatio...
    99+
    2022-11-12
  • Android自定义加载控件实现数据加载动画
    本文实例为大家分享了Android自定义加载控件,第一次小人跑动的加载效果眼前一亮,相比传统的PrograssBar高大上不止一点,于是走起,自定义了控件LoadingView...
    99+
    2022-06-06
    数据 动画 Android
  • html5用video标签流式加载的实现
    这篇文章主要介绍了html5用video标签流式加载的实现,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。video标签浏览器的video标签通常是接收一个src属性,然后浏览...
    99+
    2023-06-09
  • pytorch加载自己的图片数据集的2种方法详解
    目录ImageFolder 加载数据集使用pytorch提供的Dataset类创建自己的数据集。Dataset加载数据集总结pytorch加载图片数据集有两种方法。 1.ImageF...
    99+
    2022-11-11
  • Android实现滑动加载数据的方法
    本文实例讲述了Android实现滑动加载数据的方法。分享给大家供大家参考。具体实现方法如下: EndLessActivity.java如下: package com.Scro...
    99+
    2022-06-06
    方法 数据 Android
  • PyTorch数据读取的实现示例
    前言 PyTorch作为一款深度学习框架,已经帮助我们实现了很多很多的功能了,包括数据的读取和转换了,那么这一章节就介绍一下PyTorch内置的数据读取模块吧 模块介绍 pan...
    99+
    2022-11-11
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作