This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
- Unlike Pytorch-Encoding, you don't need custom nn.DataParallel
- Unlike Inplace-ABN, you can just replace your nn.BatchNorm2d to this module implementation, since it will not mark for inplace operation
- You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
- Backward computation is rewritten and tested against behavior of nn.BatchNorm2d
For PyTorch, please refer to https://pytorch.org/
NOTE : The code is tested only with PyTorch v0.4.0, CUDA9.1.85/CuDNN7.1.4 on ubuntu16.04
(It can also be compiled and run on the JetsonTX2, but won't work as multi-gpu synchronnized BN.)
To install all dependencies using pip, run:
pip install -U -r requirements.txt
make_ext.sh to build the extension. for example:
Please refer to
test.py for testing the difference between
import torch from modules import nn as NN num_gpu = torch.cuda.device_count() model = nn.Sequential( nn.Conv2d(3, 3, 1, 1, bias=False), NN.BatchNorm2d(3), nn.ReLU(inplace=True), nn.Conv2d(3, 3, 1, 1, bias=False), NN.BatchNorm2d(3), ).cuda() model = nn.DataParallel(model, device_ids=range(num_gpu)) x = torch.rand(num_gpu, 3, 2, 2).cuda() z = model(x)
compute in each gpu
gather all from workers to master and compute where
and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
forward batchnorm using global stats by
where is weight parameter and is bias parameter.
save for backward
Compute below sums on each gpu
then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
compute gradients using global stats
Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.