当前位置:网站首页>医学图像数据增强-归一化
医学图像数据增强-归一化
2022-08-08 13:44:00 【智能之心】
normalizations.py
# Copyright 2021 Division of Medical Image Computing, By DJ.
# German Cancer Research Center (DKFZ)
# https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/augmentations/normalizations.py
# and Applied Computer Vision Lab, Helmholtz Imaging Platform
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
r"""
range_normalization(data, per_channel=True, rnge=(0, 1), eps=1e-8)
min_max_normalization(data, eps)
zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8)
mean_std_normalization(data, mean, std, per_channel=True)
cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False)
-focus:
from normalizations import (range_normalization, min_max_normalization, zero_mean_unit_variance_normalization, mean_std_normalization, cut_off_outliers)
"""
def range_normalization(data, per_channel=True, rnge=(0, 1), eps=1e-8):
r""" 最大最小值归一,默认归一到 (0,1)之间
@Args:
data: (b,c,z(,y(,x)))
rnge: default (0,1)
per_channel: normal in 对每个通道单独进行, 或所有通道同时进行.
@Returns:
data_normalized : shape like data'inputs, dtype=np.float32.
"""
data_normalized = np.zeros(data.shape, dtype=np.float32)
for b in range(data.shape[0]):
if per_channel:
for c in range(data.shape[1]):
tmp_normalized = min_max_normalization(data[b, c], eps)
assert data_normalized.dtype==tmp_normalized.dtype, "Type must be the same!"
data_normalized[b, c] = tmp_normalized
pass
pass
else:
tmp_normalized = min_max_normalization(data[b], eps)
assert data_normalized.dtype==tmp_normalized.dtype, "Type must be the same!"
data_normalized[b] = tmp_normalized
pass
pass
data_normalized *= (rnge[1] - rnge[0])
data_normalized += rnge[0]
return data_normalized
def min_max_normalization(data, eps):
r""" function: (x-a)/(b-a+eps)
@Args:
data: (z(,y(,x)))
"""
mn = data.min()
mx = data.max()
data_normalized = data - mn
old_range = mx - mn + eps
data_normalized = np.true_divide(data_normalized, old_range)
return np.array(data_normalized, dtype=np.float32)
def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8):
r"""零均值方差单元归一化
@Args:
data: (b,c,z(,y(,x)))
per_channel: normal in 对每个通道单独进行, 或所有通道同时进行.
@Returns:
data_normalized : shape like data'inputs, dtype=np.float32.
"""
data_normalized = np.zeros(data.shape, dtype=np.float32)
for b in range(data.shape[0]):
if per_channel:
for c in range(data.shape[1]):
mean = data[b, c].mean()
std = data[b, c].std() + epsilon
data_normalized[b, c] = (data[b, c] - mean) / std
pass
pass
else:
mean = data[b].mean()
std = data[b].std() + epsilon
data_normalized[b] = (data[b] - mean) / std
pass
pass
return data_normalized
def mean_std_normalization(data, mean, std, per_channel=True):
r"""自定义方差和标准差归一方法
@Args:
data: (b,c,z(,y(,x)))
per_channel: normal in 对每个通道单独进行, 或所有通道同时进行.
"""
data_normalized = np.zeros(data.shape, dtype=np.float32)
if isinstance(data, np.ndarray):
data_shape = tuple(list(data.shape))
elif isinstance(data, (list, tuple)):
assert len(data) > 0 and isinstance(data[0], np.ndarray)
data_shape = [len(data)] + list(data[0].shape)
else:
raise TypeError("Data has to be either a numpy array or a list")
if per_channel and isinstance(mean, float) and isinstance(std, float):
mean = [mean] * data_shape[1]
std = [std] * data_shape[1]
elif per_channel and isinstance(mean, (tuple, list, np.ndarray)):
assert len(mean) == data_shape[1]
elif per_channel and isinstance(std, (tuple, list, np.ndarray)):
assert len(std) == data_shape[1]
for b in range(data_shape[0]):
if per_channel:
for c in range(data_shape[1]):
data_normalized[b][c] = (data[b][c] - mean[c]) / std[c]
else:
data_normalized[b] = (data[b] - mean) / std
return data_normalized
def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False):
r"""截断部分像素值, 可以截断不需要的目标干扰
@Args:
百分比范围(percentile_lower, percentile_upper)
"""
for b in range(len(data)):
if per_channel:
for c in range(data.shape[1]):
cut_off_lower = np.percentile(data[b, c], percentile_lower)
cut_off_upper = np.percentile(data[b, c], percentile_upper)
data[b, c][data[b, c] < cut_off_lower] = cut_off_lower
data[b, c][data[b, c] > cut_off_upper] = cut_off_upper
pass
pass
else:
cut_off_lower = np.percentile(data[b], percentile_lower)
cut_off_upper = np.percentile(data[b], percentile_upper)
data[b][data[b] < cut_off_lower] = cut_off_lower
data[b][data[b] > cut_off_upper] = cut_off_upper
pass
pass
return data
def test():
import os
import numpy as np
import SimpleITK as sitk
def plot_3view(image_norm):
b, c, z, y, x = image_norm.shape
import matplotlib.pyplot as plt
plt.figure(figsize=(10,10))
plt.subplot(131)
plt.imshow(image_norm[0,0, z//2, ::, ::]) # x,y,z
plt.subplot(132)
plt.imshow(image_norm[0,0, ::, y//2, ::]) # x,y,z
plt.subplot(133)
plt.imshow(image_norm[0,0, ::, ::, x//2]) # x,y,z
plt.show()
pass
mhd_path = r'xxx.mhd'
sitk_image = sitk.ReadImage(mhd_path) # xyz
x,y,z = sitk_image.GetSize()
data = sitk.GetArrayFromImage(sitk_image) # zyx
data.shape # zyx
data = data[np.newaxis, np.newaxis] # nczyx
# 可视化归一化 1
image_norm = range_normalization(data, per_channel=True)
print(image_norm.min(), image_norm.max())
plot_3view(image_norm)
# 可视化归一化 2
image_norm = zero_mean_unit_variance_normalization(data, False)
print(image_norm.min(), image_norm.max())
plot_3view(image_norm)
image_norm = zero_mean_unit_variance_normalization(data, True)
print(image_norm.min(), image_norm.max())
plot_3view(image_norm)
# 可视化像素截断 3
image_norm = mean_std_normalization(data, mean=43., std=52., per_channel=True)
print(image_norm.min(), image_norm.max())
plot_3view(image_norm)
# 可视化像素截断 4
image_norm = cut_off_outliers(data.copy(), percentile_lower=0.2, percentile_upper=95, per_channel=False)
plot_3view(data)
plot_3view(image_norm)
if __name__=="__main__":
test()
边栏推荐
- 论文理解:“Self-adaptive loss balanced Physics-informed neural networks“
- win32&mfc————win32菜单栏&库
- 活动报名| StreamNative 受邀参与 ITPUB 在线技术沙龙
- 机器学习+深度学习笔记(持续更新~)
- R语言ggplot2可视化:使用ggpubr包的ggdonutchart函数可视化甜甜圈图(donut chart)、为甜甜圈图添加自定义标签(包含文本内容以及数值百分比)、lab.font参数设置标
- (4) FlinkSQL writes socket data to mysql Method 1
- 【JS高级】ES5标准规范之严格模式下的保护对象_09
- KD-SCFNet: More Accurate and Efficient Salient Object Detection Through Knowledge Distillation (ECCV2022)
- a += 1 += 1为什么是错的?
- 客户案例 | 提高银行信用卡客户贡献率
猜你喜欢
使用.NET简单实现一个Redis的高性能克隆版(三)
路由器——交换机——网络交换机:区别
Implementation of FIR filter based on FPGA (1) - using fir1 function design
腾讯,投了个 “离诺贝尔奖最近的华人”
Tsinghua | GLM-130B: An Open Bilingual Pre-training Model
化工行业数字化供应链系统:赋能化工企业高质量发展,促进上下游协同
QtWebassembly遇到的一些报错问题及解决方案
The use of string function, character function, memory function and its analog implementation
Three classic topics in C language: three-step flip method, Young's matrix, and tossing and dividing method
[Redis] Redis installation and use of client redis-cli (batch operation)
随机推荐
一文搞懂│XSS攻击、SQL注入、CSRF攻击、DDOS攻击、DNS劫持
《预训练周刊》第56期:长文本理解、即时问答、掩码自监督
用 Antlr 重构脚本解释器
基于FPGA的FIR滤波器的实现(1)—采用fir1函数设计
清华|GLM-130B:一个开放的双语预训练模型
itk中生成drr整理
sample函数—R语言
QWebAssembly中文适配
KD-SCFNet:通过知识蒸馏实现更准确、更高效的显着目标检测(ECCV2022)
itk中图像2d-3d配准整理
PC端实用软件推荐
使用shardingjdbc实现读写分离配置
连锁小酒馆第一股,海伦司能否梦圆大排档?
【Rust—LeetCode题解】1.两数之和
window停掉指定端口的进程
如果Controller里有私有的方法,能成功访问吗?
PostgreSQL 用户与schema有什么区别?
Qt 在循环中超时跳出
Code Casual Recording Notes_Dynamic Programming_322 Change Exchange
【os.path】的相关用法(持更)