Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Link for BaiduDrive, password:qnm5

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Owner
Qiaole Dong
Qiaole Dong
Source code for the GPT-2 story generation models in the EMNLP 2020 paper "STORIUM: A Dataset and Evaluation Platform for Human-in-the-Loop Story Generation"

Storium GPT-2 Models This is the official repository for the GPT-2 models described in the EMNLP 2020 paper [STORIUM: A Dataset and Evaluation Platfor

Nader Akoury 27 Dec 20, 2022
public repo for ESTER dataset and modeling (EMNLP'21)

Project / Paper Introduction This is the project repo for our EMNLP'21 paper: https://arxiv.org/abs/2104.08350 Here, we provide brief descriptions of

PlusLab 19 Oct 27, 2022
Data Augmentation with Variational Autoencoders

Documentation Pyraug This library provides a way to perform Data Augmentation using Variational Autoencoders in a reliable way even in challenging con

112 Nov 30, 2022
Code for the RA-L (ICRA) 2021 paper "SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition"

SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition [ArXiv+Supplementary] [IEEE Xplore RA-L 2021] [ICRA 2021 YouTube Video]

Sourav Garg 63 Dec 12, 2022
Food Drinks and groceries Images Multi Lingual (FooDI-ML) dataset.

Food Drinks and groceries Images Multi Lingual (FooDI-ML) dataset.

41 Jan 04, 2023
Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The original code is written in keras.

CasRel-pytorch-reimplement Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The o

longlongman 170 Dec 01, 2022
Python utility to generate filesystem content for Obsidian.

Security Vault Generator Quickly parse, format, and output common frameworks/content for Obsidian.md. There is a strong focus on MITRE ATT&CK because

Justin Angel 73 Dec 02, 2022
This project aims at providing a concise, easy-to-use, modifiable reference implementation for semantic segmentation models using PyTorch.

Semantic Segmentation on PyTorch (include FCN, PSPNet, Deeplabv3, Deeplabv3+, DANet, DenseASPP, BiSeNet, EncNet, DUNet, ICNet, ENet, OCNet, CCNet, PSANet, CGNet, ESPNet, LEDNet, DFANet)

2.4k Jan 08, 2023
A implemetation of the LRCN in mxnet

A implemetation of the LRCN in mxnet ##Abstract LRCN is a combination of CNN and RNN ##Installation Download UCF101 dataset ./avi2jpg.sh to split the

44 Aug 25, 2022
Cookiecutter PyTorch Lightning

Cookiecutter PyTorch Lightning Instructions # install cookiecutter pip install cookiecutter

Mazen 8 Nov 06, 2022
Implementation of paper "Decision-based Black-box Attack Against Vision Transformers via Patch-wise Adversarial Removal"

Patch-wise Adversarial Removal Implementation of paper "Decision-based Black-box Attack Against Vision Transformers via Patch-wise Adversarial Removal

4 Oct 12, 2022
[ICLR2021oral] Rethinking Architecture Selection in Differentiable NAS

DARTS-PT Code accompanying the paper ICLR'2021: Rethinking Architecture Selection in Differentiable NAS Ruochen Wang, Minhao Cheng, Xiangning Chen, Xi

Ruochen Wang 86 Dec 27, 2022
Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System

Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System Authors: Yixuan Su, Lei Shu, Elman Mansimov, Arshit Gupta, Deng Cai, Yi-An Lai

Amazon Web Services - Labs 123 Dec 23, 2022
Where2Act: From Pixels to Actions for Articulated 3D Objects

Where2Act: From Pixels to Actions for Articulated 3D Objects The Proposed Where2Act Task. Given as input an articulated 3D object, we learn to propose

Kaichun Mo 69 Nov 28, 2022
PERIN is Permutation-Invariant Semantic Parser developed for MRP 2020

PERIN: Permutation-invariant Semantic Parsing David Samuel & Milan Straka Charles University Faculty of Mathematics and Physics Institute of Formal an

ÚFAL 40 Jan 04, 2023
[ECCV2020] Content-Consistent Matching for Domain Adaptive Semantic Segmentation

[ECCV20] Content-Consistent Matching for Domain Adaptive Semantic Segmentation This is a PyTorch implementation of CCM. News: GTA-4K list is available

Guangrui Li 88 Aug 25, 2022
Free like Freedom

This is all very much a work in progress! More to come! ( We're working on it though! Stay tuned!) Installation Open an Anaconda Prompt (in Windows, o

2.3k Jan 04, 2023
Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP

Wav2CLIP 🚧 WIP 🚧 Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP πŸ“„ πŸ”— Ho-Hsiang Wu, Prem Seetharaman

Descript 240 Dec 13, 2022
Online Multi-Granularity Distillation for GAN Compression (ICCV2021)

Online Multi-Granularity Distillation for GAN Compression (ICCV2021) This repository contains the pytorch codes and trained models described in the IC

Bytedance Inc. 299 Dec 16, 2022
This repository provides data for the VAW dataset as described in the CVPR 2021 paper titled "Learning to Predict Visual Attributes in the Wild"

Visual Attributes in the Wild (VAW) This repository provides data for the VAW dataset as described in the CVPR 2021 Paper: Learning to Predict Visual

Adobe Research 36 Dec 30, 2022