iis服务器助手广告
返回顶部
首页 > 资讯 > 后端开发 > Python >pytorchgeometric的GNN、GCN的节点分类方式
  • 926
分享到

pytorchgeometric的GNN、GCN的节点分类方式

pytorchgeometricGNN的节点分类GCN的节点分类 2022-12-17 12:12:31 926人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

目录PyTorch geometric的GNN、GCN节点分类pytorch下GCN代码解读总结pytorch geometric的GNN、GCN节点分类 # -*- coding:

pytorch geometric的GNN、GCN节点分类

# -*- coding: utf-8 -*-

import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid
import torch_geometric.nn as pyg_nn
import torch_geometric.transfORMs as T


# load dataset
def get_data(folder="node_classify/cora", data_name="cora"):
    # dataset = Planetoid(root=folder, name=data_name)
    dataset = Planetoid(root=folder, name=data_name,
                        transform=T.NormalizeFeatures())
    return dataset


# create the graph cnn model
class GraphCNN(nn.Module):
    def __init__(self, in_c, hid_c, out_c):
        super(GraphCNN, self).__init__()
        self.conv1 = pyg_nn.GCNConv(in_channels=in_c, out_channels=hid_c)
        self.conv2 = pyg_nn.GCNConv(in_channels=hid_c, out_channels=out_c)

    def forward(self, data):
        # data.x data.edge_index
        x = data.x  # [N, C]
        edge_index = data.edge_index  # [2 ,E]

        hid = self.conv1(x=x, edge_index=edge_index)  # [N, D]
        hid = F.relu(hid)

        out = self.conv2(x=hid, edge_index=edge_index)  # [N, out_c]

        out = F.log_softmax(out, dim=1)  # [N, out_c]

        return out


class OwnGCN(nn.Module):
    def __init__(self, in_c, hid_c, out_c):
        super(OwnGCN, self).__init__()
        self.in_ = pyg_nn.SGConv(in_c, hid_c, K=2)

        self.conv1 = pyg_nn.APPNP(K=2, alpha=0.1)
        self.conv2 = pyg_nn.APPNP(K=2, alpha=0.1)

        self.out_ = pyg_nn.SGConv(hid_c, out_c, K=2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.in_(x, edge_index)
        x = F.dropout(x, p=0.1, training=self.training)

        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.1, training=self.training)

        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.1, training=self.training)

        x = self.out_(x, edge_index)

        return F.log_softmax(x, dim=1)


# todo list
class YourOwnGCN(nn.Module):
    pass


def analysis_data(dataset):
    print("Basic Info:      ", dataset[0])
    print("# Nodes:         ", dataset[0].num_nodes)
    print("# Features:      ", dataset[0].num_features)
    print("# Edges:         ", dataset[0].num_edges)
    print("# Classes:       ", dataset.num_classes)
    print("# Train samples: ", dataset[0].train_mask.sum().item())
    print("# Valid samples: ", dataset[0].val_mask.sum().item())
    print("# Test samples:  ", dataset[0].test_mask.sum().item())
    print("Undirected:      ", dataset[0].is_undirected())


def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    cora_dataset = get_data()

    # todo list
    # my_net = GraphCNN(in_c=cora_dataset.num_features, hid_c=150, out_c=cora_dataset.num_classes)
    my_net = OwnGCN(in_c=cora_dataset.num_features, hid_c=300, out_c=cora_dataset.num_classes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    my_net = my_net.to(device)
    data = cora_dataset[0].to(device)

    optimizer = torch.optim.Adam(my_net.parameters(), lr=1e-2, weight_decay=1e-3)
    """
    # model train
    my_net.train()
    for epoch in range(500):
        optimizer.zero_grad()

        output = my_net(data)
        loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        _, prediction = output.max(dim=1)

        valid_correct = prediction[data.val_mask].eq(data.y[data.val_mask]).sum().item()
        valid_number = data.val_mask.sum().item()

        valid_acc = valid_correct / valid_number
        print("Epoch: {:03D}".format(epoch + 1), "Loss: {:.04f}".format(loss.item()),
              "Valid Accuracy:: {:.4f}".format(valid_acc))
    """

    # model test
    my_net = torch.load("node_classify/best.pth")
    my_net.eval()

    _, prediction = my_net(data).max(dim=1)

    target = data.y

    test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
    test_number = data.test_mask.sum().item()

    train_correct = prediction[data.train_mask].eq(target[data.train_mask]).sum().item()
    train_number = data.train_mask.sum().item()

    print("==" * 20)

    print("Accuracy of Train Samples: {:.04f}".format(train_correct / train_number))

    print("Accuracy of Test  Samples: {:.04f}".format(test_correct / test_number))


def test_main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    cora_dataset = get_data()
    data = cora_dataset[0].to(device)

    my_net = torch.load("node_classify/best.pth")

    my_net.eval()
    _, prediction = my_net(data).max(dim=1)

    target = data.y

    test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
    test_number = data.test_mask.sum().item()

    train_correct = prediction[data.train_mask].eq(target[data.train_mask]).sum().item()
    train_number = data.train_mask.sum().item()

    print("==" * 20)

    print("Accuracy of Train Samples: {:.04f}".format(train_correct / train_number))

    print("Accuracy of Test  Samples: {:.04f}".format(test_correct / test_number))


if __name__ == '__main__':
    test_main()
    # main()
    # dataset = get_data()
    # analysis_data(dataset)

pytorch下GCN代码解读

def main():
    print("hello world")
main()

import os.path as osp
import argparse

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, ChebConv  # noqa

#GCN用于提取图的特征参数然后用于分类

#数据集初始化部分
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
                    help='Use GDC preprocessing.')
