iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch零基础入门之逻辑斯蒂回归
  • 786
分享到

PyTorch零基础入门之逻辑斯蒂回归

2024-04-02 19:04:59 786人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录学习总结一、sigmoid函数二、和Linear的区别三、逻辑斯蒂回归(分类)PyTorch实现Reference学习总结 (1)和上一讲的模型训练是类似的,只是在线性模型的基础

学习总结

(1)和上一讲的模型训练是类似的,只是在线性模型的基础上加个sigmoid,然后loss函数改为交叉熵BCE函数(当然也可以用其他函数),另外一开始的数据y_data也从数值改为类别0和1(本例为二分类,注意x_datay_data这里也是矩阵的形式)。

一、sigmoid函数

loGIStic function是一种sigmoid函数(还有其他sigmoid函数),但由于使用过于广泛,pytorch默认logistic function叫为sigmoid函数。还有如下的各种sigmoid函数:

在这里插入图片描述

二、和Linear的区别

逻辑斯蒂和线性模型的unit区别如下图:

在这里插入图片描述

sigmoid函数是不需要参数的,所以不用对其初始化(直接调用nn.functional.sigmoid即可)。
另外loss函数从MSE改用交叉熵BCE:尽可能和真实分类贴近。

在这里插入图片描述

如下图右方表格所示,当 y ^ \hat{y} y^​越接近y时则BCE Loss值越小。

在这里插入图片描述

三、逻辑斯蒂回归(分类)PyTorch实现


# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 08:35:00 2021

@author: 86493
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt  
import torch.nn.functional as F
import numpy as np

# 准备数据
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])


losslst = []

class LogisticRegressionModel(nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
        
    def forward(self, x):
    	# 和线性模型的网络的唯一区别在这句,多了F.sigmoid
        y_predict = F.sigmoid(self.linear(x))
        return y_predict
    
model = LogisticRegressionModel()

# 使用交叉熵作损失函数
criterion = torch.nn.BCELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(), 
                            lr = 0.01)

# 训练
for epoch in range(1000):
    y_predict = model(x_data)
    loss = criterion(y_predict, y_data)
    # 打印loss对象会自动调用__str__
    print(epoch, loss.item())
    losslst.append(loss.item())
    # 梯度清零后反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 画图
plt.plot(range(1000), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()


# test
# 每周学习的时间,200个点
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
# 画 probability of pass = 0.5的红色横线
plt.plot([0, 10], [0.5, 0.5], c = 'r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

在这里插入图片描述

可以看出处于通过和不通过的分界线是Hours=2.5。

在这里插入图片描述

Reference

pytorch官方文档

到此这篇关于PyTorch零基础入门之逻辑斯蒂回归的文章就介绍到这了,更多相关PyTorch 逻辑斯蒂回归内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: PyTorch零基础入门之逻辑斯蒂回归

本文链接: https://www.lsjlt.com/news/154807.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • PyTorch零基础入门之逻辑斯蒂回归
    目录学习总结一、sigmoid函数二、和Linear的区别三、逻辑斯蒂回归(分类)PyTorch实现Reference学习总结 (1)和上一讲的模型训练是类似的,只是在线性模型的基础...
    99+
    2022-11-12
  • python回归分析逻辑斯蒂模型之多分类任务详解
    目录逻辑斯蒂回归模型多分类任务1.ovr策略2.one vs one策略3.softmax策略逻辑斯蒂回归模型多分类案例实现逻辑斯蒂回归模型多分类任务 上节中,我们使用逻辑斯蒂回归完...
    99+
    2022-11-11
  • PyTorch零基础入门之构建模型基础
    目录一、神经网络的构造二、神经网络中常见的层2.1不含模型参数的层2.2含模型参数的层(1)代码栗子1(2)代码栗子22.3二维卷积层stride2.4池化层三、LeNet模型栗子三...
    99+
    2022-11-12
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作