iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >聊聊pytorch中Optimizer与optimizer.step()的用法
  • 970
分享到

聊聊pytorch中Optimizer与optimizer.step()的用法

2024-04-02 19:04:59 970人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

当我们想指定每一层的学习率时: optim.SGD([ {'params': model.base.parameters()},

当我们想指定每一层的学习率时:


optim.SGD([
                    {'params': model.base.parameters()},
                    {'params': model.classifier.parameters(), 'lr': 1e-3}
                ], lr=1e-2, momentum=0.9)

这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所有的参数。

进行单次优化

所有的optimizer都实现了step()方法,这个方法会更新所有的参数。它能按两种方式来使用:


optimizer.step()

这是大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数。

例子


for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()         
optimizer.step(closure)

一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数,因此你需要传入一个闭包去允许它们重新计算你的模型。这个闭包应当清空梯度,计算损失,然后返回。

例子:


for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

补充:Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别

首先需要明确optimzier优化器的作用, 形象地来说,优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用,这也是机器学习里面最一般的方法论。

从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西:

1. 优化器需要知道当前的网络或者别的什么模型的参数空间

这也就是为什么在训练文件中,正式开始训练之前需要将网络的参数放到优化器里面,比如使用PyTorch的话总会出现类似如下的代码:


optimizer_G = Adam(model_G.parameters(), lr=train_c.lr_G)   # lr 使用的是初始lr
optimizer_D = Adam(model_D.parameters(), lr=train_c.lr_D)

2. 需要知道反向传播的梯度信息

我们还是从代码入手,如下所示是Pytorch 中SGD优化算法的step()函数具体写法,具体SGD的写法放在参考部分。


def step(self, closure=None):
            """PerfORMs a single optimization step.
            Arguments:
                closure (callable, optional): A closure that reevaluates the model
                    and returns the loss.
            """
            loss = None
            if closure is not None:
                loss = closure()
     
            for group in self.param_groups:
                weight_decay = group['weight_decay']
                momentum = group['momentum']
                dampening = group['dampening']
                nesterov = group['nesterov']
     
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf     
                    p.data.add_(-group['lr'], d_p)     
            return loss

从上面的代码可以看到step这个函数使用的是参数空间(param_groups)中的grad,也就是当前参数空间对应的梯度,这也就解释了为什么optimzier使用之前需要zero清零一下,因为如果不清零,那么使用的这个grad就得同上一个mini-batch有关,这不是我们需要的结果。

再回过头来看,我们知道optimizer更新参数空间需要基于反向梯度,因此,当调用optimizer.step()的时候应当是loss.backward()的时候,这也就是经常会碰到,如下情况


total_loss.backward()
optimizer_G.step()

loss.backward()在前,然后跟一个step。

那么为什么optimizer.step()需要放在每一个batch训练中,而不是epoch训练中,这是因为现在的mini-batch训练模式是假定每一个训练集就只有mini-batch这样大,因此实际上可以将每一次mini-batch看做是一次训练,一次训练更新一次参数空间,因而optimizer.step()放在这里。

scheduler.step()按照Pytorch的定义是用来更新优化器的学习率的,一般是按照epoch为单位进行更换,即多少个epoch后更换一次学习率,因而scheduler.step()放在epoch这个大循环下。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: 聊聊pytorch中Optimizer与optimizer.step()的用法

本文链接: https://www.lsjlt.com/news/127230.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • 聊聊pytorch中Optimizer与optimizer.step()的用法
    当我们想指定每一层的学习率时: optim.SGD([ {'params': model.base.parameters()}, ...
    99+
    2022-11-12
  • Optimizer与optimizer.step()怎么在pytorch中使用
    今天就跟大家聊聊有关Optimizer与optimizer.step()怎么在pytorch中使用,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。当我们想指定每一层的学习率时:opti...
    99+
    2023-06-15
  • 聊聊C#中的Mixin的具体用法
    目录写在前面从一个简单例子说起在类中实现单例在父类中实现单例轮到Mixin出场定义Mixin在C#中在8.0之前从C#8.0开始写在前面 Mixin本意是指冰淇淋表面加的那些草莓酱,...
    99+
    2022-11-13
  • 聊聊PHP中Base64 、Blob与File的相互转换方法
    本篇文章给大家带来了关于php的相关知识,其中主要跟大家聊一聊Base64 、Blob和File之间是怎么相互转换的?感兴趣的朋友下面一起来看一下吧,希望对大家有帮助。前言在获取图片时,遇到需要转换格式的情况,所以记录下来分享。正文一、格式...
    99+
    2023-05-14
    前端
  • 聊聊Vue中的计算属性、方法与侦听器
    也就是说,当计算属性依赖的数据发生改变时,它会重新计算;若没有变化时,则不计算,会一直使用上一次计算的结果(这样也就提高了一些性能)。在我们的代码中,当 firstName 或 lastName 改变时,fullName 会重新计算,不变时...
    99+
    2023-05-14
    前端 JavaScript Vue.js
  • 聊聊JavaScript中.?、??、??=的用法以及含义
    目录前言可选链(.)空值合并运算符()空值赋值运算符(=)趣味问答时间:值得注意的是 : 是忽视 null ,undefined 等错误的值最后前言 在项目中我们往往要做很多很多的空...
    99+
    2022-11-13
  • 详细聊聊golang中函数的用法
    随着计算机技术的不断发展,编程语言也在不断更新换代,其中Golang是近年来非常热门的一种编程语言,它的高效、安全、易用受到了很多开发者的喜爱。在Golang中,函数是一种非常重要的编程元素,本文将详细介绍Golang函数的用法。一、函数的...
    99+
    2023-05-14
  • 聊聊php中常用的排序方法(算法)
    PHP作为一门重要的编程语言,其实在多个方面都拥有着很好的表现。在数据处理中,排序算法是最为常见和重要的一部分。PHP中提供了多种排序算法,下面详细介绍PHP中常用的排序方法。冒泡排序冒泡排序是PHP中最经典的排序算法之一。该算法通过遍历比...
    99+
    2023-05-14
    php 排序
  • 聊聊PHP中die()和sleep()函数的用法
    在上一篇《聊聊PHP中删除字符串的逗号和尾部斜杠的方法》给大家介绍了PHP删除字符串中的逗号以及尾部斜杠的方法,感兴趣的朋友可以去学习了解一下~ 本文也将给大家通过示例来讲解标题所述...
    99+
    2022-11-12
  • 聊聊php中箭头符号(->)的用法
    PHP箭头(->)是一种用于对象访问的符号。在PHP中,对象是一组属性和方法的集合。箭头符号允许开发人员访问和操作这些属性和方法。在PHP中,对象可以通过实例化类创建,然后使用箭头符号来访问对象的属性和方法。例如,下面是一个简单的PH...
    99+
    2023-05-14
    php 箭头
  • 聊聊php中“<”符的多种使用方法
    PHP是一种流行的编程语言,广泛应用于web开发和服务器端编程。在PHP中,小于号“<”有多种使用方法,我们来一一了解。1.小于号作为比较运算符小于号最常见的使用方法是作为比较运算符。我们可以使用小于号“<”来比较两个字符、数字...
    99+
    2023-05-14
  • 聊聊PHP中Public修饰符的使用方法
    PHP是一种非常流行的编程语言,被广泛应用于Web开发和服务器脚本编写。作为一种面向对象编程语言,PHP中存在许多访问修饰符,其中public是最常见的一种。public修饰符指定的成员变量或者成员函数可以被这个类的任意对象访问。类中的成员...
    99+
    2023-05-14
  • 聊聊python中令人迷惑的duplicated和drop_duplicates()用法
    前言 在算face_track_id map有感: 开始验证 data={'state':[1,1,2,2,1,2,2,2],'pop':['a','b','c','d','b','c','d','d']} fr...
    99+
    2022-06-02
    python duplicated drop_duplicates()
  • 聊聊R语言中Legend 函数的参数用法
    如下所示: legend(x, y = NULL, legend, fill = NULL, col = par("col"), border = "black", lty, l...
    99+
    2022-11-11
  • 一文聊聊Vue中provide和inject的使用方法
    Vue中如何使用provide与inject?下面本篇文章就来给大家介绍一下Vue中provide和inject的使用方法,希望对大家有所帮助!在vue2.0里面provide与inject是以选项式(配置)API的方式在组件中进行使用的,...
    99+
    2023-05-14
    Vue javascript
  • 一起聊聊Go语言中的语法糖的使用
    目录前言进入正题可变长参数声明不定长数组... 操作符切片循环忽略变量、字段或者导包短变量声明另类的返回值总结前言 由于工作变动,我现在已经开始使用Golang了。用了一段时间之后,...
    99+
    2022-11-13
  • 聊聊Git中删除用户名和密码信息的方法(两种)
    在使用Git时,有时候我们需要删除已经保存的用户名和密码信息。这种情况通常出现在Git账号密码发生变化或者需要切换账号的情况下。本文将介绍如何删除Git中保存的用户名和密码信息。查看已保存的用户名和密码信息首先,我们需要查看当前Git所保存...
    99+
    2023-10-22
  • pytorch中Schedule与warmup_steps的用法说明
    1. lr_scheduler相关 lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup...
    99+
    2022-11-12
  • Pytorch模型中的parameter与buffer用法
    Parameter 和 buffer If you have parameters in your model, which should be saved and restore...
    99+
    2022-11-12
  • Pytorch中的model.train() 和 model.eval() 原理与用法解析
    目录Pytorch中的model.train() 和 model.eval() 原理与用法一、两种模式二、功能1. model.train()2. model.eval()3. 总结...
    99+
    2023-05-15
    Pytorch model.train() 和 model.eval() python model.train() model.eval()使用
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作