Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Overview

Enformer - Pytorch (wip)

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. The original tensorflow sonnet code can be found here.

Citations

@article {Avsec2021.04.07.438649,
    author  = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
    title   = {Effective gene expression prediction from sequence by integrating long-range interactions},
    elocation-id = {2021.04.07.438649},
    year    = {2021},
    doi     = {10.1101/2021.04.07.438649},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
    eprint  = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
    journal = {bioRxiv}
}
Comments
  • Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Hi @lucidrains,

    Thank you so much for your efforts in releasing the PyTorch version of the Enformer model! I am really excited to use it for my particular implementation.

    I was wondering if it is possible to use the pre-trained huggingface model to just get the human output head. The reason is that inference takes a few minutes, and since I just need human data, this will help make my implementation a bit smoother. Is there a way to do this elegantly with the current codebase, or would I need to rewrite some functions to allow for this? From what I have seen so far it doesn't seem that this modularity is possible yet.

    The way I have set up my inference currently is as follows:

    class EnformerInference:
        def __init__(self, data_path: str, model_path="EleutherAI/enformer-official-rough"):
            if torch.cuda.is_available():
                device = torch.device("cuda")
            else:
                device = torch.device("cpu")
            self.device = device
            self.model = Enformer.from_pretrained(model_path)
            self.data = EnformerDataLoader(pd.read_csv(data_path, sep="\t")) # returns a one hot encoded torch.Tensor representation of the sequence of interest
                                                                                                                              
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.model(x.to(self.device))
    

    Any guidance on this would be greatly appreciated, thank you!

    opened by aaronwtr 4
  • Host weights on HuggingFace hub

    Host weights on HuggingFace hub

    Hi Phil Wang,

    Created a little demo on how you can easily load pre-trained weights from the HuggingFace hub into your Enformer model. I've basically followed this guide which Sylvain (@sgugger) wrote recently. It's a new feature that let's you push model weights to the hub and allows to load them into any custom PyTorch/TF/Flax model.

    From this PR, you can do (after pip install enformer-pytorch):

    from enformer_pytorch import Enformer
    
    model = Enformer.from_pretrained("nielsr/enformer-preview")
    

    If you consent, then I'll transfer all weights to the eleutherai organization on the hub, such that you can do from_pretrained("eleutherai/enformer-preview").

    The weights are hosted here: https://huggingface.co/nielsr/enformer-preview. As you can see in the "files and versions" tab, it contains a pytorch_model.bin file, which has a size of about 1GB. You can also load the other variant, as follows:

    model = Enformer.from_pretrained("nielsr/enformer-corr_coef_obj")
    

    To make it work, the only thing that is required is encapsulating all hyperparameters regarding the model architecture into a separate EnformerConfig object (which I've defined in config_enformer.py). It can be instantiated as follows:

    from enformer_pytorch import EnformerConfig
    
    config = EnformerConfig(
        dim = 1536,
        depth = 11,
        heads = 8,
        output_heads = dict(human = 5313, mouse = 1643),
        target_length = 896,
    )
    

    To initialize an Enformer model with randomly initialized weights, you can do:

    from enformer_pytorch import Enformer
    
    model = Enformer(config)
    

    There's no need for the config.yml and model_loader.py files anymore, as these are now handled by HuggingFace :)

    Let me know what you think about it :)

    Kind regards,

    Niels

    To do:

    • [x] upload remaining checkpoints to the hub
    • [x] transfer checkpoints to the eleutherai organization
    • [x] remove config.yml and model_loading.py scripts
    • [x] update README
    opened by NielsRogge 4
  • Minor potential typo in `FastaInterval` class

    Minor potential typo in `FastaInterval` class

    Hello, first off thanks so much for this incredible repository, it's greatly accelerating a project I am working on!

    I've been using the GenomeIntervalDataset class and notice a minor potential typo in the FastaInterval class when I was trying to fetch a sequence with a negative start position and got an empty tensor back. It looks like there is logic for clipping the start position at 0 and padding the sequence here https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#L137-L143 but that it wasn't being used in my case as it was inside the above if clause that I wasn't triggering https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#LL128C9-L128C82. If I unindent that logic then everything worked fine for me.

    If it was unintentional to have the clipping inside that if clause I'd be happy to submit a trivial PR to fix the indentation.

    Thanks again for all your work

    opened by sofroniewn 2
  • example data files

    example data files

    Hi, in the README, you mentioned to use sequences.bed and hg38.ml.fa files to build the GenomeIntervalDataset, but I can't find these example data files, could you provide the links of these files ? Thanks!

    opened by yingyuan830 2
  • Why do we need Residual here while we have residual connection inside conv block

    Why do we need Residual here while we have residual connection inside conv block

    we wrap conv block inside Residual: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L318

    while we have residual connection already inside conv block here: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L226

    opened by inspirit 2
  • Add base_model_prefix

    Add base_model_prefix

    This PR fixes the from_pretrained method by adding base_model_prefix, as this makes sure weights are properly loaded from the hub.

    Kudos to @sgugger for finding the bug.

    opened by NielsRogge 2
  • How to load the pre-trained Enfromer model?

    How to load the pre-trained Enfromer model?

    Hi, I encountered a problem when trying to load the pre-trained enformer model.

    from enformer_pytorch import Enformer model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError Traceback (most recent call last) Input In [3], in 1 from enformer_pytorch import Enformer ----> 2 model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError: type object 'Enformer' has no attribute 'from_pretrained'

    opened by yzJiang9 2
  • enformer TF pretrained weights

    enformer TF pretrained weights

    Hello!

    Thanks for this wonderful resource. I was wondering whether you can point me to how to obtain the model weights for the original TF version of Enformer, or the actual weights if they are stored somewhere easily accessible. I see the model on TF hub but am not sure exactly how to extract the weights - I seem to be running into some issues potentially because the original code is sonnet based and the model is always loaded as a custom user object..

    Much appreciated!

    opened by naumanjaved 1
  • AttentionPool bug?

    AttentionPool bug?

    Looking at the attention pool class did you mean to have

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = self.pool_size)
    

    instead of

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
    

    Here's the full class

    class AttentionPool(nn.Module):
        def __init__(self, dim, pool_size = 2):
            super().__init__()
            self.pool_size = pool_size
            self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
            self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
    
        def forward(self, x):
            b, _, n = x.shape
            remainder = n % self.pool_size
            needs_padding = remainder > 0
    
            if needs_padding:
                x = F.pad(x, (0, remainder), value = 0)
                mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
                mask = F.pad(mask, (0, remainder), value = True)
    
            x = self.pool_fn(x)
            logits = self.to_attn_logits(x)
    
            if needs_padding:
                mask_value = -torch.finfo(logits.dtype).max
                logits = logits.masked_fill(self.pool_fn(mask), mask_value)
    
            attn = logits.softmax(dim = -1)
    
            return (x * attn).sum(dim = -1)
    
    opened by cmlakhan 1
  • Colab notebook for computing the correlation across different basenji2 dataset splits.

    Colab notebook for computing the correlation across different basenji2 dataset splits.

    New features:

    1. Colab notebook for computing correlations across the different basenji2 dataset splits.
    2. Pytorch metric for computing the mean of per-channel correlations properly aggregated across a region set.
    opened by jstjohn 0
  • Computing Contribution Scores

    Computing Contribution Scores

    From the paper:

    To better understand what sequence elements Enformer is utilizing when making predictions, we computed two different gene expression contribution scores — input gradients (gradient × input and attention weights

    I was just wondering how to compute input gradients and fetch the attention matrix for the given input. I'm not well versed with PyTorch, so I'm sorry if this is a noob question.

    opened by Prakash2403 0
  • Models in training splits

    Models in training splits

    Hey,

    Is there a way of getting the models trained in each training set, as mentioned in the "Model training and evaluation" paragraph of the Enformer paper?

    Thanks!

    opened by luciabarb 0
  • metric for enformer

    metric for enformer

    Hello, can I ask how you find of the human pearson R is 0.625 for validation, and 0.65 for test? Couldn't find any information in the paper. Is there any other place that records this?

    opened by Rachel66666 0
  • error loading enformer package

    error loading enformer package

    I am trying to install the enformer package but seem to be getting the following error:

    >>> import torch
    >>> from enformer_pytorch import Enformer
    Traceback (most recent call last):
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 905, in _get_module
        return importlib.import_module("." + module_name, self.__name__)
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/importlib/__init__.py", line 127, in import_module
        return _bootstrap._gcd_import(name[level:], package, level)
      File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/modeling_utils.py", line 76, in <module>
        from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
    ImportError: cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/__init__.py", line 2, in <module>
        from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/modeling_enformer.py", line 14, in <module>
        from transformers import PreTrainedModel
      File "<frozen importlib._bootstrap>", line 1055, in _handle_fromlist
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 895, in __getattr__
        module = self._get_module(self._class_to_module[name])
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 907, in _get_module
        raise RuntimeError(
    RuntimeError: Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
    cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    

    I simply cloned an existing pytorch environment on Conda (using cuda 11.1 and torch 1.10) and then pip installed the hugging face packages and enformer packages

    pip install transformers
    pip install datasets
    pip install accelerate
    pip install tokenizers
    pip install enformer-pytorch
    

    Any idea why I'm getting this error?

    opened by cmlakhan 1
Releases(0.5.6)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Data from "HateCheck: Functional Tests for Hate Speech Detection Models" (Röttger et al., ACL 2021)

In this repo, you can find the data from our ACL 2021 paper "HateCheck: Functional Tests for Hate Speech Detection Models". "test_suite_cases.csv" con

Paul Röttger 43 Nov 11, 2022
Implementation of "DeepOrder: Deep Learning for Test Case Prioritization in Continuous Integration Testing".

DeepOrder Implementation of DeepOrder for the paper "DeepOrder: Deep Learning for Test Case Prioritization in Continuous Integration Testing". Project

6 Nov 07, 2022
Official Pytorch implementation of "Beyond Static Features for Temporally Consistent 3D Human Pose and Shape from a Video", CVPR 2021

TCMR: Beyond Static Features for Temporally Consistent 3D Human Pose and Shape from a Video Qualtitative result Paper teaser video Introduction This r

Hongsuk Choi 215 Jan 06, 2023
Code for Universal Semi-Supervised Semantic Segmentation models paper accepted in ICCV 2019

USSS_ICCV19 Code for Universal Semi Supervised Semantic Segmentation accepted to ICCV 2019. Full Paper available at https://arxiv.org/abs/1811.10323.

Tarun K 68 Nov 24, 2022
Prompts - Read a textfile of prompts and import into anki via ankiconnect

prompts read a textfile of prompts and import into anki via ankiconnect Usage In

Alexander Cobleigh 2 Jul 28, 2022
Sibur challange 2021 competition - 6 place

sibur challange 2021 Решение на 6 место: https://sibur.ai-community.com/competitions/5/tasks/13 Скор 1.4066/1.4159 public/private. Архитектура - однос

Ivan 5 Jan 11, 2022
This is my codes that can visualize the psnr image in testing videos.

CVPR2018-Baseline-PSNRplot This is my codes that can visualize the psnr image in testing videos. Future Frame Prediction for Anomaly Detection – A New

Wenhao Yang 12 May 29, 2021
Weakly Supervised Posture Mining with Reverse Cross-entropy for Fine-grained Classification

Fine-grainedImageClassification Weakly Supervised Posture Mining with Reverse Cross-entropy for Fine-grained Classification We trained model here: lin

ZhenchaoTang 14 Oct 21, 2022
Notes, programming assignments and quizzes from all courses within the Coursera Deep Learning specialization offered by deeplearning.ai

Coursera-deep-learning-specialization - Notes, programming assignments and quizzes from all courses within the Coursera Deep Learning specialization offered by deeplearning.ai: (i) Neural Networks an

Aman Chadha 1.7k Jan 08, 2023
Official PyTorch implementation of MX-Font (Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Experts)

Introduction Pytorch implementation of Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Expert. | paper Song Park1

Clova AI Research 97 Dec 23, 2022
Unified file system operation experience for different backend

megfile - Megvii FILE library Docs: http://megvii-research.github.io/megfile megfile provides a silky operation experience with different backends (cu

MEGVII Research 76 Dec 14, 2022
Pytorch implementation of SimSiam Architecture

SimSiam-pytorch A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR)

Saeed Shurrab 1 Oct 20, 2021
Pose estimation with MoveNet Lightning

Pose Estimation With MoveNet Lightning MoveNet is the TensorFlow pre-trained model that identifies 17 different key points of the human body. It is th

Yash Vora 2 Jan 04, 2022
Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation

Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation (CVPR2019) This is a pytorch implementatio

Yawei Luo 280 Jan 01, 2023
Code reproduce for paper "Vehicle Re-identification with Viewpoint-aware Metric Learning"

VANET Code reproduce for paper "Vehicle Re-identification with Viewpoint-aware Metric Learning" Introduction This is the implementation of article VAN

EMDATA-AILAB 23 Dec 26, 2022
Robotic Process Automation in Windows and Linux by using Driagrams.net BPMN diagrams.

BPMN_RPA Robotic Process Automation in Windows and Linux by using BPMN diagrams. With this Framework you can draw Business Process Model Notation base

23 Dec 14, 2022
A library that allows for inference on probabilistic models

Bean Machine Overview Bean Machine is a probabilistic programming language for inference over statistical models written in the Python language using

Meta Research 234 Dec 29, 2022
Hitters Linear Regression - Hitters Linear Regression With Python

Hitters_Linear_Regression Kullanacağımız veri seti Carnegie Mellon Üniversitesi'

AyseBuyukcelik 2 Jan 26, 2022
State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Fredrik Carlsson 88 Dec 30, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022