args = parser.parse_args()#是否使用GDC优化
dataset = 'CiteSeer'#训练用的数据集
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)#数据集存放位置
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())#数据初始化类,其dataset的基类是一个torch.utils.data.Dataset对象
data = dataset[0]#只有一个图作为训练数据
#print(data)

#预处理和模型定义
if args.use_gdc:
    gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
                normalization_out='col',
                diffusion_kwargs=dict(method='ppr', alpha=0.05),
                sparsification_kwargs=dict(method='topk', k=128,
                                           dim=0), exact=True)
    data = gdc(data)#图扩散卷积用于预处理

#搭建模型
class Net(torch.nn.Module):
    #放置参数层(一般为可学习层,不可学习层也可放置,若不放置,则在forward中用functional实现)
    def __init__(self):
        super(Net, self).__init__()#在不覆盖Module的Init函数的情况下设置Net的init函数
        self.conv1 = GCNConv(dataset.num_features, 16, cached=True,
                             normalize=not args.use_gdc)#第一层GCN卷积运算输入特征向量大小为num_features输出大小为16
        #GCNConv的def init需要in_channel和out_channel(卷积核的数量)的参数,并对in_channel和out_channel调用linear函数,而该函数的作用为构建全连接层
        self.conv2 = GCNConv(16, dataset.num_classes, cached=True,
                             normalize=not args.use_gdc)#第二层GCN卷积运算输入为16(第一层的输出)输出为num_class
        # self.conv1 = ChebConv(data.num_features, 16, K=2)
        # self.conv2 = ChebConv(16, data.num_features, K=2)

    #实现模型的功能各个层之间的连接关系(具体实现)
    def forward(self):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr#赋值data.x特征向量edge_index图的形状,edge_attr权重矩阵
        x = F.relu(self.conv1(x, edge_index, edge_weight))#第一层用非线性激活函数relu
        #x,edge_index,edge_weight特征矩阵,邻接矩阵,权重矩阵组成GCN核心公式
        x = F.dropout(x, training=self.training)#用dropout函数防止过拟合
        x = self.conv2(x, edge_index, edge_weight)#第二层输出
        return F.log_softmax(x, dim=1)#log_softmax激活函数用于最后一层返回分类结果
#Z=log_softmax(A*RELU(A*X*W0)*W1)A连接关系X特征矩阵W参数矩阵
#得到Z后即可用于分类
#softmax(dim=1)行和为1再取log  x为节点的embedding

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#指定设备
model, data = Net().to(device), data.to(device)#copy model,data到device上

#优化算法
optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),#权重衰减避免过拟合
    dict(params=model.conv2.parameters(), weight_decay=0)#需要优化的参数
], lr=0.01)  # Only perform weight-decay on first convolution.
#lr步长因子或者是学习率

#模型训练
def train():
    model.train()#设置成train模式
    optimizer.zero_grad()#清空所有被优化的变量的梯度
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()#损失函数训练参数用于节点分类
    optimizer.step()#步长
     
@torch.no_grad()#不需要计算梯度,也不进行反向传播

#测试
def test():
    model.eval()#设置成evaluation模式
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):#mask矩阵,掩膜作用与之相应的部分不会被更新
        pred = logits[mask].max(1)[1]#mask对应点的输出向量最大值并取序号1
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()#判断是否相等计算准确度
        accs.append(acc)
    return accs

best_val_acc = test_acc = 0

#执行
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()#训练准确率,实际输入的准确率,测试准确率
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'#类型及保留位数
    print(log.format(epoch, train_acc, best_val_acc, test_acc))#输出格式化函数'''

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: pytorchgeometric的GNN、GCN的节点分类方式

