位置: IT常识 - 正文

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

编辑:rootadmin
Pytorch实现EdgeCNN(基于PyTorch实现) 文章目录前言一、导入相关库二、加载Cora数据集三、定义EdgeCNN网络3.1 定义EdgeConv层3.1.1 特征拼接3.1.2 max聚合3.1.3 特征映射3.1.4 EdgeConv层3.2 定义EdgeCNN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch encoder decoder,pytorch embedding lookup,pytorch embedding lookup,pyTorch实现多分类预测,pytorch demo,pytorch encoder decoder,pyTorch实现多分类预测,pytorch generator,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

本文我们将使用PyTorch来简易实现一个EdgeCNN,不使用PyG库,让新手可以理解如何PyTorch来搭建一个简易的图网络实例demo。

一、导入相关库

本项目是采用自己实现的EdgeCNN,并没有使用 PyG 库,原因是为了帮助新手朋友们能够对EdgeConv的原理有个更深刻的理解,如果熟悉之后可以尝试使用PyG库直接调用 EdgeConv 这个图层即可。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义EdgeCNN网络3.1 定义EdgeConv层

这里我们就不重点介绍EdgeCNN网络了,相信大家能够掌握基本原理,本文我们使用的是PyTorch定义网络层。

对于EdgeConv的常用参数:

nn:进行节点特征转换使用的 MLP网络,需要自己定义传入aggr:聚合邻居节点特征时采用的方式,默认为 max

我们在实现时也是考虑这几个常见参数

对于EdgeConv的传播公式为: xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=j∈N(i)∑​hθ​(xi​∣∣xj​−xi​)

上式子中的 xix_ixi​ 代表中心节点特征信息, xjx_jxj​ 代表邻居节点的特征信息,对于 hθh_{\theta}hθ​ 代表每个 EdgeConv 层的可学习参数,也就是对应传入的MLP层中的可学习参数。

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

所以我们的任务无非就是获取这几个变量,然后进行传播计算即可

3.1.1 特征拼接

该环节实现的公式为:xi∣∣xj−xix_i||x_j-x_ixi​∣∣xj​−xi​,对于这个公式来说,我们要获得两个变量,一个是中心节点 xix_ixi​(target)的特征信息,一个是邻居节点 xjx_jxj​(source)的特征信息。

对于这两个变量的获取很容易,利用 edge_index 就可以提取出来,edge_index 中保存的是每一条边的一对起始节点与终止节点,对于起始节点可以认为就是 i,对于终止节点就可以认为是 j,然后我们就会获得两个向量,分别为 row 和 col ,这两个向量就是起始顶点和终止顶点的集合。

然后我们在根据索引进行提取特征,利用 x_i = x[row] 和 x_j = x[col] 就可以将中心节点和终止节点对应的特征获取,维度为【E,feature_size】。

然后就可以按照公式实现做差然后与中心节点的特征进行拼接,获得拼接后的特征维度为原来的2倍。

row, col = edge_index # 获取target、source节点索引 [E]x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size]x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size]x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size]

对于这里 x_i 和 x_j 以及起始节点的索引初学可能混淆,所以多多打印中间结果一步一步调试进行理解。

3.1.2 max聚合

对于 EdgeConv 的默认聚合方式为 max,其实还可以使用 mean 、sum 等排列不变函数进行聚合。

对于聚合操作就是公式中求和符号那里,只不过框架给的公式是 sum ,对于聚合我们希望做的是将中心节点的邻居特征按照指定的聚合方式进行聚合。

我们可以利用 PyG 工具库中提供的 scatter 函数进行操作,该函数可以指定聚合方式以及聚合维度等参数,使用方法就是需要传入需要聚合的 Tensor ,此外还需要传入一个 index ,指明哪些向量为同一个邻居的节点,举个例子,我们传入的 index=[0,0,0,1,1] ,这就代表第一个、第二个、第三个为同一节点的邻居,所以就会将待聚合的 Tensor 的第一个向量、第二个向量、第三个向量按照指定聚合方式进行聚合。

这里说的有点抽象,自己尝试一个简单示例就明白了。

out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size]3.1.3 特征映射

在公式中有个 hθh_{\theta}hθ​,这个就代表 MLP 做特征映射做的,对于官方给的 EdgeConv 需要我们手动传入 MLP 模型,所以本项目自实现也是按照这种方式,MLP 的操作在 EdgeConv 中并没有实现,而是利用传入的模型进行操作。

这里注意一点就是定义的 MLP 模型的输入维度应该为原始维度的2倍,因为我们在这之前进行了特征拼接操作,所以特征维度进行了加倍。

out = self.mlp(out) # 特征映射 [num_nodes, out_channels]3.1.4 EdgeConv层

接下来就可以定义EdgeConv层了,该层实现了1个函数,为 forward()

forward():这个函数定义模型的传播过程,也就是上面公式的 xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=∑j∈N(i)​hθ​(xi​∣∣xj​−xi​)# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out

