Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Link for BaiduDrive, password:qnm5

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Owner
Qiaole Dong
Qiaole Dong
Split your patch similarly to `git add -p` but supporting multiple buckets

split-patch.py This is git add -p on steroids for patches. Given a my.patch you can run ./split-patch.py my.patch You can choose in which bucket to p

102 Oct 06, 2022
AirPose: Multi-View Fusion Network for Aerial 3D Human Pose and Shape Estimation

AirPose AirPose: Multi-View Fusion Network for Aerial 3D Human Pose and Shape Estimation Check the teaser video This repository contains the code of A

Robot Perception Group 41 Dec 05, 2022
GluonMM is a library of transformer models for computer vision and multi-modality research

GluonMM is a library of transformer models for computer vision and multi-modality research. It contains reference implementations of widely adopted baseline models and also research work from Amazon

42 Dec 02, 2022
MassiveSumm: a very large-scale, very multilingual, news summarisation dataset

MassiveSumm: a very large-scale, very multilingual, news summarisation dataset This repository contains links to data and code to fetch and reproduce

Daniel Varab 19 Dec 16, 2022
A library for low-memory inferencing in PyTorch.

Pylomin Pylomin (PYtorch LOw-Memory INference) is a library for low-memory inferencing in PyTorch. Installation ... Usage For example, the following c

3 Oct 26, 2022
Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

Single Image Super-Resolution with EDSR, WDSR and SRGAN A Tensorflow 2.x based implementation of Enhanced Deep Residual Networks for Single Image Supe

Martin Krasser 1.3k Jan 06, 2023
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Code for "Offline Meta-Reinforcement Learning with Advantage Weighting" [ICML 2021]

Offline Meta-Reinforcement Learning with Advantage Weighting (MACAW) MACAW code used for the experiments in the ICML 2021 paper. Installing the enviro

Eric Mitchell 28 Jan 01, 2023
Example of semantic segmentation in Keras

keras-semantic-segmentation-example Example of semantic segmentation in Keras Single class example: Generated data: random ellipse with random color o

53 Mar 23, 2022
Non-stationary GP package written from scratch in PyTorch

NSGP-Torch Examples gpytorch model with skgpytorch # Import packages import torch from regdata import NonStat2D from gpytorch.kernels import RBFKernel

Zeel B Patel 1 Mar 06, 2022
Trading environnement for RL agents, backtesting and training.

TradzQAI Trading environnement for RL agents, backtesting and training. Live session with coinbasepro-python is finaly arrived ! Available sessions: L

Tony Denion 164 Oct 30, 2022
Bottom-up attention model for image captioning and VQA, based on Faster R-CNN and Visual Genome

bottom-up-attention This code implements a bottom-up attention model, based on multi-gpu training of Faster R-CNN with ResNet-101, using object and at

Peter Anderson 1.3k Jan 09, 2023
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
Extract MNIST handwritten digits dataset binary file into bmp images

MNIST-dataset-extractor Extract MNIST handwritten digits dataset binary file into bmp images More info at http://yann.lecun.com/exdb/mnist/ Dependenci

Omar Mostafa 6 May 24, 2021
The code for "Deep Level Set for Box-supervised Instance Segmentation in Aerial Images".

Deep Levelset for Box-supervised Instance Segmentation in Aerial Images Wentong Li, Yijie Chen, Wenyu Liu, Jianke Zhu* Any questions or discussions ar

sunshine.lwt 112 Jan 05, 2023
Surrogate-Assisted Genetic Algorithm for Wrapper Feature Selection

SAGA Surrogate-Assisted Genetic Algorithm for Wrapper Feature Selection Please refer to the Jupyter notebook (Example.ipynb) for an example of using t

9 Dec 28, 2022
Lucid Sonic Dreams syncs GAN-generated visuals to music.

Lucid Sonic Dreams Lucid Sonic Dreams syncs GAN-generated visuals to music. By default, it uses NVLabs StyleGAN2, with pre-trained models lifted from

731 Jan 02, 2023
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022
AoT is a system for automatically generating off-target test harness by using build information.

AoT: Auto off-Target Automatically generating off-target test harness by using build information. Brought to you by the Mobile Security Team at Samsun

Samsung 10 Oct 19, 2022
Privacy-Preserving Machine Learning (PPML) Tutorial Presented at PyConDE 2022

PPML: Machine Learning on Data you cannot see Repository for the tutorial on Privacy-Preserving Machine Learning (PPML) presented at PyConDE 2022 Abst

Valerio Maggio 10 Aug 16, 2022