Neural network pruning for finding a sparse computational model for controlling a biological motor task.

Overview

MothPruning

Scientific Overview

Originally inspired by biological nervous systems, deep neural networks (DNNs) are powerful computational tools for modeling complex systems. DNNs are used in a diversity of domains and have helped solve some of the most intractable problems in physics, biology, and computer science. Despite their prevalence, the use of DNNs as a modeling tool comes with some major downsides. DNNs are highly overparameterized, which often results in them being difficult to generalize and interpret, as well as being incredibly computationally expensive. Unlike DNNs, which are often trained until they reach the highest accuracy possible, biological networks have to balance performance with robustness to a noisy and dynamic environment. Biological neural systems use a variety of mechanisms to promote specialized and efficient pathways capable of performing complex tasks in the presence of noise. One such mechanism, synaptic pruning, plays a significant role in refining task-specific behaviors. Synaptic pruning results in a more sparsely connected network that can still perform complex cognitive and motor tasks. Here, we draw inspiration from biology and use DNNs and the method of neural network pruning to find a sparse computational model for controlling a biological motor task.

In this work, we use the inertial dynamics model in [2] to simulate examples of M. sexta hovering flight. These data are used to train a DNN to learn the controllers for hovering. Drawing inspiration from pruning in biological neural systems, we sparsify the network using neural network pruning. Here, we prune weights based simply on their magnitudes, removing those weights closest to zero. Insects must maneuver through high noise environments to accomplish controlled flight. It is often assumed that there is a trade-off between perfect flight control and robustness to noise and that the sensory data may be limited by the signal-to-noise ratio. Thus the network need not train for the most accurate model since in practice noise prevents high-fidelity models from exhibiting their underlying accuracy. Rather, we seek to find the sparsest model capable of performing the task given the noisy environment. We employed two methods for neural network pruning: either through manually setting weights to zero or by utilizing binary masking layers. Furthermore, the DNN is pruned sequentially, meaning groups of weights are removed slowly from the network, with retraining in-between successive prunes, until a target sparsity is reached. Monte Carlo simulations are also used to quantify the statistical distribution of network weights during pruning given random initialization of network weights.

For more information, please see our paper [1].

This is an image!

Project Description

The deep, fully-connected neural network was constructed with ten input variables and seven output variables. The initial and final state space conditions are the inputs to the network: i, i, i, i, i, i, f, f, f, and f. The network predicts the control variables and the final derivatives of the state space in its output layer: x, y, , f, f, f, and f.

After the fully-connected network is trained to a minimum error, we used the method of neural network pruning to promote sparsity between the network layers. In this work, a target sparsity (percentage of pruned network weights) is specified and the smallest magnitude weights are forced to zero. The network is then retrained until a minimum error is reached. This process is repeated until most of the weights have been pruned from the network.

The training and pruning protocols were developed using Keras with the TensorFlow backend. To scale up training for the statistical analysis of many networks, the training and pruning protocols were parallelized using the Jax framework.

To ensure weights remain pruned during retraining, we implemented the pruning functionality of a TensorFlow built toolkit called the Model Optimization Toolkit. The toolkit contains functions for pruning deep neural networks. In the Model Optimization Toolkit, pruning is achieved through the use of binary masking layers that are multiplied element-wise to each weight matrix in the network.

To be able to train and analyze many neural networks, the training and pruning protocols were parallelized in the Jax framework. Jax however does not come with a toolkit for pruning, therefore pruning by way of the binary masking matrices was coded into the training loop.

Installation

Create new conda environment with tools for generating data and training network (Note that this environment requires a GPU and the correct NVIDIA drivers).

conda env create -f environment_ODE_DL.yml

Create kernelspec (so you can see this kernel in JupyterLab).

conda activate [environment name]
python -m ipykernel install --user --name [environment name]
conda deactivate

To install Jax and Flax please follow the instructions on the Jax Github.

Data

To use the TensorFlow version of this code, you need to gerenate simulations of moth hovering for the data. The Jax version (multi-network train and prune) has data provided in this repository.

cd MothMachineLearning/Underactuated/GenerateData

and use 010_OneTorqueParallelSims.ipynb to generate the simulations.

How to use

The following guide walks through the process of training and pruning many networks in parallel using the Jax framework. However, the TensorFlow code is also provided for experimentation and visualization.

Step 1: Train networks

cd MothMachineLearning/Underactuated/TrainNetwork/multiNetPrune/

First we train and prune the desired number of networks in parallel using the Jax framework. Choose the number of networks you wish to train/prune in parallel by adjusting the numParallel parameter. You can also define the number of layers, units, and other hyperparameters. Use the command

python3 step1_train.py

to train and prune the networks in parallel.

Step 2: Evaluate at prunes

Next, the networks need to be evaulated at each prune. Use the command

python3 step2_pruneEval.py

to evaluate the networks at each prune.

Step 3: Pre-process networks

This code prepares the networks for sparse network identification (explained in the next step). It essentially just reorganizes the data. Open and run step3_preprocess.ipynb to preprocess, making sure to change modeltimestamp and the file names to the correct ones for your run.

Step 4: Find sparse networks

This codes finds the optimally sparse networks. For each network, the most pruned version whose loss is below a specified threshold (here 0.001) is kept. For example, the image below is a single network that has gone through the sequential pruning process and the red line specifies the defined threshold. For this example, the optimally sparse network is the one pruned by 94% (i.e. 6% of the original weights remain).

This is an image!

The sparse networks are collected and saved to a file called sparseNetworks.pkl. Open and run step4_findSparse.ipynb, making sure to change modeltimestamp and the file names to the correct ones for your run.

