A concise but complete implementation of CLIP with various experimental improvements from recent papers

Overview

x-clip (wip)

A concise but complete implementation of CLIP with various experimental improvements from recent papers

Install

$ pip install x-clip

Usage

import torch
from x_clip import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
    use_all_token_embeds = True   # whether to use fine-grained contrastive learning (FILIP)
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

Citations

@misc{radford2021learning,
    title   = {Learning Transferable Visual Models From Natural Language Supervision}, 
    author  = {Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
    year    = {2021},
    eprint  = {2103.00020},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yao2021filip,
    title   = {FILIP: Fine-grained Interactive Language-Image Pre-Training}, 
    author  = {Lewei Yao and Runhui Huang and Lu Hou and Guansong Lu and Minzhe Niu and Hang Xu and Xiaodan Liang and Zhenguo Li and Xin Jiang and Chunjing Xu},
    year    = {2021},
    eprint  = {2111.07783},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Model forward outputs to text/image similarity score

    Model forward outputs to text/image similarity score

    Any insight on how to take the image/text embeddings (or nominal model forward output) to achieve a simple similarity score as done in the huggingface implementation? HF example here

    In the original paper I see the dot products of the image/text encoder outputs were used, but here I was having troubles with the dimensions on the outputs.

    opened by paulcjh 12
  • Using different encoders in CLIP

    Using different encoders in CLIP

    Hi, I am wondering if it was possible to use different encoders in CLIP ? For images not using vit but resnet for example. And is it possible to replace the text encoder by a features encoder for example ? If I have a vector of features for a given image and I want to use x-clip how should I do that ? I have made a code example that doesnt seems to work, here is what I did:

    import torch
    from x_clip import CLIP
    import torch.nn as nn
    from torchvision import models
    
    class Image_Encoder(torch.nn.Module):
        #output size is (bs,512)
        def __init__(self):
            super(Image_Encoder, self).__init__()
            self.model_pre = models.resnet18(pretrained=False)
            self.base=nn.Sequential(*list(self.model_pre.children()))
            self.base[0]=nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            self.resnet=self.base[:-1]
    
        def forward(self, x):
            out=self.resnet(x).squeeze()
            return out
    
    
    class features_encoder(torch.nn.Module):
        #output size is (bs,512)
        def __init__(self):
            super(features_encoder, self).__init__()
            self.model =nn.Linear(2048,512)
    
        def forward(self, x):
            out=self.model(x)
            return out
    
    images_encoder=Image_Encoder()
    features_encoder=features_encoder()
    
    clip = CLIP(
        image_encoder = images_encoder,
        text_encoder = features_encoder,
        dim_image = 512,
        dim_text = 512,
        dim_latent = 512
    )
    
    features= torch.randn(4,2048)
    images = torch.randn(4, 3, 256, 256)
    
    loss = clip(features, images, return_loss = True)
    loss.backward()
    

    but I got the following error : forward() takes 2 positional arguments but 3 were given

    Thanks

    opened by ethancohen123 8
  • Visual ssl with channels different than 3

    Visual ssl with channels different than 3

    Hi, seems to be a bug when trying to use visual ssl with a different number of channel than 3 . I think the error came from the visual ssl type ~row 280 here:

    #send a mock image tensor to instantiate parameters self.forward(torch.randn(1, 3, image_size, image_size))

    opened by ethancohen123 4
  • Allow other types of visual  SSL when initiating CLIP

    Allow other types of visual SSL when initiating CLIP

    In the following code as part of CLIP.__init__

            if use_visual_ssl:
                if visual_ssl_type == 'simsiam':
                    ssl_type = SimSiam
                elif visual_ssl_type == 'simclr':
                    ssl_type = partial(SimCLR, temperature = simclr_temperature)
                else:
                    raise ValueError(f'unknown visual_ssl_type')
    
                self.visual_ssl = ssl_type(
                    self.visual_transformer,
                    image_size = visual_image_size,
                    hidden_layer = visual_ssl_hidden_layer
                )
    

    the visual self-supervised learning is hardcoded. I would suggest changing this to accept the visual SSL module as an argument when instantiating CLIP to allow flexibility in the same manner as it does for the image encoder and text encoder.

    Example:

    barlow = BarlowTwins(augmentatation_fns)
    clip = CLIP(..., visual_ssl=barlow)
    
    opened by Froskekongen 4
  • Extract Text and Image Latents

    Extract Text and Image Latents

    Hi, in the current implementation we can only extract text and image embedding (by set return_encodings=True) which are obtained before applying latent linear layers. Isn't it better to add an option to extract latent embeddings? Another importance of this is that with the current code, it is impossible to extract the similarity matrix between a batch of images and a batch of text.

    opened by mmsamiei 2
  • NaN with mock data

    NaN with mock data

    Hi lucidrains,

    Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.

    import torch
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True    
    torch.backends.cudnn.benchmark = True
    
    import random
    import numpy as np
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    num_text_tokens = 10000
    batch_sz = 12
    text_seq_len = 256
    visual_image_size = 256
    
    # mock data
    
    data_sz = 1000
    all_text = torch.randint(0, num_text_tokens, (data_sz, text_seq_len)).cuda()
    all_images = torch.randn(data_sz, 3, visual_image_size, visual_image_size).cuda()
    
    text = torch.zeros((batch_sz, text_seq_len), dtype=torch.long).cuda()
    images = torch.zeros((batch_sz, 3, visual_image_size, visual_image_size)).cuda()
    
    ##########################################################################################
    
    import wandb
    import datetime
    wandb.init(project="Test", name=datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), save_code=False)
    
    from x_clip import CLIP
    
    clip = CLIP(
        dim_text = 512,
        dim_image = 512,
        dim_latent = 512,
        num_text_tokens = num_text_tokens,
        text_enc_depth = 6,
        text_seq_len = text_seq_len,
        text_heads = 8,
        visual_enc_depth = 6,
        visual_image_size = visual_image_size,
        visual_patch_size = 32,
        visual_heads = 8,
        use_all_token_embeds = False,           # whether to use fine-grained contrastive learning (FILIP)
        decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
        extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
        use_visual_ssl = True,                  # whether to do self supervised learning on iages
        visual_ssl_type = 'simclr',             # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
        use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
        text_ssl_loss_weight = 0.05,            # weight for text MLM loss
        image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
    ).cuda()
    
    optimizer = torch.optim.Adam(clip.parameters(), lr=1e-4, betas=(0.9, 0.99))
    
    for step in range(999999):
        for i in range(batch_sz):
            data_id = random.randrange(0, data_sz - 1)
            text[i] = all_text[data_id]
            images[i] = all_images[data_id]
    
        loss = clip(
            text,
            images,
            freeze_image_encoder = False,   # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
            return_loss = True              # needs to be set to True to return contrastive loss
        )
        clip.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(clip.parameters(), 1.0)
        optimizer.step()
    
        now_loss = loss.item()
        wandb.log({"loss": now_loss}, step = step)
        print(step, now_loss)
    
        if 'nan' in str(now_loss):
            break
    
    opened by BlinkDL 1
  • Unable to train to convergence (small dataset)

    Unable to train to convergence (small dataset)

    Hi nice work with x-clip. Hoping to play around with it and eventually combine it into your DALLE2 work.

    Currently having some trouble training on roughly 30k image-text pairs. Loss eventually goes negative and starts producing Nan's. I've dropped learning rate down (1e-4) and I'm clipping gradients (max_norm=0.5).

    Any thoughts on what are sane training params/configs on such a small dataset using x-clip?

    opened by jacobwjs 9
Releases(0.12.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Deep Reinforcement Learning based Trading Agent for Bitcoin

Deep Trading Agent Deep Reinforcement Learning based Trading Agent for Bitcoin using DeepSense Network for Q function approximation. For complete deta

Kartikay Garg 669 Dec 29, 2022
Beyond Image to Depth: Improving Depth Prediction using Echoes (CVPR 2021)

Beyond Image to Depth: Improving Depth Prediction using Echoes (CVPR 2021) Kranti Kumar Parida, Siddharth Srivastava, Gaurav Sharma. We address the pr

Kranti Kumar Parida 33 Jun 27, 2022
Development kit for MIT Scene Parsing Benchmark

Development Kit for MIT Scene Parsing Benchmark [NEW!] Our PyTorch implementation is released in the following repository: https://github.com/hangzhao

MIT CSAIL Computer Vision 424 Dec 01, 2022
RE3: State Entropy Maximization with Random Encoders for Efficient Exploration

State Entropy Maximization with Random Encoders for Efficient Exploration (RE3) (ICML 2021) Code for State Entropy Maximization with Random Encoders f

Younggyo Seo 47 Nov 29, 2022
A fast Evolution Strategy implementation in Python

Evostra: Evolution Strategy for Python Evolution Strategy (ES) is an optimization technique based on ideas of adaptation and evolution. You can learn

Mika 251 Dec 08, 2022
A GPU-optional modular synthesizer in pytorch, 16200x faster than realtime, for audio ML researchers.

torchsynth The fastest synth in the universe. Introduction torchsynth is based upon traditional modular synthesis written in pytorch. It is GPU-option

torchsynth 229 Jan 02, 2023
Predict the latency time of the deep learning models

Deep Neural Network Prediction Step 1. Genernate random parameters and Run them sequentially : $ python3 collect_data.py -gp -ep -pp -pl pooling -num

QAQ 1 Nov 12, 2021
i3DMM: Deep Implicit 3D Morphable Model of Human Heads

i3DMM: Deep Implicit 3D Morphable Model of Human Heads CVPR 2021 (Oral) Arxiv | Poject Page This project is the official implementation our work, i3DM

Tarun Yenamandra 60 Jan 03, 2023
Image based Human Fall Detection

Here I integrated the YOLOv5 object detection algorithm with my own created dataset which consists of human activity images to achieve low cost, high accuracy, and real-time computing requirements

UTTEJ KUMAR 12 Dec 11, 2022
Demonstrates iterative FGSM on Apple's NeuralHash model.

apple-neuralhash-attack Demonstrates iterative FGSM on Apple's NeuralHash model. TL;DR: It is possible to apply noise to CSAM images and make them loo

Lim Swee Kiat 11 Jun 23, 2022
Python script for performing depth completion from sparse depth and rgb images using the msg_chn_wacv20. model in ONNX

ONNX msg_chn_wacv20 depth completion Python script for performing depth completion from sparse depth and rgb images using the msg_chn_wacv20 model in

Ibai Gorordo 19 Oct 22, 2022
Revealing and Protecting Labels in Distributed Training

Revealing and Protecting Labels in Distributed Training

Google Interns 0 Nov 09, 2022
Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)

DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) This repo is the official imp

