广告
返回顶部
首页 > 资讯 > 精选 >如何解决Pytorch中Batch Normalization layer的问题
  • 929
分享到

如何解决Pytorch中Batch Normalization layer的问题

2023-06-15 05:06:25 929人浏览 薄情痞子
摘要

小编给大家分享一下如何解决PyTorch中Batch NORMalization layer的问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!1. 注意mome

小编给大家分享一下如何解决PyTorch中Batch NORMalization layer的问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!

1. 注意momentum的定义

Pytorch中的BN层的动量平滑和常见的动量法计算方式是相反的,默认的momentum=0.1

如何解决Pytorch中Batch Normalization layer的问题

BN层里的表达式为:

如何解决Pytorch中Batch Normalization layer的问题

其中γ和β是可以学习的参数。在Pytorch中,BN层的类的参数有:

CLASS torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

每个参数具体含义参见文档,需要注意的是,affine定义了BN层的参数γ和β是否是可学习的(不可学习默认是常数1和0).

2. 注意BN层中含有统计数据数值,即均值和方差

track_running_stats – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True

在训练过程中model.train(),train过程的BN的统计数值—均值和方差是通过当前batch数据估计的。

并且测试时,model.eval()后,若track_running_stats=True,模型此刻所使用的统计数据是Running status 中的,即通过指数衰减规则,积累到当前的数值。否则依然使用基于当前batch数据的估计值。

3. BN层的统计数据更新

是在每一次训练阶段model.train()后的forward()方法中自动实现的,而不是在梯度计算与反向传播中更新optim.step()中完成

4. 冻结BN及其统计数据

从上面的分析可以看出来,正确的冻结BN的方式是在模型训练时,把BN单独挑出来,重新设置其状态为eval (在model.train()之后覆盖training状态).

解决方案:

You should use apply instead of searching its children, while named_children() doesn't iteratively search submodules.

def set_bn_eval(m):    classname = m.__class__.__name__    if classname.find('BatchNorm') != -1:      m.eval()model.apply(set_bn_eval)

或者,重写module中的train()方法:

def train(self, mode=True):        """        Override the default train() to freeze the BN parameters        """        super(MyNet, self).train(mode)        if self.freeze_bn:            print("Freezing Mean/Var of BatchNorm2D.")            if self.freeze_bn_affine:                print("Freezing Weight/Bias of BatchNorm2D.")        if self.freeze_bn:            for m in self.backbone.modules():                if isinstance(m, nn.BatchNorm2d):                    m.eval()                    if self.freeze_bn_affine:                        m.weight.requires_grad = False                        m.bias.requires_grad = False

5. Fix/frozen Batch Norm when training may lead to RuntimeError: expected Scalar type Half but found Float

解决办法:

import torchimport torch.nn as nnfrom torch.nn import initfrom torchvision import modelsfrom torch.autograd import Variablefrom apex.fp16_utils import *def fix_bn(m):    classname = m.__class__.__name__    if classname.find('BatchNorm') != -1:        m.eval()model = models.resnet50(pretrained=True)model.cuda()model = network_to_half(model)model.train()model.apply(fix_bn) # fix batchnorminput = Variable(torch.FloatTensor(8, 3, 224, 224).cuda().half())output = model(input)output_mean = torch.mean(output)output_mean.backward()

Please do

def fix_bn(m):    classname = m.__class__.__name__    if classname.find('BatchNorm') != -1:        m.eval().half()

Reason for this is, for regular training it is better (performance-wise) to use cudnn batch norm, which requires its weights to be in fp32, thus batch norm modules are not converted to half in network_to_half. However, cudnn does not support batchnorm backward in the eval mode , which is what you are doing, and to use pytorch implementation for this, weights have to be of the same type as inputs.

补充:深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方

用pytorch做dropout和BN时需要注意的地方

pytorch做dropout:

就是train的时候使用dropout,训练的时候不使用dropout,

pytorch里面是通过net.eval()固定整个网络参数,包括不会更新一些前向的参数,没有dropout,BN参数固定,理论上对所有的validation set都要使用net.eval()

net.train()表示会纳入梯度的计算。

