Deep Image Matting implementation in PyTorch

Overview

Deep Image Matting

Deep Image Matting paper implementation in PyTorch.

Differences

  1. "fc6" is dropped.
  2. Indices pooling.

"fc6" is clumpy, over 100 millions parameters, makes the model hard to converge. I guess it is the reason why the model (paper) has to be trained stagewisely.

Performance

  • The Composition-1k testing dataset.
  • Evaluate with whole image.
  • SAD normalized by 1000.
  • Input image is normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
  • Both erode and dialte to generate trimap.
Models SAD MSE Download
paper-stage0 59.6 0.019
paper-stage1 54.6 0.017
paper-stage3 50.4 0.014
my-stage0 66.8 0.024 Link

Dependencies

  • Python 3.5.2
  • PyTorch 1.1.0

Dataset

Adobe Deep Image Matting Dataset

Follow the instruction to contact author for the dataset.

MSCOCO

Go to MSCOCO to download:

PASCAL VOC

Go to PASCAL VOC to download:

Usage

Data Pre-processing

Extract training images:

$ python pre_process.py

Train

$ python train.py

If you want to visualize during training, run in your terminal:

$ tensorboard --logdir runs

Experimental results

The Composition-1k testing dataset

  1. Test:
$ python test.py

It prints out average SAD and MSE errors when finished.

The alphamatting.com dataset

  1. Download the evaluation datasets: Go to the Datasets page and download the evaluation datasets. Make sure you pick the low-resolution dataset.

  2. Extract evaluation images:

$ python extract.py
  1. Evaluate:
$ python eval.py

Click to view whole images:

Image Trimap1 Trimap2 Trimap3
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image

Demo

Download pre-trained Deep Image Matting Link then run:

$ python demo.py
Image/Trimap Output/GT New BG/Compose
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image

小小的赞助~

Sample

若对您有帮助可给予小小的赞助~




