当前位置:网站首页>torch_ Geometric learning 1, messagepassing
torch_ Geometric learning 1, messagepassing
2022-04-23 07:17:00 【Breeze_】
Installed version
# for windows10
pip install torch==1.2.0 # 10.0 cuda Single card GTX1660
pip install torch_geometric==1.4.1
pip install torch_sparse==0.4.4
pip install torch_scatter==1.4.0
pip install torch_cluster==1.4.5
torch_geometric
debug Pattern
with torch_geometric.debug():
out = model(data.x,data.edge_index)
torch_geometric.nn
1. The messaging MessagePassing
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
The convolution operator is extended to irregular fields , It can usually be expressed as a neighborhood aggregation , Or the process of messaging .
- x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k−1) Representation node i i i In the k − 1 k-1 k−1 Point characteristics of layer ;
- e j , i \mathbf{e}_{j,i} ej,i Indication point j j j point-to-point i i i Edge features of , Optional ;
- □ \square □ Represents a differentiable 、 Permutation invariance function , for example sum、mean、max;
- γ \gamma γ and ϕ \phi ϕ Represents a differentiable function , for example MLPs
PyG Provides messageppassing Base class , It helps to create this kind of message passing graph neural network by automatically processing message propagation . The user only needs to define the function ϕ \phi ϕ, namely message(), γ \gamma γ namely update(), And the messaging scheme to be used , namely aggr=“add”, aggr="mean" or aggr=“max”.
class MessagePassing (aggr='add',flow='source_to_target', node_dim=0)
-
aggr = (“add”、“mean” or “max”), Aggregation mode ;
-
flow = (“source_to_target” or “target_to_source”), The flow direction of messaging ;
-
node_dim, Indicates which axis to propagate along .
Correlation function :
MessagePassing.propatage(edge_index, size = None)MessagePassing.message(...)MessagePassing.update(aggr_out,** **...)MessagePassing.update(aggr_out,** **...)
Will be applied to GCN On the floor ,GCN The mathematical definition of layer is as follows :
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ( i ) ⋅ deg ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=j∈N(i)∪{
i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
among , Firstly, the characteristics of adjacent nodes are analyzed , Through the weight matrix θ \theta θ To transform ; Then through its degree , Normalize ; Finally, sum it up . The formula can be divided into the following steps :
- Add self-loops to the adjacency matrix.
- Linear transformation of node characteristic matrix .
- Calculate the normalization coefficient .
- Node feature normalization .
- Sum all adjacent node features .
The first three steps are calculated before message delivery , The last two steps pass MessagePassing Class implementation
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x Of shape by [N, in_channels]
# edge_index Of shape by [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linear transformation of node characteristic matrix
x = self.lin(x)
# Step 3: Normalization operation
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start spreading news
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j Of shape by [E, out_channels]
# Step 4: Node feature normalization
return norm.view(-1, 1) * x_j
2. Graph convolution GCN
Source of the paper : Semi-supervised Classification with Graph Convolutional Networks
版权声明
本文为[Breeze_]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230610323101.html
边栏推荐
- 取消远程依赖,用本地依赖
- Itop4412 HDMI display (4.0.3_r1)
- [2021 book recommendation] learn winui 3.0
- BottomSheetDialogFragment + ViewPager+Fragment+RecyclerView 滑动问题
- Itop4412 surfaceflinger (4.4.4_r1)
- Viewpager2 realizes Gallery effect. After notifydatasetchanged, pagetransformer displays abnormal interface deformation
- 第4章 Pytorch数据处理工具箱
- Three methods to realize the rotation of ImageView with its own center as the origin
- Personal blog website construction
- MySQL notes 1_ database
猜你喜欢
![[2021 book recommendation] artistic intelligence for IOT Cookbook](/img/8a/3ff45a911becb895e6dd9e061ac252.png)
[2021 book recommendation] artistic intelligence for IOT Cookbook

Miscellaneous learning

杂七杂八的学习

PaddleOCR 图片文字提取

组件化学习(3)ARouter中的Path和Group注解

Cancel remote dependency and use local dependency

Itop4412 LCD backlight drive (PWM)

从0开始封装一套项目的网络请求框架

Project, how to package

C connection of new world Internet of things cloud platform (simple understanding version)
随机推荐
Pytorch模型保存与加载(示例)
Recyclerview batch update view: notifyitemrangeinserted, notifyitemrangeremoved, notifyitemrangechanged
MySQL notes 3_ Restraint_ Primary key constraint
Component learning (2) arouter principle learning
PyTorch中的一些常见数据类型转换方法,与list和np.ndarray的转换方法
Markdown basic grammar notes
Cancel remote dependency and use local dependency
MySQL notes 5_ Operation data
this.getOptions is not a function
【2021年新书推荐】Learn WinUI 3.0
[exynos4412] [itop4412] [android-k] add product options
第4章 Pytorch数据处理工具箱
Encapsulate a set of project network request framework from 0
【动态规划】三角形最小路径和
【动态规划】不同的二叉搜索树
红外传感器控制开关
1.1 PyTorch和神经网络
Component learning
MySQL notes 2_ data sheet
Personal blog website construction