Deep Image Matting implementation in PyTorch

Overview

Deep Image Matting

Deep Image Matting paper implementation in PyTorch.

Differences

  1. "fc6" is dropped.
  2. Indices pooling.

"fc6" is clumpy, over 100 millions parameters, makes the model hard to converge. I guess it is the reason why the model (paper) has to be trained stagewisely.

Performance

  • The Composition-1k testing dataset.
  • Evaluate with whole image.
  • SAD normalized by 1000.
  • Input image is normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
  • Both erode and dialte to generate trimap.
Models SAD MSE Download
paper-stage0 59.6 0.019
paper-stage1 54.6 0.017
paper-stage3 50.4 0.014
my-stage0 66.8 0.024 Link

Dependencies

  • Python 3.5.2
  • PyTorch 1.1.0

Dataset

Adobe Deep Image Matting Dataset

Follow the instruction to contact author for the dataset.

MSCOCO

Go to MSCOCO to download:

PASCAL VOC

Go to PASCAL VOC to download:

Usage

Data Pre-processing

Extract training images:

$ python pre_process.py

Train

$ python train.py

If you want to visualize during training, run in your terminal:

$ tensorboard --logdir runs

Experimental results

The Composition-1k testing dataset

  1. Test:
$ python test.py

It prints out average SAD and MSE errors when finished.

The alphamatting.com dataset

  1. Download the evaluation datasets: Go to the Datasets page and download the evaluation datasets. Make sure you pick the low-resolution dataset.

  2. Extract evaluation images:

$ python extract.py
  1. Evaluate:
$ python eval.py

Click to view whole images:

Image Trimap1 Trimap2 Trimap3
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image

Demo

Download pre-trained Deep Image Matting Link then run:

$ python demo.py
Image/Trimap Output/GT New BG/Compose
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image

小小的赞助~

Sample

若对您有帮助可给予小小的赞助~




