Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Link for BaiduDrive, password:qnm5

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Owner
Qiaole Dong
Qiaole Dong
Mmdetection3d Noted - MMDetection3D is an open source object detection toolbox based on PyTorch

MMDetection3D is an open source object detection toolbox based on PyTorch

Jiangjingwen 13 Jan 06, 2023
Code for CVPR2019 Towards Natural and Accurate Future Motion Prediction of Humans and Animals

Motion prediction with Hierarchical Motion Recurrent Network Introduction This work concerns motion prediction of articulate objects such as human, fi

Shuang Wu 85 Dec 11, 2022
Learning to Reconstruct 3D Manhattan Wireframes from a Single Image

Learning to Reconstruct 3D Manhattan Wireframes From a Single Image This repository contains the PyTorch implementation of the paper: Yichao Zhou, Hao

Yichao Zhou 50 Dec 27, 2022
Deeprl - Standard DQN and dueling network for simple games

DeepRL This code implements the standard deep Q-learning and dueling network with experience replay (memory buffer) for playing simple games. DQN algo

Yao Zhou 6 Apr 12, 2020
[CVPR 2021] Unsupervised Degradation Representation Learning for Blind Super-Resolution

DASR Pytorch implementation of "Unsupervised Degradation Representation Learning for Blind Super-Resolution", CVPR 2021 [arXiv] Overview Requirements

Longguang Wang 318 Dec 24, 2022
Official implementation of the ICML2021 paper "Elastic Graph Neural Networks"

ElasticGNN This repository includes the official implementation of ElasticGNN in the paper "Elastic Graph Neural Networks" [ICML 2021]. Xiaorui Liu, W

liuxiaorui 34 Dec 04, 2022
Do you like Quick, Draw? Well what if you could train/predict doodles drawn inside Streamlit? Also draws lines, circles and boxes over background images for annotation.

Streamlit - Drawable Canvas Streamlit component which provides a sketching canvas using Fabric.js. Features Draw freely, lines, circles, boxes and pol

Fanilo Andrianasolo 325 Dec 28, 2022
Python package to add text to images, textures and different backgrounds

nider Python package for text images generation and watermarking Free software: MIT license Documentation: https://nider.readthedocs.io. nider is an a

Vladyslav Ovchynnykov 131 Dec 30, 2022
DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting

DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting Created by Yongming Rao*, Wenliang Zhao*, Guangyi Chen, Yansong Tang, Zheng Z

Yongming Rao 322 Dec 31, 2022
Extending JAX with custom C++ and CUDA code

Extending JAX with custom C++ and CUDA code This repository is meant as a tutorial demonstrating the infrastructure required to provide custom ops in

Dan Foreman-Mackey 237 Dec 23, 2022
Deep learning PyTorch library for time series forecasting, classification, and anomaly detection

Deep learning for time series forecasting Flow forecast is an open-source deep learning for time series forecasting framework. It provides all the lat

AIStream 1.2k Jan 04, 2023
This is a computer vision based implementation of the popular childhood game 'Hand Cricket/Odd or Even' in python

Hand Cricket Table of Content Overview Installation Game rules Project Details Future scope Overview This is a computer vision based implementation of

Abhinav R Nayak 6 Jan 12, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021

Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021 The code for training mCOLT/mRASP2, a multilingua

104 Jan 01, 2023
A generalist algorithm for cell and nucleus segmentation.

Cellpose | A generalist algorithm for cell and nucleus segmentation. Cellpose was written by Carsen Stringer and Marius Pachitariu. To learn about Cel

MouseLand 733 Dec 29, 2022
Multi-Task Learning as a Bargaining Game

Nash-MTL Official implementation of "Multi-Task Learning as a Bargaining Game". Setup environment conda create -n nashmtl python=3.9.7 conda activate

Aviv Navon 87 Dec 26, 2022
《Deep Single Portrait Image Relighting》(ICCV 2019)

Ratio Image Based Rendering for Deep Single-Image Portrait Relighting [Project Page] This is part of the Deep Portrait Relighting project. If you find

62 Dec 21, 2022
An University Project of Quera Web Crawling.

WebCrawlerProject An University Project of Quera Web Crawling. خزشگر اینستاگرام در این پروژه شما باید با استفاده از کتابخانه های زیر یک خزشگر اینستاگر

Mahdi 3 Aug 12, 2022
SFD implement with pytorch

S³FD: Single Shot Scale-invariant Face Detector A PyTorch Implementation of Single Shot Scale-invariant Face Detector Description Meanwhile train hand

Jun Li 251 Dec 22, 2022
Like ThreeJS but for Python and based on wgpu

pygfx A render engine, inspired by ThreeJS, but for Python and targeting Vulkan/Metal/DX12 (via wgpu). Introduction This is a Python render engine bui

139 Jan 07, 2023