iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch实现FedProx联邦学习算法
  • 610
分享到

PyTorch实现FedProx联邦学习算法

2024-04-02 19:04:59 610人浏览 薄情痞子

Python 官方文档:入门教程 => 点击学习

摘要

目录I. 前言III. FedProx1. 模型定义2. 服务器端3. 客户端更新IV. 完整代码I. 前言 FedProx的原理请见:FedAvg联邦学习FedProx异质网络优化

I. 前言

FedProx的原理请见:FedAvg联邦学习FedProx异质网络优化实验总结

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

数据集为某城市十个地区的风电功率,我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

III. FedProx

算法伪代码:

1. 模型定义

客户端的模型为一个简单的四层神经网络模型:

# -*- coding:utf-8 -*-
"""
@Time: 2022/03/03 12:23
@Author: KI
@File: model.py
@Motto: Hungry And Humble
"""
from torch import nn
class ANN(nn.Module):
    def __init__(self, args, name):
        super(ANN, self).__init__()
        self.name = name
        self.len = 0
        self.loss = 0
        self.fc1 = nn.Linear(args.input_dim, 20)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 1)
    def forward(self, data):
        x = self.fc1(data)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

2. 服务器端

服务器端和FedAvg一致,即重复进行客户端采样、参数传达、参数聚合三个步骤:

# -*- coding:utf-8 -*-
"""
@Time: 2022/03/03 12:50
@Author: KI
@File: server.py
@Motto: Hungry And Humble
"""
import copy
import random
import numpy as np
import torch
from model import ANN
from client import train, test
class FedProx:
    def __init__(self, args):
        self.args = args
        self.nn = ANN(args=self.args, name='server').to(args.device)
        self.nns = []
        for i in range(self.args.K):
            temp = copy.deepcopy(self.nn)
            temp.name = self.args.clients[i]
            self.nns.append(temp)
    def server(self):
        for t in range(self.args.r):
            print('round', t + 1, ':')
            # sampling
            m = np.max([int(self.args.C * self.args.K), 1])
            index = random.sample(range(0, self.args.K), m)  # st
            # dispatch
            self.dispatch(index)
            # local updating
            self.client_update(index, t)
            # aggregation
            self.aggregation(index)
        return self.nn
    def aggregation(self, index):
        s = 0
        for j in index:
            # nORMal
            s += self.nns[j].len
        params = {}
        for k, v in self.nns[0].named_parameters():
            params[k] = torch.zeros_like(v.data)
        for j in index:
            for k, v in self.nns[j].named_parameters():
                params[k] += v.data * (self.nns[j].len / s)
        for k, v in self.nn.named_parameters():
            v.data = params[k].data.clone()
    def dispatch(self, index):
        for j in index:
            for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()):
                old_params.data = new_params.data.clone()
    def client_update(self, index, global_round):  # update nn
        for k in index:
            self.nns[k] = train(self.args, self.nns[k], self.nn, global_round)
    def global_test(self):
        model = self.nn
        model.eval()
        for client in self.args.clients:
            model.name = client
            test(self.args, model)

3. 客户端更新

FedProx中客户端需要优化的函数为:

作者在FedAvg损失函数的基础上,引入了一个proximal term,我们可以称之为近端项。引入近端项后,客户端在本地训练后得到的模型参数 w将不会与初始时的服务器参数wt偏离太多。

对应的代码为:

def train(args, model, server, global_round):
    model.train()
    Dtr, Dte = nn_seq_wind(model.name, args.B)
    model.len = len(Dtr)
    global_model = copy.deepcopy(server)
    if args.weight_decay != 0:
        lr = args.lr * pow(args.weight_decay, global_round)
    else:
        lr = args.lr
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                                     weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9, weight_decay=args.weight_decay)
    print('training...')
    loss_function = nn.MSELoss().to(args.device)
    loss = 0
    for epoch in range(args.E):
        for (seq, label) in Dtr:
            seq = seq.to(args.device)
            label = label.to(args.device)
            y_pred = model(seq)
            optimizer.zero_grad()
            # compute proximal_term
            proximal_term = 0.0
            for w, w_t in zip(model.parameters(), global_model.parameters()):
                proximal_term += (w - w_t).norm(2)
            loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term
            loss.backward()
            optimizer.step()
        print('epoch', epoch, ':', loss.item())
    return model