对于我们实现这个网络的实现效率上来讲比PyG框架内置的 EdgeConv 层稍差一点,因为我们是按照公式来一步一步利用矩阵计算得到,没有对矩阵计算以及算法进行优化,不然初学者可能看不太懂,不利于理解EdgeConv公式的传播过程,有能力的小伙伴可以看下官方源码学习一下,框架内是按照消息传递方式实现的。

3.2 定义EdgeCNN网络

上面我们已经实现好了 EdgeConv 的网络层,之后就可以调用这个层来搭建 EdgeCNN 网络。

# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个EdgeConv层,第一层的参数的输入维度就是初始每个节点的特征维度 * 2,输出维度是16。

第二个层的输入维度为16 * 2,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

这里说明一下为什么要将输入乘以2,原因是在使用MLP进行特征转换之前,会将中心节点的特征与中心节点和邻居节点的差向量做拼接,所以得到的输出维度为节点的特征维度 * 2。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9629 训练精度为:0.1214【EPOCH: 】21训练损失为:1.6709 训练精度为:0.5714【EPOCH: 】41训练损失为:1.3965 训练精度为:0.7571【EPOCH: 】61训练损失为:1.1095 训练精度为:0.8643【EPOCH: 】81训练损失为:0.9088 训练精度为:0.9286【EPOCH: 】101训练损失为:0.7454 训练精度为:0.9643【EPOCH: 】121训练损失为:0.5841 训练精度为:0.9643【EPOCH: 】141训练损失为:0.4985 训练精度为:0.9714【EPOCH: 】161训练损失为:0.3954 训练精度为:0.9714【EPOCH: 】181训练损失为:0.3339 训练精度为:0.9857【Finished Training!】>>>Train Accuracy: 1.0000 Train Loss: 0.3133>>>Test Accuracy: 0.4230 Test Loss: 1.6562训练集测试集Accuracy1.00000.4230Loss0.31331.6562完整代码import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 4.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/300575.html 转载请保留说明!

上一篇:用css画一个csdn程序猿(用css画一个扇形)

