位置: IT常识 - 正文

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

编辑:rootadmin
Pytorch实现EdgeCNN(基于PyTorch实现) 文章目录前言一、导入相关库二、加载Cora数据集三、定义EdgeCNN网络3.1 定义EdgeConv层3.1.1 特征拼接3.1.2 max聚合3.1.3 特征映射3.1.4 EdgeConv层3.2 定义EdgeCNN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch encoder decoder,pytorch embedding lookup,pytorch embedding lookup,pyTorch实现多分类预测,pytorch demo,pytorch encoder decoder,pyTorch实现多分类预测,pytorch generator,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

本文我们将使用PyTorch来简易实现一个EdgeCNN,不使用PyG库,让新手可以理解如何PyTorch来搭建一个简易的图网络实例demo。

一、导入相关库

本项目是采用自己实现的EdgeCNN,并没有使用 PyG 库,原因是为了帮助新手朋友们能够对EdgeConv的原理有个更深刻的理解,如果熟悉之后可以尝试使用PyG库直接调用 EdgeConv 这个图层即可。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义EdgeCNN网络3.1 定义EdgeConv层

这里我们就不重点介绍EdgeCNN网络了,相信大家能够掌握基本原理,本文我们使用的是PyTorch定义网络层。

对于EdgeConv的常用参数:

nn:进行节点特征转换使用的 MLP网络,需要自己定义传入aggr:聚合邻居节点特征时采用的方式,默认为 max

我们在实现时也是考虑这几个常见参数

对于EdgeConv的传播公式为: xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=j∈N(i)∑​hθ​(xi​∣∣xj​−xi​)

上式子中的 xix_ixi​ 代表中心节点特征信息, xjx_jxj​ 代表邻居节点的特征信息,对于 hθh_{\theta}hθ​ 代表每个 EdgeConv 层的可学习参数,也就是对应传入的MLP层中的可学习参数。

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

所以我们的任务无非就是获取这几个变量,然后进行传播计算即可

3.1.1 特征拼接

该环节实现的公式为:xi∣∣xj−xix_i||x_j-x_ixi​∣∣xj​−xi​,对于这个公式来说,我们要获得两个变量,一个是中心节点 xix_ixi​(target)的特征信息,一个是邻居节点 xjx_jxj​(source)的特征信息。

对于这两个变量的获取很容易,利用 edge_index 就可以提取出来,edge_index 中保存的是每一条边的一对起始节点与终止节点,对于起始节点可以认为就是 i,对于终止节点就可以认为是 j,然后我们就会获得两个向量,分别为 row 和 col ,这两个向量就是起始顶点和终止顶点的集合。

然后我们在根据索引进行提取特征,利用 x_i = x[row] 和 x_j = x[col] 就可以将中心节点和终止节点对应的特征获取,维度为【E,feature_size】。

然后就可以按照公式实现做差然后与中心节点的特征进行拼接,获得拼接后的特征维度为原来的2倍。

row, col = edge_index # 获取target、source节点索引 [E]x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size]x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size]x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size]

对于这里 x_i 和 x_j 以及起始节点的索引初学可能混淆,所以多多打印中间结果一步一步调试进行理解。

3.1.2 max聚合

对于 EdgeConv 的默认聚合方式为 max,其实还可以使用 mean 、sum 等排列不变函数进行聚合。

对于聚合操作就是公式中求和符号那里,只不过框架给的公式是 sum ,对于聚合我们希望做的是将中心节点的邻居特征按照指定的聚合方式进行聚合。

我们可以利用 PyG 工具库中提供的 scatter 函数进行操作,该函数可以指定聚合方式以及聚合维度等参数,使用方法就是需要传入需要聚合的 Tensor ,此外还需要传入一个 index ,指明哪些向量为同一个邻居的节点,举个例子,我们传入的 index=[0,0,0,1,1] ,这就代表第一个、第二个、第三个为同一节点的邻居,所以就会将待聚合的 Tensor 的第一个向量、第二个向量、第三个向量按照指定聚合方式进行聚合。

这里说的有点抽象,自己尝试一个简单示例就明白了。

out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size]3.1.3 特征映射

在公式中有个 hθh_{\theta}hθ​,这个就代表 MLP 做特征映射做的,对于官方给的 EdgeConv 需要我们手动传入 MLP 模型,所以本项目自实现也是按照这种方式,MLP 的操作在 EdgeConv 中并没有实现,而是利用传入的模型进行操作。

这里注意一点就是定义的 MLP 模型的输入维度应该为原始维度的2倍,因为我们在这之前进行了特征拼接操作,所以特征维度进行了加倍。

out = self.mlp(out) # 特征映射 [num_nodes, out_channels]3.1.4 EdgeConv层

