当前位置:网站首页>pytorch geometric中为何要将稀疏邻接矩阵写成转置的形式adj_t
pytorch geometric中为何要将稀疏邻接矩阵写成转置的形式adj_t
2022-04-21 13:49:00 【每天都想躺平的大喵】
pytorch geometric中为何要将稀疏邻接矩阵写成转置的形式adj_t
一开始接触pytorch geometric的小伙伴可能和我有一样的疑问,为何数据中邻接矩阵要写成转置的形式。直到看了源码,我才理解作者这样写,是因为信息传递方式的原因,这里我跟大家分享一下。
edge_index
首先pytorch geometric的边信息可以有两种存储模式,第一种是edge_index,它的shape是[2, N],其中N是边的数目。第一个N维的元素存储边的原点的信息,称为source,第二个N维的元素存储边的目标点的信息,称为target。举个例子,如果我们有以下这样一张有向图,那么edge_index是这样的: tensor([[1, 2, 3, 4], [0, 0, 0, 0]]),边是(1,0), (2,0), (3,0), (3,0)

如果以上的图是无向图的话,那么0这个节点也指向1,2,3,4这几个节点,edge_index则应该是的: tensor([[1, 2, 3, 4,0, 0, 0, 0], [0, 0, 0, 0, 1, 2, 3, 4]]),边是(1,0), (2,0), (3,0), (3,0), (0,1), (0,2), (0,3), (0,4)。
edge_index这么写的原因是,在pytorch geometric中,用scatter一类的方式可以很方便地实现,从source到target,这种默认的边传递方式。(当然传递方式你也可以改成从target传递到source。)如果以上你还有不是很明白的地方,那就先记住,边传递的方式是从source到target的,后面在看源码的过程中,会慢慢明白的。
adj_t
pytorch geometric的边信息的第二种存储模式是adj_t,它是一个sparse tensor。这里我们看到作者在adj后面加上了t,说明它是邻接矩阵的转置。为什么要写成转置呢,我们接着上面edge_index讲。
首先我们为什么需要稀疏邻接矩阵,而不是直接用edge_index?那是因为如果可以用稀疏邻接矩阵可以极大地加快计算速度,节约内存。当然我们也有一些避免不了显式边传递的图算法,比如GAT这种需要在边上单独操作的图算法。
将edge_index转换成邻接矩阵的时候,自然而然地会写出以下形式:
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,
sparse_sizes=(num_nodes, num_nodes))
但是我们都知道,矩阵计算AW是将行上的邻居的特征聚合的过程,其中A是邻接矩阵,W是特征矩阵。如果是刚接触图算法的小伙伴,我借用了以下一张图,可以看出来,每个节点最终生成的embedding是它在A所在行中邻居对应的特征值的求和,所以本质上是聚合列对应的信息聚到行。图中A只有邻居E,不过可以想一下点D,它有B, C, E三个邻居,因此它的特征是B, C, E三个邻居特征的和,聚合了列中B, C, E对应的信息到行中的D。

那这样就产生了一个问题,edge_index中信息传递是source to target,也就是edge_index[0] to edge_index[1],而adj中是col到row,这样就产生了不一致的问题。所以在做矩阵计算传递信息的时候,作者将adj转换成adj_t,并且将它作为默认形式,这样就保持了一致。
举例
看一个作者在文档中的关于GIN实现的例子:https://pytorch-geometric.readthedocs.io/en/latest/notes/sparse_tensor.html?highlight=adj_t#memory-efficient-aggregations
from torch_sparse import matmul
class GINConv(MessagePassing):
def __init__(self):
super().__init__(aggr="add")
def forward(self, x, edge_index):
out = self.propagate(edge_index, x=x)
return MLP((1 + eps) x + out)
def message(self, x_j):
return x_j
def message_and_aggregate(self, adj_t, x):
return matmul(adj_t, x, reduce=self.aggr)
可以看到在message_and_aggregate这一步信息传递的过程中,使用的是默认的adj_t。
我正在跟新pytorch geometric系列教程pytorch_geometric代码详解,欢迎大家评论交流。
版权声明
本文为[每天都想躺平的大喵]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_39925939/article/details/121331550
边栏推荐
- Color gradient (columns, rings, etc.)
- Common configuration items of echart (line, area, text)
- npm---环境
- Number of effective triangles (double pointer (reverse scanning) + bisection)
- String - 1 String length (10 points) the C language standard function library includes the strlen function, which is used to calculate the length of the string. As an exercise, we write a function wit
- 【leetcode】144.二叉树的前序遍历
- 招聘-长期有效
- String - 1. Longueur de la chaîne (10 points) La Bibliothèque de fonctions standard du langage C comprend une fonction strlen qui calcule la longueur de la chaîne. Comme exercice, nous écrivons nous -
- 字符串串动变化 (10 分)下列程序中,函数fun的功能是:在字符串str中找出ASCII码值最大的字符,将该字符前的所有字符向后顺序移动一个位置,然后将该字符放到第一个位置上。
- 自动化监控系统Prometheus&Grafana入门实战
猜你喜欢

Number II that occurs only once (hash, bit operation, logic circuit, finite state automata)

<2021SC@SDUSC>山东大学软件工程应用与实践JPress代码分析(十一)

Vagrant详细教程

通过区块划分提高随机生成圆球干涉检查的效率

Use of JSON server

《商用密码应用与安全性评估》第三章 商用密码标准与产品应用-小结

mysql-三星索引和cost值成本计算

软件工程-基础篇刷题

山东大学项目实训树莓派提升计划二期(一)项目概述、树莓派简介

New technology is coming again, embrace agp7 0, are you ready to say goodbye to transform?
随机推荐
Tool function - date formatting
Character sorting in the string (10 points): please write the function fun to sort the strings with a length of 8 characters in descending order.
NPM --- NPM configuration file
【leetcode】144.二叉树的前序遍历
Esgyndb about the performance improvement of delete with index
<2021SC@SDUSC>山东大学软件工程应用与实践JPress代码分析(十一)
《商用密码应用与安全性评估》第一章 密码基础知识-小结
STM32单片机初学5-IIC通信驱动OLED屏幕
软件工程-基础篇刷题
EsgynDB 关于收集core信息的小技巧
npm---package.json
The data show how fierce eth burns
机器学习笔记 - SVD奇异值分解(3) 在图像上应用 SVD
Vagrant详细教程
JVM内存分配机制详解
并发锁机制之synchronized
Unittest单元测试(一)
[special topic of stack and queue] - Dual queue simulation stack
【leetcode】144. Preorder traversal of binary tree
Software engineering - Fundamentals