Neural network pruning for finding a sparse computational model for controlling a biological motor task.

Overview

MothPruning

Scientific Overview

Originally inspired by biological nervous systems, deep neural networks (DNNs) are powerful computational tools for modeling complex systems. DNNs are used in a diversity of domains and have helped solve some of the most intractable problems in physics, biology, and computer science. Despite their prevalence, the use of DNNs as a modeling tool comes with some major downsides. DNNs are highly overparameterized, which often results in them being difficult to generalize and interpret, as well as being incredibly computationally expensive. Unlike DNNs, which are often trained until they reach the highest accuracy possible, biological networks have to balance performance with robustness to a noisy and dynamic environment. Biological neural systems use a variety of mechanisms to promote specialized and efficient pathways capable of performing complex tasks in the presence of noise. One such mechanism, synaptic pruning, plays a significant role in refining task-specific behaviors. Synaptic pruning results in a more sparsely connected network that can still perform complex cognitive and motor tasks. Here, we draw inspiration from biology and use DNNs and the method of neural network pruning to find a sparse computational model for controlling a biological motor task.

In this work, we use the inertial dynamics model in [2] to simulate examples of M. sexta hovering flight. These data are used to train a DNN to learn the controllers for hovering. Drawing inspiration from pruning in biological neural systems, we sparsify the network using neural network pruning. Here, we prune weights based simply on their magnitudes, removing those weights closest to zero. Insects must maneuver through high noise environments to accomplish controlled flight. It is often assumed that there is a trade-off between perfect flight control and robustness to noise and that the sensory data may be limited by the signal-to-noise ratio. Thus the network need not train for the most accurate model since in practice noise prevents high-fidelity models from exhibiting their underlying accuracy. Rather, we seek to find the sparsest model capable of performing the task given the noisy environment. We employed two methods for neural network pruning: either through manually setting weights to zero or by utilizing binary masking layers. Furthermore, the DNN is pruned sequentially, meaning groups of weights are removed slowly from the network, with retraining in-between successive prunes, until a target sparsity is reached. Monte Carlo simulations are also used to quantify the statistical distribution of network weights during pruning given random initialization of network weights.

For more information, please see our paper [1].

This is an image!

Project Description

The deep, fully-connected neural network was constructed with ten input variables and seven output variables. The initial and final state space conditions are the inputs to the network: i, i, i, i, i, i, f, f, f, and f. The network predicts the control variables and the final derivatives of the state space in its output layer: x, y, , f, f, f, and f.

After the fully-connected network is trained to a minimum error, we used the method of neural network pruning to promote sparsity between the network layers. In this work, a target sparsity (percentage of pruned network weights) is specified and the smallest magnitude weights are forced to zero. The network is then retrained until a minimum error is reached. This process is repeated until most of the weights have been pruned from the network.

The training and pruning protocols were developed using Keras with the TensorFlow backend. To scale up training for the statistical analysis of many networks, the training and pruning protocols were parallelized using the Jax framework.

To ensure weights remain pruned during retraining, we implemented the pruning functionality of a TensorFlow built toolkit called the Model Optimization Toolkit. The toolkit contains functions for pruning deep neural networks. In the Model Optimization Toolkit, pruning is achieved through the use of binary masking layers that are multiplied element-wise to each weight matrix in the network.

To be able to train and analyze many neural networks, the training and pruning protocols were parallelized in the Jax framework. Jax however does not come with a toolkit for pruning, therefore pruning by way of the binary masking matrices was coded into the training loop.

Installation

Create new conda environment with tools for generating data and training network (Note that this environment requires a GPU and the correct NVIDIA drivers).

conda env create -f environment_ODE_DL.yml

Create kernelspec (so you can see this kernel in JupyterLab).

conda activate [environment name]
python -m ipykernel install --user --name [environment name]
conda deactivate

To install Jax and Flax please follow the instructions on the Jax Github.

Data

To use the TensorFlow version of this code, you need to gerenate simulations of moth hovering for the data. The Jax version (multi-network train and prune) has data provided in this repository.

cd MothMachineLearning/Underactuated/GenerateData

and use 010_OneTorqueParallelSims.ipynb to generate the simulations.

How to use

The following guide walks through the process of training and pruning many networks in parallel using the Jax framework. However, the TensorFlow code is also provided for experimentation and visualization.

Step 1: Train networks

cd MothMachineLearning/Underactuated/TrainNetwork/multiNetPrune/

First we train and prune the desired number of networks in parallel using the Jax framework. Choose the number of networks you wish to train/prune in parallel by adjusting the numParallel parameter. You can also define the number of layers, units, and other hyperparameters. Use the command

python3 step1_train.py

to train and prune the networks in parallel.

Step 2: Evaluate at prunes

Next, the networks need to be evaulated at each prune. Use the command

python3 step2_pruneEval.py

to evaluate the networks at each prune.

Step 3: Pre-process networks

This code prepares the networks for sparse network identification (explained in the next step). It essentially just reorganizes the data. Open and run step3_preprocess.ipynb to preprocess, making sure to change modeltimestamp and the file names to the correct ones for your run.

Step 4: Find sparse networks

This codes finds the optimally sparse networks. For each network, the most pruned version whose loss is below a specified threshold (here 0.001) is kept. For example, the image below is a single network that has gone through the sequential pruning process and the red line specifies the defined threshold. For this example, the optimally sparse network is the one pruned by 94% (i.e. 6% of the original weights remain).

This is an image!

The sparse networks are collected and saved to a file called sparseNetworks.pkl. Open and run step4_findSparse.ipynb, making sure to change modeltimestamp and the file names to the correct ones for your run.