接下来就可以定义EdgeConv层了,该层实现了1个函数,为 forward()

forward():这个函数定义模型的传播过程,也就是上面公式的 xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=∑j∈N(i)​hθ​(xi​∣∣xj​−xi​)# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out

对于我们实现这个网络的实现效率上来讲比PyG框架内置的 EdgeConv 层稍差一点,因为我们是按照公式来一步一步利用矩阵计算得到,没有对矩阵计算以及算法进行优化,不然初学者可能看不太懂,不利于理解EdgeConv公式的传播过程,有能力的小伙伴可以看下官方源码学习一下,框架内是按照消息传递方式实现的。

3.2 定义EdgeCNN网络

上面我们已经实现好了 EdgeConv 的网络层,之后就可以调用这个层来搭建 EdgeCNN 网络。

# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个EdgeConv层,第一层的参数的输入维度就是初始每个节点的特征维度 * 2,输出维度是16。

第二个层的输入维度为16 * 2,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

这里说明一下为什么要将输入乘以2,原因是在使用MLP进行特征转换之前,会将中心节点的特征与中心节点和邻居节点的差向量做拼接,所以得到的输出维度为节点的特征维度 * 2。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9629 训练精度为:0.1214【EPOCH: 】21训练损失为:1.6709 训练精度为:0.5714【EPOCH: 】41训练损失为:1.3965 训练精度为:0.7571【EPOCH: 】61训练损失为:1.1095 训练精度为:0.8643【EPOCH: 】81训练损失为:0.9088 训练精度为:0.9286【EPOCH: 】101训练损失为:0.7454 训练精度为:0.9643【EPOCH: 】121训练损失为:0.5841 训练精度为:0.9643【EPOCH: 】141训练损失为:0.4985 训练精度为:0.9714【EPOCH: 】161训练损失为:0.3954 训练精度为:0.9714【EPOCH: 】181训练损失为:0.3339 训练精度为:0.9857【Finished Training!】>>>Train Accuracy: 1.0000 Train Loss: 0.3133>>>Test Accuracy: 0.4230 Test Loss: 1.6562训练集测试集Accuracy1.00000.4230Loss0.31331.6562完整代码import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 4.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/300575.html 转载请保留说明!

上一篇:用css画一个csdn程序猿(用css画一个扇形)

