iis服务器助手广告广告
返回顶部
首页 > 资讯 > 精选 >Pytorch中的model.train()和model.eval()怎么使用
  • 125
分享到

Pytorch中的model.train()和model.eval()怎么使用

2023-07-06 03:07:21 125人浏览 薄情痞子
摘要

本文小编为大家详细介绍“PyTorch中的model.train()和model.eval()怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中的model.train()和model.eval()怎么使用”文章能帮助

本文小编为大家详细介绍“PyTorch中的model.train()和model.eval()怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中的model.train()和model.eval()怎么使用”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

Pytorch中的model.train() 和 model.eval() 原理与用法

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train()model.eval()

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout

如果模型中有BN层(Batch NORMalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

补充:pytroch:model.train()、model.eval()的使用

前言:最近在把两个模型的代码整合到一起,发现有一个模型的代码整合后性能大不如前,但基本上是源码迁移,找了一天原因才发现是因为model.eval()和model.train()放错了位置!!!故在此介绍一下pytroch框架下model.train()、model.eval()的作用和不同点。

model.train、model.eval

1.model.train和model.eval放在代码什么位置

简单的说:model.train放在网络训练前,model.eval放在网络测试前。

常见的位置摆放错误(也是我犯的错误)有把model.train()放在for epoch in range(epoch):前面,同时在test或者val(测试或者评估函数)中只放置model.eval,这就导致了只有第一个epoch模型训练是使用了model.train(),之后的epoch模型训练时都采用model.eval().可能会影响训练好模型的性能。
修改方式:可以在test函数里return前面添加model.train()或者把model.train()放到for epoch in range(epoch):语句下面。

model.train()for epoch in range(epoch):    for train_batch in train_loader:        ...    zhibiao = test(epoch, test_loader, model)        def test(epoch, test_loader, model):    model.eval()    for test_batch in test_loader:        ...    return zhibiao
2.model.train和model.eval有什么作用

model.train()和model.eval()的区别主要在于Batch NormalizationDropout两层。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

下面是model.train 和model.eval的源码,可以看到是利用self.training = mode来判断是使用train还是eval。这个参数将传递到一些常用层,比如dropout、BN层等。

def train(self: T, mode: bool = True) -> T:        r"""Sets the module in training mode.        This has any effect only on certain modules. See documentations of        particular modules for details of their behaviors in training/evaluation        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,        etc.        Args:            mode (bool): whether to set training mode (``True``) or evaluation                         mode (``False``). Default: ``True``.        Returns:            Module: self        """        self.training = mode        for module in self.children():            module.train(mode)        return self    def eval(self: T) -> T:        r"""Sets the module in evaluation mode.        This has any effect only on certain modules. See documentations of        particular modules for details of their behaviors in training/evaluation        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,        etc.        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.        Returns:            Module: self        """        return self.train(False)

拿dropout层的源码举例,可以看到传递了self.training这个参数。

class Dropout(_DropoutNd):    r"""During training, randomly zeroes some of the elements of the input    tensor with probability :attr:`p` using samples from a Bernoulli    distribution. Each channel will be zeroed out independently on every forward    call.    This has proven to be an effective technique for regularization and    preventing the co-adaptation of neurons as described in the paper    `Improving neural networks by preventing co-adaptation of feature    detectors`_ .    Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during    training. This means that during evaluation the module simply computes an    identity function.    Args:        p: probability of an element to be zeroed. Default: 0.5        inplace: If set to ``True``, will do this operation in-place. Default: ``False``    Shape:        - Input: :math:`(*)`. Input can be of any shape        - Output: :math:`(*)`. Output is of the same shape as input    Examples::        >>> m = nn.Dropout(p=0.2)        >>> input = torch.randn(20, 16)        >>> output = m(input)    .. _Improving neural networks by preventing co-adaptation of feature        detectors: https://arxiv.org/abs/1207.0580    """    def forward(self, input: Tensor) -> Tensor:        return F.dropout(input, self.p, self.training, self.inplace)
3.为什么主要区别在于BN层和dropout层

在BN层中,主要涉及到四个需要更新的参数,分别是running_mean,running_var,weight,bias。这里的weight,bias是Pytorch官方实现中的叫法,有点误导人,其实weight就是gamma,bias就是beta。当然它这样的叫法也符合实际的应用场景。其实gamma,beta就是对规范化后的值进行一个加权求和操作running_mean,running_var是当前所求得的所有batch_size下的均值和方差,每经过一个mini_batch我们都会更新running_mean,running_var.为什么要更新它?因为测试的时候,往往是一个一个的图像feed至网络的,如果你在这里对其进行计算均值方差显然是不合理的,所以model.eval()这个语句就是控制BN层中的running_mean,running_std不更新。采用训练结束后的running_mean,running_std来规范化该张图像。

dropout层在训练过程中会随机舍弃一些神经元用来提高性能,但测试过程中如果还是测试的模型还是和训练时一样随机舍弃了一些神经元(不是原模型)这就和测试的本意相违背。因为测试的模型应该是我们最终得到的模型,而这个模型应该是一个完整的模型。

4.BN层和dropout层的作用

既然都讲到这了,不了解一些BN层和dropout层的作用就说不过去了。
BN层的原理和作用建议读一下这篇博客:神经网络中BN层的原理与作用

dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。注意是暂时,对于随机梯度下降来说,由于是随机丢弃,故而每一个mini-batch都在训练不同的网络。

大规模的神经网络有两个缺点:费时、容易过拟合

Dropout的出现很好的可以解决这个问题,每次做完dropout,相当于从原始的网络中找到一个更瘦的网络。因而,对于一个有N个节点的神经网络,有了dropout后,就可以看做是2^n个模型的集合了,但此时要训练的参数数目却是不变的,这就解决了费时的问题。

将dropout比作是有性繁殖,将基因随机进行拆分,可以将优秀的基因传下来,并且降低基因之间的联合适应性,使得复杂的大段大段基因联合适应性变成比较小的一个一个小段基因的联合适应性。

dropout也能达到同样的效果,它强迫一个神经单元,和随机挑选出来的其他神经单元共同工作,达到好的效果。消除减弱了神经元节点间的联合适应性,增强了泛化能力。

读到这里,这篇“Pytorch中的model.train()和model.eval()怎么使用”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注编程网精选频道。

--结束END--

本文标题: Pytorch中的model.train()和model.eval()怎么使用

本文链接: https://www.lsjlt.com/news/357496.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Pytorch中的model.train()和model.eval()怎么使用
    本文小编为大家详细介绍“Pytorch中的model.train()和model.eval()怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中的model.train()和model.eval()怎么使用”文章能帮助...
    99+
    2023-07-06
  • 【Pytorch】model.train() 和 model.eval() 原理与用法
    文章目录 一、两种模式二、功能1. model.train()2. model.eval()为什么测试时要用 model.eval() ? 3. 总结与对比 三、Dropout 简介...
    99+
    2023-10-06
    python 机器学习 pytorch
  • model.train()和model.eval()模式怎么使用
    这篇“model.train()和model.eval()模式怎么使用”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“mode...
    99+
    2023-07-05
  • Pytorch中的model.train() 和 model.eval() 原理与用法解析
    目录Pytorch中的model.train() 和 model.eval() 原理与用法一、两种模式二、功能1. model.train()2. model.eval()3. 总结...
    99+
    2023-05-15
    Pytorch model.train() 和 model.eval() python model.train() model.eval()使用
  • pytorch中的model.eval()和BN层的使用
    看代码吧~ class ConvNet(nn.module): def __init__(self, num_class=10): super(ConvN...
    99+
    2024-04-02
  • pytorch中如何使用model.eval()和BN层
    这篇文章给大家分享的是有关pytorch中如何使用model.eval()和BN层的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。代码如下class ConvNet(nn.module): &n...
    99+
    2023-06-15
  • Pytorch中model.eval()的作用是什么
    这篇文章主要介绍了Pytorch中model.eval()的作用是什么的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch中model.eval()的作用是什么文章都会有所收获,下面我们一起来看看吧。m...
    99+
    2023-07-05
  • 详解model.train()和model.eval()两种模式的原理与用法
    一、两种模式 pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。 一般用法是:在训练开始之前写上 ...
    99+
    2023-03-23
    model.train()原理用法 model.eval()原理用法 model.train()和model.eval()
  • Pytorch中关于model.eval()的作用及分析
    目录model.eval()的作用及分析结论Pytorch踩坑之model.eval()问题比较常见的有两方面的原因1) data2)model.state_dict()model....
    99+
    2023-02-03
    Pytorch model.eval model.eval的作用 model.eval()
  • PyTorch中的train()、eval()和no_grad()怎么使用
    本篇内容介绍了“PyTorch中的train()、eval()和no_grad()怎么使用”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!什么...
    99+
    2023-07-05
  • BCELoss和BCEWithLogitsLoss怎么在Pytorch中使用
    BCELoss和BCEWithLogitsLoss怎么在Pytorch中使用?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。BCELoss在图片多标签分类时,如果3张图片分3类,...
    99+
    2023-06-15
  • ubuntu中pytorch怎么安装和使用
    要在Ubuntu中安装PyTorch,可以使用conda或pip进行安装。以下是使用conda安装PyTorch的步骤: 首先,确...
    99+
    2024-03-01
    ubuntu pytorch
  • PyTorch中的nn.Embedding怎么使用
    这篇“PyTorch中的nn.Embedding怎么使用”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“PyTorch中的nn...
    99+
    2023-07-02
  • pytorch中nn.Dropout怎么使用
    小编给大家分享一下pytorch中nn.Dropout怎么使用,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!看代码吧~Class USeDropout(nn.Module):   &...
    99+
    2023-06-15
  • Pytorch中怎么使用TensorBoard
    本文小编为大家详细介绍“Pytorch中怎么使用TensorBoard”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中怎么使用TensorBoard”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。一...
    99+
    2023-07-02
  • pytorch中[..., 0]怎么使用
    这篇文章将为大家详细讲解有关pytorch中[..., 0]怎么使用,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。在看程序的时候看到了x[…, 0]的语句不是很理解,后来自己做实验略微了解,以此记录方便自...
    99+
    2023-06-15
  • PyTorch中torch.utils.data.DataLoader怎么使用
    这篇文章主要介绍“PyTorch中torch.utils.data.DataLoader怎么使用”,在日常操作中,相信很多人在PyTorch中torch.utils.data.DataLoader怎么使用问题上存在疑惑,小编查阅了各式资料,...
    99+
    2023-07-02
  • Pytorch中的torch.distributions库怎么使用
    本文小编为大家详细介绍“Pytorch中的torch.distributions库怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中的torch.distributions库怎么使用”文章能帮助大家解决疑惑,下面跟着小...
    99+
    2023-07-05
  • PyTorch中的nn.Module类怎么使用
    这篇文章主要讲解了“PyTorch中的nn.Module类怎么使用”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“PyTorch中的nn.Module类怎么使用”吧!PyTorch nn.Mo...
    99+
    2023-07-05
  • Python中Pytorch怎么使用
    这篇文章将为大家详细讲解有关Python中Pytorch怎么使用,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。一、TensorTensor(张量是一个统称,其中包括很多类型):0阶张量:标量、常数、0-D...
    99+
    2023-06-15
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作