Note that if a network does not have a single prune that is below the loss threshold, it will be skipped and not included in the list of sparseNetworks. For example, if you trained and pruned 10 networks and 3 did not have a prune below a loss of 0.001, the list sparseNetworks will be length 7.

References

[1] Zahn, O., Bustamante, Jr J., Switzer, C., Daniel, T., and Kutz, J. N. (2022). Pruning deep neural networks generates a sparse, bio-inspired nonlinear controller for insect flight.

[2] Bustamante, Jr J., Ahmed, M., Deora, T., Fabien, B., and Daniel, T. (2021). Abdominal movements in insect flight reshape the role of non-aerodynamic structures for flight maneuverability. J. Integrative and Comparative Biology. In revision.

Owner
Olivia Thomas
Physics graduate student at the University of Washington
Olivia Thomas
SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021)

SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021) PyTorch implementation of SnapMix | paper Method Overview Cite

DavidHuang 126 Dec 30, 2022
Automatic Number Plate Recognition using Contours and Convolution Neural Networks (CNN)

Cite our paper if you find this project useful https://www.ijariit.com/manuscripts/v7i4/V7I4-1139.pdf Abstract Image processing technology is used in

Adithya M 2 Jun 28, 2022
The repository forked from NVlabs uses our data. (Differentiable rasterization applied to 3D model simplification tasks)

nvdiffmodeling [origin_code] Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Autom

Qiujie (Jay) Dong 2 Oct 31, 2022
Official PyTorch implementation of the paper Image-Based CLIP-Guided Essence Transfer.

TargetCLIP- official pytorch implementation of the paper Image-Based CLIP-Guided Essence Transfer This repository finds a global direction in StyleGAN

Hila Chefer 221 Dec 13, 2022
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
RoboDesk A Multi-Task Reinforcement Learning Benchmark

RoboDesk A Multi-Task Reinforcement Learning Benchmark If you find this open source release useful, please reference in your paper: @misc{kannan2021ro

Google Research 66 Oct 07, 2022
Source code of article "Towards Toxic and Narcotic Medication Detection with Rotated Object Detector"

Towards Toxic and Narcotic Medication Detection with Rotated Object Detector Introduction This is the source code of article: Towards Toxic and Narcot

Woody. Wang 3 Oct 29, 2022
Conversational text Analysis using various NLP techniques

PyConverse Let me try first Installation pip install pyconverse Usage Please try this notebook that demos the core functionalities: basic usage noteb

Rita Anjana 158 Dec 25, 2022
PyTorch implementation(s) of various ResNet models from Twitch streams.

pytorch-resnet-twitch PyTorch implementation(s) of various ResNet models from Twitch streams. Status: ResNet50 currently not working. Will update in n

Daniel Bourke 3 Jan 11, 2022
Generalized hybrid model for mode-locked laser diodes with an extended passive cavity

GenHybridMLLmodel Generalized hybrid model for mode-locked laser diodes with an extended passive cavity This hybrid simulation strategy combines a tra

Stijn Cuyvers 3 Sep 21, 2022
The Dual Memory is build from a simple CNN for the deep memory and Linear Regression fro the fast Memory

Simple-DMA a simple Dual Memory Architecture for classifications. based on the paper Dual-Memory Deep Learning Architectures for Lifelong Learning of

1 Jan 27, 2022
Code for "Layered Neural Rendering for Retiming People in Video."

Layered Neural Rendering in PyTorch This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "Layered Neural Rendering

Google 154 Dec 16, 2022
TransGAN: Two Transformers Can Make One Strong GAN

[Preprint] "TransGAN: Two Transformers Can Make One Strong GAN", Yifan Jiang, Shiyu Chang, Zhangyang Wang

VITA 1.5k Jan 07, 2023
Library extending Jupyter notebooks to integrate with Apache TinkerPop and RDF SPARQL.

Graph Notebook: easily query and visualize graphs The graph notebook provides an easy way to interact with graph databases using Jupyter notebooks. Us

Amazon Web Services 501 Dec 28, 2022
Data Preparation, Processing, and Visualization for MoVi Data

MoVi-Toolbox Data Preparation, Processing, and Visualization for MoVi Data, https://www.biomotionlab.ca/movi/ MoVi is a large multipurpose dataset of

Saeed Ghorbani 51 Nov 27, 2022
(JMLR' 19) A Python Toolbox for Scalable Outlier Detection (Anomaly Detection)

Python Outlier Detection (PyOD) Deployment & Documentation & Stats & License PyOD is a comprehensive and scalable Python toolkit for detecting outlyin

Yue Zhao 6.6k Jan 05, 2023
Type4Py: Deep Similarity Learning-Based Type Inference for Python

Type4Py: Deep Similarity Learning-Based Type Inference for Python This repository contains the implementation of Type4Py and instructions for re-produ

Software Analytics Lab 45 Dec 15, 2022
Target Propagation via Regularized Inversion

Target Propagation via Regularized Inversion The present code implements an ideal formulation of target propagation using regularized inverses compute

Vincent Roulet 0 Dec 02, 2021
Record radiologists' eye gaze when they are labeling images.

Record radiologists' eye gaze when they are labeling images. Read for installation, usage, and deep learning examples. Why use MicEye Versatile As a l

24 Nov 03, 2022
D²Conv3D: Dynamic Dilated Convolutions for Object Segmentation in Videos

D²Conv3D: Dynamic Dilated Convolutions for Object Segmentation in Videos This repository contains the implementation for "D²Conv3D: Dynamic Dilated Co

17 Oct 20, 2022