net_dropped = torch.nn.Sequential(    torch.nn.Linear(1, N_HIDDEN),    torch.nn.Dropout(0.5),  # drop 50% of the neuron    torch.nn.ReLU(),    torch.nn.Linear(N_HIDDEN, N_HIDDEN),    torch.nn.Dropout(0.5),  # drop 50% of the neuron    torch.nn.ReLU(),    torch.nn.Linear(N_HIDDEN, 1),)for t in range(500):    pred_drop = net_dropped(x)    loss_drop = loss_func(pred_drop, y)    optimizer_drop.zero_grad()    loss_drop.backward()    optimizer_drop.step()    if t % 10 == 0:        # change to eval mode in order to fix drop out effect        net_dropped.eval()  # parameters for dropout differ from train mode        test_pred_drop = net_dropped(test_x)        # change back to train mode        net_dropped.train()

pytorch做Batch Normalization:

net.eval()固定整个网络参数,固定BN的参数,moving_mean 和moving_var,不懂这个看下图:

if self.do_bn:                bn = nn.BatchNorm1d(10, momentum=0.5)                setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module                self.bns.append(bn)    for epoch in range(EPOCH):        print('Epoch: ', epoch)        for net, l in zip(nets, losses):            net.eval()              # set eval mode to fix moving_mean and moving_var            pred, layer_input, pre_act = net(test_x)            net.train()             # free moving_mean and moving_var        plot_histogram(*layer_inputs, *pre_acts)

moving_mean 和moving_var

如何解决Pytorch中Batch Normalization layer的问题

Tensorflow做dropout和BN时需要注意的地方

dropout和BN都有一个training的参数表明到底是train还是test, 表明test那dropout就是不dropout,BN就是固定住了BN的参数;

tf_is_training = tf.placeholder(tf.bool, None)  # to control dropout when training and testing# dropout netd1 = tf.layers.dense(tf_x, N_HIDDEN, tf.nn.relu)d1 = tf.layers.dropout(d1, rate=0.5, training=tf_is_training)   # drop out 50% of inputsd2 = tf.layers.dense(d1, N_HIDDEN, tf.nn.relu)d2 = tf.layers.dropout(d2, rate=0.5, training=tf_is_training)   # drop out 50% of inputsd_out = tf.layers.dense(d2, 1)for t in range(500):    sess.run([o_train, d_train], {tf_x: x, tf_y: y, tf_is_training: True})  # train, set is_training=True    if t % 10 == 0:        # plotting        plt.cla()        o_loss_, d_loss_, o_out_, d_out_ = sess.run(            [o_loss, d_loss, o_out, d_out], {tf_x: test_x, tf_y: test_y, tf_is_training: False} # test, set is_training=False        )# pytorch    def add_layer(self, x, out_size, ac=None):        x = tf.layers.dense(x, out_size, kernel_initializer=self.w_init, bias_initializer=B_INIT)        self.pre_activation.append(x)        # the momentum plays important rule. the default 0.99 is too high in this case!        if self.is_bn: x = tf.layers.batch_normalization(x, momentum=0.4, training=tf_is_train)    # when have BN        out = x if ac is None else ac(x)        return out

当BN的training的参数为train时,只是表示BN的参数是可变化的,并不是代表BN会自己更新moving_mean 和moving_var,因为这个操作是前向更新的op,在做train之前必须确保moving_mean 和moving_var更新了,更新moving_mean 和moving_var的操作在tf.GraphKeys.UPDATE_OPS

 # !! IMPORTANT !! the moving_mean and moving_variance need to be updated,        # pass the update_ops with control_dependencies to the train_op        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)        with tf.control_dependencies(update_ops):            self.train = tf.train.AdamOptimizer(LR).minimize(self.loss)

以上是“如何解决Pytorch中Batch Normalization layer的问题”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注编程网精选频道!

--结束END--

本文标题: 如何解决Pytorch中Batch Normalization layer的问题

本文链接: https://www.lsjlt.com/news/278227.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • 如何解决Pytorch中Batch Normalization layer的问题
    小编给大家分享一下如何解决Pytorch中Batch Normalization layer的问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!1. 注意mome...
    99+
    2023-06-15
  • 解决Pytorch中Batch Normalization layer踩过的坑
    1. 注意momentum的定义 Pytorch中的BN层的动量平滑和常见的动量法计算方式是相反的,默认的momentum=0.1 BN层里的表达式为: 其中γ和β是可以学习的参...
    99+
    2022-11-12
  • 如何解决layer图标icon不加载的问题
    这篇文章将为大家详细讲解有关如何解决layer图标icon不加载的问题,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。之前在项目中使用layer弹框感觉体验很好,这次的项目...
    99+
    2022-10-19
  • 如何在pytorch中解决state_dict()的拷贝问题
    如何在pytorch中解决state_dict()的拷贝问题?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。model.state_dict()是浅拷贝,返回的参...
    99+
    2023-06-06
  • 如何解决layer弹层遮罩挡住窗体的问题
    这篇文章主要介绍如何解决layer弹层遮罩挡住窗体的问题,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!使用代码:<div>    <d...
    99+
    2022-10-19
  • Jetson NX配置pytorch的问题如何解决
    这篇文章主要介绍“Jetson NX配置pytorch的问题如何解决”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“Jetson NX配置pytorch的问题如何解决”文章能帮助大...
    99+
    2023-07-05
  • 解决pytorch中的kl divergence计算问题
    偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和p...
    99+
    2022-11-12
  • 如何解决layui弹出层layer中area过大被遮挡的问题
    小编给大家分享一下如何解决layui弹出层layer中area过大被遮挡的问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!la...
    99+
    2022-10-19
  • 如何解决layer弹出层自适应页面大小的问题
    这篇文章主要介绍了如何解决layer弹出层自适应页面大小的问题,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。网上的解决方案大都是以下几种:1...
    99+
    2022-10-19
  • 解决Pytorch中的神坑:关于model.eval的问题
    有时候使用Pytorch训练完模型,在测试数据上面得到的结果令人大跌眼镜。 这个时候需要检查一下定义的Model类中有没有 BN 或 Dropout 层,如果有任何一个存在 那么在测...
    99+
    2022-11-12
  • 如何解决pytorch显存一直变大的问题
    本篇内容介绍了“如何解决pytorch显存一直变大的问题”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!在代码中添加以下两行可以解决:torc...
    99+
    2023-06-14
  • 如何解决安装pytorch时报sslerror错误的问题
    这篇文章给大家分享的是有关如何解决安装pytorch时报sslerror错误的问题的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。首先说一下 ,我是用的anaconda3装的pytorch为了方便建议你也安装一个。...
    99+
    2023-06-15
  • 在pytorch中复制模型时出现问题如何解决
    在pytorch中复制模型时出现问题如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。直接使用model2=model1会出现当更新model2时,model1的权重也...
    99+
    2023-06-06
  • pytorch部署到jupyter中的问题及解决方案
    目录pytorch部署到jupyter中两种解决方案pytorch部署到jupyter中 在安装Aconda的同时,会将jupyter notebook一起安装,不过这里的jupy...
    99+
    2022-11-11
  • 如何解决layer关闭弹出窗口触发表单提交的问题
    这篇文章主要介绍如何解决layer关闭弹出窗口触发表单提交的问题,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!1、前言 表单的代码:<form>  此处理代码...
    99+
    2022-10-19
  • pytorch网络模型构建场景的问题如何解决
    今天小编给大家分享一下pytorch网络模型构建场景的问题如何解决的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。网络模型构建...
    99+
    2023-07-05
  • Pytorch中retain_graph的坑如何解决
    本篇内容主要讲解“Pytorch中retain_graph的坑如何解决”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Pytorch中retain_graph的坑如何解决”吧!Pytorch中re...
    99+
    2023-07-05
  • 解决pytorch 损失函数中输入输出不匹配的问题
    一、pytorch 损失函数中输入输出不匹配问题 File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\...
    99+
    2022-11-12
  • 如何解决mysql中auto_increment的问题
    这篇文章将为大家详细讲解有关如何解决mysql中auto_increment的问题,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。mysql中的auto_increment...
    99+
    2022-10-19
  • 如何解决VB.NET中ReadProcessMemory的问题
    这篇文章主要介绍如何解决VB.NET中ReadProcessMemory的问题,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!在学校上VB.NET的课,VB以前学过一点点,只会看不会写,不过没有办法,学校开的,所以几个...
    99+
    2023-06-17
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作