Autolfads-tf2 - A TensorFlow 2.0 implementation of Latent Factor Analysis via Dynamical Systems (LFADS) and AutoLFADS

Overview

autolfads-tf2

A TensorFlow 2.0 implementation of LFADS and AutoLFADS.

Installation

Clone the autolfads-tf2 repo and create and activate a conda environment with Python 3.7. Use conda to install cudatoolkit and cudnn and pip install the lfads_tf2 and tune_tf2 packages with the -e (editable) flag. This will allow you to import these packages anywhere when your environment is activated, while also allowing you to edit the code directly in the repo.

git clone [email protected]:snel-repo/autolfads-tf2.git
cd autolfads-tf2
conda create --name autolfads-tf2 python=3.7
conda activate autolfads-tf2
conda install -c conda-forge cudatoolkit=10.0
conda install -c conda-forge cudnn=7.6
pip install -e lfads-tf2
pip install -e tune-tf2

Usage

Training single models with lfads_tf2

The first step to training an LFADS model is setting the hyperparameter (HP) values. All HPs, their descriptions, and their default values are given in the defaults.py module. Note that these default values are unlikely to work well on your dataset. To overwrite any or all default values, the user must define new values in a YAML file (example in configs/lorenz.yaml).

The lfads_tf2.models.LFADS constructor takes as input the path to the configuration file that overwrites default HP values. The path to the modeled dataset is also specified in the config, so LFADS will load the dataset automatically.

The train function will execute the training loop until the validation loss converges or some other stopping criteria is reached. During training, the model will save various outputs in the folder specified by MODEL_DIR. Console outputs will be saved to train.log, metrics will be saved to train_data.csv, and checkpoints will be saved in lfads_ckpts.

After training, the sample_and_average function can be used to compute firing rate estimates and other intermediate model outputs and save them to posterior_samples.h5 in the MODEL_DIR.

We provide a simple example in example_scripts/train_lfads.py.

Training AutoLFADS models with tune_tf2

The autolfads-tf2 framework uses ray.tune to distribute models over a computing cluster, monitor model performance, and exploit high-performing models and their HPs.

Setting up a ray cluster

If you'll be running AutoLFADS on a single machine, you can skip this section. If you'll be running across multiple machines, you must initialize the cluster using these instructions before you can submit jobs via the Python API.

Fill in the fields indicated by <>'s in the ray_cluster_template.yaml, and save this file somewhere accessible. Ensure that a range of ports is open for communication on all machines that you intend to use (e.g. 10000-10099 in the template). In your autolfads-tf2 environment, start the cluster using ray up <NEW_CLUSTER_CONFIG>. The cluster may take up to a minute to get started. You can test that all machines are in the cluster by ensuring that all IP addresses are printed when running example_scripts/ray_test.py.

Starting an AutoLFADS run

To run AutoLFADS, copy the run_pbt.py script and adjust paths and hyperparameters to your needs. Make sure to only use only as many workers as can fit on the machine(s) at once. If you want to run across multiple machines, make sure to set SINGLE_MACHINE = False in run_pbt.py. To start your PBT run, simply run run_pbt.py. When the run is complete, the best model will be copied to a best_model folder in your PBT run folder. The model will automatically be sampled and averaged and all outputs will be saved to posterior_samples.h5.

References

Keshtkaran MR, Sedler AR, Chowdhury RH, Tandon R, Basrai D, Nguyen SL, Sohn H, Jazayeri M, Miller LE, Pandarinath C. A large-scale neural network training framework for generalized estimation of single-trial population dynamics. bioRxiv. 2021 Jan 1.

Keshtkaran MR, Pandarinath C. Enabling hyperparameter optimization in sequential autoencoders for spiking neural data. Advances in Neural Information Processing Systems. 2019; 32.

Comments
  • Update lfads-tf2 dependencies for Google Colab compatibility

    Update lfads-tf2 dependencies for Google Colab compatibility

    Summary of changes to setup.py

    • Change pandas==1.0.0 to pandas==1.* to avoid a dependency conflict with google-colab
    • Add PyYAML>=5.1 so that yaml.full_loadworks in lfads-tf2.
    opened by yahiaali 0
  • Are more recent versions of tensorflow/CUDA supported by the package?

    Are more recent versions of tensorflow/CUDA supported by the package?

    Right now the package supports TF 2.0 and CUDA 10.0 which are more than 3 years old. Is there support planned/already established for more recent Tensorflow and CUDA versions?

    Thanks!

    opened by stes 0
  • Error: No 'git' repo detected for 'lfads_tf2'

    Error: No 'git' repo detected for 'lfads_tf2'

    Hello, I am having this issue. I have followed all the installation instructions, and I was wondering why this issue would come up. autolfads-tf2 is cloned using git, and it is inside the git folder. But it seems like train_lfads.py is not loading data. I am using Window 10.

    error

    Thank you so much in advance!

    opened by jinoh5 0
  • Add warnings and assertion to chop functions for bad overlap

    Add warnings and assertion to chop functions for bad overlap

    Add warnings and assertion to chop functions when requested overlap is greater than half of window length

    Addresses https://github.com/snel-repo/autolfads-tf2/issues/2

    opened by raeedcho 0
  •  `merge_chops` is unable to merge when the requested overlap is more than half of the window length

    `merge_chops` is unable to merge when the requested overlap is more than half of the window length

    Without really thinking a whole lot about it, I chopped data to window length 100 and overlap 80, since this would leave at most 20 points of unmodeled data at the end of the trials I'm trying to model. The chopping seems to work totally fine, but when merging the chops together, it seems that the code assumes that the overlap will be at most half the size of the window, and the math to put the chops back together breaks down in weird ways, leading to duplicated data in the final array.

    On further thought, it makes sense to some degree to limit the overlap to be at most half of the window length, since otherwise, data from more than two chops would have to be integrated together to merge everything--if this is the thought process, I think it would be a good idea to put an assertion in both functions that this is the case (or maybe at least an assertion in the merge_chops function and a warning in the chop_data function, since chopping technically works fine).

    If instead it would make sense to be able to merge chops with overlap greater than half the window size, then I think the merge_chops function needs to be reworked to be able to integrate across more than two chops

    opened by raeedcho 0
