iis服务器助手广告广告
返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch+PyG实现GraphSAGE过程示例详解
  • 380
分享到

Pytorch+PyG实现GraphSAGE过程示例详解

Pytorch PyG实现GraphSAGEPytorch PyG 2023-05-17 05:05:36 380人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

目录GraphSAGE简介实现步骤数据准备实现模型模型训练GraphSAGE简介 GraphSAGE(Graph Sampling and Aggregation)是一种常见的图神经

GraphSAGE简介

GraphSAGE(Graph Sampling and Aggregation)是一种常见的图神经网络模型,主要用于结点级别的表征学习。该模型基于采样和聚合策略,将一个结点及其邻居节点信息融合在一起,得到其表征表示,并通过多轮迭代更新来提高表征的精度。

实现步骤

数据准备

在本次实现中,我们仍然使用Cora数据集作为示例进行测试,由于GraphSage主要聚焦于单一节点特征的更新,因此这里不需要对数据集做特别处理,只需要将数据转化成PyG格式即可。

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx, to_networkx
# 加载cora数据集
dataset = Planetoid(root='./cora', name='Cora')
data = dataset[0]
# 将nx.Graph形式的图转换成PyG需要的格式
graph = to_networkx(data)
data = from_networkx(graph)
# 获取节点数量和特征向量维度
num_nodes = data.num_nodes
num_features = dataset.num_features
num_classes = dataset.num_classes
# 建立需要训练的节点分割数据集
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[:num_nodes - 1000] = True
data.test_mask[-1000:] = True
data.val_mask[num_nodes - 2000: num_nodes - 1000] = True

实现模型

接下来,我们需要定义GraphSAGE模型。与传统的GCN中只需要一层卷积操作不同,GraphSAGE包含两层卷积和采样(也称“聚合”)操作。

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super(GraphSAGE, self).__init__()
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_channels = hidden_channels if i != 0 else num_features
            out_channels = num_classes if i == num_layers - 1 else hidden_channels
            self.convs.append(SAGEConv(in_channels, out_channels))
    def forward(self, x, edge_index):
        for _, conv in enumerate(self.convs[:-1]):
            x = F.relu(conv(x, edge_index))
        # 最后一层不用激活函数
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=-1)

在上述代码中,我们实现了多层GraphSAGE卷积和相应的聚合函数,并使用ReLU和softmax函数来进行特征提取和分类分数的输出。

模型训练

定义好模型之后,就可以开始针对Cora数据集进行模型训练。首先还是需要先指定优化器和损失函数,并设定一些参数用于记录训练过程中的信息,如Epochs、Batch size、学习率等。

# 初始化GraphSage并指定参数
num_layers = 2
hidden_channels = 256
model = GraphSAGE(hidden_channels, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
# 训练过程
for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    out = model(data.x.to(device), data.edge_index.to(device))
    loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask])
    loss.backward()
    optimizer.step()
    # 在各个测试阶段检测一下准确率
    if epoch % 10 == 0:
        with torch.no_grad():
            _, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
            correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item())
            acc = correct / data.test_mask.sum().item()
            print("Epoch {:03D}, Train Loss {:.4f}, Test Acc {:.4f}".fORMat(
                epoch, loss.item(), acc))

在上述代码中,我们使用有标记的训练数据拟合GraphSAGE模型,在各个验证阶段测试准确率,并通过梯度下降法优化损失函数。

以上就是PyTorch+PyG实现GraphSAGE过程示例详解的详细内容,更多关于Pytorch PyG实现GraphSAGE的资料请关注编程网其它相关文章!

--结束END--

本文标题: Pytorch+PyG实现GraphSAGE过程示例详解

本文链接: https://www.lsjlt.com/news/210678.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作