当前位置:网站首页>torch_ Geometric learning 1, messagepassing
torch_ Geometric learning 1, messagepassing
2022-04-23 07:17:00 【Breeze_】
Installed version
# for windows10
pip install torch==1.2.0 # 10.0 cuda Single card GTX1660
pip install torch_geometric==1.4.1
pip install torch_sparse==0.4.4
pip install torch_scatter==1.4.0
pip install torch_cluster==1.4.5
torch_geometric
debug Pattern
with torch_geometric.debug():
out = model(data.x,data.edge_index)
torch_geometric.nn
1. The messaging MessagePassing
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
The convolution operator is extended to irregular fields , It can usually be expressed as a neighborhood aggregation , Or the process of messaging .
- x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k−1) Representation node i i i In the k − 1 k-1 k−1 Point characteristics of layer ;
- e j , i \mathbf{e}_{j,i} ej,i Indication point j j j point-to-point i i i Edge features of , Optional ;
- □ \square □ Represents a differentiable 、 Permutation invariance function , for example sum、mean、max;
- γ \gamma γ and ϕ \phi ϕ Represents a differentiable function , for example MLPs
PyG Provides messageppassing
Base class , It helps to create this kind of message passing graph neural network by automatically processing message propagation . The user only needs to define the function ϕ \phi ϕ, namely message(), γ \gamma γ namely update(), And the messaging scheme to be used , namely aggr=“add”, aggr="mean" or aggr=“max”.
class MessagePassing (aggr='add',flow='source_to_target', node_dim=0)
-
aggr = (“add”、“mean” or “max”), Aggregation mode ;
-
flow = (“source_to_target” or “target_to_source”), The flow direction of messaging ;
-
node_dim, Indicates which axis to propagate along .
Correlation function :
MessagePassing.propatage(edge_index, size = None)
MessagePassing.message(...)
MessagePassing.update(aggr_out,** **...)
MessagePassing.update(aggr_out,** **...)
Will be applied to GCN On the floor ,GCN The mathematical definition of layer is as follows :
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ( i ) ⋅ deg ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=j∈N(i)∪{
i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
among , Firstly, the characteristics of adjacent nodes are analyzed , Through the weight matrix θ \theta θ To transform ; Then through its degree , Normalize ; Finally, sum it up . The formula can be divided into the following steps :
- Add self-loops to the adjacency matrix.
- Linear transformation of node characteristic matrix .
- Calculate the normalization coefficient .
- Node feature normalization .
- Sum all adjacent node features .
The first three steps are calculated before message delivery , The last two steps pass MessagePassing Class implementation
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x Of shape by [N, in_channels]
# edge_index Of shape by [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linear transformation of node characteristic matrix
x = self.lin(x)
# Step 3: Normalization operation
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start spreading news
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j Of shape by [E, out_channels]
# Step 4: Node feature normalization
return norm.view(-1, 1) * x_j
2. Graph convolution GCN
Source of the paper : Semi-supervised Classification with Graph Convolutional Networks
版权声明
本文为[Breeze_]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230610323101.html
边栏推荐
- MarkDown基础语法笔记
- Pytorch模型保存与加载(示例)
- 第4章 Pytorch数据处理工具箱
- Android interview Online Economic encyclopedia [constantly updating...]
- Component learning (2) arouter principle learning
- 项目,怎么打包
- 1.2 初试PyTorch神经网络
- Markdown basic grammar notes
- What did you do during the internship
- AVD Pixel_ 2_ API_ 24 is already running. If that is not the case, delete the files at C:\Users\admi
猜你喜欢
1.1 PyTorch和神经网络
this. getOptions is not a function
机器学习 三: 基于逻辑回归的分类预测
【2021年新书推荐】Learn WinUI 3.0
Miscellaneous learning
Viewpager2 realizes Gallery effect. After notifydatasetchanged, pagetransformer displays abnormal interface deformation
一款png生成webp,gif, apng,同时支持webp,gif, apng转化的工具iSparta
Bottomsheetdialogfragment conflicts with listview recyclerview Scrollview sliding
谷歌AdMob广告学习
树莓派:双色LED灯实验
随机推荐
Android room database quick start
Binder机制原理
[dynamic programming] triangle minimum path sum
MarkDown基础语法笔记
adb shell 常用命令
this. getOptions is not a function
PyTorch 模型剪枝实例教程三、多参数与全局剪枝
ThreadLocal,看我就够了!
Pytorch模型保存与加载(示例)
“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated
[Exynos4412][iTOP4412][Android-K]添加产品选项
WebView displays a blank due to a certificate problem
BottomSheetDialogFragment + ViewPager+Fragment+RecyclerView 滑动问题
ProcessBuilder工具类
MySQL notes 5_ Operation data
Easyui combobox 判断输入项是否存在于下拉列表中
【2021年新书推荐】Kubernetes in Production Best Practices
谷歌AdMob广告学习
【2021年新书推荐】Effortless App Development with Oracle Visual Builder
Fill the network gap