本文链接: https://www.lsjlt.com/news/175136.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • pytorchgeometric的GNN、GCN的节点分类方式
    目录pytorch geometric的GNN、GCN节点分类pytorch下GCN代码解读总结pytorch geometric的GNN、GCN节点分类 # -*- coding:...
    99+
    2022-12-17
    pytorch geometric GNN的节点分类 GCN的节点分类
  • PyG搭建GCN模型实现节点分类GCNConv参数详解
    目录前言模型搭建1. 前向传播2. 反向传播3. 训练4. 测试完整代码前言 在上一篇文章PyG搭建GCN前的准备:了解PyG中的数据格式中,大致了解了PyG中的数据格式,这篇文章主...
    99+
    2024-04-02
  • 【RAC】Oracle10g RAC 节点重配的方式
    前段时间说过Oracle11g RAC节点重配的一些说明,相对于Oracle10g来说,更方便更便于管理。那么Oracle10 RAC 需要通过什么方式呢,或者需要注意什么呢  ...
    99+
    2024-04-02
  • Vue3获取DOM节点的方式有哪些
    这篇文章主要讲解了“Vue3获取DOM节点的方式有哪些”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Vue3获取DOM节点的方式有哪些”吧!1 .原生js获取 DOM 节点:document...
    99+
    2023-07-05
  • Oracle Rac添加节点的方式有哪些
    这篇文章给大家介绍Oracle Rac添加节点的方式有哪些,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。 添加节点两种方式:1.克隆一个已有网格主目录。&...
    99+
    2024-04-02
  • 详解QTreeWidget隐藏节点的两种方式
    目录简述方法一:直接隐藏式方法二:间接隐藏式结尾简述 关于QTreeWidget隐藏节点有两种方式,一种是直接隐藏,一种是间接隐藏,但是两种方式各有差异,下面请听具体解说。 方法一:...
    99+
    2024-04-02
  • Vue3获取DOM节点的3种方式实例
    目录1 .原生js获取 DOM 节点:2. vue2中获取当前组件的实例对象:3.vue3中获取当前组件的实例对象:总结1 .原生js获取 DOM 节点: document.quer...
    99+
    2023-02-23
    vue3 获取dom vue3.0 获取dom vue怎么获取dom节点
  • JavaScript添加节点的方法
    小编给大家分享一下JavaScript添加节点的方法,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!JavaScript添加节点的方法:1、使用appendChil...
    99+
    2023-06-14
  • Javascript removeChild()删除节点及删除子节点的方法
    在JavaScript中,可以使用`removeChild()`方法删除一个指定的子节点。要删除一个节点及其子节点,需要先遍历该节点的子节点,并递归调用`removeChild()`方法来删除每个子节点。以下是一个示例代码,演示如何使用...
    99+
    2023-08-09
    Java
  • Netty分布式ByteBuf的分类方式源码解析
    目录ByteBuf根据不同的分类方式 会有不同的分类结果1.Pooled和Unpooled2.基于直接内存的ByteBuf和基于堆内存的ByteBuf3.safe和unsafe上一小...
    99+
    2024-04-02
  • JavaScript的节点操作实例分析
    今天小编给大家分享一下JavaScript的节点操作实例分析的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一...
    99+
    2024-04-02
  • JQuery中DOM节点的示例分析
    小编给大家分享一下JQuery中DOM节点的示例分析,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!具体如下:Jquery中DOM...
    99+
    2024-04-02
  • HTML中DOM节点的示例分析
    小编给大家分享一下HTML中DOM节点的示例分析,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!  在HTML DOM中,所有事物...
    99+
    2024-04-02
  • javascript创建新节点的方法
    这篇文章主要介绍“javascript创建新节点的方法”,在日常操作中,相信很多人在javascript创建新节点的方法问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”java...
    99+
    2024-04-02
  • javascript删除子节点的方法
    小编给大家分享一下javascript删除子节点的方法,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!javascript删除子节点的方法:首先获取父节点对象和子节...
    99+
    2023-06-14
  • html中删除节点的方法
    这篇文章将为大家详细讲解有关html中删除节点的方法,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。在html中,可以利用DOM Element对象的removeChild()方法来删除节点;需要先获取到指...
    99+
    2023-06-15
  • javascript中节点的删除方法
    这篇文章将为大家详细讲解有关javascript中节点的删除方法,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。javascript删除节点的方法:1、使用remove()方法,可用于删除父节点上的所有元素...
    99+
    2023-06-14
  • javascript删除div节点的方法
    小编给大家分享一下javascript删除div节点的方法,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!方法:1、首先获取div节点,然后使用remove()来删...
    99+
    2023-06-14
  • jQuery操作元素节点的方法
    本篇内容主要讲解“jQuery操作元素节点的方法”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“jQuery操作元素节点的方法”吧!一、查找节点示例:<!DOCTYPE html&...
    99+
    2023-06-29
  • 服务器分类的方式有哪些
    根据不同的标准和功能,服务器可以分为以下几类:1. 按照用途分类:Web服务器、数据库服务器、邮件服务器、文件服务器、应用服务器等。...
    99+
    2023-06-06
    服务器分类 服务器
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作