当前位置:网站首页>Sharpness difference (SD) calculation method of image reconstruction and generation domain index
Sharpness difference (SD) calculation method of image reconstruction and generation domain index
2022-04-23 20:48:00 【NuerNuer】
The index is 2016 Year of leCun There is a detailed explanation in an article ,https://arxiv.org/abs/1511.05440
More indicators :https://blog.csdn.net/hacker_long/article/details/104509523
Specific implementation mode :
import numpy as np
import os
import cv2
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow.compat.v1 as tf1
tf1.disable_v2_behavior()
def log10(t):
"""
Calculates the base-10 log of each element in t.
@param t: The tensor from which to calculate the base-10 log.
@return: A tensor with the base-10 log of each element in t.
"""
numerator = tf1.log(t)
denominator = tf1.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
def sharp_diff_error(gen_frames, gt_frames):
"""
Computes the Sharpness Difference error between the generated images and the ground truth
images.
@param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the
generator model.
@param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for
each frame in gen_frames.
@return: A scalar tensor. The Sharpness Difference error over each frame in the batch.
"""
gen_frames = tf1.to_float(gen_frames) / 255
gt_frames = tf1.to_float(gt_frames) / 255
shape = tf.shape(gen_frames)
num_pixels = tf1.to_float(shape[1] * shape[2] * shape[3])
# gradient difference
# create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively.
# TODO: Could this be simplified with one filter [[-1, 2], [0, -1]]?
pos = tf.constant(np.identity(3), dtype=tf.float32)
neg = -1 * pos
filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) # #[1, 2, 3, 3]
filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # #[2,1,3,3]
strides = [1, 1, 1, 1] # stride of (1, 1)
padding = 'SAME'
gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding))
gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding))
gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding))
gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding))
gen_grad_sum = gen_dx + gen_dy
gt_grad_sum = gt_dx + gt_dy
grad_diff = tf.abs(gt_grad_sum - gen_grad_sum)
batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(grad_diff, [1, 2, 3])))
return tf.reduce_mean(batch_errors)
if __name__ == "__main__":
ori_path = 'xxx.jpg'
recon_path = 'xxx.png'
batch_ori_img = []
batch_recon_img = []
ori_img = cv2.resize(cv2.imread(ori_path), (616, 112))
recon_img = cv2.resize(cv2.imread(recon_path), (616, 112))
print(ori_img.shape)
batch_ori_img.append(ori_img)
batch_recon_img.append(recon_img)
batch_ori_img_n = np.array(batch_ori_img)
batch_recon_img_n = np.array(batch_recon_img)
print(batch_recon_img_n.shape)
sd = sharp_diff_error(batch_recon_img_n, batch_ori_img_n)
sess = tf1.Session()
sd_ = sess.run(sd)
print(sd_)
It is worth noting that : Here, we skillfully design the parameters of convolution kernel , The image is listed between columns by using convolution , Difference between lines , among filter_x The size is [1,2,3,3], filter_y The size is [2,1,3,3], They correspond to each other filter Of height,width, in_channel and out_channel. Detailed principle, you can refer to tf.nn.conv2d Usage of .
PS:PSNR and SSIM The calculation of pandas Both provide mature methods that can be called directly
版权声明
本文为[NuerNuer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210545522493.html
边栏推荐
- Leetcode 1337. Row K with the weakest combat effectiveness in the matrix
- Go limit depth traversal of files in directory
- Parsing methods of JSON data in C - jar and jobobject: error reading jar from jsonreader Current JsonReader item
- Some grounded words
- Leetcode 542, 01 matrix
- Graph traversal - BFS, DFS
- Latex formula
- Singleton mode
- UKFslam
- Syntaxerror: unexpected token r in JSON at position 0
猜你喜欢
Cmake project under vs2019: calculating binocular parallax using elas method
Identifier CV is not defined in opencv4_ CAP_ PROP_ FPS; CV_ CAP_ PROP_ FRAME_ COUNT; CV_ CAP_ PROP_ POS_ Frames problem
Preliminary understanding of cache elimination algorithm (LRU and LFU)
內網滲透之DOS命令
CUDA, NVIDIA driver, cudnn download address and version correspondence
JS arrow function user and processing method of converting arrow function into ordinary function
Common problems in deploying projects with laravel and composer for PHP
小米手机全球已舍弃“MI”品牌,全面改用“xiaomi”全称品牌
Mysql database common sense storage engine
Commande dos pour la pénétration de l'Intranet
随机推荐
MySQL 存储过程和函数
Leetcode 1346. Check whether integers and their multiples exist
Realrange, reduce, repeat and einops in einops package layers. Rearrange and reduce in torch. Processing methods of high-dimensional data
MySQL advanced common functions
Unity solves Z-fighting
Bash script learning -- for loop traversal
Psychological formula for converting RGB to gray value
Parsing methods of JSON data in C - jar and jobobject: error reading jar from jsonreader Current JsonReader item
Learn to C language fourth day
一些接地气的话儿
MySQL进阶之表的增删改查
How to learn software testing? Self study or training? After reading this article, you will understand
Pytorch preserves different forms of pre training models
unity 功能扩展
Another data analysis artifact: Polaris is really powerful
危机即机遇,远程办公效率为何会提升?
GO语言开发天天生鲜项目第三天 案例-新闻发布系统二
CUDA, NVIDIA driver, cudnn download address and version correspondence
MySQL基础合集
Unity Odin ProgressBar add value column