如今我已剑指天涯 46 Dec 21, 2022
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior. The code will release soon. Implementation Python3 PyTorch=1.0 NVIDIA GPU+

FengZhang 34 Dec 04, 2022
Rotated Box Is Back : Accurate Box Proposal Network for Scene Text Detection

Rotated Box Is Back : Accurate Box Proposal Network for Scene Text Detection This material is supplementray code for paper accepted in ICDAR 2021 We h

NCSOFT 30 Dec 21, 2022
A web application that provides real time temperature and humidity readings of a house.

About A web application which provides real time temperature and humidity readings of a house. If you're interested in the data collected so far click

Ben Thompson 3 Jan 28, 2022
Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Ceph.

Project Aquarium Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Cep

Aquarist Labs 73 Jul 21, 2022
Out-of-boundary View Synthesis towards Full-frame Video Stabilization

Out-of-boundary View Synthesis towards Full-frame Video Stabilization Introduction | Update | Results Demo | Introduction This repository contains the

25 Oct 10, 2022
Interpolation-based reduced-order models

Interpolation-reduced-order-models Interpolation-based reduced-order models High-fidelity computational fluid dynamics (CFD) solutions are time consum

Donovan Blais 1 Jan 10, 2022
BisQue is a web-based platform designed to provide researchers with organizational and quantitative analysis tools for 5D image data. Users can extend BisQue by implementing containerized ML workflows.

Overview BisQue is a web-based platform specifically designed to provide researchers with organizational and quantitative analysis tools for up to 5D

Vision Research Lab @ UCSB 26 Nov 29, 2022