当前位置:网站首页>torch使用踩坑日记,矩阵加速运算
torch使用踩坑日记,矩阵加速运算
2022-04-22 05:37:00 【Guapifang】
我们知道矩阵运算可以加速,用torch搭建的模型,我们预测数据往往都是同时读取多条(在显卡支持的条件下尽可能同时处理多条数据),这样时间会很快,如果单条数据预测非常慢的,最近在写模型处理大数据,我加载多条数据一起处理,因为数据量太大,没法直接全部转换成矩阵,所以是用列表储存了数据,然后预测的时间就是切片导入到模型中,如下代码所示。
import torch
from torch.autograd import Variable
from torch import nn
import numpy as np
from tqdm import tqdm
batch_size = 16
model = nn.Linear(100, 2)#这里假设为一个简单的线性映射,输出2个数据
lis = []
for i in range(1000000):
data = np.random.rand(100)#每次生成一条100维的数据
lis.append(data) #储存数据
#预测输出
for index in tqdm(np.arange(0, 1000000, batch_size)):
x = torch.tensor(lis[index: index+batch_size], dtype=torch.float32)
out = model(x)
上面有注意到一个细节,我用列表list储存了数据,因为数据量太大,如果储存完后直接
lis = np.array(lis)
可能会爆内存的,因为列表list和numpy数据格式储存方式不一样,占用的内存numpy要大很多,于是我就直接
x = torch.tensor(lis[index: index+batch_size], dtype=torch.float32)
然后预测,这样会报错吗?不会,也可以正常运行,但是没有启动矩阵加速的作用,在python运算的底层还是类似于一条一条的运算,时间上没有任何优化,运行截图如下:

torch进行了警告,说输入的是list,但是也能正常运行计算,但是和矩阵运算不一样,上面的运行花了16s,现在修改成矩阵格式如下:
import torch
from torch.autograd import Variable
from torch import nn
import numpy as np
from tqdm import tqdm
batch_size = 16
model = nn.Linear(100, 2)#这里假设为一个简单的线性映射,输出2个数据
lis = []
for i in range(1000000):
data = np.random.rand(100)#每次生成一条100维的数据
lis.append(data) #储存数据
#预测输出
for index in tqdm(np.arange(0, 1000000, batch_size)):
x = torch.tensor(np.array(lis[index: index+batch_size]), dtype=torch.float32)
out = model(x)
运行截图如下:

3s就跑完了,矩阵加速运算,永远的神~
版权声明
本文为[Guapifang]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_43918046/article/details/123348916
边栏推荐
- Single machine deployment redis master-slave and sentinel mode (one master, two slave and three sentinels)
- LeetCode 898. 子数组按位或操作 - set
- 数据挖掘——序列模式挖掘
- acwing854. Floyd求最短路
- ThreadLocal.ThreadLocalMap分析
- 数字三角形(动态规划dp)
- Installing mysql8 under Linux
- 等腰三角形-第九届蓝桥省赛-C组
- Fastjson determines whether the JSON string is object or list < object >
- 关于form表单点击submit按钮后,页面自动刷新的问题解决
猜你喜欢
随机推荐
Advanced part of MySQL
Judge whether there are links in the linked list
list stream: reduce的使用实例
什么是JSON?初识JSON
辰辰采草药
‘PdfFileWriter‘ object has no attribute ‘stream‘
Redis缓存负载均衡使用的一致性哈希算法
09-Redis之IO多路复用
Strong connected component of "tarjan" undirected graph
1.计算a+b
折现分割平面
LeetCode 1770. 执行乘法运算的最大分数 -- 区间DP
AcWing 836. 合并集合(并查集)
JVM探究
Circular linked list 2
最长上升子序列(lis)
POI 和 EasyExcel练习
寻找矩阵中“块“的个数(BFS)
mysql中on duplicate key update 使用详解
数的范围( 二分 经典模板题目)