下一篇:网易二面:CPU狂飙900%,该怎么处理?(网易游戏二面)

  • 微信语音通话怎么设置和来电一样(微信语音通话怎么设置铃声)

    微信语音通话怎么设置和来电一样(微信语音通话怎么设置铃声)

  • vivox70pro怎么录屏(vivox70pro怎么录屏幕视频)

    vivox70pro怎么录屏(vivox70pro怎么录屏幕视频)

  • 中国银行手机盾怎么重新开通(中国银行手机盾怎么开通)

    中国银行手机盾怎么重新开通(中国银行手机盾怎么开通)

  • 华为剪切板记录在哪里找(华为剪切板记录不显示)

    华为剪切板记录在哪里找(华为剪切板记录不显示)

  • 苹果手机流量超了怎么自动关闭通知(苹果手机流量超过200m不能下载)

    苹果手机流量超了怎么自动关闭通知(苹果手机流量超过200m不能下载)

  • 美版抖音怎么不能看(下了美版抖音打开什么都没有)

    美版抖音怎么不能看(下了美版抖音打开什么都没有)

  • 学信网密保问题顺序有要求吗(学信网密保问题设置错误是什么意思)

    学信网密保问题顺序有要求吗(学信网密保问题设置错误是什么意思)

  • 华为微信视频怎么美颜在哪里设置(华为微信视频怎么录音录像)

    华为微信视频怎么美颜在哪里设置(华为微信视频怎么录音录像)

  • 淘宝视频尺寸(淘宝视频尺寸不符合要求怎么办)

    淘宝视频尺寸(淘宝视频尺寸不符合要求怎么办)

  • 键盘三个灯都不亮了怎么回事(键盘三个灯都不亮怎么办)

    键盘三个灯都不亮了怎么回事(键盘三个灯都不亮怎么办)

  • 苹果se2用的什么基带(苹果se2用的什么芯片)

    苹果se2用的什么基带(苹果se2用的什么芯片)

  • 苹果x换屏后有影响吗(苹果x换屏后有延迟)

    苹果x换屏后有影响吗(苹果x换屏后有延迟)

  • 安卓cajviewer用不了(cajviewer安卓手机版不能使用)

    安卓cajviewer用不了(cajviewer安卓手机版不能使用)

  • coocaa是什么牌子的电视(coocaa是什么牌子的电视遥控器)

    coocaa是什么牌子的电视(coocaa是什么牌子的电视遥控器)

  • 健康码申请怎么不用手机号(健康码申请怎么写)

    健康码申请怎么不用手机号(健康码申请怎么写)

  • qq小太阳怎么获得(qq的小太阳图标是什么样子的)

    qq小太阳怎么获得(qq的小太阳图标是什么样子的)

  • nex3s带无线充电吗(vivo nex3s支不支持无线充电)

    nex3s带无线充电吗(vivo nex3s支不支持无线充电)

  • 华为p20智能遥控在哪(华为P20智能遥控)

    华为p20智能遥控在哪(华为P20智能遥控)

  • 小米手表可以浏览照片吗(小米手表可以浏览网页吗)

    小米手表可以浏览照片吗(小米手表可以浏览网页吗)

  • 2mbps是多少网速(472mbps是多少网速)

    2mbps是多少网速(472mbps是多少网速)

  • vivo手机运动计步怎么开启(vivo手机运动计步在哪)

    vivo手机运动计步怎么开启(vivo手机运动计步在哪)

  • 手机运行内存越来越小(手机运行内存越大玩游戏越流畅吗)

    手机运行内存越来越小(手机运行内存越大玩游戏越流畅吗)

  • qq电脑里的图片怎么打印(qq里的图片在电脑哪里)

    qq电脑里的图片怎么打印(qq里的图片在电脑哪里)

  • 打印图片怎么设置白底(打印图片怎么设置彩色)

    打印图片怎么设置白底(打印图片怎么设置彩色)

  • 微信聊天会被第三方看见吗(微信聊天会被监控吗 微信聊天别人能看到吗)

    微信聊天会被第三方看见吗(微信聊天会被监控吗 微信聊天别人能看到吗)

  • ipad怎么连校园网(ipad怎么连校园网登陆网页)

    ipad怎么连校园网(ipad怎么连校园网登陆网页)

  • 极路由和普通路由有什么区别 极路由和普通路由区别详解(极路由好用吗)

    极路由和普通路由有什么区别 极路由和普通路由区别详解(极路由好用吗)

  • bash命令  命令解释器(bash详解)

    bash命令 命令解释器(bash详解)

  • 商业写字楼
  • 安装服务费增值税专票税率多少
  • 增值税普通发票税率
  • 普通发票可以抵税点吗
  • 进口货物的完税价格不包括
  • 网银发工资怎么增员的
  • 免税蔬菜税额用什么表示
  • 固定资产清理属于流动资产吗
  • 冲减收入怎么做账
  • 经营活动现金净流量在报表上怎么看
  • 融资租入固定资产属于资产吗
  • 应交税费进项税额属于什么科目借贷方向
  • 销货清单和发票金额不一致
  • 办公场所转租赁需要交哪些税费
  • 提前退休取得的一次性补贴收入
  • 核定征收的个人所得税怎么申报
  • 发票记账联丢失怎么写情况说明
  • 金税三期啥意思
  • 发票为什么会查不到信息
  • 增值税进项税额在借方还是贷方
  • 个体工商户允许哪些经营范围
  • 工程回扣增值税如何处理?
  • 购进货物赠送客户增值税处理
  • 苹果静音模式siri
  • 腾达路由器管理员密码
  • world超链接
  • 关于其他应收款账户的说法
  • win 11 bug
  • 没按时报税罚款多少
  • win10开机强制进入修复模式
  • 工程竣工结算和决算的区别
  • vueajax请求的五个步骤
  • 多源传感器融合
  • PHP:mcrypt_ecb()的用法_Mcrypt函数
  • 购买方发票已认证丢失了如何处理
  • 应付债券利息费用
  • 表格uplook
  • 双色球python算法
  • 职工薪酬纳税调整明细表怎么填写
  • tensorflow gui
  • 使用ajax实现页面分页
  • 转增资本属于什么会计科目
  • 税控会计分录
  • 织梦不更新了
  • 不动产进项税额抵扣新政策2021
  • 信用证保证金账户属于什么账户性质
  • 税盘维护费可以年年抵扣吗
  • 项目清算后未售房产怎么纳税
  • 合伙企业是否需要缴纳印花税
  • 存货的发出计价方法有哪些
  • 收到的业务赔偿如何入账
  • 社保年度汇算清缴怎么做
  • 借款合同的印花税计税依据
  • 进项发票冲红退回怎么做账
  • 新成立公司会计要做哪些事情
  • 请问,制造企业有哪些?
  • windowsxp不能启动怎么修复
  • win2003进入安全模式
  • windows中的帐户类别administrator为
  • linux用户管理器在哪
  • linux系统怎么共享
  • win7系统怎么关闭屏幕保护
  • win7系统屏幕保护设置禁用如何开启
  • js中如何实现数字相加
  • android真机调试解析包错误
  • unity3d怎么做游戏
  • 前端面试题csdn
  • javascript对象的种类
  • linux中的shell命令
  • python erf
  • vue的ssr渲染
  • css3瀑布流布局
  • python字符串中的反斜杠
  • django 表单
  • jquery.load()方法,刷新网页
  • 胡世军简历年龄多大
  • 宁夏地税领导班子名单
  • 盐城合作医疗在手机上怎么交
  • 建筑企业个人所得税管理办法
  • 广西地税代收工作怎么样
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设