下一篇:网易二面:CPU狂飙900%,该怎么处理?(网易游戏二面)

  • 附近荣耀手机售后维修点(附近荣耀手机售后维修点电话)

    附近荣耀手机售后维修点(附近荣耀手机售后维修点电话)

  • 我的快手怎么没有快接单功能(我的快手怎么没有放映厅功能)

    我的快手怎么没有快接单功能(我的快手怎么没有放映厅功能)

  • 三星s10为什么还卡顿(三星s10为什么还没上市)

    三星s10为什么还卡顿(三星s10为什么还没上市)

  • 京东闪电退款在哪儿看(京东 闪电退款)

    京东闪电退款在哪儿看(京东 闪电退款)

  • 支付宝无收款权限多久解除(支付宝无收款权限要多久恢复?)

    支付宝无收款权限多久解除(支付宝无收款权限要多久恢复?)

  • 腾讯会议怎么上课(腾讯会议怎么上传视频)

    腾讯会议怎么上课(腾讯会议怎么上传视频)

  • hmscore停止运行是什么意思(hms core显示停止运行,手机还能正常使用吗)

    hmscore停止运行是什么意思(hms core显示停止运行,手机还能正常使用吗)

  • 为什么手机qq下载的文件找不到(为什么手机qq下载的文件删除不了)

    为什么手机qq下载的文件找不到(为什么手机qq下载的文件删除不了)

  • apple watch充电的时候屏幕上显示电源线(apple watch充电方案)

    apple watch充电的时候屏幕上显示电源线(apple watch充电方案)

  • 一直通话中有几种原因(一直通话中什么原因)

    一直通话中有几种原因(一直通话中什么原因)

  • 苹果8可以升级13.5吗(苹果8可以升级16.1.1吗)

    苹果8可以升级13.5吗(苹果8可以升级16.1.1吗)

  • 耳机突然声音小了怎么回事(耳机突然声音小了)

    耳机突然声音小了怎么回事(耳机突然声音小了)

  • 微信不能在电脑上登录是怎么回事(微信不能在电脑上登录是什么原因和手机号)

    微信不能在电脑上登录是怎么回事(微信不能在电脑上登录是什么原因和手机号)

  • 抖音上的表情包怎么保存到相册(怎样保存抖音上的表情包)

    抖音上的表情包怎么保存到相册(怎样保存抖音上的表情包)

  • 网络卡顿是路由器的原因吗(网络卡顿是路由器的问题还是网络的问题)

    网络卡顿是路由器的原因吗(网络卡顿是路由器的问题还是网络的问题)

  • qq音乐内测版是每个人都有吗(qq 音乐官方内测)

    qq音乐内测版是每个人都有吗(qq 音乐官方内测)

  • qq等级加速包在哪里看(qq等级加速包怎么看)

    qq等级加速包在哪里看(qq等级加速包怎么看)

  • 网络受限是什么原因(网络受限是什么原因 登不上国内软件)

    网络受限是什么原因(网络受限是什么原因 登不上国内软件)

  • 多媒体数据压缩技术指标(多媒体数据压缩技术是多媒体关键技术之一)

    多媒体数据压缩技术指标(多媒体数据压缩技术是多媒体关键技术之一)

  • 抖音限制分享怎么解决(抖音分享被限)

    抖音限制分享怎么解决(抖音分享被限)

  • 苹果11放几个卡(苹果11放几张电话卡)

    苹果11放几个卡(苹果11放几张电话卡)

  • qq音乐怎么注销qq号(qq音乐怎么注销登录)

    qq音乐怎么注销qq号(qq音乐怎么注销登录)

  • t99平台是什么(t99咋样)

    t99平台是什么(t99咋样)

  • 抖音蓝v号是什么(抖音蓝v号有什么好处)

    抖音蓝v号是什么(抖音蓝v号有什么好处)

  • 电脑systeminfo命令打不开提示systeminfo.exe丢失怎么办?(system 命令)

    电脑systeminfo命令打不开提示systeminfo.exe丢失怎么办?(system 命令)

  • 合同解除的效力民法典
  • 应交印花税会计分录
  • 进口产品销售需要什么资质
  • 个人发票需要身份证信息吗?
  • 邮政开票税点是什么意思
  • 收回公司经营权需要做什么
  • 独资企业是向地税申报个税吗
  • 简易征收和简易计税的区别
  • 小规模纳税人销售已使用固定资产
  • 库存商品报废进项转出
  • 如果一直没到国税局办理登记怎么办
  • 公司车辆高速费用能开增值税专用发票吗
  • 转账支票怎么进账到个人账户
  • 没有核定税种怎么报税
  • 技术开发费加计扣除优惠政策
  • 纳税人识别号在哪里能查到
  • 企事业承包承租方缴纳的管理费税费
  • 维修材料费主要包括
  • 非盈利组织纳税筹划
  • 公司作账都按不含税价吗
  • 二手房产增值税率
  • 坏账准备需要做账吗
  • 收回长期股权投资账务处理成本法
  • 合同成本对应科目
  • 长期股权投资溢价购入
  • Win10系统如何修改开机密码
  • 苹果电脑mac设备在哪里
  • 收益性支出的项目有哪些
  • 销售返利如何做账
  • 逾期未收回包装物押金会计分录
  • 无形资产摊销时点
  • 个人股份转让
  • 企业收到对外投资收益交所得税吗
  • 过拟合能不能从根本上解决
  • 报废汽车残值收入如何计税
  • 不想预缴所得税能不能提前暂估费用,会计分录
  • 十四届智能车规则
  • vue中用echarts
  • 后浪是什么意思网络用语
  • 销售产品的包装费
  • mysql的备份方式
  • 增值税专用发票怎么开
  • 融资租赁的固定资产
  • 车辆购置税如何在电子税务局缴纳
  • 待报解预算收入怎么做分录
  • 合同履约成本如何设置明细科目
  • 无形资产减值准备借贷
  • 机票报销是什么发票
  • 营改增后增值税增加了什么征收范围
  • 长期待摊销费用属于
  • 简易计税开具的发票取得的进项可以抵扣嘛
  • 损益类会计科目有哪些
  • 以货物抵应收账款的分录
  • 房屋租赁合同怎么写对房东有利
  • 收到税务局退还的个税手续费怎么入账
  • 建账的基本原则是什么
  • win7系统如何彻底删除xp
  • 师说词类活用
  • xp操作系统入门
  • mac插hdmi没画面
  • rpm命令的作用是什么
  • Ubuntu14.04 的 SSH 无密码登录的设置方法
  • OS X10.10.5 Yosemite beta2发布 os x10.10.5yosemite beta2官网下载地址
  • edge以ie
  • ubuntu操作
  • win7系统每次关机都安装更新
  • win10系统最新更新
  • 64位Win7环境下vs2013配置opengl
  • js计算字符串长度 汉字长度
  • input lead
  • javascript运用
  • linux做ftp
  • js强制把网址设为密码
  • 如何使用form表单
  • 用shell脚本创建用户
  • 安卓 截图
  • js根据name取值
  • 丹麦个人所得税税率表
  • 烟丝和烟有什么区别
  • 安徽地税局领导班子名单
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设