PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Overview

PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

DOI

Blog post with full documentation: Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Image of SimCLR Arch

See also PyTorch Implementation for BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning.

Installation

$ conda env create --name simclr --file env.yml
$ conda activate simclr
$ python run.py

Config file

Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the run.py file.

$ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100 

If you want to run it on CPU (for debugging purposes) use the --disable-cuda option.

For 16-bit precision GPU training, there NO need to to install NVIDIA apex. Just use the --fp16_precision flag and this implementation will use Pytorch built in AMP training.

Feature Evaluation

Feature evaluation is done using a linear model protocol.

First, we learned features using SimCLR on the STL10 unsupervised set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the STL10 train set and evaluated on the STL10 test set.

Check the Open In Colab notebook for reproducibility.

Note that SimCLR benefits from longer training.

Linear Classification Dataset Feature Extractor Architecture Feature dimensionality Projection Head dimensionality Epochs Top1 %
Logistic Regression (Adam) STL10 SimCLR ResNet-18 512 128 100 74.45
Logistic Regression (Adam) CIFAR10 SimCLR ResNet-18 512 128 100 69.82
Logistic Regression (Adam) STL10 SimCLR ResNet-50 2048 128 50 70.075
Comments
  • A question about the

    A question about the "labels"

    Hi! I have a question about the definition of "labels" in the script "simclr.py".

    On line 54 of "simclr.py", the authors defined:

    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

    So all the entries of "labels" are all zeros. But I think according to the paper, there should be an entry as 1 for the positive pair?

    Thanks in advance for your reply!

    opened by kekehia123 6
  • size of tensors in cosine_simiarity function

    size of tensors in cosine_simiarity function

    Hi , I'm trying to understand the code in : loss/nt_xent.py

    we are sending "representations" on both arguments

        def forward(self, zis, zjs):
            representations = torch.cat([zjs, zis], dim=0)
            similarity_matrix = self.similarity_function(representations, representations)
    

    But when receiving it in cosine_similarity func somehow the sizes are: (N, 1, C) and y shape: (1, 2N, C), how can it be double if you sent the same argument

        def _cosine_simililarity(self, x, y):
            # x shape: (N, 1, C)
            # y shape: (1, 2N, C)
            # v shape: (N, 2N)
            v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
            return v
    

    Thanks for your help.

    opened by BattashB 5
  • How do i train the SimCLR model with my local dataset?

    How do i train the SimCLR model with my local dataset?

    Dear researcher, Thank you for the open-source code you provided, it is of great help to me for understanding contrastive learning. But I still have some confusion when training the SimCLR model with my local dataset, could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

    opened by bestalllen 4
  • Question about CE Loss

    Question about CE Loss

    Hello,

    Thanks for sharing the code, nice implementation.

    The way you calculate the loss by using a mask is quite brilliant. But I have a question.

    logits = torch.cat((positives, negatives), dim=1) So if I'm not wrong, the first column of logits is positive and the rest are negatives.

    labels = torch.zeros(2 * self.batch_size).to(self.device).long() But your labels are all zeros, which means no matter positive or negative, the similarity should low.

    So I wonder is the first column of labels supposed to be 1 instead of 0.

    Thanks for your help.

    opened by WShijun1991 4
  • Issue with batch-size

    Issue with batch-size

    In function info_nce_loss, the line 28, creates labels based on batch_size and on other side we have STL10 dataset which has 100,000 images which is divisible by batch_size of 32 and having batch_size like 128 or 64 gives a remainder of 32.

    Having batch_size != 32, causes error in line 42, because the similarity matrix will based on features and labels will be based on batch size.

    For instance, if the batch size = 128, the remaining images in the dataset in the last iter of data_loader is 32. Since we create two variant of each image we'll have 64 images. Now we have 128 x 2 = 256 labels from line 28, and we'll have similarity matrix of (64 x 128, 128 x 64) => (64 x 64) but with mask (256 x 256) causing "dimension mismatch"

    Solution: Change Line 28 as below

    labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)

    image

    opened by Mayurji 3
  • 'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    Considering the setting in 'scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)',I think 'scheduler.step()' should be called every step in 'for (xis,xjs),_ in train_loader'. Otherwise,lr will nerver change until 'len(train_loader)' epochs but not steps

    opened by GuohongLi 3
  • Is it something wrong with the training model for CIFAR-10 experiments?

    Is it something wrong with the training model for CIFAR-10 experiments?

    Hi,

    I find that the ResNet20 model for CIFAR-10 experiments is not fully correct. The head conv structure should be modified (stride=1 and no pooling,) because the image size of CIFAR-10 is very small.

    opened by timqqt 2
  • GPU utilization rate is low

    GPU utilization rate is low

    Hi, thanks for the code!

    When I tried to run it on single GPU (v-100), the utilizaiton rate is very low (~0-10%) even if I increase num_worker. Would you know why this happens and how to solve it? Thanks!

    opened by LiJunnan1992 2
  • Why cos_sim after L2 norm?

    Why cos_sim after L2 norm?

    Hi, This code is really useful for me. Thanks! But I got a question about the NT-Xent loss. I noticed that you use L2 norm on z and then use cos_similarity after that. But cos_similarity already contain the function of l2 norm. Why use L2 norm first?

    opened by BoPang1996 2
  • NT_Xent Loss function: all negatives are not being used?

    NT_Xent Loss function: all negatives are not being used?

    Hi @sthalles , Thank you for sharing your code!

    Pl correct me if I am wrong: I see that in line loss/nt_xent.py line 57 (below) you are not computing contrastive loss for all negative pairs as you are reshaping total negatives in 2D array i.e. only a part of negative pairs are being used for a single positive pair, right? :

    _negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)_
    _logits = torch.cat((positives, negatives), dim=1)_
    

    Hope to hear from you soon.

    -Ishan

    opened by DAVEISHAN 2
  • Validation Loss calculation

    Validation Loss calculation

    First of all, thank you for your great work!

    Method _validate in simclr.py will raise ZeroDivisionError at line 148 if the validation data loader performs only one iteration (since counter starts from 0).

    opened by alessiamarcolini 2
  • evaluation code batch_size & validation process

    evaluation code batch_size & validation process

    I'm really appreciated about your good work :) I left a question because I got confused while studying through your great code.

    First, I wonder why you used "batch_size=batch_size*2" differently from train_loader in the test_loader part of the file "mini_batch_logistic_regression_valuator.ipynb". Is it related to creating 2 views when doing data augmentation?

    Also, in the last cell of this file, I'm confused whether the second "for" (of the two "for") in the large epoch "for" statement corresponds to the test process or the validation process. I thought it was a test process, because loss update, backpropagation, optimization, etc. were done only in the first "for", and the second yield only accuracy, but is that right? Or I'm confused if the second "for" is a validating process because the first "for" and the second "for" are going together in the entire epoch processing.

    opened by YejinS 0
  • Review Training | Fine-Tune | Test details

    Review Training | Fine-Tune | Test details

    Hi, I just want to check all the experiments details and make sure I didn't miss any part(?

    1. Training Phase : use SimCLR (two encoder branches) to train on ImageNet for 1000 epochs to get a init pretrained weights.
    2. Fine-Tuned : load the init pretrained weights on the resnet18(50/101/...) with freezed parameters and concate with a linear classifier, and train the classifier with CIFAR10/STL10 training dataset for 100 epochs.
    3. Test Phase : freeze all the encoder, classifier parameters, and test on the CIFAR10/STL10 testing dataset.

    Is this the way how you get the top1 acc in the README?

    opened by Howeng98 0
  • Confusion matrix

    Confusion matrix

    Does anyone know how to add the confusion matrix in this code? After I added it according to the online one, something went wrong. I don't know what went wrong in my code.I can't solve it. please help help me! Thanks. def confusion_matrix(output, labels, conf_matrix):

    preds = torch.argmax(output, dim=-1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix
    
    opened by here101 0
  • batch size affect

    batch size affect

    Hi, I'm trying to experiment with CIFAR-10 with the default hyper-params, and it seems to yield a better score when using smaller batch size (e.g. 72% with batch size 256 yet 78% with batch size 128). Anyone in the same situation, here?

    opened by VietHoang1512 1
  •  ModuleNotFoundError: No module named 'torch.cuda'

    ModuleNotFoundError: No module named 'torch.cuda'

    I am using pythion 3.7 on Win10, Anaconda Jupyter. I have successfully installed torch-1.10.0+cu113 torchaudio-0.10.0+cu113 torchvision-0.11.1+cu113. When trying to import torch , I get ModuleNotFoundError: No module named 'torch.cuda' Detailed error:

    ModuleNotFoundError                       Traceback (most recent call last)
    <ipython-input-1-bfd2c657fa76> in <module>
          1 import numpy as np
          2 import pandas as pd
    ----> 3 import torch
          4 import torch.nn as nn
          5 from sklearn.model_selection import train_test_split
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\__init__.py in <module>
        603 
        604 # Shared memory manager needs to know the exact location of manager executable
    --> 605 _C._initExtension(manager_path())
        606 del manager_path
        607 
    
    ModuleNotFoundError: No module named 'torch.cuda'
    

    I found posts for similar error No module named 'torch.cuda.amp'. However, any of the suggested solutions worked. Please advise.

    opened by m-bor 0
Releases(v1.0.1)
SCI-AIDE : High-fidelity Few-shot Histopathology Image Synthesis for Rare Cancer Diagnosis

SCI-AIDE : High-fidelity Few-shot Histopathology Image Synthesis for Rare Cancer Diagnosis Pretrained Models In this work, we created synthetic tissue

Emirhan Kurtuluş 1 Feb 07, 2022
All-in-one Docker container that allows a user to explore Nautobot in a lab environment.

Nautobot Lab This container is not for production use! Nautobot Lab is an all-in-one Docker container that allows a user to quickly get an instance of

Nautobot 29 Sep 16, 2022
Fuse radar and camera for detection

SAF-FCOS: Spatial Attention Fusion for Obstacle Detection using MmWave Radar and Vision Sensor This project hosts the code for implementing the SAF-FC

ChangShuo 18 Jan 01, 2023
Curriculum Domain Adaptation for Semantic Segmentation of Urban Scenes, ICCV 2017

AdaptationSeg This is the Python reference implementation of AdaptionSeg proposed in "Curriculum Domain Adaptation for Semantic Segmentation of Urban

Yang Zhang 128 Oct 19, 2022
PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation The paper: https://arxiv.org/abs/1704.03296 What makes

Jacob Gildenblat 322 Dec 17, 2022
Using this you can control your PC/Laptop volume by Hand Gestures (pinch-in, pinch-out) created with Python.

Hand Gesture Volume Controller Using this you can control your PC/Laptop volume by Hand Gestures (pinch-in, pinch-out). Code Firstly I have created a

Tejas Prajapati 16 Sep 11, 2021
The official implementation of VAENAR-TTS, a VAE based non-autoregressive TTS model.

VAENAR-TTS This repo contains code accompanying the paper "VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis". Sa

THUHCSI 138 Oct 28, 2022
基于Paddle框架的fcanet复现

fcanet-Paddle 基于Paddle框架的fcanet复现 fcanet 本项目基于paddlepaddle框架复现fcanet,并参加百度第三届论文复现赛,将在2021年5月15日比赛完后提供AIStudio链接~敬请期待 参考项目: frazerlin-fcanet 数据准备 本项目已挂

QuanHao Guo 7 Mar 07, 2022
Python script that allows you to automatically setup your Growtopia server.

AutoSetup Python script that allows you to automatically setup your Growtopia server. How To Use Firstly, install all the required modules that used i

Aspire 3 Mar 06, 2022
Colossal-AI: A Unified Deep Learning System for Large-Scale Parallel Training

ColossalAI An integrated large-scale model training system with efficient parallelization techniques. arXiv: Colossal-AI: A Unified Deep Learning Syst

HPC-AI Tech 7.9k Jan 08, 2023
Implementing a simplified copy of Shazam application from scratch using MinHashing and LSH.

Building Shazam from scratch In this repository we tried to implement a simplified copy of the Shazam application able to tell you the name of a song

Arturo Ghinassi 0 Nov 17, 2022
Official pytorch implementation for Learning to Listen: Modeling Non-Deterministic Dyadic Facial Motion (CVPR 2022)

Learning to Listen: Modeling Non-Deterministic Dyadic Facial Motion This repository contains a pytorch implementation of "Learning to Listen: Modeling

50 Dec 17, 2022
Implementation of the "Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos" paper.

Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos Introduction Point cloud videos exhibit irregularities and lack of or

Hehe Fan 101 Dec 29, 2022
DVG-Face: Dual Variational Generation for Heterogeneous Face Recognition, TPAMI 2021

DVG-Face: Dual Variational Generation for HFR This repo is a PyTorch implementation of DVG-Face: Dual Variational Generation for Heterogeneous Face Re

52 Dec 30, 2022
Code repository for paper `Skeleton Merger: an Unsupervised Aligned Keypoint Detector`.

Skeleton Merger Skeleton Merger, an Unsupervised Aligned Keypoint Detector. The paper is available at https://arxiv.org/abs/2103.10814. A map of the r

北海若 48 Nov 14, 2022
Ppq - A powerful offline neural network quantization tool with custimized IR

PPL Quantization Tool(PPL 量化工具) PPL Quantization Tool (PPQ) is a powerful offlin

605 Jan 03, 2023
GNNAdvisor: An Efficient Runtime System for GNN Acceleration on GPUs

GNNAdvisor: An Efficient Runtime System for GNN Acceleration on GPUs [Paper, Slides, Video Talk] at USENIX OSDI'21 @inproceedings{GNNAdvisor, title=

YUKE WANG 47 Jan 03, 2023
A testcase generation tool for Persistent Memory Programs.

PMFuzz PMFuzz is a testcase generation tool to generate high-value tests cases for PM testing tools (XFDetector, PMDebugger, PMTest and Pmemcheck) If

Systems Research at ShiftLab 14 Jul 24, 2022
Technical Indicators implemented in Python only using Numpy-Pandas as Magic - Very Very Fast! Very tiny! Stock Market Financial Technical Analysis Python library . Quant Trading automation or cryptocoin exchange

MyTT Technical Indicators implemented in Python only using Numpy-Pandas as Magic - Very Very Fast! to Stock Market Financial Technical Analysis Python

dev 34 Dec 27, 2022
TakeInfoatNistforICS - Take Information in NIST NVD for ICS

Take Information in NIST NVD for ICS This project developed with Python. When yo

5 Sep 05, 2022