Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Overview

Enformer - Pytorch (wip)

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. The original tensorflow sonnet code can be found here.

Citations

@article {Avsec2021.04.07.438649,
    author  = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
    title   = {Effective gene expression prediction from sequence by integrating long-range interactions},
    elocation-id = {2021.04.07.438649},
    year    = {2021},
    doi     = {10.1101/2021.04.07.438649},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
    eprint  = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
    journal = {bioRxiv}
}
Comments
  • Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Hi @lucidrains,

    Thank you so much for your efforts in releasing the PyTorch version of the Enformer model! I am really excited to use it for my particular implementation.

    I was wondering if it is possible to use the pre-trained huggingface model to just get the human output head. The reason is that inference takes a few minutes, and since I just need human data, this will help make my implementation a bit smoother. Is there a way to do this elegantly with the current codebase, or would I need to rewrite some functions to allow for this? From what I have seen so far it doesn't seem that this modularity is possible yet.

    The way I have set up my inference currently is as follows:

    class EnformerInference:
        def __init__(self, data_path: str, model_path="EleutherAI/enformer-official-rough"):
            if torch.cuda.is_available():
                device = torch.device("cuda")
            else:
                device = torch.device("cpu")
            self.device = device
            self.model = Enformer.from_pretrained(model_path)
            self.data = EnformerDataLoader(pd.read_csv(data_path, sep="\t")) # returns a one hot encoded torch.Tensor representation of the sequence of interest
                                                                                                                              
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.model(x.to(self.device))
    

    Any guidance on this would be greatly appreciated, thank you!

    opened by aaronwtr 4
  • Host weights on HuggingFace hub

    Host weights on HuggingFace hub

    Hi Phil Wang,

    Created a little demo on how you can easily load pre-trained weights from the HuggingFace hub into your Enformer model. I've basically followed this guide which Sylvain (@sgugger) wrote recently. It's a new feature that let's you push model weights to the hub and allows to load them into any custom PyTorch/TF/Flax model.

    From this PR, you can do (after pip install enformer-pytorch):

    from enformer_pytorch import Enformer
    
    model = Enformer.from_pretrained("nielsr/enformer-preview")
    

    If you consent, then I'll transfer all weights to the eleutherai organization on the hub, such that you can do from_pretrained("eleutherai/enformer-preview").

    The weights are hosted here: https://huggingface.co/nielsr/enformer-preview. As you can see in the "files and versions" tab, it contains a pytorch_model.bin file, which has a size of about 1GB. You can also load the other variant, as follows:

    model = Enformer.from_pretrained("nielsr/enformer-corr_coef_obj")
    

    To make it work, the only thing that is required is encapsulating all hyperparameters regarding the model architecture into a separate EnformerConfig object (which I've defined in config_enformer.py). It can be instantiated as follows:

    from enformer_pytorch import EnformerConfig
    
    config = EnformerConfig(
        dim = 1536,
        depth = 11,
        heads = 8,
        output_heads = dict(human = 5313, mouse = 1643),
        target_length = 896,
    )
    

    To initialize an Enformer model with randomly initialized weights, you can do:

    from enformer_pytorch import Enformer
    
    model = Enformer(config)
    

    There's no need for the config.yml and model_loader.py files anymore, as these are now handled by HuggingFace :)

    Let me know what you think about it :)

    Kind regards,

    Niels

    To do:

    • [x] upload remaining checkpoints to the hub
    • [x] transfer checkpoints to the eleutherai organization
    • [x] remove config.yml and model_loading.py scripts
    • [x] update README
    opened by NielsRogge 4
  • Minor potential typo in `FastaInterval` class

    Minor potential typo in `FastaInterval` class

    Hello, first off thanks so much for this incredible repository, it's greatly accelerating a project I am working on!

    I've been using the GenomeIntervalDataset class and notice a minor potential typo in the FastaInterval class when I was trying to fetch a sequence with a negative start position and got an empty tensor back. It looks like there is logic for clipping the start position at 0 and padding the sequence here https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#L137-L143 but that it wasn't being used in my case as it was inside the above if clause that I wasn't triggering https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#LL128C9-L128C82. If I unindent that logic then everything worked fine for me.

    If it was unintentional to have the clipping inside that if clause I'd be happy to submit a trivial PR to fix the indentation.

    Thanks again for all your work

    opened by sofroniewn 2
  • example data files

    example data files

    Hi, in the README, you mentioned to use sequences.bed and hg38.ml.fa files to build the GenomeIntervalDataset, but I can't find these example data files, could you provide the links of these files ? Thanks!

    opened by yingyuan830 2
  • Why do we need Residual here while we have residual connection inside conv block

    Why do we need Residual here while we have residual connection inside conv block

    we wrap conv block inside Residual: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L318

    while we have residual connection already inside conv block here: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L226

    opened by inspirit 2
  • Add base_model_prefix

    Add base_model_prefix

    This PR fixes the from_pretrained method by adding base_model_prefix, as this makes sure weights are properly loaded from the hub.

    Kudos to @sgugger for finding the bug.

    opened by NielsRogge 2
  • How to load the pre-trained Enfromer model?

    How to load the pre-trained Enfromer model?

    Hi, I encountered a problem when trying to load the pre-trained enformer model.

    from enformer_pytorch import Enformer model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError Traceback (most recent call last) Input In [3], in 1 from enformer_pytorch import Enformer ----> 2 model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError: type object 'Enformer' has no attribute 'from_pretrained'

    opened by yzJiang9 2
  • enformer TF pretrained weights

    enformer TF pretrained weights

    Hello!

    Thanks for this wonderful resource. I was wondering whether you can point me to how to obtain the model weights for the original TF version of Enformer, or the actual weights if they are stored somewhere easily accessible. I see the model on TF hub but am not sure exactly how to extract the weights - I seem to be running into some issues potentially because the original code is sonnet based and the model is always loaded as a custom user object..

    Much appreciated!

    opened by naumanjaved 1
  • AttentionPool bug?

    AttentionPool bug?

    Looking at the attention pool class did you mean to have

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = self.pool_size)
    

    instead of

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
    

    Here's the full class

    class AttentionPool(nn.Module):
        def __init__(self, dim, pool_size = 2):
            super().__init__()
            self.pool_size = pool_size
            self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
            self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
    
        def forward(self, x):
            b, _, n = x.shape
            remainder = n % self.pool_size
            needs_padding = remainder > 0
    
            if needs_padding:
                x = F.pad(x, (0, remainder), value = 0)
                mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
                mask = F.pad(mask, (0, remainder), value = True)
    
            x = self.pool_fn(x)
            logits = self.to_attn_logits(x)
    
            if needs_padding:
                mask_value = -torch.finfo(logits.dtype).max
                logits = logits.masked_fill(self.pool_fn(mask), mask_value)
    
            attn = logits.softmax(dim = -1)
    
            return (x * attn).sum(dim = -1)
    
    opened by cmlakhan 1
  • Colab notebook for computing the correlation across different basenji2 dataset splits.

    Colab notebook for computing the correlation across different basenji2 dataset splits.

    New features:

    1. Colab notebook for computing correlations across the different basenji2 dataset splits.
    2. Pytorch metric for computing the mean of per-channel correlations properly aggregated across a region set.
    opened by jstjohn 0
  • Computing Contribution Scores

    Computing Contribution Scores

    From the paper:

    To better understand what sequence elements Enformer is utilizing when making predictions, we computed two different gene expression contribution scores — input gradients (gradient × input and attention weights

    I was just wondering how to compute input gradients and fetch the attention matrix for the given input. I'm not well versed with PyTorch, so I'm sorry if this is a noob question.

    opened by Prakash2403 0
  • Models in training splits

    Models in training splits

    Hey,

    Is there a way of getting the models trained in each training set, as mentioned in the "Model training and evaluation" paragraph of the Enformer paper?

    Thanks!

    opened by luciabarb 0
  • metric for enformer

    metric for enformer

    Hello, can I ask how you find of the human pearson R is 0.625 for validation, and 0.65 for test? Couldn't find any information in the paper. Is there any other place that records this?

    opened by Rachel66666 0
  • error loading enformer package

    error loading enformer package

    I am trying to install the enformer package but seem to be getting the following error:

    >>> import torch
    >>> from enformer_pytorch import Enformer
    Traceback (most recent call last):
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 905, in _get_module
        return importlib.import_module("." + module_name, self.__name__)
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/importlib/__init__.py", line 127, in import_module
        return _bootstrap._gcd_import(name[level:], package, level)
      File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/modeling_utils.py", line 76, in <module>
        from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
    ImportError: cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/__init__.py", line 2, in <module>
        from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/modeling_enformer.py", line 14, in <module>
        from transformers import PreTrainedModel
      File "<frozen importlib._bootstrap>", line 1055, in _handle_fromlist
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 895, in __getattr__
        module = self._get_module(self._class_to_module[name])
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 907, in _get_module
        raise RuntimeError(
    RuntimeError: Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
    cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    

    I simply cloned an existing pytorch environment on Conda (using cuda 11.1 and torch 1.10) and then pip installed the hugging face packages and enformer packages

    pip install transformers
    pip install datasets
    pip install accelerate
    pip install tokenizers
    pip install enformer-pytorch
    

    Any idea why I'm getting this error?

    opened by cmlakhan 1
Releases(0.5.6)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Synthesize photos from PhotoDNA using machine learning 🌱

Ribosome Synthesize photos from PhotoDNA. See the blog post for more information. Installation Dependencies You can install Python dependencies using

Anish Athalye 112 Nov 23, 2022
A PaddlePaddle implementation of STGCN with a few modifications in the model architecture in order to forecast traffic jam.

About This repository contains the code of a PaddlePaddle implementation of STGCN based on the paper Spatio-Temporal Graph Convolutional Networks: A D

Tianjian Li 1 Jan 11, 2022
Patches desktop steam to look like the new steamdeck ui.

steam_deck_ui_patch The Deck UI patch will patch the regular desktop steam to look like the brand new SteamDeck UI. This patch tool currently works on

The_IT_Dude 3 Aug 29, 2022
SpiroMask: Measuring Lung Function Using Consumer-Grade Masks

SpiroMask: Measuring Lung Function Using Consumer-Grade Masks Anonymised repository for paper submitted for peer review at ACM HEALTH (October 2021).

0 May 10, 2022
Simple tutorials on Pytorch DDP training

pytorch-distributed-training Distribute Dataparallel (DDP) Training on Pytorch Features Easy to study DDP training You can directly copy this code for

Ren Tianhe 188 Jan 06, 2023
NALSM: Neuron-Astrocyte Liquid State Machine

NALSM: Neuron-Astrocyte Liquid State Machine This package is a Tensorflow implementation of the Neuron-Astrocyte Liquid State Machine (NALSM) that int

Computational Brain Lab 4 Nov 28, 2022
Official source code of paper 'IterMVS: Iterative Probability Estimation for Efficient Multi-View Stereo'

IterMVS official source code of paper 'IterMVS: Iterative Probability Estimation for Efficient Multi-View Stereo' Introduction IterMVS is a novel lear

Fangjinhua Wang 127 Jan 04, 2023
A computational optimization project towards the goal of gerrymandering the results of a hypothetical election in the UK.

A computational optimization project towards the goal of gerrymandering the results of a hypothetical election in the UK.

Emma 1 Jan 18, 2022
Efficient Two-Step Networks for Temporal Action Segmentation (Neurocomputing 2021)

Efficient Two-Step Networks for Temporal Action Segmentation This repository provides a PyTorch implementation of the paper Efficient Two-Step Network

8 Apr 16, 2022
RE3: State Entropy Maximization with Random Encoders for Efficient Exploration

State Entropy Maximization with Random Encoders for Efficient Exploration (RE3) (ICML 2021) Code for State Entropy Maximization with Random Encoders f

Younggyo Seo 47 Nov 29, 2022
A Novel Plug-in Module for Fine-grained Visual Classification

Pytorch implementation for A Novel Plug-in Module for Fine-Grained Visual Classification. fine-grained visual classification task.

ChouPoYung 109 Dec 20, 2022
EgoNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale

EgonNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale Paper: EgoNN: Egocentric Neural Network for Point Cloud

19 Sep 20, 2022
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

71 Oct 25, 2022
Discovering Explanatory Sentences in Legal Case Decisions Using Pre-trained Language Models.

Statutory Interpretation Data Set This repository contains the data set created for the following research papers: Savelka, Jaromir, and Kevin D. Ashl

17 Dec 23, 2022
Code repository for Semantic Terrain Classification for Off-Road Autonomous Driving

BEVNet Datasets Datasets should be put inside data/. For example, data/semantic_kitti_4class_100x100. Training BEVNet-S Example: cd experiments bash t

(Brian) JoonHo Lee 24 Dec 12, 2022
A robotic arm that mimics hand movement through MediaPipe tracking.

La-Z-Arm A robotic arm that mimics hand movement through MediaPipe tracking. Hardware NVidia Jetson Nano Sparkfun Pi Servo Shield Micro Servos Webcam

Alfred 1 Jun 05, 2022
Pytorch implementation of SimSiam Architecture

SimSiam-pytorch A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR)

Saeed Shurrab 1 Oct 20, 2021
A Free and Open Source Python Library for Multiobjective Optimization

Platypus What is Platypus? Platypus is a framework for evolutionary computing in Python with a focus on multiobjective evolutionary algorithms (MOEAs)

Project Platypus 424 Dec 18, 2022
SPRING is a seq2seq model for Text-to-AMR and AMR-to-Text (AAAI2021).

SPRING This is the repo for SPRING (Symmetric ParsIng aNd Generation), a novel approach to semantic parsing and generation, presented at AAAI 2021. Wi

Sapienza NLP group 98 Dec 21, 2022
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022