Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Overview

Bootstrap Your Own Latent (BYOL), in Pytorch

PyPI version

Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.

This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.

Update 1: There is now new evidence that batch normalization is key to making this technique work well

Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work

Update 3: Finally, we have some analysis for why this works

Yannic Kilcher's excellent explanation

Now go save your organization from having to pay for labels :)

Install

$ pip install byol-pytorch

Usage

Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.

BYOL → SimSiam

A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the use_momentum flag to False. You will no longer need to invoke update_moving_average if you go this route as shown in the example below.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Advanced

While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    projection_size = 256,           # the projection size
    projection_hidden_size = 4096,   # the hidden dimension of the MLP for both the projection and prediction
    moving_average_decay = 0.99      # the moving average decay factor for the target encoder, already set at what paper recommends
)

By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the augment_fn keyword.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn
)

In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

augment_fn2 = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip(),
    kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn,
    augment_fn2 = augment_fn2,
)

To fetch the embeddings or the projections, you simply have to pass in a return_embeddings = True flag to the BYOL learner instance

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)

Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.

https://github.com/lucidrains/pixel-level-contrastive-learning

Citation

@misc{grill2020bootstrap,
    title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
    author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
    year = {2020},
    eprint = {2006.07733},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{chen2020exploring,
    title={Exploring Simple Siamese Representation Learning}, 
    author={Xinlei Chen and Kaiming He},
    year={2020},
    eprint={2011.10566},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
Comments
  • Negative Loss, Transfer Learning/Fine-Tuning Question

    Negative Loss, Transfer Learning/Fine-Tuning Question

    Hi! Thanks for sharing this repo -- really clean and easy to use.

    When training using the PyTorch Lightning script from the repo, my loss is negative (and gets more negative over time) when training. Is this expected? Screenshot 2020-06-22 at 6 23 47 PM


    I'm curious to know if you've fine-tuned a pretrained model using this BYOL as the README example suggested. If yes, how were the results? Any intuition regarding how many epochs to fine-tune for?

    Thanks!

    opened by rsomani95 13
  • AssertionError: hidden layer never emitted an output with multi-gpu training

    AssertionError: hidden layer never emitted an output with multi-gpu training

    I tried your library with a WideResnet40-2 model and used layer_index=-2.

    The lightning example works fine for single-gpu but i got the error with multiple GPUs.

    opened by reactivetype 7
  • How to transfer the trained ckpt to pytorch.pth model?

    How to transfer the trained ckpt to pytorch.pth model?

    I use the example script to train a model, I got a ckpt file. but how could I extra the trained resnet50.pth instead of the whole SelfSupervisedLearner? Sorry I am new for pytorch lightning lib. What I want is the SelfSupervised resnet50.pth, because I want this to replace the original ImageNet-pretrained one. Thank you a lot.

    opened by knaffe 5
  • Training loss decreased and then increased

    Training loss decreased and then increased

    Hi, I used your example on my own data. The training loss decreased and then increased after 100 epochs, which is wired. Did you meet similar situations? Is it hard to train the model? the batchsize is 128/256 lr is 0.1/0.2 weight_decay is 1e-6

    opened by easonyang1996 4
  • Can't load ckpt

    Can't load ckpt

    I use byol-pytorch-master/examples/lightning/train.py to generate ckpt locally after training, but when I load ckpt, there will be the following errors. How should I load it? Thanks a lot! 截屏2020-11-18 上午12 51 48

    opened by AndrewTal 4
  • BYOL uses different augmentations for view1 and view2

    BYOL uses different augmentations for view1 and view2

    opened by OlivierDehaene 4
  • Transferring results on Cifar and other datasets

    Transferring results on Cifar and other datasets

    Thanks for your open sourcing!

    I notice that the BYOL has a large gap on the transferring downstream datasets: e.g., SimCLR reaches 71.6% on Cifar 100, while BYOL can reach to 78.4%.

    I understand that this might depends on the downstream training protocols. And could you provide us a sample code on that, especially for the LBFGS optimized logistic regressor?

    opened by jacobswan1 4
  • The saved network is same as the initial one?

    The saved network is same as the initial one?

    Firstly, thank you so much for this clean implementation!!

    The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?

    The code I tested:

    import torch
    from net.byol import BYOL
    from torchvision import models
     
           
    resnet = models.resnet50(pretrained=True)
    param_1 = resnet.parameters()
    
    learner = BYOL(
        resnet,
        image_size = 256,
        hidden_layer = 'avgpool'
    )
    
    opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
    
    def sample_unlabelled_images():
        return torch.randn(20, 3, 256, 256)
    
    for _ in range(2):
        images = sample_unlabelled_images()
        loss = learner(images)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average() # update moving average of target encoder
    
    # save your improved network
    torch.save(resnet.state_dict(), './checkpoints/improved-net.pt')
    
    # restore the model      
    resnet2 = models.resnet50()
    resnet2.load_state_dict(torch.load('./checkpoints/improved-net.pt'))
    param_2 = resnet2.parameters()
    
    # test whether two models are the same 
    for p1, p2 in zip(param_1, param_2):
        if p1.data.ne(p2.data).sum() > 0:
            print('They are different.')
    print('They are same.')
    
    opened by KimMeen 3
  • the maximum batch size can only be set to 32

    the maximum batch size can only be set to 32

    When I run the code with a 2080ti GPU with 10G memory, the maximum batch size can only be set to 32. Is there any place in the code that takes up a lot of video memory?

    opened by cuixianheng 3
  • Pretrained network

    Pretrained network

    Hi, thanks for sharing the code and making it so easy to use. I see in the example you set resnet = models.resnet50(pretrained=True). Is this what is done in the paper? Shouldn't self-supervised-learned networks be trained from scratch?

    Thanks again, P.

    opened by pmorerio 3
  • Singleton Class Members

    Singleton Class Members

    Forgive me for my unfamiliarity with software design, but I'm wondering why it is necessary to write a singleton wrapper for projector and target_encoder. Is there any disadvantage of initializing them in __init__?

    opened by wentaoyuan 3
  • Increase EMA-parameter during training

    Increase EMA-parameter during training

    Hi, I noticed that the EMA-parameter (called beta in the code, τ in the paper) is not updated during training. In the paper they describe that they increase τ from the start value to 1 during training: "Specifically, we set τ = 1 − (1 − τbase) · (cos(πk/K) + 1)/2 with k the current training step and K the maximum number of training steps." This makes a huge difference to the validation loss at the end of the training.

    without_tau_update with_tau_update

    opened by Benjamin-Hansson 1
  • Why the loss is different from BYOL authors'

    Why the loss is different from BYOL authors'

    I found the loss is different from the loss said in BYOL paper which should be a L2 loss and I did't find explanation... The loss in this repo is a cosine loss, and I just want to know why. BTW, thanks for this great repo!

    opened by Jing-XING 2
  • How to cluster/predict images?

    How to cluster/predict images?

    Hi, I have trained using examples given with pytorch-lightning. I couldn't find code to do clustering of images after training. How can I find which image falls in which cluster? Is there any predictor API? I want to do something like this

    image

    opened by laxmimerit 1
  • BN layer weights and biases are not updated

    BN layer weights and biases are not updated

    Thanks for sharing this repo, great work!

    I trained BYOL on my data and noticed that the weights and biases for BN layers are not updated on the saved model. I used resnet18 without pretrained weights resnet = models.resnet50(pretrained=False). After training for multiple epochs, the saved model has bn1.weight all equal to 1.0 and bn1.bias all equal to 0.0 .

    Is this the expected behavior or am I missing something? Appreciate your response!

    opened by kregmi 1
  •  Warning: grad and param do not obey the gradient layout contract.

    Warning: grad and param do not obey the gradient layout contract.

    Has anybody gotten a similar warning when using it?

    Warning: grad and param do not obey the gradient layout contract. This is not an error, but may impair performance. grad.sizes() = [512, 256, 1, 1], strides() = [256, 1, 1, 1] param.sizes() = [512, 256, 1, 1], strides() = [256, 1, 256, 256] (function operator())

    opened by mohaEs 3
Releases(0.6.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers (NeurIPS 2021)

OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers (NeurIPS 2021) This is an PyTorch implementation of OpenMatc

Vision and Learning Group 38 Dec 26, 2022
Official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right"

Surface Form Competition This is the official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right" We p

Peter West 46 Dec 23, 2022
Official PyTorch implementation of "IntegralAction: Pose-driven Feature Integration for Robust Human Action Recognition in Videos", CVPRW 2021

IntegralAction: Pose-driven Feature Integration for Robust Human Action Recognition in Videos Introduction This repo is official PyTorch implementatio

Gyeongsik Moon 29 Sep 24, 2022
Transfer Learning for Pose Estimation of Illustrated Characters

bizarre-pose-estimator Transfer Learning for Pose Estimation of Illustrated Characters Shuhong Chen *, Matthias Zwicker * WACV2022 [arxiv] [video] [po

Shuhong Chen 142 Dec 28, 2022
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
Manage the availability of workspaces within Frappe/ ERPNext (sidebar) based on user-roles

Workspace Permissions Manage the availability of workspaces within Frappe/ ERPNext (sidebar) based on user-roles. Features Configure foreach workspace

Patrick.St. 18 Sep 26, 2022
Distributed Arcface Training in Pytorch

Distributed Arcface Training in Pytorch

3 Nov 23, 2021
Realtime YOLO Monster Detection With Non Maximum Supression

Realtime-YOLO-Monster-Detection-With-Non-Maximum-Supression Table of Contents In

5 Oct 07, 2022
Enhancing Knowledge Tracing via Adversarial Training

Enhancing Knowledge Tracing via Adversarial Training This repository contains source code for the paper "Enhancing Knowledge Tracing via Adversarial T

Xiaopeng Guo 14 Oct 24, 2022
《Towards High Fidelity Face Relighting with Realistic Shadows》(CVPR 2021)

Towards High Fidelity Face-Relighting with Realistic Shadows Andrew Hou, Ze Zhang, Michel Sarkis, Ning Bi, Yiying Tong, Xiaoming Liu. In CVPR, 2021. T

114 Dec 10, 2022
Some useful blender add-ons for SMPL skeleton's poses and global translation.

Blender add-ons for SMPL skeleton's poses and trans There are two blender add-ons for SMPL skeleton's poses and trans.The first is for making an offli

犹在镜中 154 Jan 04, 2023
Pytorch implementation of AREL

Status: Archive (code is provided as-is, no updates expected) Agent-Temporal Attention for Reward Redistribution in Episodic Multi-Agent Reinforcement

8 Nov 25, 2022
A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 06, 2023
DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data.

DWIPrep: A Robust Preprocessing Pipeline for dMRI Data DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data. The transp

Gal Ben-Zvi 1 Jan 09, 2023
Apply our monocular depth boosting to your own network!

MergeNet - Boost Your Own Depth Boost custom or edited monocular depth maps using MergeNet Input Original result After manual editing of base You can

Computational Photography Lab @ SFU 142 Dec 17, 2022
CSAC - Collaborative Semantic Aggregation and Calibration for Separated Domain Generalization

CSAC Introduction This repository contains the implementation code for paper: Co

ScottYuan 5 Jul 22, 2022
NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

100 Sep 28, 2022
PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending"

Bridging the Visual Gap: Wide-Range Image Blending PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending".

Chia-Ni Lu 69 Dec 20, 2022
Fantasy Points Prediction and Dream Team Formation

Fantasy-Points-Prediction-and-Dream-Team-Formation Collected Data from open source resources that have over 100 Parameters for predicting cricket play

Akarsh Singh 2 Sep 13, 2022
[CVPR 2020] Transform and Tell: Entity-Aware News Image Captioning

Transform and Tell: Entity-Aware News Image Captioning This repository contains the code to reproduce the results in our CVPR 2020 paper Transform and

Alasdair Tran 85 Dec 13, 2022