当前位置:网站首页>torch_geometric学习一,MessagePassing
torch_geometric学习一,MessagePassing
2022-04-23 06:11:00 【小风_】
安装版本
# for windows10
pip install torch==1.2.0 # 10.0 cuda 单卡GTX1660
pip install torch_geometric==1.4.1
pip install torch_sparse==0.4.4
pip install torch_scatter==1.4.0
pip install torch_cluster==1.4.5
torch_geometric
debug模式
with torch_geometric.debug():
out = model(data.x,data.edge_index)
torch_geometric.nn
1.消息传递MessagePassing
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
卷积算子推广到不规则域,通常可以表示为一个邻域聚合,或消息传递的过程。
- x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k−1)表示节点 i i i在第 k − 1 k-1 k−1层的点特征;
- e j , i \mathbf{e}_{j,i} ej,i表示点 j j j到点 i i i的边特征,可选;
- □ \square □表示一个可微、置换不变性函数,例如sum、mean、max;
- γ \gamma γ和 ϕ \phi ϕ表示可微函数,例如MLPs
PyG提供了messageppassing
基类,它通过自动处理消息传播来帮助创建这类消息传递图神经网络。用户只需要定义函数 ϕ \phi ϕ,即message(), γ \gamma γ即update(),以及要使用的消息传递的方案,即aggr=“add”, aggr="mean"或aggr=“max”。
class MessagePassing (aggr='add',flow='source_to_target', node_dim=0)
-
aggr = (“add”、“mean”或“max”),聚合模式;
-
flow = (“source_to_target”或“target_to_source”),消息传递的流方向;
-
node_dim,指示沿着哪个轴传播。
相关函数:
MessagePassing.propatage(edge_index, size = None)
MessagePassing.message(...)
MessagePassing.update(aggr_out,** **...)
MessagePassing.update(aggr_out,** **...)
将应用到GCN层上,GCN层数学定义如下:
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ( i ) ⋅ deg ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=j∈N(i)∪{
i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
其中,首先对相邻节点特征,通过权值矩阵 θ \theta θ进行变换;再通过其度,进行归一化;最后进行求和。该公式可分为以下步骤:
- Add self-loops to the adjacency matrix。
- 对节点特征矩阵进行线性变换。
- 计算归一化系数。
- 节点特征归一化。
- 求和所有相邻节点特征。
前三步在消息传递之前进行计算,后两步骤通过MessagePassing类实现
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x的shape为 [N, in_channels]
# edge_index的shape为 [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: 节点特征矩阵进行线性变换
x = self.lin(x)
# Step 3: 归一化操作
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: 开始传播消息
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j的shape为 [E, out_channels]
# Step 4: 节点特征归一化
return norm.view(-1, 1) * x_j
2.图卷积GCN
论文来源: Semi-supervised Classification with Graph Convolutional Networks
版权声明
本文为[小风_]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_33952811/article/details/120979838
边栏推荐
- Using stack to realize queue out and in
- 三子棋小游戏
- ArcGIS License Server Administrator 无法启动解决方法
- adb shell常用模拟按键keycode
- 组件化学习(1)思想及实现方式
- 记录webView显示空白的又一坑
- [recommendation of new books in 2021] practical IOT hacking
- [2021 book recommendation] effortless app development with Oracle visual builder
- ./gradlew: Permission denied
- js时间获取本周一、周日,判断时间是今天,今天前、后
猜你喜欢
树莓派:双色LED灯实验
Cancel remote dependency and use local dependency
Encapsulate a set of project network request framework from 0
Itop4412 HDMI display (4.4.4_r1)
[2021 book recommendation] kubernetes in production best practices
[2021 book recommendation] artistic intelligence for IOT Cookbook
JVM basics you should know
Google AdMob advertising learning
【2021年新书推荐】Enterprise Application Development with C# 9 and .NET 5
【2021年新书推荐】Artificial Intelligence for IoT Cookbook
随机推荐
iTOP4412 HDMI显示(4.4.4_r1)
“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated
組件化學習
Thanos.sh灭霸脚本,轻松随机删除系统一半的文件
【2021年新书推荐】Effortless App Development with Oracle Visual Builder
组件化学习
素数求解的n种境界
Recyclerview batch update view: notifyitemrangeinserted, notifyitemrangeremoved, notifyitemrangechanged
org. xml. sax. SAXParseException; lineNumber: 141; columnNumber: 252; cvc-complex-type. 2.4. a: Found element 'B
ArcGIS License Server Administrator 无法启动解决方法
MySQL笔记3_约束_主键约束
BottomSheetDialogFragment + ViewPager+Fragment+RecyclerView 滑动问题
Bottomsheetdialogfragment conflicts with listview recyclerview Scrollview sliding
[2021 book recommendation] Red Hat Certified Engineer (RHCE) Study Guide
PaddleOCR 图片文字提取
[2021 book recommendation] red hat rhcsa 8 cert Guide: ex200
MySQL notes 4_ Primary key auto_increment
iTOP4412内核反复重启
[exynos4412] [itop4412] [android-k] add product options
iTOP4412 SurfaceFlinger(4.0.3_r1)