这是一个unet-pytorch的源码,可以训练自己的模型

Overview

Unet:U-Net: Convolutional Networks for Biomedical Image Segmentation目标检测模型在Pytorch当中的实现


目录

  1. 性能情况 Performance
  2. 所需环境 Environment
  3. 注意事项 Attention
  4. 文件下载 Download
  5. 预测步骤 How2predict
  6. 训练步骤 How2train
  7. miou计算 miou
  8. 参考资料 Reference

性能情况

unet并不适合VOC此类数据集,其更适合特征少,需要浅层特征的医药数据集之类的。

训练数据集 权值文件名称 测试数据集 输入图片大小 mIOU
VOC12+SBD unet_voc.pth VOC-Val12 512x512 55.11

所需环境

torch==1.2.0
torchvision==0.4.0

注意事项

unet_voc.pth是基于VOC拓展数据集训练的。
unet_medical.pth是使用示例的细胞分割数据集训练的。
在使用时需要注意区分。

文件下载

训练所需的unet_voc.pth和unet_medical.pth可在百度网盘中下载。
链接: https://pan.baidu.com/s/1AUBpqsSgamoQGEYpNjJg7A 提取码: i3ck

VOC拓展数据集的百度网盘如下:
链接: https://pan.baidu.com/s/1BrR7AUM1XJvPWjKMIy2uEw 提取码: vszf

预测步骤

一、使用预训练权重

a、VOC预训练权重

  1. 下载完库后解压,如果想要利用voc训练好的权重进行预测,在百度网盘或者release下载unet_voc.pth,放入model_data,运行即可预测。
img/street.jpg
  1. 利用video.py可进行摄像头检测。

b、医药预训练权重

  1. 下载完库后解压,如果想要利用医药数据集训练好的权重进行预测,在百度网盘或者release下载unet_medical.pth,放入model_data,修改unet.py中的model_path和num_classes;
_defaults = {
    "model_path"        : 'model_data/unet_voc.pth',
    "model_image_size"  : (512, 512, 3),
    "num_classes"       : 21,
    "cuda"              : True,
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"             : True
}
  1. 运行即可预测。
img/cell.png

二、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在unet.py文件里面,在如下部分修改model_path、backbone和num_classes使其对应训练好的文件;model_path对应logs文件夹下面的权值文件
_defaults = {
    "model_path"        : 'model_data/unet_voc.pth',
    "model_image_size"  : (512, 512, 3),
    "num_classes"       : 21,
    "cuda"              : True,
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"             : True
}
  1. 运行predict.py,输入
img/street.jpg
  1. 利用video.py可进行摄像头检测。

训练步骤

一、训练voc数据集

  1. 将我提供的voc数据集放入VOCdevkit中(无需运行voc2unet.py)。
  2. 在train.py中设置对应参数,默认参数已经对应voc数据集所需要的参数了,所以只要修改backbone和model_path即可。
  3. 运行train.py进行训练。

二、训练自己的数据集

  1. 本文使用VOC格式进行训练。
  2. 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的SegmentationClass中。
  3. 训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
  4. 在训练前利用voc2unet.py文件生成对应的txt。
  5. 注意修改train.py的num_classes为分类个数+1。
  6. 运行train.py即可开始训练。

三、训练医药数据集

  1. 下载VGG的预训练权重到model_data下面。
  2. 按照默认参数运行train_medical.py即可开始训练。

miou计算

参考miou计算视频和博客。

Reference

https://github.com/ggyyzm/pytorch_segmentation
https://github.com/bonlime/keras-deeplab-v3-plus

