Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Overview

Robust Video Matting (RVM)

Teaser

English | 中文

Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves 4K 76FPS and HD 104FPS on an Nvidia GTX 1080 Ti GPU. The project was developed at ByteDance Inc.


News

  • [Aug 25 2021] Source code and pretrained models are published.
  • [Jul 27 2021] Paper is accepted by WACV 2022.

Showreel

Watch the showreel video (YouTube, Bilibili) to see the model's performance.

All footage in the video are available in Google Drive and Baidu Pan (code: tb3w).


Demo

  • Webcam Demo: Run the model live in your browser. Visualize recurrent states.
  • Colab Demo: Test our model on your own videos with free GPU.

Download

We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See inference documentation for more instructions.

Framework Download Notes
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
Official weights for PyTorch. Doc
TorchHub Nothing to Download. Easiest way to use our model in your PyTorch project. Doc
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
If inference on mobile, consider export int8 quantized models yourself. Doc
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter.
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel. Doc
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip
Run the model on the web. Demo, Starter Code
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio. Doc, Exporter

All models are available in Google Drive and Baidu Pan (code: gym7).


PyTorch Example

  1. Install dependencies:
pip install -r requirements_inference.txt
  1. Load the model:
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  1. To convert videos, we provide a simple conversion API:
from inference import convert_video

convert_video(
    model,                           # The model, can be on any device (cpu or cuda).
    input_source='input.mp4',        # A video file or an image sequence directory.
    output_type='video',             # Choose "video" or "png_sequence"
    output_composition='output.mp4', # File path if video; directory path if png sequence.
    output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
    downsample_ratio=None,           # A hyperparameter to adjust or use None for auto.
    seq_chunk=12,                    # Process n frames at once for better parallelism.
)
  1. Or write your own inference code:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.
downsample_ratio = 0.25                                # Adjust based on your video.

with torch.no_grad():
    for src in DataLoader(reader):                     # RGB tensor normalized to 0 ~ 1.
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # Cycle the recurrent states.
        com = fgr * pha + bgr * (1 - pha)              # Composite to green background. 
        writer.write(com)                              # Write frame.
  1. The models and converter API are also available through TorchHub.
# Load the model.
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"

# Converter API.
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

Please see inference documentation for details on downsample_ratio hyperparameter, more converter arguments, and more advanced usage.


Training and Evaluation

Please refer to the training documentation to train and evaluate your own model.


Speed

Speed is measured with inference_speed_test.py for reference.

GPU dType HD (1920x1080) 4K (3840x2160)
RTX 3090 FP16 172 FPS 154 FPS
RTX 2060 Super FP16 134 FPS 108 FPS
GTX 1080 Ti FP32 104 FPS 74 FPS
  • Note 1: HD uses downsample_ratio=0.25, 4K uses downsample_ratio=0.125. All tests use batch size 1 and frame chunk 1.
  • Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32.
  • Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to PyNvCodec.

Project Members

