Implements Gradient Centralization and allows it to use as a Python package in TensorFlow

Overview

Gradient Centralization TensorFlow Twitter

PyPI Upload Python Package Flake8 Lint Python Version

Binder Open In Colab

GitHub license PEP8 GitHub stars GitHub forks GitHub watchers

This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper Gradient Centralization: A New Optimization Technique for Deep Neural Networks. It can both speedup training process and improve the final generalization performance of DNNs.

Installation

Run the following to install:

pip install gradient-centralization-tf

Usage

gctf.centralized_gradients_for_optimizer

Create a centralized gradients functions for a specified optimizer.

Arguments:

  • optimizer: a tf.keras.optimizers.Optimizer object. The optimizer you are using.

Example:

>>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
>>> optimizer.get_gradients = gctf.centralized_gradients_for_optimizer(opt)
>>> model.compile(optimizer = opt, ...)

gctf.get_centralized_gradients

Computes the centralized gradients.

This function is ideally not meant to be used directly unless you are building a custom optimizer, in which case you could point get_gradients to this function. This is a modified version of tf.keras.optimizers.Optimizer.get_gradients.

Arguments:

  • optimizer: a tf.keras.optimizers.Optimizer object. The optimizer you are using.
  • loss: Scalar tensor to minimize.
  • params: List of variables.

Returns:

A gradients tensor.

gctf.optimizers

Pre built updated optimizers implementing GC.

This module is speciially built for testing out GC and in most cases you would be using gctf.centralized_gradients_for_optimizer though this module implements gctf.centralized_gradients_for_optimizer. You can directly use all optimizers with tf.keras.optimizers updated for GC.

Example:

>>> model.compile(optimizer = gctf.optimizers.adam(learning_rate = 0.01), ...)
>>> model.compile(optimizer = gctf.optimizers.rmsprop(learning_rate = 0.01, rho = 0.91), ...)
>>> model.compile(optimizer = gctf.optimizers.sgd(), ...)

Returns:

A tf.keras.optimizers.Optimizer object.

Developing gctf

To install gradient-centralization-tf, along with tools you need to develop and test, run the following in your virtualenv:

git clone [email protected]:Rishit-dagli/Gradient-Centralization-TensorFlow
# or clone your own fork

