最近需要训练图卷积神经网络(Graph Convolution Neural Network, GCNN),在配置GCNN环境上总结了一些经验。

我觉得对于初学者而言,图神经网络的训练会有2个难点:

①环境配置

②数据集制作

一、环境配置

我最初光想到要给GCNN配环境就觉得有些困难,感觉相比于目标检测、分类识别这些任务用规则数据,图神经网络的模型、数据都是图,所以内心觉得会比较难。

我之前更有一个误区,就是觉得不规则结构的图数据不能用CUDA进行并行加速。实际上,图,在电脑里也是以张量这种规则结构数据存在的,完全能用CUDA进行加速计算,训练GCN前配置CUDA完全OK。


以下是我配置的环境,可用CUDA成功运行link_pred.py

几个关键包的版本:

torch                                2.4.1
torch-geometric               2.3.1
torchaudio                       2.4.1
torchvision                       0.14.0
torchviz                            0.0.2

pandas                             1.0.3

numpy                              1.20.0

 CUDA: 11.8

注意要先安装好CUDA,显示了:

 

再安装GPU版本的torch,不然python检测安装的是cpu版本的torch。这时,就得卸载重新安装了

环境配置成功:

print(torch.__version__)
print(torch.cuda.is_available())

如果CUDA环境安装失败,会打印:

2.4.1+cpu
False

其实只安装torch和CUDA还好,如果你的python中有numpy和pandas可能解决版本之间的冲突会耗费不少时间,我就是在numpy和pandas版本上试了很久,最终找到现在的版本是相互兼容的。

CUDA的版本切换可以参考我的另一篇博客:

CUDA版本切换

二、数据集制作

掌握图数据集制作的关键在于掌握slices切片:

for ...data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index,             edge_label=Edge_label)data_list.append(data)
data_, slices = self.collate(data_list)  # 将不同大小的图数据对齐,填充
torch.save((data_, slices), self.processed_paths[0])

和CNN不同的是,GCN没有样本维度,需要把所有样本拼成一张大图喂给GCN进行训练 

数据集生成代码:

#作者:zhouzhichao
#创建时间:2025/5/30
#内容:生成200个样本的PYG数据集import h5py
import hdf5storage
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import negative_samplingbase_dir = "D:\\无线通信网络认知\\论文1\\experiment\\直推式拓扑推理实验\\拓扑生成\\200样本\\"N = 30
grapg_size = N
train_n = 31
M = 3000class graph_data(InMemoryDataset):def __init__(self, root, signals=None, tp_list = None, transform=None, pre_transform=None):# self.Signals = Signals# self.Tp_list = Tp_listself.signals = signalsself.tp_list = tp_listsuper().__init__(root, transform, pre_transform)# self.data, self.slices = torch.load(self.processed_paths[0])self.data = torch.load(self.processed_paths[0])# 返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致@propertydef processed_file_names(self):return ['gcn_data.pt']# 生成数据集所用的方法def process(self):# data_list = []# for k in range(200):# signals = self.Signals[:, :, k]# tp_list = np.array(mat_file[self.Tp_list[0, k]])signals = self.signalstp_list =self.tp_list# tp = Tp[:,:,k]X = torch.tensor(signals, dtype=torch.float)# 所有的边Edge_index = torch.tensor(tp_list, dtype=torch.long)# 所有的边1标签edge_label = np.ones((tp_list.shape[1]))# edge_label = np.zeros((tp_list.shape[1]))Edge_label = torch.tensor(edge_label, dtype=torch.float)neg_edge_index = negative_sampling(edge_index=Edge_index, num_nodes=grapg_size,num_neg_samples=Edge_index.shape[1], method='sparse')# 拼接正负样本索引# c = 0# for i in range(31):#     for i in range(31):#         if torch.equal(Edge_index[:, i], neg_edge_index[:, i]):#             c = c + 1# print("c: ",c)Edge_label_index = Edge_indexperm = torch.randperm(Edge_index.size(1))Edge_index = Edge_index[:, perm]Edge_index = Edge_index[:, :train_n]Edge_label_index = torch.cat([Edge_label_index, neg_edge_index],dim=-1,)# 拼接正负样本Edge_label = torch.cat([Edge_label,Edge_label.new_zeros(neg_edge_index.size(1))], dim=0)# Edge_label = torch.cat([#     Edge_label,#     Edge_label.new_ones(neg_edge_index.size(1))# ], dim=0)data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)torch.save(data, self.processed_paths[0])# data_list.append(data)# data_, slices = self.collate(data_list)  # 将不同大小的图数据对齐,填充# torch.save((data_, slices), self.processed_paths[0])for snr in [0,20,40]:print("snr: ", snr)mat_file = h5py.File(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# mat_file = hdf5storage.loadmat(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# 获取数据集Signals = mat_file["Signals"][()]# signals = np.swapaxes(signals, 1, 0)Tp = mat_file["Tp"][()]Tp_list = mat_file["Tp_list"][()]# tp_list = tp_list - 1# 关闭文件# mat_file.close()# graph_data("gcn_data")# n = Signals.shape[2]n = 10for i in range(n):signals = Signals[:,:,i]tp_list = np.array(mat_file[Tp_list[0, i]])root = "gcn_data-"+str(i)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)graph_data(root, signals = signals, tp_list = tp_list)print("")print("...图数据生成完成...")