Comments
  • the frozen model named BEST_checkpoint.tar cannot be uncompressed

    the frozen model named BEST_checkpoint.tar cannot be uncompressed

    when I try to uncompress the frozen model it shows

    tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    this means the .tar file is not complete

    opened by banrenmasanxing 6
  • my own datasets are all full human body images

    my own datasets are all full human body images

    Hi,thanks for your excellent work.Now i prepare my own datasets.This datasets are consists of thounds of high resolution image(average 4000*4000).They are all full human body images.When i process these images,i meet a questions: When i crop the trimap(generated from alpha),often crop some places which are not include hair.Such as foot,leg.Is it ok to input these images into [email protected]

    opened by lfxx 5
  • run demo.py question!

    run demo.py question!

    File "demo.py", line 84, in new_bgs = random.sample(new_bgs, 10) File "C:\Users\15432\AppData\Local\conda\conda\envs\python34\lib\random.py", line 324, in sample raise ValueError("Sample larger than population") ValueError: Sample larger than population

    opened by kxcg99 5
  • Invalid BEST_checkpoint.tar ?

    Invalid BEST_checkpoint.tar ?

    Hi, thank you for the code. I tried to download the pretrained model and extract it but it dosnt work.

    tar xvf BEST_checkpoint.tar BEST_checkpoint
    

    results in

    tar: Ceci ne ressemble pas à une archive de type « tar »
    tar: On saute à l'en-tête suivant
    tar: BEST_checkpoint : non trouvé dans l'archive
    tar: Arrêt avec code d'échec à cause des erreurs précédentes
    

    anything i'm doing the wrong way ? or the provided tar is not valid ? kind reards

    opened by flocreate 4
  • How can i get the Trimaps of my pictures?

    How can i get the Trimaps of my pictures?

    Now, I got a model, I want to use it but I can't, because I have not the Trimaps of my pictures. Are there the script of code to build the Trimaps? How can i get the Trimaps of my pictures?

    opened by huangjunxiong11 3
  • can not unpack the 'BEST_checkpoint.tar'

    can not unpack the 'BEST_checkpoint.tar'

    When i download the file "BEST_checkpoint.tar" successfully, i can't unpack it. Actually, when i try to unpack 'BEST_checkpoint.tar', it make an error. Is it my fault , or, Is the file mistaken?

    opened by huangjunxiong11 3
  • Demo error

    Demo error

    /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "demo.py", line 69, in checkpoint = torch.load(checkpoint) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 368, in load return _load(f, map_location, pickle_module) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 542, in _load result = unpickler.load() File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 505, in persistent_load data_type(size), location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 114, in default_restore_location result = fn(storage, location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 95, in _cuda_deserialize device = validate_cuda_device(location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 79, in validate_cuda_device raise RuntimeError('Attempting to deserialize object on a CUDA ' RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

    opened by Mlt123 3
  • Deep-Image-Matting-v2 implemetation on Android

    Deep-Image-Matting-v2 implemetation on Android

    Hi, Thanks for you work! its looking awesome output. I want to integrate your demo into android project. Is it possible to integrate model into android Project? If it possible, then How can i integrate this model into android project? Can you please give some suggestions? Thanks in advance.

    opened by charlizesmith 3
  • unable to start training using pretrained weigths

    unable to start training using pretrained weigths

    whenever pre-trained weights are used for training the model using own dataset, the following error is occurring.

    python3 train.py --batch-size 4 --checkpoint checkpoint/BEST_checkpoint.tar

    /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "train.py", line 180, in main() File "train.py", line 176, in main train_net(args) File "train.py", line 71, in train_net logger=logger) File "train.py", line 112, in train alpha_out = model(img) # [N, 3, 320, 320] File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 143, in forward if t.device != self.src_device_obj: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 539, in getattr type(self).name, name)) AttributeError: 'DataParallel' object has no attribute 'src_device_obj'

    opened by dev-srikanth 3
  • v2 didn't performance well as v1?

    v2 didn't performance well as v1?

    Hi, thanks for your pretrained model! I test both your v1 pretrained model and v2 pretrained model , v2 is much faster than v1 , but I found it didn't performance well as v1. the image: WechatIMG226 the origin tri map: test7_tri the v1 output: WechatIMG225 the v2 output: test7_result

    do you know what's the problem?

    Thanks,

    opened by MarSaKi 3
  • Questions about the PyTorch version and an issue in training regarding to the batch size

    Questions about the PyTorch version and an issue in training regarding to the batch size

    Hi,

    Thank you for sharing your PyTorch version of reimplementation. Would you like to share the PyTorch version you used to development?

    I am using PyTorch 1.0.1, CUDA 9, two RTX 2080 Ti to run the 'train.py' since I see you use Data Parallel module to support multi-GPUs training. However, I encountered and the trackbacks are here:

    Traceback (most recent call last): File "train.py", line 171, in main() File "train.py", line 167, in main train_net(args) File "train.py", line 64, in train_net logger=logger) File "train.py", line 103, in train alpha_out = model(img) # [N, 3, 320, 320] File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply raise output File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker output = module(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 127, in forward up4 = self.up4(up5, indices_4, unpool_shape4) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 87, in forward outputs = self.conv(outputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 43, in forward outputs = self.cbr_unit(inputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward self.padding, self.dilation, self.groups) RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

    I have tested the DATA PARALLELISM using the example here and it works well.

    opened by wuyujack 3
Owner
Yang Liu
Algorithm engineer
Yang Liu
Progressive Image Deraining Networks: A Better and Simpler Baseline

Progressive Image Deraining Networks: A Better and Simpler Baseline [arxiv] [pdf] [supp] Introduction This paper provides a better and simpler baselin

190 Dec 01, 2022
Repository of 3D Object Detection with Pointformer (CVPR2021)

3D Object Detection with Pointformer This repository contains the code for the paper 3D Object Detection with Pointformer (CVPR 2021) [arXiv]. This wo

Zhuofan Xia 117 Jan 06, 2023
Knowledge Distillation Toolbox for Semantic Segmentation

SegDistill: Toolbox for Knowledge Distillation on Semantic Segmentation Networks This repo contains the supported code and configuration files for Seg

9 Dec 12, 2022
Roger Labbe 13k Dec 29, 2022
Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Is it Time to Replace CNNs with Transformers for Medical Images? Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (C

Christos Matsoukas 80 Dec 27, 2022
Asymmetric Bilateral Motion Estimation for Video Frame Interpolation, ICCV2021

ABME (ICCV2021) Junheum Park, Chul Lee, and Chang-Su Kim Official PyTorch Code for "Asymmetric Bilateral Motion Estimation for Video Frame Interpolati

Junheum Park 86 Dec 28, 2022
Rafael Project- Classifying rockets to different types using data science algorithms.

Rocket-Classify Rafael Project- Classifying rockets to different types using data science algorithms. In this project we received data base with data

Hadassah Engel 5 Sep 18, 2021
A deep learning model for style-specific music generation.

DeepJ: A model for style-specific music generation https://arxiv.org/abs/1801.00887 Abstract Recent advances in deep neural networks have enabled algo

Henry Mao 704 Nov 23, 2022
Subpopulation detection in high-dimensional single-cell data

PhenoGraph for Python3 PhenoGraph is a clustering method designed for high-dimensional single-cell data. It works by creating a graph ("network") repr

Dana Pe'er Lab 42 Sep 05, 2022
Implementation of paper "Self-supervised Learning on Graphs:Deep Insights and New Directions"

SelfTask-GNN A PyTorch implementation of "Self-supervised Learning on Graphs: Deep Insights and New Directions". [paper] In this paper, we first deepe

Wei Jin 85 Oct 13, 2022
Unofficial pytorch-lightning implement of Mip-NeRF

mipnerf_pl Unofficial pytorch-lightning implement of Mip-NeRF, Here are some results generated by this repository (pre-trained models are provided bel

Jianxin Huang 159 Dec 23, 2022
[ICCV' 21] "Unsupervised Point Cloud Pre-training via Occlusion Completion"

OcCo: Unsupervised Point Cloud Pre-training via Occlusion Completion This repository is the official implementation of paper: "Unsupervised Point Clou

Hanchen 204 Dec 24, 2022
Predicting Price of house by considering ,house age, Distance from public transport

House-Price-Prediction Predicting Price of house by considering ,house age, Distance from public transport, No of convenient stores around house etc..

Musab Jaleel 1 Jan 08, 2022
Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents

DeepXML Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents Architectures and algorithms DeepXML supports

Extreme Classification 49 Nov 06, 2022
Lightweight Salient Object Detection in Optical Remote Sensing Images via Feature Correlation

CorrNet This project provides the code and results for 'Lightweight Salient Object Detection in Optical Remote Sensing Images via Feature Correlation'

Gongyang Li 13 Nov 03, 2022
Benchmark spaces - Benchmarks of how well different two dimensional spaces work for clustering algorithms

benchmark_spaces Benchmarks of how well different two dimensional spaces work fo

Bram Cohen 6 May 07, 2022
Official Implementation of "Tracking Grow-Finish Pigs Across Large Pens Using Multiple Cameras"

Multi Camera Pig Tracking Official Implementation of Tracking Grow-Finish Pigs Across Large Pens Using Multiple Cameras CVPR2021 CV4Animals Workshop P

44 Jan 06, 2023
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Contrastive Learning Inverts the Data Generating Process

Official code to reproduce the results and data presented in the paper Contrastive Learning Inverts the Data Generating Process.

71 Nov 25, 2022
Code for Ditto: Building Digital Twins of Articulated Objects from Interaction

Ditto: Building Digital Twins of Articulated Objects from Interaction Zhenyu Jiang, Cheng-Chun Hsu, Yuke Zhu CVPR 2022, Oral Project | arxiv News 2022

UT Robot Perception and Learning Lab 78 Dec 22, 2022