pip install -e .[dev]

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Comments
  • On windows Tensorflow 2.5 it gives error

    On windows Tensorflow 2.5 it gives error

    On windows 10 with miniconda enviroment tensorflow 2.5 gives error on centralized_gradients.py file.

    the solution is change import keras.backend as K with import tensorflow.keras.backend as K

    bug 
    opened by mgezer 5
  • The results in the mnist example are wrong/misleading

    The results in the mnist example are wrong/misleading

    Describe the bug The results in your colab ipython notebook are misleading: https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/examples/gctf_mnist.ipynb

    In this example, the model is first trained with a normal Adam optimizer:

    model.compile(optimizer = tf.keras.optimizers.Adam(),
                  loss = 'sparse_categorical_crossentropy',
                  metrics = ['accuracy'])
    
    history_no_gctf = model.fit(training_images, training_labels, epochs=5, callbacks = [time_callback_no_gctf])
    

    And afterwards the same model is recompiled with the gctf.optimizers.adam(). However, recompiling a keras model does not reset the weights. This means that in the first fit call the model is trained and then in the second fit call with the new optimizer the same model is used and of course then the results are better.

    This can be fixed, by recreating the model for the second run, by just adding these few lines:

    import gctf #import gctf
    
    time_callback_gctf = TimeHistory()
    
    # Model architecture
    model = tf.keras.models.Sequential([
                                        tf.keras.layers.Flatten(), 
                                        tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                        tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                        tf.keras.layers.Dense(64, activation=tf.nn.relu),
                                        tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                        tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                        tf.keras.layers.Dense(64, activation=tf.nn.relu), 
                                        tf.keras.layers.Dense(10, activation=tf.nn.softmax)])
    
    model.compile(optimizer = gctf.optimizers.adam(),
                  loss = 'sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    history_gctf = model.fit(training_images, training_labels, epochs=5, callbacks=[time_callback_gctf])
    

    However, then the results are not better than without gctf:

    Type                   Execution time    Accuracy      Loss
    -------------------  ----------------  ----------  --------
    Model without gctf:           24.7659    0.88825   0.305801
    Model with gctf               24.7881    0.889567  0.30812
    

    Could you please clarify what happens here. I tried this gctf.optimizers.adam() optimizer in my own research and it didn't change the results at all and now after seeing it doesn't work in the example which was constructed here. Makes me question the results of this paper.

    To Reproduce Execute the colab file given in the repository: https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/examples/gctf_mnist.ipynb

    Expected behavior The right comparison would be if both models start from a random initialization, not that the second model can start with the already pre-trained weights.

    Looking forward to a fast a swift explanation.

    Best, Max

    question 
    opened by themasterlink 2
  • Wider dependency requirements

    Wider dependency requirements

    The package as of now to be installed requires tensorflow ~= 2.4.0 and keras ~= 2.4.0. It turns out that this is sometimes problematic for folks who have custom installations of TensorFlow and a winder requirement could be set up.

    enhancement 
    opened by Rishit-dagli 1
  • Release 0.0.3

    Release 0.0.3

    This release includes some fixes and improvements

    ✅ Bug Fixes / Improvements

    • Allow wider versions for TensorFlow and Keras while installing the package (#14 )
    • Fixed incorrect usage example in docstrings and description for centralized_gradients_for_optimizer (#13 )
    • Add clear aims for each of the examples of using gctf (#15 )
    • Updates PyPi classifiers to clearly show the aims of this project. This should have no changes in the way you use this package (#18 )
    • Add clear instructions for using this with custom optimizers i.e. directly use get_centralized_gradients however a complete example has not been pushed due to the reasons mentioned in the issue (#16 )
    opened by Rishit-dagli 0
  • Add an

    Add an "About The Examples" section

    Add an "About The Examples" section which contains a summary of the usage example notebooks and links to run it on Binder and Colab.


    Close #15

    opened by Rishit-dagli 0
  • Update relevant pypi classifiers

    Update relevant pypi classifiers

    Add PyPI classifiers for:

    • Development status
    • Intended Audience
    • Topic

    Further also added the Programming Language :: Python :: 3 :: Only classifer


    Closes #18

    opened by Rishit-dagli 0
  • Update pypi classifiers

    Update pypi classifiers

    I am specifically thinking of adding three more categories of pypi classifiers:

    • Development status
    • Intended Audience
    • Topic

    Apart from this I also think it would be great to add the Programming Language :: Python :: 3 :: Only to make sure the audience to know that this package is intended for Python 3 only.

    opened by Rishit-dagli 0
  • Add an

    Add an "About the examples" section

    It would be great to write an "About the example" section which could demonstrate in short what the example notebooks aim to achieve and show.

    documentation 
    opened by Rishit-dagli 0
  • Error in usage example for gctf.centralized_gradients_for_optimizer

    Error in usage example for gctf.centralized_gradients_for_optimizer

    I noticed that the docstrings for gctf.centralized_gradients_for_optimizer have an error in the example usage section. The example creates an Adam optimizer instance and saves it to opt however the centralized_gradients_for_optimizer is applied on optimizer which ideally does not exist and running the example would result in an error.

    documentation 
    opened by Rishit-dagli 0
  • [ImgBot] Optimize images

    [ImgBot] Optimize images

    opened by imgbot[bot] 0
  • [ImgBot] Optimize images

    [ImgBot] Optimize images

    opened by imgbot[bot] 0
Releases(v0.0.3)
  • v0.0.3(Mar 11, 2021)

    This release includes some fixes and improvements

    ✅ Bug Fixes / Improvements

    • Allow wider versions for TensorFlow and Keras while installing the package (#14 )
    • Fixed incorrect usage example in docstrings and description for centralized_gradients_for_optimizer (#13 )
    • Add clear aims for each of the examples of using gctf (#15 )
    • Updates PyPi classifiers to clearly show the aims of this project. This should have no changes in the way you use this package (#18 )
    • Add clear instructions for using this with custom optimizers i.e. directly use get_centralized_gradients however a complete example has not been pushed due to the reasons mentioned in the issue (#16 )
    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Feb 21, 2021)

    This release includes some fixes and improvements

    ✅ Bug Fixes / Improvements

    • Fix the issue of supporting multiple modules
    • Fix multiple typos.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Feb 20, 2021)

Owner
Rishit Dagli
High School, Ted-X, Ted-Ed speaker|Mentor, TFUG Mumbai|International Speaker|Microsoft Student Ambassador|#ExploreML Facilitator
Rishit Dagli
Self-Supervised Pre-Training for Transformer-Based Person Re-Identification

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification [pdf] The official repository for Self-Supervised Pre-Training for Transfo

Hao Luo 116 Jan 04, 2023
Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are implemented and can be seen in tensorboard.

Sarus published models Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are

Sarus Technologies 39 Aug 19, 2022
This project is for a Twitter bot that monitors a bird feeder in my backyard. Any detected birds are identified and posted to Twitter.

Backyard Birdbot Introduction This is a silly hobby project to use existing ML models to: Detect any birds sighted by a webcam Identify whic

Chi Young Moon 71 Dec 25, 2022
SenseNet is a sensorimotor and touch simulator for deep reinforcement learning research

SenseNet is a sensorimotor and touch simulator for deep reinforcement learning research

59 Feb 25, 2022
Source code of our BMVC 2021 paper: AniFormer: Data-driven 3D Animation with Transformer

AniFormer This is the PyTorch implementation of our BMVC 2021 paper AniFormer: Data-driven 3D Animation with Transformer. Haoyu Chen, Hao Tang, Nicu S

24 Nov 02, 2022
SOTA model in CIFAR10

A PyTorch Implementation of CIFAR Tricks 调研了CIFAR10数据集上各种trick,数据增强,正则化方法,并进行了实现。目前项目告一段落,如果有更好的想法,或者希望一起维护这个项目可以提issue或者在我的主页找到我的联系方式。 0. Requirement

PJDong 58 Dec 21, 2022
A PyTorch Reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution

TecoGAN-PyTorch Introduction This is a PyTorch reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution (VSR). Please refer to

165 Dec 17, 2022
Python scripts using the Mediapipe models for Halloween.

Mediapipe-Halloween-Examples Python scripts using the Mediapipe models for Halloween. WHY Mainly for fun. But this repository also includes useful exa

Ibai Gorordo 23 Jan 06, 2023
Calculates carbon footprint based on fuel mix and discharge profile at the utility selected. Can create graphs and tabular output for fuel mix based on input file of series of power drawn over a period of time.

carbon-footprint-calculator Conda distribution ~/anaconda3/bin/conda install anaconda-client conda-build ~/anaconda3/bin/conda config --set anaconda_u

Seattle university Renewable energy research 7 Sep 26, 2022
An intuitive library to extract features from time series

Time Series Feature Extraction Library Intuitive time series feature extraction This repository hosts the TSFEL - Time Series Feature Extraction Libra

Associação Fraunhofer Portugal Research 589 Jan 04, 2023
STARCH compuets regional extreme storm physical characteristics and moisture balance based on spatiotemporal precipitation data from reanalysis or climate model data.

STARCH (Storm Tracking And Regional CHaracterization) STARCH computes regional extreme storm physical and moisture balance characteristics based on sp

Onosama 7 Oct 20, 2022
[NeurIPS 2021] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods

Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods Large Scale Learning on Non-Homophilous Graphs: New Benchmark

60 Jan 03, 2023
Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets).

TOQ-Nets-PyTorch-Release Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets). Temporal and Object Quantification Net

Zhezheng Luo 9 Jun 30, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
This repo implements a 3D segmentation task for an airport baggage dataset.

3D CT Scan Segmentation With Occupancy Network This repo implements a 3D superresolution segmentation task for an airport baggage dataset. Our final p

Christoph Reich 2 Mar 28, 2022
Official Repository for "Robust On-Policy Data Collection for Data Efficient Policy Evaluation" (NeurIPS 2021 Workshop on OfflineRL).

Robust On-Policy Data Collection for Data-Efficient Policy Evaluation Source code of Robust On-Policy Data Collection for Data-Efficient Policy Evalua

Autonomous Agents Research Group (University of Edinburgh) 2 Oct 09, 2022
Fast (simple) spectral synthesis and emission-line fitting of DESI spectra.

FastSpecFit Introduction This repository contains code and documentation to perform fast, simple spectral synthesis and emission-line fitting of DESI

5 Aug 02, 2022
(ICCV 2021) PyTorch implementation of Paper "Progressive Correspondence Pruning by Consensus Learning"

CLNet (ICCV 2021) PyTorch implementation of Paper "Progressive Correspondence Pruning by Consensus Learning" [project page] [paper] Citing CLNet If yo

Chen Zhao 22 Aug 26, 2022
Code repo for "Towards Interpretable Deep Networks for Monocular Depth Estimation" paper.

InterpretableMDE A PyTorch implementation for "Towards Interpretable Deep Networks for Monocular Depth Estimation" paper. arXiv link: https://arxiv.or

Zunzhi You 16 Aug 12, 2022
A curated list and survey of awesome Vision Transformers.

English | 简体中文 A curated list and survey of awesome Vision Transformers. You can use mind mapping software to open the mind mapping source file. You c

OpenMMLab 281 Dec 21, 2022