Gradient Centralization TensorFlow
This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper Gradient Centralization: A New Optimization Technique for Deep Neural Networks. It can both speedup training process and improve the final generalization performance of DNNs.
Installation
Run the following to install:
pip install gradient-centralization-tf
Usage
gctf.centralized_gradients_for_optimizer
Create a centralized gradients functions for a specified optimizer.
Arguments:
optimizer
: atf.keras.optimizers.Optimizer object
. The optimizer you are using.
Example:
>>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
>>> optimizer.get_gradients = gctf.centralized_gradients_for_optimizer(opt)
>>> model.compile(optimizer = opt, ...)
gctf.get_centralized_gradients
Computes the centralized gradients.
This function is ideally not meant to be used directly unless you are building a custom optimizer, in which case you
could point get_gradients
to this function. This is a modified version of
tf.keras.optimizers.Optimizer.get_gradients
.
Arguments:
optimizer
: atf.keras.optimizers.Optimizer
object. The optimizer you are using.loss
: Scalar tensor to minimize.params
: List of variables.
Returns:
A gradients tensor.
gctf.optimizers
Pre built updated optimizers implementing GC.
This module is speciially built for testing out GC and in most cases you would be using gctf.centralized_gradients_for_optimizer
though this module implements gctf.centralized_gradients_for_optimizer
. You can directly use all optimizers with tf.keras.optimizers
updated for GC.
Example:
>>> model.compile(optimizer = gctf.optimizers.adam(learning_rate = 0.01), ...)
>>> model.compile(optimizer = gctf.optimizers.rmsprop(learning_rate = 0.01, rho = 0.91), ...)
>>> model.compile(optimizer = gctf.optimizers.sgd(), ...)
Returns:
A tf.keras.optimizers.Optimizer
object.
Developing gctf
To install gradient-centralization-tf
, along with tools you need to develop and test, run the following in your
virtualenv:
git clone [email protected]:Rishit-dagli/Gradient-Centralization-TensorFlow
# or clone your own fork
pip install -e .[dev]