You might also like...
Comments
  • 询问一下预训练的问题

    询问一下预训练的问题

    你好,打扰了。我是想问下主干模型是指的是在下采样过程中使用的vgg吗?如果我不改变上采样是不是就不用使用imagenet训练。然后注销掉model_path=‘’ 以及 if model_path !=‘’这段。然后使用自己的数据集去进行训练。 谢谢大佬!!!!!!。实际上大佬你的voc的权重文件是不是为二次预训练的数据。 不好意思,语言表达能力不行。俺不晓得这样说大佬明不明白。

    opened by Nine9844 5
  • 训练一段时间后,CE loss变为NAN

    训练一段时间后,CE loss变为NAN

    您好,看了您的教程我试着自己搭建了一个U-Net模型,并采用Dice + CE loss作为损失函数,但在迭代几十个epoch后,我的CE loss返回了NAN值,反馈的结果是 ‘Function 'LogSoftmaxBackward' returned nan values in its 0th output.’ 同样的数据在您源码上运行没有出现这个问题,请问您是否知道些解决方法?

    opened by Breeze-Zero 2
  • 为啥在dataloader第40行转换的array的shape和cv2不一样呢

    为啥在dataloader第40行转换的array的shape和cv2不一样呢

    我使用json_to_dataset.py转化mask后尝试使用代码查看shape import cv2 import numpy as np from PIL import Image

    file = '/home/fut/Downloads/unet-pytorch-main/mydata/masks/ID_1110_json.png' img = cv2.imread(file, cv2.IMREAD_UNCHANGED) print(img.shape)

    pil = Image.open(file) img2 = np.array(pil) print(img2.shape) 结果会是: (800, 800, 3) (800, 800) 为什么PIL读取后通道就没了,正是因为这个原因你的项目会很好跑起来。

    opened by futureflsl 1
  • from tqdm import tqdm 报错

    from tqdm import tqdm 报错

    import os import time

    import numpy as np import torch import torch.backends.cudnn as cudnn import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm

    opened by Luke-Wei 1
Releases(v3.0)
  • v3.0(Apr 22, 2022)

    重要更新

    • 支持step、cos学习率下降法。
    • 支持adam、sgd优化器选择。
    • 支持不同预测模式的选择,单张图片预测、文件夹预测、视频预测、图片裁剪。
    • 更新summary.py文件,用于观看网络结构。
    • 增加了多GPU训练。
    Source code(tar.gz)
    Source code(zip)
  • v2.2(Mar 4, 2022)

    重要更新

    • 更新train.py文件,增加了大量的注释,增加多个可调整参数。
    • 更新predict.py文件,增加了大量的注释,增加fps、视频预测、批量预测等功能。
    • 更新unet.py文件,增加了大量的注释,增加先验框选择、置信度、非极大抑制等参数。
    • 合并get_dr_txt.py、get_gt_txt.py和get_map.py文件,通过一个文件来实现数据集的评估。
    • 更新voc_annotation.py文件,增加多个可调整参数。
    • 更新callback.py文件,防止多线程错误。
    • 更新summary.py文件,用于观看网络结构。
    Source code(tar.gz)
    Source code(zip)
Owner
Bubbliiiing
Bubbliiiing
Pytorch implementations of the paper Value Functions Factorization with Latent State Information Sharing in Decentralized Multi-Agent Policy Gradients

LSF-SAC Pytorch implementations of the paper Value Functions Factorization with Latent State Information Sharing in Decentralized Multi-Agent Policy G

Hanhan 2 Aug 14, 2022
Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models

Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models This repo contains a barebones implementation for the atta

16 Dec 04, 2022
Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Scribble-Supervised LiDAR Semantic Segmentation Dataset and code release for the paper Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORA

102 Dec 25, 2022
An Artificial Intelligence trying to drive a car by itself on a user created map

An Artificial Intelligence trying to drive a car by itself on a user created map

Akhil Sahukaru 17 Jan 13, 2022
python library for invisible image watermark (blind image watermark)