训练代码:

#作者:zhouzhichao
#创建时间:25年5月29日
#内容:统计图中有关系节点和无关系节点的GCN特征欧式距离import sys
import torch
import random
import numpy as np
import pandas as pd
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
sys.path.append('D:\无线通信网络认知\论文1\experiment\直推式拓扑推理实验\GCN推理')
from gcn_dataset import graph_data
print(torch.__version__)
print(torch.cuda.is_available())mode = "gcn"class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = GCNConv(Input_L, 1000)self.conv2 = GCNConv(1000, 20)def encode(self, x, edge_index):x1 = self.conv1(x, edge_index)x1_1 = x1.relu()x2 = self.conv2(x1_1, edge_index)x2_2 = x2.relu()return x2_2def decode(self, z, edge_label_index):# 节点和边都是矩阵,不同的计算方法致使:节点->节点,节点->边# nodes_relation = (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)# distances  = torch.norm(z[edge_label_index[0]] - z[edge_label_index[1]], dim=-1)distance_squared = torch.sum((z[edge_label_index[0]] - z[edge_label_index[1]]) ** 2, dim=-1)# print("distance_squared: ",distance_squared)return distance_squareddef decode_all(self, z):prob_adj = z @ z.t()  # 得到所有边概率矩阵return (prob_adj > 0).nonzero(as_tuple=False).t()  # 返回概率大于0的边,以edge_index的形式@torch.no_grad()def test(self,input_data):model.eval()z = model.encode(input_data.x, input_data.edge_index)out = model.decode(z, input_data.edge_label_index).view(-1)out = 1 - outN = 30
train_n = 31
M = 3000
# snr = -20
# for train_n in range(1,51):
# for M in range(3000, 499, -100):
for snr in [0,20,40]:print("snr: ", snr)for I in range(10):root = "gcn_data-"+str(I)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)gcn_data = graph_data(root)Input_L = gcn_data.x.shape[1]model = Net()# model = Net().to(device)optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()z = model.encode(gcn_data.x, gcn_data.edge_index)# out = model.decode(z, train_data.edge_label_index).view(-1).sigmoid()out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outloss = criterion(out, gcn_data.edge_label)loss.backward()optimizer.step()return lossmin_loss = 99999count = 0#早停for epoch in range(10000):loss = train()if loss<min_loss:min_loss = losscount = 0count = count + 1if count>100:breakprint("epoch:  ",epoch,"   loss: ",round(loss.item(),2), "   min_loss: ",round(min_loss.item(),2))z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)list_0 = []list_1 = []for i in range(len(gcn_data.edge_label)):true_label = gcn_data.edge_label[i].item()euclidean_distance_value = out[i].item()if true_label==1:list_1.append(euclidean_distance_value)if true_label==0:list_0.append(euclidean_distance_value)minlength = min(len(list_1), len(list_0))list_1 = random.sample(list_1, minlength)list_0 = random.sample(list_0, minlength)value = list_1 + list_0large_class = list(np.full(len(value), snr))small_class = list(np.full(len(list_1), 1)) + list(np.full(len(list_0), 0))data = {'large_class': large_class,'small_class': small_class,'value': value}# 创建一个 DataFramedf = pd.DataFrame(data)## # 保存到 Excel 文件file_path = 'D:\无线通信网络认知\论文1\大修意见\图聚类、阈值相似性图实验补充\\' + mode + '_similarity_' + str(snr) + 'db_'+str(I)+'.xlsx'df.to_excel(file_path, index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/908052.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/908052.shtml
英文地址,请注明出处:http://en.pswp.cn/news/908052.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年微信小程序开发:AR/VR与电商的最新案例

