当前位置:网站首页>直观理解 torch.nn.Unfold
直观理解 torch.nn.Unfold
2022-04-23 06:12:00 【wujpbb7】
torch.nn.Unfold 是把batch中的数据按 C、Kernel_W、Kernel_H 打包,详细解释参考:
PyTorch中torch.nn.functional.unfold函数使用详解
本文主要是把 Unfold 返回的tensor的中间部分还原成 patches。
# -*- coding:utf-8 -*-
import cv2
import torch
import numpy as np
img1 = cv2.imread('../128128/1.png')
img2 = cv2.imread('../128128/2.png')
batch = torch.tensor([img1, img2]).permute(0,3,1,2)
print(batch.shape) # [2,3,128,128]
hor_block_num = 2
ver_block_num = 4
n,c,h,w = batch.shape
assert not h%ver_block_num
assert not w%hor_block_num
patch_h, patch_w = h//ver_block_num, w//hor_block_num
use_unfold = True
if (not use_unfold):
# 方法一
patches = batch.view(n, c, ver_block_num, patch_h, hor_block_num, patch_w)
print(patches.shape) # [2,3,4,32,2,64]
patches = patches.permute(0,2,4,3,5,1)
print(patches.shape) # [2,4,2,32,64,3]
else:
# 方法二
split_block = torch.nn.Unfold(kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
patches = split_block(batch.float())
print(patches.shape) # [2,6144,8]
patches = patches.permute(0,2,1).view(n,ver_block_num,hor_block_num,-1,patch_h,patch_w)
print(patches.shape) # [2,4,2,3,32,64]
patches = patches.permute(0,1,2,4,5,3).byte()
print(patches.shape) # [2,4,2,32,64,3]
# 保存patch
for i in range(patches.shape[0]):
img = patches[i]
for y in range(ver_block_num):
for x in range(hor_block_num):
patch_filename = '../128128/%d_y%d_x%d.png'%(i+1,y,x)
patch = img[y,x].numpy()
cv2.imwrite(patch_filename, patch)
效果如下:
版权声明
本文为[wujpbb7]所创,转载请带上原文链接,感谢
https://blog.csdn.net/blueblood7/article/details/120786045
边栏推荐
- 第4章 Pytorch数据处理工具箱
- Pytorch trains the basic process of a network in five steps
- Mysql database installation and configuration details
- How keras saves and loads the keras model
- 最简单完整的libwebsockets的例子
- The Cora dataset was trained and tested using the official torch GCN
- Chapter 4 pytoch data processing toolbox
- PyTorch 13. 嵌套函数和闭包(狗头)
- Machine learning III: classification prediction based on logistic regression
- Raspberry Pie: two color LED lamp experiment
猜你喜欢
MySQL installation and configuration - detailed tutorial
【点云系列】Relationship-based Point Cloud Completion
F.pad 的妙用
Paddleocr image text extraction
PaddleOCR 图片文字提取
ArcGIS license server administrator cannot start the workaround
机器学习 二:基于鸢尾花(iris)数据集的逻辑回归分类
Chapter 8 generative deep learning
Chapter 1 numpy Foundation
Summary of image classification white box anti attack technology
随机推荐
Data class of kotlin journey
face_recognition人脸检测
利用官方torch版GCN训练并测试cora数据集
PyTorch 9. 优化器
[point cloud series] pnp-3d: a plug and play for 3D point clouds
Raspberry Pie: two color LED lamp experiment
Chapter 4 pytoch data processing toolbox
【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation
torch_ Geometric learning 1, messagepassing
第3章 Pytorch神经网络工具箱
cmder中文乱码问题
PyTorch 14. module类
GEE配置本地开发环境
[recommendation for new books in 2021] professional azure SQL managed database administration
[dynamic programming] triangle minimum path sum
C language, a number guessing game
[recommendation of new books in 2021] practical IOT hacking
【动态规划】不同路径2
第1章 NumPy基础
PyTorch 13. 嵌套函数和闭包(狗头)