位置: IT常识 - 正文

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

编辑:rootadmin
Pytorch实现EdgeCNN(基于PyTorch实现) 文章目录前言一、导入相关库二、加载Cora数据集三、定义EdgeCNN网络3.1 定义EdgeConv层3.1.1 特征拼接3.1.2 max聚合3.1.3 特征映射3.1.4 EdgeConv层3.2 定义EdgeCNN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch encoder decoder,pytorch embedding lookup,pytorch embedding lookup,pyTorch实现多分类预测,pytorch demo,pytorch encoder decoder,pyTorch实现多分类预测,pytorch generator,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

本文我们将使用PyTorch来简易实现一个EdgeCNN,不使用PyG库,让新手可以理解如何PyTorch来搭建一个简易的图网络实例demo。

一、导入相关库

本项目是采用自己实现的EdgeCNN,并没有使用 PyG 库,原因是为了帮助新手朋友们能够对EdgeConv的原理有个更深刻的理解,如果熟悉之后可以尝试使用PyG库直接调用 EdgeConv 这个图层即可。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义EdgeCNN网络3.1 定义EdgeConv层

这里我们就不重点介绍EdgeCNN网络了,相信大家能够掌握基本原理,本文我们使用的是PyTorch定义网络层。

对于EdgeConv的常用参数:

nn:进行节点特征转换使用的 MLP网络,需要自己定义传入aggr:聚合邻居节点特征时采用的方式,默认为 max

我们在实现时也是考虑这几个常见参数

对于EdgeConv的传播公式为: xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=j∈N(i)∑​hθ​(xi​∣∣xj​−xi​)

上式子中的 xix_ixi​ 代表中心节点特征信息, xjx_jxj​ 代表邻居节点的特征信息,对于 hθh_{\theta}hθ​ 代表每个 EdgeConv 层的可学习参数,也就是对应传入的MLP层中的可学习参数。

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

所以我们的任务无非就是获取这几个变量,然后进行传播计算即可

3.1.1 特征拼接

该环节实现的公式为:xi∣∣xj−xix_i||x_j-x_ixi​∣∣xj​−xi​,对于这个公式来说,我们要获得两个变量,一个是中心节点 xix_ixi​(target)的特征信息,一个是邻居节点 xjx_jxj​(source)的特征信息。

对于这两个变量的获取很容易,利用 edge_index 就可以提取出来,edge_index 中保存的是每一条边的一对起始节点与终止节点,对于起始节点可以认为就是 i,对于终止节点就可以认为是 j,然后我们就会获得两个向量,分别为 row 和 col ,这两个向量就是起始顶点和终止顶点的集合。

然后我们在根据索引进行提取特征,利用 x_i = x[row] 和 x_j = x[col] 就可以将中心节点和终止节点对应的特征获取,维度为【E,feature_size】。

然后就可以按照公式实现做差然后与中心节点的特征进行拼接,获得拼接后的特征维度为原来的2倍。

row, col = edge_index # 获取target、source节点索引 [E]x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size]x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size]x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size]

对于这里 x_i 和 x_j 以及起始节点的索引初学可能混淆,所以多多打印中间结果一步一步调试进行理解。

3.1.2 max聚合

对于 EdgeConv 的默认聚合方式为 max,其实还可以使用 mean 、sum 等排列不变函数进行聚合。

对于聚合操作就是公式中求和符号那里,只不过框架给的公式是 sum ,对于聚合我们希望做的是将中心节点的邻居特征按照指定的聚合方式进行聚合。

我们可以利用 PyG 工具库中提供的 scatter 函数进行操作,该函数可以指定聚合方式以及聚合维度等参数,使用方法就是需要传入需要聚合的 Tensor ,此外还需要传入一个 index ,指明哪些向量为同一个邻居的节点,举个例子,我们传入的 index=[0,0,0,1,1] ,这就代表第一个、第二个、第三个为同一节点的邻居,所以就会将待聚合的 Tensor 的第一个向量、第二个向量、第三个向量按照指定聚合方式进行聚合。

这里说的有点抽象,自己尝试一个简单示例就明白了。

out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size]3.1.3 特征映射

在公式中有个 hθh_{\theta}hθ​,这个就代表 MLP 做特征映射做的,对于官方给的 EdgeConv 需要我们手动传入 MLP 模型,所以本项目自实现也是按照这种方式,MLP 的操作在 EdgeConv 中并没有实现,而是利用传入的模型进行操作。

这里注意一点就是定义的 MLP 模型的输入维度应该为原始维度的2倍,因为我们在这之前进行了特征拼接操作,所以特征维度进行了加倍。

out = self.mlp(out) # 特征映射 [num_nodes, out_channels]3.1.4 EdgeConv层

接下来就可以定义EdgeConv层了,该层实现了1个函数,为 forward()

forward():这个函数定义模型的传播过程,也就是上面公式的 xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=∑j∈N(i)​hθ​(xi​∣∣xj​−xi​)# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out