Releases(v0.1)
Owner
Systems Neural Engineering Lab
Emory University and Georgia Institute of Technology
Systems Neural Engineering Lab
A large-scale benchmark for co-optimizing the design and control of soft robots, as seen in NeurIPS 2021.

Evolution Gym A large-scale benchmark for co-optimizing the design and control of soft robots. As seen in Evolution Gym: A Large-Scale Benchmark for E

121 Dec 14, 2022
Efficiently computes derivatives of numpy code.

Note: Autograd is still being maintained but is no longer actively developed. The main developers (Dougal Maclaurin, David Duvenaud, Matt Johnson, and

Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton 6.1k Jan 08, 2023
Codebase of deep learning models for inferring stability of mRNA molecules

Kaggle OpenVaccine Models Codebase of deep learning models for inferring stability of mRNA molecules, corresponding to the Kaggle Open Vaccine Challen

Eternagame 40 Dec 29, 2022
duralava is a neural network which can simulate a lava lamp in an infinite loop.

duralava duralava is a neural network which can simulate a lava lamp in an infinite loop. Example This is not a real lava lamp but a "fake" one genera

Maximilian Bachl 87 Dec 20, 2022
Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning

H-Transformer-1D Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs. For now,

Phil Wang 123 Nov 17, 2022
Educational 2D SLAM implementation based on ICP and Pose Graph

slam-playground Educational 2D SLAM implementation based on ICP and Pose Graph How to use: Use keyboard arrow keys to navigate robot. Press 'r' to vie

Kirill 19 Dec 17, 2022
Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation

DistMIS Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation. DistriMIS Distributing Deep Learning Hyperparameter Tuning

HiEST 2 Sep 09, 2022
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

63 Nov 18, 2022
This is the official pytorch implementation of the BoxEL for the description logic EL++

BoxEL: Box EL++ Embedding This is the official pytorch implementation of the BoxEL for the description logic EL++. BoxEL++ is a geometric approach bas

1 Nov 03, 2022
A no-BS, dead-simple training visualizer for tf-keras

A no-BS, dead-simple training visualizer for tf-keras TrainingDashboard Plot inter-epoch and intra-epoch loss and metrics within a jupyter notebook wi

Vibhu Agrawal 3 May 28, 2021
Good Classification Measures and How to Find Them

Good Classification Measures and How to Find Them This repository contains supplementary materials for the paper "Good Classification Measures and How

Yandex Research 7 Nov 13, 2022
[CVPR 2022] PoseTriplet: Co-evolving 3D Human Pose Estimation, Imitation, and Hallucination under Self-supervision (Oral)

PoseTriplet: Co-evolving 3D Human Pose Estimation, Imitation, and Hallucination under Self-supervision Kehong Gong*, Bingbing Li*, Jianfeng Zhang*, Ta

256 Dec 28, 2022
Sign Language is detected in realtime using video sequences. Our approach involves MediaPipe Holistic for keypoints extraction and LSTM Model for prediction.

RealTime Sign Language Detection using Action Recognition Approach Real-Time Sign Language is commonly predicted using models whose architecture consi

Rishikesh S 15 Aug 20, 2022
Python periodic table module

elemenpy Hello! elements.py is a small Python periodic table module that is used for calling certain information about an element. Installation Instal

Eric Cheng 2 Dec 27, 2021
CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces

CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces This is a repository for the following pape

17 Oct 13, 2022
A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation

##A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation. #USAGE To run the trained classifier on some images: python w

Alex Seewald 13 Nov 17, 2022
Age and Gender prediction using Keras

cnn_age_gender Age and Gender prediction using Keras Dataset example : Description : UTKFace dataset is a large-scale face dataset with long age span

XN3UR0N 58 May 03, 2022
190 Jan 03, 2023
A solution to ensure Crowd Management with Contactless and Safe systems.

CovidTrack A Solution to ensure Crowd Management with Contactless and Safe systems. ML Model Mask Detection Social Distancing Detection Analytics Page

Om Khare 1 Nov 10, 2021
Domain Generalization with MixStyle, ICLR'21.

MixStyle This repo contains the code of our ICLR'21 paper, "Domain Generalization with MixStyle". The OpenReview link is https://openreview.net/forum?

Kaiyang 208 Dec 28, 2022