当前位置:网站首页>Complete image segmentation efficiently based on MindSpore and realize Dice!
Complete image segmentation efficiently based on MindSpore and realize Dice!
2022-08-05 10:01:00 【Ascension MindSpore】

Dice Introduction and implementation of coefficients
DiceCoefficient principle
DiceIt is the most frequently used metric in medical image competitions,It is an ensemble similarity measure,通常用于计算两个样本的相似度,The value threshold is[0, 1].Often used for image segmentation in medical images,The best result of segmentation is 1,The worst time result is 0.
Dice系数计算公式如下:

当然DiceThere is also another expression,is used in the confusion matrixTP,FP,FN来表达:
The principle of this formula is shown in the figure below:

MindSpore代码实现
先简单介绍一下MindSpore——新一代AI开源计算框架.创新编程范式,AIScientists and engineers more易使用,便于开放式创新;This computational framework satisfies终端、边缘计算、云全场景需求,能更好保护数据隐私;可开源,形成广阔应用生态.
2020年3月28日,华为在开发者大会2020上宣布,全场景AI计算框架MindSpore在码云正式开源.MindSpore着重提升易用性并降低AI开发者的开发门槛,MindSpore原生适应每个场景包括端、边缘和云,并能够在按需协同的基础上,通过实现AI算法即代码,使开发态变得更加友好,显著减少模型开发时间,降低模型开发门槛.
通过MindSpore自身的技术创新及MindSpore与华为昇腾AI处理器的协同优化,实现了Efficient operation,大大提高了计算性能;MindSpore也支持GPU、CPU等其它处理器.
"""Dice"""
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class Dice(Metric):
def __init__(self, smooth=1e-5):
super(Dice, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self._dice_coeff_sum = 0
self._samples_num = 0
self.clear()
def clear(self):
# Yes to clear historical data
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
# 更新输入数据,y_pred和y,The data entry type can beTensor,lisy或numpy,维度必须相等
if len(inputs) != 2:
raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
# 将数据进行转换,统一转换为numpy
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
if y_pred.shape != y.shape:
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
'the shape of y is {}.'.format(y_pred.shape, y.shape))
# Seek the intersection first,利用dotThe corresponding points are multiplied and added together
intersection = np.dot(y_pred.flatten(), y.flatten())
# 求并集,先将输入shapeAll pulled to one dimension,Then do point multiplication respectively,The two inputs are then added together
unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
# 利用公式进行计算,加smooth是为了防止分母为0,避免当pred和true都为0时,分子被0除的问题,同时减少过拟合
single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth)
# The coefficients for each batch are accumulated
self._dice_coeff_sum += single_dice_coeff
def eval(self):
# 进行计算
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)使用方法如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
dice = metric.eval()
print(dice)
0.20467791371802546每个batch(两组数据)进行计算的时候如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
x1= Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y1 = Tensor(np.array([[1, 0], [1, 1], [1, 0]]))
metric.update(x1, y1)
avg_dice = metric.eval()
print(dice)Dice Loss 介绍及实现
Dice Loss原理
Dice Loss 原理是在 Dice Calculated on the basis of coefficients,用1去减Dice系数
This is the case where there is only one image per batch in the binary classification,当一个批次有N张图片时,可以将图片压缩为一维向量,如下图:

对应的label也会相应变化,最后一起计算N张图片的Dice系数和Dice Loss.
MindSpore 二分类 DiceLoss 代码实现
class DiceLoss(_Loss):
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 求交集,和dicecoefficients in the same way
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
# 求并集,和dicecoefficients in the same way
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
# 利用公式进行计算
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0]
return dice_loss.mean()
@constexpr
def _check_shape(logits_shape, label_shape):
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.DiceLoss(smooth=1e-5)
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7953220862819745]MindSpore 多分类 MultiClassDiceLoss 代码实现
在MindSporeThere are various loss functions to choose from in semantic segmentation,However, the most commonly used loss function is to use cross entropy.
class MultiClassDiceLoss(_Loss):
def __init__(self, weights=None, ignore_indiex=None, activation=A.Softmax(axis=1)):
super(MultiClassDiceLoss, self).__init__()
# 利用Dice系数
self.binarydiceloss = DiceLoss(smooth=1e-5)
# 权重是一个Tensor,Should be the same dimension as the number of categories:Tensor of shape `[num_classes, dim]`.
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
# The ordinal number of the category to ignore
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
# 使用激活函数
self.activation = A.get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, Cell):
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
self.activation_flag = self.activation is not None
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 先定义一个loss,初始值为0
total_loss = 0
# 如果使用激活函数
if self.activation_flag:
logits = self.activation(logits)
# Iterates by the first number of the dimension of the label
for i in range(label.shape[1]):
if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None:
_check_weights(self.weights, label)
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/label.shape[1]使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7761003]Dice Loss 存在的问题
训练误差曲线非常混乱,很难看出关于收敛的信息.尽管可以检查在验证集上的误差来避开此问题.
边栏推荐
猜你喜欢

egg框架使用(一)

Bias lock/light lock/heavy lock lock is healthier. How is locking and unlocking accomplished?

Tanabata romantic date without overtime, RPA robot helps you get the job done

哪位大佬有20年4月或者1月的11G GI和ojvm补丁呀,帮忙发下?

leetcode: 529. Minesweeper Game

深度学习21天——卷积神经网络(CNN):天气识别(第5天)

基于MindSpore高效完成图像分割,实现Dice!

mysql进阶(二十七)数据库索引原理

Seata source code analysis: initialization process of TM RM client

Marketing Suggestions | You have an August marketing calendar to check! Suggest a collection!
随机推荐
19.服务器端会话技术Session
蚁剑webshell动态加密连接分析与实践
2022.8.3
shell脚本实例
MySQL advanced (twenty-seven) database index principle
正则表达式replaceFirst()方法具有什么功能呢?
轩辕实验室丨欧盟EVITA项目预研 第一章(四)
Marketing Suggestions | You have an August marketing calendar to check! Suggest a collection!
NowCoderTOP35-40——持续更新ing
告白数字化转型时代:麦聪软件以最简单的方式让企业把数据用起来
为什么sys_class 里显示的很多表的 RELTABLESPACE 值为 0 ?
dotnet OpenXML 解析 PPT 图表 面积图入门
Development common manual link sharing
百年北欧奢华家电品牌ASKO智能三温区酒柜臻献七夕,共品珍馐爱意
MySQL内部函数介绍
仿SBUS与串口数据固定转换
项目成本控制如何帮助项目成功?
无题十四
攻防世界-PWN-new_easypwn
公众号如何运维?公众号运维专业团队