Note that if a network does not have a single prune that is below the loss threshold, it will be skipped and not included in the list of sparseNetworks. For example, if you trained and pruned 10 networks and 3 did not have a prune below a loss of 0.001, the list sparseNetworks will be length 7.

References

[1] Zahn, O., Bustamante, Jr J., Switzer, C., Daniel, T., and Kutz, J. N. (2022). Pruning deep neural networks generates a sparse, bio-inspired nonlinear controller for insect flight.

[2] Bustamante, Jr J., Ahmed, M., Deora, T., Fabien, B., and Daniel, T. (2021). Abdominal movements in insect flight reshape the role of non-aerodynamic structures for flight maneuverability. J. Integrative and Comparative Biology. In revision.

Owner
Olivia Thomas
Physics graduate student at the University of Washington
Olivia Thomas
A curated list and survey of awesome Vision Transformers.

English | 简体中文 A curated list and survey of awesome Vision Transformers. You can use mind mapping software to open the mind mapping source file. You c

OpenMMLab 281 Dec 21, 2022
✨风纪委员会自动投票脚本,利用Github Action帮你进行裁决操作(为了让其他风纪委员有案件可判,本程序从中午12点才开始运行,有需要请自己修改运行时间)

风纪委员会自动投票 本脚本通过使用Github Action来实现B站风纪委员的自动投票功能,喜欢请给我点个STAR吧! 如果你不是风纪委员,在符合风纪委员申请条件的情况下,本脚本会自动帮你申请 投票时间是早上八点,如果有需要请自行修改.github/workflows/Judge.yml中的时间,

Pesy Wu 25 Feb 17, 2021
Pytorch implementation of Depth-conditioned Dynamic Message Propagation forMonocular 3D Object Detection

DDMP-3D Pytorch implementation of Depth-conditioned Dynamic Message Propagation forMonocular 3D Object Detection, a paper on CVPR2021. Instroduction T

Li Wang 32 Nov 09, 2022
Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV

Realtime Face Anti-Spoofing Detection 🤖 Realtime Face Anti Spoofing Detection with Face Detector to detect real and fake faces Please star this repo

Prem Kumar 86 Aug 03, 2022
A collection of inference modules for fastai2

fastinference A collection of inference modules for fastai including inference speedup and interpretability Install pip install fastinference There ar

Zachary Mueller 83 Oct 10, 2022
This repository lets you interact with Lean through a REPL.

lean-gym This repository lets you interact with Lean through a REPL. See Formal Mathematics Statement Curriculum Learning for a presentation of lean-g

OpenAI 87 Dec 28, 2022
Learning from Synthetic Data with Fine-grained Attributes for Person Re-Identification

Less is More: Learning from Synthetic Data with Fine-grained Attributes for Person Re-Identification Suncheng Xiang Shanghai Jiao Tong University Over

SunchengXiang 68 Dec 13, 2022
A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.

Torch-RecHub A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend. 安装 pip install torch-rechub 主要特性 scikit-learn风格易用

Mincai Lai 67 Jan 04, 2023
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
Improving the robustness and performance of biomedical NLP models through adversarial training

RobustBioNLP Improving the robustness and performance of biomedical NLP models through adversarial training In this repository you can find suppliment

Milad Moradi 3 Sep 20, 2022
TabNet for fastai

TabNet for fastai This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (=2.0) library. The original paper https://ar

Mikhail Grankin 116 Oct 21, 2022
Training DiffWave using variational method from Variational Diffusion Models.

Variational DiffWave Training DiffWave using variational method from Variational Diffusion Models. Quick Start python train_distributed.py discrete_10

Chin-Yun Yu 26 Dec 13, 2022
Code repository for the paper Computer Vision User Entity Behavior Analytics

Computer Vision User Entity Behavior Analytics Code repository for "Computer Vision User Entity Behavior Analytics" Code Description dataset.csv As di

Sameer Khanna 2 Aug 20, 2022
Cognition-aware Cognate Detection

Cognition-aware Cognate Detection The repository which contains our code for our EACL 2021 paper titled, "Cognition-aware Cognate Detection". This wor

Prashant K. Sharma 1 Feb 01, 2022
Code for SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021)

SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021) SyncTwin is a treatment effect estimation method tailored for observat

Zhaozhi Qian 3 Nov 03, 2022
🤗 Push your spaCy pipelines to the Hugging Face Hub

spacy-huggingface-hub: Push your spaCy pipelines to the Hugging Face Hub This package provides a CLI command for uploading any trained spaCy pipeline

Explosion 30 Oct 09, 2022
A Topic Modeling toolbox

Topik A Topic Modeling toolbox. Introduction The aim of topik is to provide a full suite and high-level interface for anyone interested in applying to

Anaconda, Inc. (formerly Continuum Analytics, Inc.) 93 Dec 01, 2022
SAT Project - The first project I had done at General Assembly, performed EDA, data cleaning and created data visualizations

Project 1: Standardized Test Analysis by Adam Klesc Overview This project covers: Basic statistics and probability Many Python programming concepts Pr

Adam Muhammad Klesc 1 Jan 03, 2022
Make differentially private training of transformers easy for everyone

private-transformers This codebase facilitates fast experimentation of differentially private training of Hugging Face transformers. What is this? Why

Xuechen Li 73 Dec 28, 2022
A simple log parser and summariser for IIS web server logs

IISLogFileParser A basic parser tool for IIS Logs which summarises findings from the log file. Inspired by the Gist https://gist.github.com/wh13371/e7

2 Mar 26, 2022