当前位置:网站首页>torch_geometric学习一,MessagePassing
torch_geometric学习一,MessagePassing
2022-04-23 06:11:00 【小风_】
安装版本
# for windows10
pip install torch==1.2.0 # 10.0 cuda 单卡GTX1660
pip install torch_geometric==1.4.1
pip install torch_sparse==0.4.4
pip install torch_scatter==1.4.0
pip install torch_cluster==1.4.5
torch_geometric
debug模式
with torch_geometric.debug():
out = model(data.x,data.edge_index)
torch_geometric.nn
1.消息传递MessagePassing
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
卷积算子推广到不规则域,通常可以表示为一个邻域聚合,或消息传递的过程。
- x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k−1)表示节点 i i i在第 k − 1 k-1 k−1层的点特征;
- e j , i \mathbf{e}_{j,i} ej,i表示点 j j j到点 i i i的边特征,可选;
- □ \square □表示一个可微、置换不变性函数,例如sum、mean、max;
- γ \gamma γ和 ϕ \phi ϕ表示可微函数,例如MLPs
PyG提供了messageppassing基类,它通过自动处理消息传播来帮助创建这类消息传递图神经网络。用户只需要定义函数 ϕ \phi ϕ,即message(), γ \gamma γ即update(),以及要使用的消息传递的方案,即aggr=“add”, aggr="mean"或aggr=“max”。
class MessagePassing (aggr='add',flow='source_to_target', node_dim=0)
-
aggr = (“add”、“mean”或“max”),聚合模式;
-
flow = (“source_to_target”或“target_to_source”),消息传递的流方向;
-
node_dim,指示沿着哪个轴传播。
相关函数:
MessagePassing.propatage(edge_index, size = None)MessagePassing.message(...)MessagePassing.update(aggr_out,** **...)MessagePassing.update(aggr_out,** **...)
将应用到GCN层上,GCN层数学定义如下:
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ( i ) ⋅ deg ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=j∈N(i)∪{
i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
其中,首先对相邻节点特征,通过权值矩阵 θ \theta θ进行变换;再通过其度,进行归一化;最后进行求和。该公式可分为以下步骤:
- Add self-loops to the adjacency matrix。
- 对节点特征矩阵进行线性变换。
- 计算归一化系数。
- 节点特征归一化。
- 求和所有相邻节点特征。
前三步在消息传递之前进行计算,后两步骤通过MessagePassing类实现
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x的shape为 [N, in_channels]
# edge_index的shape为 [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: 节点特征矩阵进行线性变换
x = self.lin(x)
# Step 3: 归一化操作
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: 开始传播消息
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j的shape为 [E, out_channels]
# Step 4: 节点特征归一化
return norm.view(-1, 1) * x_j
2.图卷积GCN
论文来源: Semi-supervised Classification with Graph Convolutional Networks
版权声明
本文为[小风_]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_33952811/article/details/120979838
边栏推荐
- MySQL5.7插入中文数据,报错:`Incorrect string value: ‘\xB8\xDF\xAE\xF9\x80 at row 1`
- Itop4412 cannot display boot animation (4.0.3_r1)
- JVM basics you should know
- BottomSheetDialogFragment 与 ListView RecyclerView ScrollView 滑动冲突问题
- Cause: dx.jar is missing
- Bottom navigation bar based on bottomnavigationview
- 读书小记——Activity
- 组件化学习(1)思想及实现方式
- 统一任务分发调度执行框架
- iTOP4412 SurfaceFlinger(4.0.3_r1)
猜你喜欢

./gradlew: Permission denied

What did you do during the internship

【2021年新书推荐】Enterprise Application Development with C# 9 and .NET 5
![[recommendation of new books in 2021] practical IOT hacking](/img/9a/13ea1e7df14a53088d4777d21ab1f6.png)
[recommendation of new books in 2021] practical IOT hacking
![[recommendation of new books in 2021] enterprise application development with C 9 and NET 5](/img/1d/cc673ca857fff3c5c48a51883d96c4.png)
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5

oracle表的约束详解

机器学习 三: 基于逻辑回归的分类预测

项目,怎么打包

C#新大陆物联网云平台的连接(简易理解版)

Cause: dx. jar is missing
随机推荐
AVD Pixel_ 2_ API_ 24 is already running. If that is not the case, delete the files at C:\Users\admi
./gradlew: Permission denied
DCMTK(DCM4CHE)与DICOOGLE协同工作
记录webView显示空白的又一坑
[2021 book recommendation] artistic intelligence for IOT Cookbook
接口幂等性问题
Project, how to package
【2021年新书推荐】Kubernetes in Production Best Practices
error 403 In most cases, you or one of your dependencies are requesting解决
Recyclerview batch update view: notifyitemrangeinserted, notifyitemrangeremoved, notifyitemrangechanged
MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
MySQL笔记2_数据表
三种实现ImageView以自身中心为原点旋转的方法
Kotlin征途之data class [数据类]
[2021 book recommendation] learn winui 3.0
[SM8150][Pixel4]LCD驱动
【2021年新书推荐】Learn WinUI 3.0
MarkDown基础语法笔记
基于BottomNavigationView实现底部导航栏
org.xml.sax.SAXParseException; lineNumber: 141; columnNumber: 252; cvc-complex-type.2.4.a: 发现了以元素 ‘b