我们在原有MSE损失函数的基础上加上了一个近端项:

for w, w_t in zip(model.parameters(), global_model.parameters()):
    proximal_term += (w - w_t).norm(2)

然后再反向传播求梯度,然后优化器step更新参数。

原始论文中还提出了一个不精确解的概念:

不过值得注意的是,我并没有在原始论文的实验部分找到如何选择 γ \gamma γ的说明。查了一下资料后发现是涉及到了近端梯度下降的知识,本文代码并没有考虑不精确解,后期可能会补上。

IV. 完整代码

链接:https://pan.baidu.com/s/1hj2EOcqIUmM-C6R1cyjE5Q 

提取码:fghp 

项目结构:

其中:

  • server.py为服务器端操作。
  • client.py为客户端操作。
  • data_process.py为数据处理部分。
  • model.py为模型定义文件。
  • args.py为参数定义文件。
  • main.py为主文件,如想要运行此项目可直接运行:
python main.py

以上就是PyTorch实现FedProx的联邦学习算法的详细内容,更多关于PyTorch实现FedProx算法的资料请关注编程网其它相关文章!

--结束END--

本文标题: PyTorch实现FedProx联邦学习算法

本文链接: https://www.lsjlt.com/news/117896.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • PyTorch实现FedProx联邦学习算法
    目录I. 前言III. FedProx1. 模型定义2. 服务器端3. 客户端更新IV. 完整代码I. 前言 FedProx的原理请见:FedAvg联邦学习FedProx异质网络优化...
    99+
    2022-11-11
  • PyTorch怎么实现FedProx联邦学习算法
    这篇文章主要介绍了PyTorch怎么实现FedProx联邦学习算法的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇PyTorch怎么实现FedProx联邦学习算法文章都会有所收获,下面我们一起来看看吧。I. 前言...
    99+
    2023-06-30
  • PyTorch实现联邦学习的基本算法FedAvg
    目录I. 前言II. 数据介绍特征构造III. 联邦学习1. 整体框架2. 服务器端3. 客户端IV. 代码实现1. 初始化2. 服务器端3. 客户端4. 测试V. 实验及结果VI....
    99+
    2022-11-11
  • FedAvg联邦学习FedProx异质网络优化实验总结
    目录前言I. FedAvgII. FedProxIII. 实验IV. 总结前言 题目: Federated Optimization for Heterogeneous Netwo...
    99+
    2022-11-13
  • 联邦学习神经网络FedAvg算法实现
    目录I. 前言II. 数据介绍1. 特征构造III. 联邦学习1. 整体框架2. 服务器端3. 客户端4. 代码实现4.1 初始化4.2 服务器端4.3 客户端4.4 测试IV. 实...
    99+
    2022-11-11
  • 联邦学习算法介绍-FedAvg详细案例-Python代码获取
    联邦学习算法介绍-FedAvg详细案例-Python代码获取 一、联邦学习系统框架二、联邦平均算法(FedAvg)三、联邦随梯度下降算法 (FedSGD)四、差分隐私随联邦梯度下降算法 (DP...
    99+
    2023-08-31
    python 算法 机器学习
  • 使用Pytorch实现强化学习——DQN算法
    目录 一、强化学习的主要构成 二、基于python的强化学习框架 三、gym 四、DQN算法 1.经验回放 2.目标网络 五、使用pytorch实现DQN算法 1.replay memory 2.神经网络部分 3.Agent 4.模型训练...
    99+
    2023-09-24
    python 开发语言
  • javascript算法学习实现代码
    排序 var len = 100000; var i; var arr = []; for(i=0; i...
    99+
    2022-11-21
    javascript 算法学习
  • python如何实现感知器学习算法
    这篇文章主要介绍python如何实现感知器学习算法,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!我们将研究一种判别式分类方法,其中直接学习评估 g(x)所需的 w 参数。我们将使用感知器学习算法。感知器学习算法很容易...
    99+
    2023-06-29
  • Python实现机器学习算法的分类
    Python算法的分类 对葡萄酒数据集进行测试,由于数据集是多分类且数据的样本分布不平衡,所以直接对数据测试,效果不理想。所以使用SMOTE过采样对数据进行处理,对数据去重,去空,处...
    99+
    2022-11-12
  • Python强化练习之PyTorch opp算法实现月球登陆器
    目录概述强化学习算法种类PPO 算法Actor-Critic 算法GymLunarLander-v2启动登陆器PPO 算法实现月球登录器PPOmain输出结果概述 从今天开始我们会开...
    99+
    2022-11-12
  • 【机器学习】DBSCAN聚类算法(含Python实现)
    文章目录 一、算法介绍二、例子三、Python实现3.1 例13.2 算法参数详解3.3 鸢尾花数据集 一、算法介绍 DBSCAN(Density-Based Spatial Clus...
    99+
    2023-10-01
    聚类 机器学习 python BBSCAN
  • 机器学习之决策树算法怎么实现
    决策树是一种常用的机器学习算法,主要用于分类和回归问题。下面是决策树算法的实现步骤:1. 数据预处理:将原始数据进行清洗和转换,包括...
    99+
    2023-10-11
    机器学习
  • Python机器学习k-近邻算法怎么实现
    这篇文章主要介绍“Python机器学习k-近邻算法怎么实现”,在日常操作中,相信很多人在Python机器学习k-近邻算法怎么实现问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Python机器学习k-近邻算法怎...
    99+
    2023-06-21
  • 机器学习线性回归算法怎么实现
    实现机器学习线性回归算法一般需要以下步骤:1. 导入所需的库:例如,numpy用于数值计算,matplotlib用于可视化数据等。2...
    99+
    2023-09-21
    机器学习
  • Python机器学习实战之k-近邻算法的实现
    目录K-近邻算法概述工作原理实施KNN算法示例:手写识别系统K-近邻算法概述 简单地说, k-近邻算法采用测量不同特征值之间的距离方法进行分类。 k-近邻算法 优点:精度高...
    99+
    2022-11-12
  • Python机器学习算法之决策树算法的实现与优缺点
    目录1.算法概述2.算法种类3.算法示例4.决策树构建示例5.算法实现步骤 6.算法相关概念7.算法实现代码8.算法优缺点9.算法优化总结1.算法概述 决策树算法是在已知各...
    99+
    2022-11-12
  • Go语言开发实现机器学习算法的方法与实践
    Go语言是一种简洁、快速和高效的编程语言,其在网络开发和服务器编程方面广泛应用。然而,随着人工智能和机器学习的迅猛发展,很多开发者开始关注如何在Go语言中实现机器学习算法。本文将介绍一些在Go语言中开发和实现机器学习算法的方法与实践。首先,...
    99+
    2023-11-20
    机器学习 实践 Go语言
  • python机器学习实现oneR算法(以鸢尾data为例)
    目录一、导包与获取数据二、划分为训练集和测试集三、定义函数:获取某特征值出现次数最多的类别及错误率四、定义函数:获取每个特征值下出现次数最多的类别、错误率五、调用函数,获取最佳特征值...
    99+
    2022-11-13
  • 如何使用 NumPy 实现分布式机器学习算法?
    分布式机器学习是一种利用多个计算机或处理器并行处理大规模数据集的机器学习方法。使用分布式机器学习算法可以显著提高算法的性能和效率。在本文中,我们将介绍如何使用 NumPy 实现分布式机器学习算法。 一、分布式机器学习的基本原理 分布式机器学...
    99+
    2023-10-30
    分布式 numy ide
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作