Comments
  • the frozen model named BEST_checkpoint.tar cannot be uncompressed

    the frozen model named BEST_checkpoint.tar cannot be uncompressed

    when I try to uncompress the frozen model it shows

    tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    this means the .tar file is not complete

    opened by banrenmasanxing 6
  • my own datasets are all full human body images

    my own datasets are all full human body images

    Hi,thanks for your excellent work.Now i prepare my own datasets.This datasets are consists of thounds of high resolution image(average 4000*4000).They are all full human body images.When i process these images,i meet a questions: When i crop the trimap(generated from alpha),often crop some places which are not include hair.Such as foot,leg.Is it ok to input these images into [email protected]

    opened by lfxx 5
  • run demo.py question!

    run demo.py question!

    File "demo.py", line 84, in new_bgs = random.sample(new_bgs, 10) File "C:\Users\15432\AppData\Local\conda\conda\envs\python34\lib\random.py", line 324, in sample raise ValueError("Sample larger than population") ValueError: Sample larger than population

    opened by kxcg99 5
  • Invalid BEST_checkpoint.tar ?

    Invalid BEST_checkpoint.tar ?

    Hi, thank you for the code. I tried to download the pretrained model and extract it but it dosnt work.

    tar xvf BEST_checkpoint.tar BEST_checkpoint
    

    results in

    tar: Ceci ne ressemble pas à une archive de type « tar »
    tar: On saute à l'en-tête suivant
    tar: BEST_checkpoint : non trouvé dans l'archive
    tar: Arrêt avec code d'échec à cause des erreurs précédentes
    

    anything i'm doing the wrong way ? or the provided tar is not valid ? kind reards

    opened by flocreate 4
  • How can i get the Trimaps of my pictures?

    How can i get the Trimaps of my pictures?

    Now, I got a model, I want to use it but I can't, because I have not the Trimaps of my pictures. Are there the script of code to build the Trimaps? How can i get the Trimaps of my pictures?

    opened by huangjunxiong11 3
  • can not unpack the 'BEST_checkpoint.tar'

    can not unpack the 'BEST_checkpoint.tar'

    When i download the file "BEST_checkpoint.tar" successfully, i can't unpack it. Actually, when i try to unpack 'BEST_checkpoint.tar', it make an error. Is it my fault , or, Is the file mistaken?

    opened by huangjunxiong11 3
  • Demo error

    Demo error

    /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "demo.py", line 69, in checkpoint = torch.load(checkpoint) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 368, in load return _load(f, map_location, pickle_module) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 542, in _load result = unpickler.load() File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 505, in persistent_load data_type(size), location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 114, in default_restore_location result = fn(storage, location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 95, in _cuda_deserialize device = validate_cuda_device(location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 79, in validate_cuda_device raise RuntimeError('Attempting to deserialize object on a CUDA ' RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

    opened by Mlt123 3
  • Deep-Image-Matting-v2 implemetation on Android

    Deep-Image-Matting-v2 implemetation on Android

    Hi, Thanks for you work! its looking awesome output. I want to integrate your demo into android project. Is it possible to integrate model into android Project? If it possible, then How can i integrate this model into android project? Can you please give some suggestions? Thanks in advance.

    opened by charlizesmith 3
  • unable to start training using pretrained weigths

    unable to start training using pretrained weigths

    whenever pre-trained weights are used for training the model using own dataset, the following error is occurring.

    python3 train.py --batch-size 4 --checkpoint checkpoint/BEST_checkpoint.tar

    /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "train.py", line 180, in main() File "train.py", line 176, in main train_net(args) File "train.py", line 71, in train_net logger=logger) File "train.py", line 112, in train alpha_out = model(img) # [N, 3, 320, 320] File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 143, in forward if t.device != self.src_device_obj: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 539, in getattr type(self).name, name)) AttributeError: 'DataParallel' object has no attribute 'src_device_obj'

    opened by dev-srikanth 3
  • v2 didn't performance well as v1?

    v2 didn't performance well as v1?

    Hi, thanks for your pretrained model! I test both your v1 pretrained model and v2 pretrained model , v2 is much faster than v1 , but I found it didn't performance well as v1. the image: WechatIMG226 the origin tri map: test7_tri the v1 output: WechatIMG225 the v2 output: test7_result

    do you know what's the problem?

    Thanks,

    opened by MarSaKi 3
  • Questions about the PyTorch version and an issue in training regarding to the batch size

    Questions about the PyTorch version and an issue in training regarding to the batch size

    Hi,

    Thank you for sharing your PyTorch version of reimplementation. Would you like to share the PyTorch version you used to development?

    I am using PyTorch 1.0.1, CUDA 9, two RTX 2080 Ti to run the 'train.py' since I see you use Data Parallel module to support multi-GPUs training. However, I encountered and the trackbacks are here:

    Traceback (most recent call last): File "train.py", line 171, in main() File "train.py", line 167, in main train_net(args) File "train.py", line 64, in train_net logger=logger) File "train.py", line 103, in train alpha_out = model(img) # [N, 3, 320, 320] File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply raise output File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker output = module(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 127, in forward up4 = self.up4(up5, indices_4, unpool_shape4) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 87, in forward outputs = self.conv(outputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 43, in forward outputs = self.cbr_unit(inputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward self.padding, self.dilation, self.groups) RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

    I have tested the DATA PARALLELISM using the example here and it works well.

    opened by wuyujack 3
Owner
Yang Liu
Algorithm engineer
Yang Liu
Moment-DETR code and QVHighlights dataset

Moment-DETR QVHighlights: Detecting Moments and Highlights in Videos via Natural Language Queries Jie Lei, Tamara L. Berg, Mohit Bansal For dataset de

Jie Lei 雷杰 133 Dec 22, 2022
Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm

DeCLIP Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm. Our paper is available in arxiv Updates ** Ou

Sense-GVT 470 Dec 30, 2022
Official codes: Self-Supervised Learning by Estimating Twin Class Distribution

TWIST: Self-Supervised Learning by Estimating Twin Class Distributions Codes and pretrained models for TWIST: @article{wang2021self, title={Self-Sup

Bytedance Inc. 85 Dec 15, 2022
Human pose estimation from video plays a critical role in various applications such as quantifying physical exercises, sign language recognition, and full-body gesture control.

Pose Detection Project Description: Human pose estimation from video plays a critical role in various applications such as quantifying physical exerci

Hassan Shahzad 2 Jan 17, 2022
Exploiting a Zoo of Checkpoints for Unseen Tasks

Exploiting a Zoo of Checkpoints for Unseen Tasks This repo includes code to reproduce all results in the above Neurips paper, authored by Jiaji Huang,

Baidu Research 8 Sep 06, 2022
PyTorch implementations of the NeRF model described in "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis"

PyTorch NeRF and pixelNeRF NeRF: Tiny NeRF: pixelNeRF: This repository contains minimal PyTorch implementations of the NeRF model described in "NeRF:

Michael A. Alcorn 178 Dec 20, 2022
Simple reference implementation of GraphSAGE.

Reference PyTorch GraphSAGE Implementation Author: William L. Hamilton Basic reference PyTorch implementation of GraphSAGE. This reference implementat

William L Hamilton 861 Jan 06, 2023
Source code for Fixed-Point GAN for Cloud Detection

FCD: Fixed-Point GAN for Cloud Detection PyTorch source code of Nyborg & Assent (2020). Abstract The detection of clouds in satellite images is an ess

Joachim Nyborg 8 Dec 22, 2022
Robot Hacking Manual (RHM). From robotics to cybersecurity. Papers, notes and writeups from a journey into robot cybersecurity.

RHM: Robot Hacking Manual Download in PDF RHM v0.4 ┃ Read online The Robot Hacking Manual (RHM) is an introductory series about cybersecurity for robo

Víctor Mayoral Vilches 233 Dec 30, 2022
Demonstrational Session git repo for H SAF User Workshop (28/1)

5th H SAF User Workshop The 5th H SAF User Workshop supported by EUMeTrain will be held in online in January 24-28 2022. This repository contains inst

H SAF 4 Aug 04, 2022
This is the source code of the 1st place solution for segmentation task (with Dice 90.32%) in 2021 CCF BDCI challenge.

1st place solution in CCF BDCI 2021 ULSEG challenge This is the source code of the 1st place solution for ultrasound image angioma segmentation task (

Chenxu Peng 30 Nov 22, 2022
Code for A Volumetric Transformer for Accurate 3D Tumor Segmentation

VT-UNet This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmentaion results of VT-UNet. Environmen

Himashi Amanda Peiris 114 Dec 20, 2022
Exe-to-xlsm - Simple script to create VBscript of exe and inject to xlsm

🎁 Exe To Office Executable file injection to Office documents: .xlsm, .docm, .p

3 Jan 25, 2022
Official PyTorch implementation of CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds

CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds Introduction This is the official PyTorch implementation of o

Yijia Weng 96 Dec 07, 2022
Official pytorch implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image"

SinGAN Project | Arxiv | CVF | Supplementary materials | Talk (ICCV`19) Official pytorch implementation of the paper: "SinGAN: Learning a Generative M

Tamar Rott Shaham 3.2k Dec 25, 2022
Data Engineering ZoomCamp

Data Engineering ZoomCamp I'm partaking in a Data Engineering Bootcamp / Zoomcamp and will be tracking my progress here. I can't promise these notes w

Aaron 61 Jan 06, 2023
Generic image compressor for machine learning. Pytorch code for our paper "Lossy compression for lossless prediction".

Lossy Compression for Lossless Prediction Using: Training: This repostiory contains our implementation of the paper: Lossy Compression for Lossless Pr

Yann Dubois 84 Jan 02, 2023
ALL Snow Removed: Single Image Desnowing Algorithm Using Hierarchical Dual-tree Complex Wavelet Representation and Contradict Channel Loss (HDCWNet)

ALL Snow Removed: Single Image Desnowing Algorithm Using Hierarchical Dual-tree Complex Wavelet Representation and Contradict Channel Loss (HDCWNet) (

Wei-Ting Chen 49 Dec 27, 2022
PyTorch implementation for paper "Full-Body Visual Self-Modeling of Robot Morphologies".

Full-Body Visual Self-Modeling of Robot Morphologies Boyuan Chen, Robert Kwiatkowskig, Carl Vondrick, Hod Lipson Columbia University Project Website |

Boyuan Chen 32 Jan 02, 2023
Bounding Wasserstein distance with couplings

BoundWasserstein These scripts reproduce the results of the article Bounding Wasserstein distance with couplings by Niloy Biswas and Lester Mackey. ar

Niloy Biswas 1 Jan 11, 2022