引言 微信小程序自2017年推出以来&#xff0c;已成为中国移动互联网生态的核心组成部分。根据最新数据&#xff0c;截至2025年&#xff0c;微信小程序的日活跃用户超过4.5亿&#xff0c;总数超过430万&#xff0c;覆盖电商、社交、线下服务等多个领域&#xff08;WeChat Mini …

互联网向左,区块链向右

2008年&#xff0c;中本聪首次提出了比特币的设想&#xff0c;这打开了去中心化的大门。 比特币白皮书清晰的描述了去中心化支付的解决方案&#xff0c;并分别从以下几个方面阐述了他的理念&#xff1a; 一、由转账双方点对点的通讯&#xff0c;而不通过中心化的第三方&#xf…

PV操作的C++代码示例讲解

文章目录 一、PV操作基本概念&#xff08;一&#xff09;信号量&#xff08;二&#xff09;P操作&#xff08;三&#xff09;V操作 二、PV操作的意义三、C中实现PV操作的方法&#xff08;一&#xff09;使用信号量实现PV操作代码解释&#xff1a; &#xff08;二&#xff09;使…

《对象创建的秘密:Java 内存布局、逃逸分析与 TLAB 优化详解》

大家好呀&#xff01;今天我们来聊聊Java世界里那些"看不见摸不着"但又超级重要的东西——对象在内存里是怎么"住"的&#xff0c;以及JVM这个"超级管家"是怎么帮我们优化管理的。放心&#xff0c;我会用最接地气的方式讲解&#xff0c;保证连小学…

简单实现Ajax基础应用

Ajax不是一种技术&#xff0c;而是一个编程概念。HTML 和 CSS 可以组合使用来标记和设置信息样式。JavaScript 可以修改网页以动态显示&#xff0c;并允许用户与新信息进行交互。内置的 XMLHttpRequest 对象用于在网页上执行 Ajax&#xff0c;允许网站将内容加载到屏幕上而无需…

详解开漏输出和推挽输出

开漏输出和推挽输出 以上是 GPIO 配置为输出时的内部示意图&#xff0c;我们要关注的其实就是这两个 MOS 管的开关状态&#xff0c;可以组合出四种状态&#xff1a; 两个 MOS 管都关闭时&#xff0c;输出处于一个浮空状态&#xff0c;此时他对其他点的电阻是无穷大的&#xff…

Matlab实现LSTM-SVM回归预测,作者:机器学习之心

Matlab实现LSTM-SVM回归预测&#xff0c;作者&#xff1a;机器学习之心 目录 Matlab实现LSTM-SVM回归预测&#xff0c;作者&#xff1a;机器学习之心效果一览基本介绍程序设计参考资料 效果一览 基本介绍 代码主要功能 该代码实现了一个LSTM-SVM回归预测模型&#xff0c;核心流…

Leetcode - 周赛 452

目录 一&#xff0c;3566. 等积子集的划分方案二&#xff0c;3567. 子矩阵的最小绝对差三&#xff0c;3568. 清理教室的最少移动四&#xff0c;3569. 分割数组后不同质数的最大数目 一&#xff0c;3566. 等积子集的划分方案 题目列表 本题有两种做法&#xff0c;dfs 选或不选…

【FAQ】HarmonyOS SDK 闭源开放能力 —Account Kit(5)

1.问题描述&#xff1a; 集成华为一键登录的LoginWithHuaweiIDButton&#xff0c; 但是Button默认名字叫 “华为账号一键登录”&#xff0c;太长无法显示&#xff0c;能否简写成“一键登录”与其他端一致&#xff1f; 解决方案&#xff1a; 问题分两个场景&#xff1a; 一、…

Asp.Net Core SignalR的分布式部署