对于我们实现这个网络的实现效率上来讲比PyG框架内置的 EdgeConv 层稍差一点,因为我们是按照公式来一步一步利用矩阵计算得到,没有对矩阵计算以及算法进行优化,不然初学者可能看不太懂,不利于理解EdgeConv公式的传播过程,有能力的小伙伴可以看下官方源码学习一下,框架内是按照消息传递方式实现的。

3.2 定义EdgeCNN网络

上面我们已经实现好了 EdgeConv 的网络层,之后就可以调用这个层来搭建 EdgeCNN 网络。

# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个EdgeConv层,第一层的参数的输入维度就是初始每个节点的特征维度 * 2,输出维度是16。

第二个层的输入维度为16 * 2,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

这里说明一下为什么要将输入乘以2,原因是在使用MLP进行特征转换之前,会将中心节点的特征与中心节点和邻居节点的差向量做拼接,所以得到的输出维度为节点的特征维度 * 2。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9629 训练精度为:0.1214【EPOCH: 】21训练损失为:1.6709 训练精度为:0.5714【EPOCH: 】41训练损失为:1.3965 训练精度为:0.7571【EPOCH: 】61训练损失为:1.1095 训练精度为:0.8643【EPOCH: 】81训练损失为:0.9088 训练精度为:0.9286【EPOCH: 】101训练损失为:0.7454 训练精度为:0.9643【EPOCH: 】121训练损失为:0.5841 训练精度为:0.9643【EPOCH: 】141训练损失为:0.4985 训练精度为:0.9714【EPOCH: 】161训练损失为:0.3954 训练精度为:0.9714【EPOCH: 】181训练损失为:0.3339 训练精度为:0.9857【Finished Training!】>>>Train Accuracy: 1.0000 Train Loss: 0.3133>>>Test Accuracy: 0.4230 Test Loss: 1.6562训练集测试集Accuracy1.00000.4230Loss0.31331.6562完整代码import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 4.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/300575.html 转载请保留说明!

上一篇:用css画一个csdn程序猿(用css画一个扇形)

