Autolfads-tf2 - A TensorFlow 2.0 implementation of Latent Factor Analysis via Dynamical Systems (LFADS) and AutoLFADS

Overview

autolfads-tf2

A TensorFlow 2.0 implementation of LFADS and AutoLFADS.

Installation

Clone the autolfads-tf2 repo and create and activate a conda environment with Python 3.7. Use conda to install cudatoolkit and cudnn and pip install the lfads_tf2 and tune_tf2 packages with the -e (editable) flag. This will allow you to import these packages anywhere when your environment is activated, while also allowing you to edit the code directly in the repo.

git clone [email protected]:snel-repo/autolfads-tf2.git
cd autolfads-tf2
conda create --name autolfads-tf2 python=3.7
conda activate autolfads-tf2
conda install -c conda-forge cudatoolkit=10.0
conda install -c conda-forge cudnn=7.6
pip install -e lfads-tf2
pip install -e tune-tf2

Usage

Training single models with lfads_tf2

The first step to training an LFADS model is setting the hyperparameter (HP) values. All HPs, their descriptions, and their default values are given in the defaults.py module. Note that these default values are unlikely to work well on your dataset. To overwrite any or all default values, the user must define new values in a YAML file (example in configs/lorenz.yaml).

The lfads_tf2.models.LFADS constructor takes as input the path to the configuration file that overwrites default HP values. The path to the modeled dataset is also specified in the config, so LFADS will load the dataset automatically.

The train function will execute the training loop until the validation loss converges or some other stopping criteria is reached. During training, the model will save various outputs in the folder specified by MODEL_DIR. Console outputs will be saved to train.log, metrics will be saved to train_data.csv, and checkpoints will be saved in lfads_ckpts.

After training, the sample_and_average function can be used to compute firing rate estimates and other intermediate model outputs and save them to posterior_samples.h5 in the MODEL_DIR.

We provide a simple example in example_scripts/train_lfads.py.

Training AutoLFADS models with tune_tf2

The autolfads-tf2 framework uses ray.tune to distribute models over a computing cluster, monitor model performance, and exploit high-performing models and their HPs.

Setting up a ray cluster

If you'll be running AutoLFADS on a single machine, you can skip this section. If you'll be running across multiple machines, you must initialize the cluster using these instructions before you can submit jobs via the Python API.

Fill in the fields indicated by <>'s in the ray_cluster_template.yaml, and save this file somewhere accessible. Ensure that a range of ports is open for communication on all machines that you intend to use (e.g. 10000-10099 in the template). In your autolfads-tf2 environment, start the cluster using ray up <NEW_CLUSTER_CONFIG>. The cluster may take up to a minute to get started. You can test that all machines are in the cluster by ensuring that all IP addresses are printed when running example_scripts/ray_test.py.

Starting an AutoLFADS run

To run AutoLFADS, copy the run_pbt.py script and adjust paths and hyperparameters to your needs. Make sure to only use only as many workers as can fit on the machine(s) at once. If you want to run across multiple machines, make sure to set SINGLE_MACHINE = False in run_pbt.py. To start your PBT run, simply run run_pbt.py. When the run is complete, the best model will be copied to a best_model folder in your PBT run folder. The model will automatically be sampled and averaged and all outputs will be saved to posterior_samples.h5.

References

Keshtkaran MR, Sedler AR, Chowdhury RH, Tandon R, Basrai D, Nguyen SL, Sohn H, Jazayeri M, Miller LE, Pandarinath C. A large-scale neural network training framework for generalized estimation of single-trial population dynamics. bioRxiv. 2021 Jan 1.

Keshtkaran MR, Pandarinath C. Enabling hyperparameter optimization in sequential autoencoders for spiking neural data. Advances in Neural Information Processing Systems. 2019; 32.

Comments
  • Update lfads-tf2 dependencies for Google Colab compatibility

    Update lfads-tf2 dependencies for Google Colab compatibility

    Summary of changes to setup.py

    • Change pandas==1.0.0 to pandas==1.* to avoid a dependency conflict with google-colab
    • Add PyYAML>=5.1 so that yaml.full_loadworks in lfads-tf2.
    opened by yahiaali 0
  • Are more recent versions of tensorflow/CUDA supported by the package?

    Are more recent versions of tensorflow/CUDA supported by the package?

    Right now the package supports TF 2.0 and CUDA 10.0 which are more than 3 years old. Is there support planned/already established for more recent Tensorflow and CUDA versions?

    Thanks!

    opened by stes 0
  • Error: No 'git' repo detected for 'lfads_tf2'

    Error: No 'git' repo detected for 'lfads_tf2'

    Hello, I am having this issue. I have followed all the installation instructions, and I was wondering why this issue would come up. autolfads-tf2 is cloned using git, and it is inside the git folder. But it seems like train_lfads.py is not loading data. I am using Window 10.

    error

    Thank you so much in advance!

    opened by jinoh5 0
  • Add warnings and assertion to chop functions for bad overlap

    Add warnings and assertion to chop functions for bad overlap

    Add warnings and assertion to chop functions when requested overlap is greater than half of window length

    Addresses https://github.com/snel-repo/autolfads-tf2/issues/2

    opened by raeedcho 0
  •  `merge_chops` is unable to merge when the requested overlap is more than half of the window length

    `merge_chops` is unable to merge when the requested overlap is more than half of the window length

    Without really thinking a whole lot about it, I chopped data to window length 100 and overlap 80, since this would leave at most 20 points of unmodeled data at the end of the trials I'm trying to model. The chopping seems to work totally fine, but when merging the chops together, it seems that the code assumes that the overlap will be at most half the size of the window, and the math to put the chops back together breaks down in weird ways, leading to duplicated data in the final array.

    On further thought, it makes sense to some degree to limit the overlap to be at most half of the window length, since otherwise, data from more than two chops would have to be integrated together to merge everything--if this is the thought process, I think it would be a good idea to put an assertion in both functions that this is the case (or maybe at least an assertion in the merge_chops function and a warning in the chop_data function, since chopping technically works fine).

    If instead it would make sense to be able to merge chops with overlap greater than half the window size, then I think the merge_chops function needs to be reworked to be able to integrate across more than two chops

    opened by raeedcho 0