invisible-watermark invisible-watermark is a python library and command line tool for creating invisible watermark over image.(aka. blink image waterm

Shield Mountain 572 Jan 07, 2023
Research on Tabular Deep Learning (Python package & papers)

Research on Tabular Deep Learning For paper implementations, see the section "Papers and projects". rtdl is a PyTorch-based package providing a user-f

Yura Gorishniy 510 Dec 30, 2022
[ICCV'21] NEAT: Neural Attention Fields for End-to-End Autonomous Driving

NEAT: Neural Attention Fields for End-to-End Autonomous Driving Paper | Supplementary | Video | Poster | Blog This repository is for the ICCV 2021 pap

254 Jan 02, 2023
Meta Self-learning for Multi-Source Domain Adaptation: A Benchmark

Meta Self-Learning for Multi-Source Domain Adaptation: A Benchmark Project | Arxiv | YouTube | | Abstract In recent years, deep learning-based methods

CVSM Group - email: <a href=[email protected]"> 188 Dec 12, 2022
A pytorch implementation of the ACL2019 paper "Simple and Effective Text Matching with Richer Alignment Features".

RE2 This is a pytorch implementation of the ACL 2019 paper "Simple and Effective Text Matching with Richer Alignment Features". The original Tensorflo

287 Dec 21, 2022
Code for Transformer Hawkes Process, ICML 2020.

Transformer Hawkes Process Source code for Transformer Hawkes Process (ICML 2020). Run the code Dependencies Python 3.7. Anaconda contains all the req

Simiao Zuo 111 Dec 26, 2022
My implementation of Fully Convolutional Neural Networks in Keras

Keras-FCN This repository contains my implementation of Fully Convolutional Networks in Keras (Tensorflow backend). Currently, semantic segmentation c

The Duy Nguyen 15 Jan 13, 2020
Do you like Quick, Draw? Well what if you could train/predict doodles drawn inside Streamlit? Also draws lines, circles and boxes over background images for annotation.

Streamlit - Drawable Canvas Streamlit component which provides a sketching canvas using Fabric.js. Features Draw freely, lines, circles, boxes and pol

Fanilo Andrianasolo 325 Dec 28, 2022
Automatic tool focused on deriving metallicities of open clusters

metalcode Automatic tool focused on deriving metallicities of open clusters. Based on the method described in Pöhnl & Paunzen (2010, https://ui.adsabs

2 Dec 13, 2021
Neighborhood Contrastive Learning for Novel Class Discovery

Neighborhood Contrastive Learning for Novel Class Discovery This repository contains the official implementation of our paper: Neighborhood Contrastiv

Zhun Zhong 56 Dec 09, 2022
This is the code for CVPR 2021 oral paper: Jigsaw Clustering for Unsupervised Visual Representation Learning

JigsawClustering Jigsaw Clustering for Unsupervised Visual Representation Learning Pengguang Chen, Shu Liu, Jiaya Jia Introduction This project provid

DV Lab 73 Sep 18, 2022
Easy to use Audio Tagging in PyTorch

Audio Classification, Tagging & Sound Event Detection in PyTorch Progress: Fine-tune on audio classification Fine-tune on audio tagging Fine-tune on s

sithu3 15 Dec 22, 2022
Erpnext app for make employee salary on payroll entry based on one or more project with percentage for all project equal 100 %

Project Payroll this app for make payroll for employee based on projects like project on 30 % and project 2 70 % as account dimension it makes genral

Ibrahim Morghim 8 Jan 02, 2023
CUAD

Contract Understanding Atticus Dataset This repository contains code for the Contract Understanding Atticus Dataset (CUAD), a dataset for legal contra

The Atticus Project 273 Dec 17, 2022
The authors' implementation of Unsupervised Adversarial Learning of 3D Human Pose from 2D Joint Locations

Unsupervised Adversarial Learning of 3D Human Pose from 2D Joint Locations This is the authors' implementation of Unsupervised Adversarial Learning of

Dwango Media Village 140 Dec 07, 2022
Face Detection & Age Gender & Expression & Recognition

Face Detection & Age Gender & Expression & Recognition

Sajjad Ayobi 188 Dec 28, 2022