下一篇:网易二面:CPU狂飙900%,该怎么处理?(网易游戏二面)

  • 抖音粉丝灯牌名称怎么改(抖音粉丝灯牌名称在哪里更改)

    抖音粉丝灯牌名称怎么改(抖音粉丝灯牌名称在哪里更改)

  • 抖音直播怎么给别人点赞(抖音直播怎么给管理员)

    抖音直播怎么给别人点赞(抖音直播怎么给管理员)

  • 爱奇艺会员最多可以几个人用(爱奇艺会员最多可以共享几台)

    爱奇艺会员最多可以几个人用(爱奇艺会员最多可以共享几台)

  • 蓝屏代码0x0000000a5(蓝屏代码0x0000000A原因)

    蓝屏代码0x0000000a5(蓝屏代码0x0000000A原因)

  • 微视为什么领不到红包(微视不能领钱了)

    微视为什么领不到红包(微视不能领钱了)

  • 显示卡在电脑哪里(电脑显示卡的作用和种类有哪些?)

    显示卡在电脑哪里(电脑显示卡的作用和种类有哪些?)

  • miui稳定版升级开发版会清除数据吗(小米稳定版升级)

    miui稳定版升级开发版会清除数据吗(小米稳定版升级)

  • 苹果se2有人脸识别吗(苹果se2有人脸识别解锁吗)

    苹果se2有人脸识别吗(苹果se2有人脸识别解锁吗)

  • 蓝牙耳机有一只不亮(蓝牙耳机有一只找不到了怎么办)

    蓝牙耳机有一只不亮(蓝牙耳机有一只找不到了怎么办)

  • 显示预览是什么意思(预览和显示的不一样)

    显示预览是什么意思(预览和显示的不一样)

  • 对方微信停用什么体现(对方微信号停用是什么状况)

    对方微信停用什么体现(对方微信号停用是什么状况)

  • 如何从手机上传照片到电脑(如何从手机上传到u盘)

    如何从手机上传照片到电脑(如何从手机上传到u盘)

  • 苹果耳机放洗衣机里洗了还能用吗(苹果无线耳机放洗衣机洗了还能用吗)

    苹果耳机放洗衣机里洗了还能用吗(苹果无线耳机放洗衣机洗了还能用吗)

  • 创客空间平台是做什么的(创客空间平台是什么)

    创客空间平台是做什么的(创客空间平台是什么)

  • 华为手机支付宝怎么分身(华为手机支付宝在桌面不显示不出来)

    华为手机支付宝怎么分身(华为手机支付宝在桌面不显示不出来)

  • 联通虚商是什么意思啊(联通虚商号码安全吗)

    联通虚商是什么意思啊(联通虚商号码安全吗)

  • 快手删作品会影响权重(快手删作品会影响浏览吗)

    快手删作品会影响权重(快手删作品会影响浏览吗)

  • 手机双卡双待怎么用(手机双卡双待怎么设置关掉一个卡)

    手机双卡双待怎么用(手机双卡双待怎么设置关掉一个卡)

  • 华为mate30pro耐摔吗(华为mate30pro抗摔吗?)

    华为mate30pro耐摔吗(华为mate30pro抗摔吗?)

  • 华为mate30怎么打开灭屏显示(华为mate30怎么打开单手操作)

    华为mate30怎么打开灭屏显示(华为mate30怎么打开单手操作)

  • 蓝牙耳机可以分开用吗(蓝牙耳机可以分开用吗,一个放家里一个放办公室)

    蓝牙耳机可以分开用吗(蓝牙耳机可以分开用吗,一个放家里一个放办公室)

  • 小米8se有红外线功能吗(小米8se有红外线遥控功能吗)

    小米8se有红外线功能吗(小米8se有红外线遥控功能吗)

  • 快手开店用淘宝还是魔筷(在快手开店好还是在淘宝开店好)

    快手开店用淘宝还是魔筷(在快手开店好还是在淘宝开店好)

  • 抖音点赞有什么好处(抖音点赞有什么奖励)

    抖音点赞有什么好处(抖音点赞有什么奖励)

  • solo3和studio3区别(solo3和studio3音质差别很大吗)

    solo3和studio3区别(solo3和studio3音质差别很大吗)

  • 查移动手机通话清单(查移动手机通话记录忘了服务密码怎)

    查移动手机通话清单(查移动手机通话记录忘了服务密码怎)

  • 关于 Vue “__ob__:Observer“ 属性的解决方案(关于减肥的好方法)

    关于 Vue “__ob__:Observer“ 属性的解决方案(关于减肥的好方法)

  • 个体工商户怎么补交个人所得税
  • 养殖合作社属于什么行业
  • 怎么导出银行对账单流水
  • 2021年成本类科目
  • 固定资产到期怎么处理
  • 报销职工住院费多久到账
  • 哪些属于不动产权
  • 固定资产当月入账下月计提折旧
  • 加权平均净资产收益率反映什么
  • 收取加盟费会计如何入账
  • 淘宝企业店铺的钱会打到哪里
  • 什么情况下发票不能冲红
  • 正规沙场需要缴纳税吗
  • 学校应该缴纳的税
  • 增值税电子普通发票可以抵扣吗
  • 现金流量表本月数和本年累计数是相等的么
  • 外贸企业增值税发票需要认证吗
  • 电子商业汇票背书是什么意思
  • win11如何安装安卓app
  • 销售免费样品账务处理
  • msoobe.exe是什么
  • 怎么计算应缴所得税
  • 潘塔纳尔湿地的主要成因
  • 小微企业报税后多久缴税
  • php 获取网页内容
  • NovelAi + Webui + Stable-diffusion本地配置
  • 网易游戏二面
  • php yii
  • shapecfg命令 管制网络设备的流量
  • 自产产品用于福利要交增值税吗
  • 建筑施工企业增值税老项目过度期成本票
  • 公司向股东借的钱怎么还
  • sqlplus分页查询
  • 免征增值税的会计处理
  • 非营利组织注册
  • 进项税额转出期限是多久
  • 公司注销时退还实收资本要交个税吗
  • 增值税发票税率计算公式
  • 记账凭证会计核算形式的程序
  • 收到上年度所得税返还会计分录
  • 固定资产多入账怎么写情况说明
  • 公司股东可以买公司股票吗
  • 制造费用需要本年累计吗
  • 公司向个人借款的会计分录怎么做
  • 技术服务费如何赋码
  • 房屋租赁合同印花税的税率
  • 开出发票后直接做账吗?
  • 技术咨询费属于什么类别
  • 工程未完工开了发票怎么做账
  • 存货明细账余额合计与存货总账余额相同
  • mysql数据库高可用方案
  • win10怎么关闭右下角图标
  • 安装windows server 2008 r2
  • windows time同步系统时间的服务无法启动报错1058解决方法
  • windows1020h2版本怎么样
  • rdclient怎么用win10怎么设置
  • windows7禁止开机启动
  • linux网络设备有哪些
  • android app 源码
  • android:View的setTag和getTag使用
  • cocos2d-js游戏开发
  • bootstrap怎么学
  • javascript的面向对象
  • jquery获取document对象
  • jquery.handleerror
  • 获取外网ip地址有什么用
  • unity ui批处理
  • javascript instanceof 与typeof使用说明
  • jquery 列表控件
  • staticlayout 换行
  • android+
  • 全面解析白羊座o型血女
  • 简述python语言
  • 晋税通注册
  • 增值税纳税申报操作流程
  • 房产的原值以什么为准
  • 税务师事务所行政登记表怎么办理
  • 天津普通发票查询平台
  • 虚开增值税专用发票罪量刑标准2023
  • 广西地税代收工作怎么样
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设