文章目录 前言一、核心二、解决方案架构三、实现方案1.使用 Azure SignalR Service2.Redis Backplane(Redis 背板方案&#xff09;3.负载均衡配置粘性会话要求无粘性会话方案&#xff08;仅WebSockets&#xff09;完整部署示例&#xff08;Redis Docker&#xff09;性能优化技…

L2-054 三点共线 - java

L2-054 三点共线 语言时间限制内存限制代码长度限制栈限制Java (javac)2600 ms512 MB16KB8192 KBPython (python3)2000 ms256 MB16KB8192 KB其他编译器2000 ms64 MB16KB8192 KB 题目描述&#xff1a; 给定平面上 n n n 个点的坐标 ( x _ i , y _ i ) ( i 1 , ⋯ , n ) (x\_i…

【 java 基础知识 第一篇 】

目录 1.概念 1.1.java的特定有哪些&#xff1f; 1.2.java有哪些优势哪些劣势&#xff1f; 1.3.java为什么可以跨平台&#xff1f; 1.4JVM,JDK,JRE它们有什么区别&#xff1f; 1.5.编译型语言与解释型语言的区别&#xff1f; 2.数据类型 2.1.long与int类型可以互转吗&…

高效背诵英语四级范文

以下是结合认知科学和实战验证的 ​​高效背诵英语作文五步法​​&#xff0c;助你在30分钟内牢固记忆一篇作文&#xff0c;特别适配考前冲刺场景&#xff1a; &#x1f4dd; ​​一、解构作文&#xff08;5分钟&#xff09;​​ ​​拆解逻辑框架​​ 用荧光笔标出&#xff…

RHEL7安装教程

RHEL7安装教程 下载RHEL7镜像 通过网盘分享的文件&#xff1a;RHEL 7.zip 链接: https://pan.baidu.com/s/1ExLhdJigj-tcrHJxIca5XA?pwdjrrj 提取码: jrrj --来自百度网盘超级会员v6的分享安装 1.打开VMware&#xff0c;新建虚拟机&#xff0c;选择自定义然后下一步 2.点击…

结构型设计模式之Decorator(装饰器)

结构型设计模式之Decorator&#xff08;装饰器&#xff09; 前言&#xff1a; 本案例通过李四举例&#xff0c;不改变源代码的情况下 对“才艺”进行增强。 摘要&#xff1a; 摘要&#xff1a; 装饰器模式是一种结构型设计模式&#xff0c;允许动态地为对象添加功能而不改变其…

Kotlin委托机制使用方式和原理

目录 类委托属性委托简单的实现属性委托Kotlin标准库中提供的几个委托延迟属性LazyLazy委托参数可观察属性Observable委托vetoable委托属性储存在Map中 实践方式双击back退出Fragment/Activity传参ViewBinding和委托 类委托 类委托有点类似于Java中的代理模式 interface Base…

SpringBoot接入Kimi实践记录轻松上手

kimi简单使用 什么是Kimi API 官网&#xff1a;https://platform.moonshot.cn/ Kimi API 并不是一个我所熟知的广泛通用的术语。我的推测是&#xff0c;你可能想问的是关于 API 的一些基础知识。API&#xff08;Application Programming Interface&#xff0c;应用程序编程接…

书籍在其他数都出现k次的数组中找到只出现一次的数(7)0603

题目 给定一个整型数组arr和一个大于1的整数k。已知arr中只有1个数出现了1次&#xff0c;其他的数都出现了k次&#xff0c;请返回只出现了1次的数。 解答&#xff1a; 对此题进行思路转换&#xff0c;可以将此题&#xff0c;转换成k进制数。 k进制的两个数c和d&#xff0c;…

React 项目初始化与搭建指南

React 项目初始化有多种方式&#xff0c;可以选择已有的脚手架工具快速创建项目&#xff0c;也可以自定义项目结构并使用构建工具实现项目的构建打包流程。 1. 脚手架方案 1.1. Vite 通过 Vite 创建 React 项目非常简单&#xff0c;只需一行命令即可完成。Vite 的工程初始化…

大模型模型推理的成本过高,如何进行量化或蒸馏优化

在人工智能的浪潮中,大模型已经成为推动技术革新的核心引擎。从自然语言处理到图像生成,再到复杂的多模态任务,像GPT、BERT、T5这样的庞大模型展现出了惊人的能力。它们在翻译、对话系统、内容生成等领域大放异彩,甚至在医疗、金融等行业中也开始扮演重要角色。可以说,这些…