Comments
  • [Questions] - Training Procedure

    [Questions] - Training Procedure

    Hi,

    I have some questions about the training procedure:

    1. In the paper, you've mentioned training Stage 1, for 15 epochs, while in the code you've set the instructions to 20 epochs. Is there a reason for such change? Will the results be similar?
    2. I could not get access to Distinctions-646, I had no reply from the authors/maintainers of the dataset. Based on your indicated file structure, I've built a similar dataset, which adds uncertainty to the quality of my training, but it is a risk I am willing to take. To have a comparison parameter (stages 1-3) do not depend on this dataset, would you mind sharing your partial training weights on pytorch (stage1/epoch19.pth, stage2/epoch21.pth, and stage3/epoch22.pth)?
    3. What is the min resolution you've used for the background images while training?

    For the 3rd time, thank you very much for your contribution to the field. It was a brilliant work. Looking forward to your future work.

    opened by SamHSlva 39
  • hardsigmoid replacement

    hardsigmoid replacement

    I've been trying to export an onnx model replacing the hardsigmoid operator.

    I have modified the site-packages/torch/onnx/symbolic_opset9.py file this way:

    @parse_args("v") def hardswish(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    @parse_args("v") def hardsigmoid(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    But I am not sure at all if this is the way to replace them with primitive ops

    When I export the onnx with this change I still get and error "OnnxImportException: Unknown type HardSigmoid encountered while parsing layer 396" with the inference engine I am trying to use.

    opened by livingbeams 15
  • VideoMatte240K-HD

    VideoMatte240K-HD

    if I'm going to train stage3 and stage4, the VideoMatte-HD data will be used. And is it right to modify the following path?VideoMatte240K_JPEG_SD to VideoMatte240K_JPEG_HD

    'videomatte': { 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', },

    opened by FengMu1995 7
  • Add Unity example to README?

    Add Unity example to README?

    Hey there, I just ported RVM to Unity using NatML, an open-source machine learning runtime. I have questions:

    1. Can I make a PR to add a link into the README to a Unity example project demonstrating using RVM?
    2. I published the model under my account on NatML Hub. Would you be interested in signing up on Hub, so that I can transfer the model to you?

    Here's the model on NatML Hub:

    @natsuite/robust-video-matting

    opened by olokobayusuf 7
  •  Some questions about training

    Some questions about training

    1.How to eliminate or reduce edge flickering problem,can i set --seq-length-lr Is it possible to increase the sequence length improvement,Does it work? 2.Only the composite image has no foreground image,Is it possible to remove foreground training,and foreground loss?or is there a better way? 3.How important is foreground prediction for matting

    Looking forward to your reply

    opened by zhanghongyong123456 5
  • Not Issue 👉 Few questions

    Not Issue 👉 Few questions

    First of all thank you for working on this project! it looks much stronger than the BMV2 !

    1. Will it work on Anaconda and Windows 10 just like BMV2 works? (not more complicated?)

    2. Will it support same hardware, or need a much more powerful CPU / GPU compare to BMV2 ?

    3. Can you please tell when will you release it again, I missed it first so I can't test it because it's still offline. It will be very nice to have it this week if possible of course.

    Thanks ahead for the answers, please keep up the good work! ❤

    opened by AlonDan 5
  • Synchronization issues between inferred mask and original video

    Synchronization issues between inferred mask and original video

    Hello!. Thanks for the code. I have had some timing issues between the inferred output in the mask compared to the original video. I made this comparison by transforming my original video and the output video from masks to frames. I have obtained the same amount of frames in both processes, so the difference can be caused by a bad configuration of mine. My original video is 30fps and 1080x1920. If you have a suggestion I would appreciate it.

    opened by italosalgado14 4
  • Weird results when use Segmentation Pass for inference

    Weird results when use Segmentation Pass for inference

    https://github.com/PeterL1n/RobustVideoMatting/blob/f8a26e27198a93a94bfd06e96b8d5a34d0660f81/inference.py#L127

    I changed this line to use Segmentation Pass. (use the pretrained weights rvm_mobilenetv3.pth)

    pha, *rec = model(src, *rec, segmentation_pass=True)
    fgr = src * pha
    

    But I got weird mask results, something like this, why?

    seg_pass_alpha

    opened by luuil 4
  • 新手问题的关于模型结果

    新手问题的关于模型结果

    大神辛苦,两个问题请教....... 1.除了更改downsample_ratio的参数值来修正抠图的精度,还可以更改那些参数来更改实现效果? 2.此项目对显卡的要求是否更高?显卡的型号会影响最后结果么? 目前,有执行model的项目,但是效果并不是很理想,再次感谢!

    Hello, two questions to consult.......

    1. In addition to changing the parameter value of downsample_ratio to correct the accuracy of matting, which other parameters can be changed to change the implementation effect?

    2. Does this project have higher requirements for graphics cards? Does the type of graphics card affect the final results?

    I have my own project to implement model, but the effect is not very ideal, thank you!(Translation from Youdao Translation)

    opened by yinjia823 4
  • FP16 is slower than FP32

    FP16 is slower than FP32

    I use pre-trained ONNX model parameters for inference tests (in Python not C++), only onnxruntime, cv2 and numpy libraries, nothing extra. Parameters downloaded from https://github.com/PeterL1n/RobustVideoMatting/releases/: rvm_mobilenetv3_fp32.onnx and rvm_mobilenetv3_fp16.onnx

    Inference on 1080x1920 video,downsample_ratio=0.25. As a result, the speed of FP32 is about 170ms (1 frame), and the speed of FP16 is about 240ms. Why is FP16 so slow?

    I have adjusted the input correctly, for src, r1i, r2i, r3i, r4i it is np.array([[[[]]]], dtype=np.float32 or 16) and for downsample_ratio it is always np.array([0.25], dtype= float32)

    I use CPU (Intel i5) for inference, Is it so slow because the CPU does not support FP16 operations?

    opened by ZachL1 3
  • foreground prediction details

    foreground prediction details

    你好,请教一下,关于前景预测,从官方提供的web demo中,我看到模型预测的前景图片中除了前景(人像)外,还存在输入图片的背景细节(非人像像素),但是我自己训练得到的模型(我的模型没有修改官方的任何细节,唯一的不同仅仅是采用我采集而来的背景图片),预测的前景图片只含有人像而不会存在输入图片的背景细节,一开始我怀疑可能是前景loss包含了所有像素(alpha可以是任何值而不仅仅是像论文中所说的大于0)的loss, 但是我查看代码后没有任何问题,和论文一致,请问这是什么原因造成? 谢谢。

    opened by li-wenquan 3
  • Add Replicate demo and

    Add Replicate demo and

    Hey @PeterL1n ! 👋

    Thanks for this wonderful project!

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can run your model! View it here: https://replicate.com/arielreplicate/robust_video_matting Replicate also have an API, so people can easily run your model from their code:

    import replicate
    model = replicate.models.get("[arielreplicate/robust_video_matting]")
    model.predict(...)
    

    If you'd like to modify the Replicate page, let me know and I can transfer ownership to your account.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. 😊

    opened by ArielReplicate 0
  • Support of NCNN version

    Support of NCNN version

    Hello How are you? Thanks for contributing to this project. I tried to use NCNN version of this model on Windows C++. Of course, there are several github repos using NCNN model in C++ but I can NOT run them because there are some issue when extracting the output data from model. First, the model are old so it is impossible to run them. Could u support NCNN version here?

    opened by rose-jinyang 0
  • Problem with exporting alpha-mask on the Replicate/COG version

    Problem with exporting alpha-mask on the Replicate/COG version

    I tried both the replicate page and local COG variants. When predicting with alpha-mask, this error is a constant:

    FileNotFoundError: [Errno 2] No such file or directory: 'alpha-mask.mp4'

    Exporting with green-screen and foreground-mask works however. Maybe this is a issue with mp4 not supporting alpha transparent video, so it fails?

    opened by SomeOrdinaryDude 0
  • A question about src_sm in the model.py

    A question about src_sm in the model.py

    I see "x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])" in mobilenetv3.py and ''' f1, f2, f3, f4 = self.backbone(src_sm) ... hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) ''' in model.py. This means the input src_sm of the decoder has not been normalized. Is that your intention?

    opened by FengMu1995 0
  • Performance using grayscale images

    Performance using grayscale images

    Hey there,

    thanks for this amazing tool! Does anybody know how the performance for grayscale images is? I want to use it in the dark and I have an infrared camera.

    If retraining is required how much GPU hours do you think are necessary?

    :)

    opened by bytosaur 1
  • Slow inference and low GPU use.

    Slow inference and low GPU use.

    The inference.py and its running at ~4.2it/s. It barely loads my RTX2060 (0-13% use) The inference_speed_test script gives me ~33.2it/s on the same model and video settings. Changing the --workers on the convert_video() function did nothing. Am I missing something? How can I run inferences faster using the full hardware potential?

    Thanks.

    opened by sharp-trickster 0
Backdoor Attack through Frequency Domain

Backdoor Attack through Frequency Domain DEPENDENCIES python==3.8.3 numpy==1.19.4 tensorflow==2.4.0 opencv==4.5.1 idx2numpy==1.2.3 pytorch==1.7.0 Data

5 Jun 18, 2022
A flag generation AI created using DeepAIs API

Vex AI or Vexiology AI is an Artifical Intelligence created to generate custom made flag design texts. It uses DeepAIs API. Please be aware that you must include your own DeepAI API key. See instruct

Bernie 10 Apr 06, 2022
Code and description for my BSc Project, September 2021

BSc-Project Disclaimer: This repo consists of only the additional python scripts necessary to run the agent. To run the project on your own personal d

Matin Tavakoli 20 Jul 19, 2022
Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras (ICCV 2021)

N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Gra

32 Dec 26, 2022
Segmentation for medical image.

EfficientSegmentation Introduction EfficientSegmentation is an open source, PyTorch-based segmentation framework for 3D medical image. Features A whol

68 Nov 28, 2022
ExCon: Explanation-driven Supervised Contrastive Learning

ExCon: Explanation-driven Supervised Contrastive Learning Contributors of this repo: Zhibo Zhang ( Zhibo (Darren) Zhang 18 Nov 01, 2022

Pytorch Implementation of Continual Learning With Filter Atom Swapping (ICLR'22 Spolight) Paper

Continual Learning With Filter Atom Swapping Pytorch Implementation of Continual Learning With Filter Atom Swapping (ICLR'22 Spolight) Paper If find t

11 Aug 29, 2022
This repository contains codes of ICCV2021 paper: SO-Pose: Exploiting Self-Occlusion for Direct 6D Pose Estimation

SO-Pose This repository contains codes of ICCV2021 paper: SO-Pose: Exploiting Self-Occlusion for Direct 6D Pose Estimation This paper is basically an

shangbuhuan 52 Nov 25, 2022
Supporting code for "Autoregressive neural-network wavefunctions for ab initio quantum chemistry".

naqs-for-quantum-chemistry This repository contains the codebase developed for the paper Autoregressive neural-network wavefunctions for ab initio qua

Tom Barrett 24 Dec 23, 2022
Pytorch library for seismic data augmentation

Pytorch library for seismic data augmentation

Artemii Novoselov 27 Nov 22, 2022
[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets

[Project] [PDF] This repository contains code for our SIGGRAPH'22 paper "StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets" by Axel Sauer, Katja

742 Jan 04, 2023
A paper using optimal transport to solve the graph matching problem.

GOAT A paper using optimal transport to solve the graph matching problem. https://arxiv.org/abs/2111.05366 Repo structure .github: Files specifying ho

neurodata 8 Jan 04, 2023
A PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework".

Mugs: A Multi-Granular Self-Supervised Learning Framework This is a PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-

Sea AI Lab 62 Nov 08, 2022
PyTorch implementation of D2C: Diffuison-Decoding Models for Few-shot Conditional Generation.

D2C: Diffuison-Decoding Models for Few-shot Conditional Generation Project | Paper PyTorch implementation of D2C: Diffuison-Decoding Models for Few-sh

Jiaming Song 90 Dec 27, 2022
Joint parameterization and fitting of stroke clusters

StrokeStrip: Joint Parameterization and Fitting of Stroke Clusters Dave Pagurek van Mossel1, Chenxi Liu1, Nicholas Vining1,2, Mikhail Bessmeltsev3, Al

Dave Pagurek 44 Dec 01, 2022
StarGAN-ZSVC: Unofficial PyTorch Implementation

This repository is an unofficial PyTorch implementation of StarGAN-ZSVC by Matthew Baas and Herman Kamper. This repository provides both model architectures and the code to inference or train them.

Jirayu Burapacheep 11 Aug 28, 2022
Source code for deep symbolic optimization.

Update July 10, 2021: This repository now supports an additional symbolic optimization task: learning symbolic policies for reinforcement learning. Th

Brenden Petersen 290 Dec 25, 2022
MediaPipe is a an open-source framework from Google for building multimodal

MediaPipe is a an open-source framework from Google for building multimodal (eg. video, audio, any time series data), cross platform (i.e Android, iOS, web, edge devices) applied ML pipelines. It is

Bhavishya Pandit 3 Sep 30, 2022
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 359 Jan 05, 2023
Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, L

3 Dec 02, 2022