Releases(v0.1)
Owner
Systems Neural Engineering Lab
Emory University and Georgia Institute of Technology
Systems Neural Engineering Lab
Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression

Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression YOLOv5 with alpha-IoU losses implemented in PyTorch. Example r

Jacobi(Jiabo He) 147 Dec 05, 2022
A mini library for Policy Gradients with Parameter-based Exploration, with reference implementation of the ClipUp optimizer from NNAISENSE.

PGPElib A mini library for Policy Gradients with Parameter-based Exploration [1] and friends. This library serves as a clean re-implementation of the

NNAISENSE 56 Jan 01, 2023
Yolov5+SlowFast: Realtime Action Detection Based on PytorchVideo

Yolov5+SlowFast: Realtime Action Detection A realtime action detection frame work based on PytorchVideo. Here are some details about our modification:

WuFan 181 Dec 30, 2022
The mini-MusicNet dataset

mini-MusicNet A music-domain dataset for multi-label classification Music transcription is sequence-to-sequence prediction problem: given an audio per

John Thickstun 4 Nov 09, 2022
Self-Supervised CNN-GCN Autoencoder

GCNDepth Self-Supervised CNN-GCN Autoencoder GCNDepth: Self-supervised monocular depth estimation based on graph convolutional network To be published

53 Dec 14, 2022
Continuous Security Group Rule Change Detection & Response at scale

Introduction Get notified of Security Group Changes across all AWS Accounts & Regions in an AWS Organization, with the ability to respond/revert those

Raajhesh Kannaa Chidambaram 3 Aug 13, 2022
Atomistic Line Graph Neural Network

Table of Contents Introduction Installation Examples Pre-trained models Quick start using colab JARVIS-ALIGNN webapp Peformances on a few datasets Use

National Institute of Standards and Technology 91 Dec 30, 2022
Author's PyTorch implementation of TD3 for OpenAI gym tasks

Addressing Function Approximation Error in Actor-Critic Methods PyTorch implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3). If y

Scott Fujimoto 1.3k Dec 25, 2022
Doing the asl sign language classification on static images using graph neural networks.

SignLangGNN When GNNs 💜 MediaPipe. This is a starter project where I tried to implement some traditional image classification problem i.e. the ASL si

10 Nov 09, 2022
DISTIL: Deep dIverSified inTeractIve Learning.

DISTIL: Deep dIverSified inTeractIve Learning. An active/inter-active learning library built on py-torch for reducing labeling costs.

decile-team 110 Dec 06, 2022
TransMVSNet: Global Context-aware Multi-view Stereo Network with Transformers.

TransMVSNet This repository contains the official implementation of the paper: "TransMVSNet: Global Context-aware Multi-view Stereo Network with Trans

旷视研究院 3D 组 155 Dec 29, 2022
Location-Sensitive Visual Recognition with Cross-IOU Loss

The trained models are temporarily unavailable, but you can train the code using reasonable computational resource. Location-Sensitive Visual Recognit

Kaiwen Duan 146 Dec 25, 2022
minimizer-space de Bruijn graphs (mdBG) for whole genome assembly

rust-mdbg: Minimizer-space de Bruijn graphs (mdBG) for whole-genome assembly rust-mdbg is an ultra-fast minimizer-space de Bruijn graph (mdBG) impleme

Barış Ekim 148 Dec 01, 2022
PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.

PySlowFast PySlowFast is an open source video understanding codebase from FAIR that provides state-of-the-art video classification models with efficie

Meta Research 5.3k Jan 03, 2023
FLVIS: Feedback Loop Based Visual Initial SLAM

FLVIS Feedback Loop Based Visual Inertial SLAM 1-Video EuRoC DataSet MH_05 Handheld Test in Lab FlVIS on UAV Platform 2-Relevent Publication: Under Re

UAV Lab - HKPolyU 182 Dec 04, 2022
PyTorch implementation of CloudWalk's recent work DenseBody

densebody_pytorch PyTorch implementation of CloudWalk's recent paper DenseBody. Note: For most recent updates, please check out the dev branch. Update

Lingbo Yang 401 Nov 19, 2022
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
For visualizing the dair-v2x-i dataset

3D Detection & Tracking Viewer The project is based on hailanyi/3D-Detection-Tracking-Viewer and is modified, you can find the original version of the

34 Dec 29, 2022
Proto-RL: Reinforcement Learning with Prototypical Representations

Proto-RL: Reinforcement Learning with Prototypical Representations This is a PyTorch implementation of Proto-RL from Reinforcement Learning with Proto

Denis Yarats 74 Dec 06, 2022
This project generates news headlines using a Long Short-Term Memory (LSTM) neural network.

News Headlines Generator bunnysaini/Generate-Headlines Goal This project aims to generate news headlines using a Long Short-Term Memory (LSTM) neural

Bunny Saini 1 Jan 24, 2022