iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >PythonPaddlePaddle机器学习之求解线性模型
  • 424
分享到

PythonPaddlePaddle机器学习之求解线性模型

2024-04-02 19:04:59 424人浏览 泡泡鱼

Python 官方文档:入门教程 => 点击学习

摘要

目录前言1. 任务描述2. 代码演练2.1 数组转张量前言 飞桨(PaddlePaddle)是集深度学习核心框架、工具组件和服务平台为一体的技术先进、功能完备的开源深度学习平台 1.

前言

飞桨(PaddlePaddle)是集深度学习核心框架工具组件和服务平台为一体的技术先进、功能完备的开源深度学习平台

1. 任务描述

  • 乘坐出租车的时候,会有一个10元的起步价,只要上车就需要收取该起步价。
  • 出租车每行驶1公里,需要再支付2元的行驶费用(2元/公里)
  • 当一个乘客做完出租车之后,车上的计价器需要算出来该乘客需要支付的乘车费用。

如果以数学模型的角度可以很容易的解除该题的线性关系,及 Y=2x+10Y=2x+10,其中YY 为最终所需费用,xx 为行驶公里数。

试想,我们用机器学习的方法进行训练是不是也可以解决该问题呢,让机器来给我们推算出 YY 与 xx 的关系。即:知道乘客乘坐公里数和支付费用,但是并不知道每公里行驶费和起步价。

2. 代码演练

首先,我们以数学模型建立关系式,定义计价收费函数。该函数用来生成机器学习的数据集。定义好函数以后,接下来,我们传入6个数据(x),该函数可以计算出对应的Y值(也就是机器学习训练用到的真实值)。

def calculate_fee(distance_travelled):
    return 10+ 2*distance_travelled
for x in [1.0, 3.0, 5.0, 9.0, 10.0, 20.0]:
    print(calculate_fee(x))

接下来开始搭建线性回归。

2.1 数组转张量

将输入数据与输出结果数组转为张量:

import paddle
import numpy
x_data = paddle.to_tensor([[1.0], [3.0], [5.0], [9.0], [10.0], [20.0]])
y_data = paddle.to_tensor([[12.0],[16.0],[20.0],[28.0],[30.0],[50.0]])
linear = paddle.nn.Linear(in_features=1,out_features=1)

# 随机初始化w,b
w_before_opt = linear.weight.numpy().item()
b_before_opt = linear.bias.numpy().item()
# 打印初始w,b
print(w_before_opt,b_before_opt)

mse_loss = paddle.nn.MSELoss()
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())

total_epoch = 5000

for i in range(total_epoch):
    y_predict = linear(x_data)
    loss = mse_loss(y_predict,y_data)

    # 反向传播(求梯度)
    loss.backward()
    # 优化器往前走一步:求出的梯度给优化器用调参
    sgd_optimizer.step()
    # 优化器把调完参数所用的梯度去清掉,下次再去求
    sgd_optimizer.clear_gradients()

    # 打印信息
    if i % 1000 == 0:
        print(i,loss.numpy())
print("finish training, loss = {}".fORMat(loss.numpy()) )

w_after_opt = linear.weight.numpy().item()
b_after_opt = linear.bias.numpy().item()
print(w_after_opt,b_after_opt)

以上就是python PaddlePaddle机器学习之求解线性模型的详细内容,更多关于Python 线性模型的资料请关注编程网其它相关文章!

--结束END--

本文标题: PythonPaddlePaddle机器学习之求解线性模型

本文链接: https://www.lsjlt.com/news/119980.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作