pytorch-syncbn
This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
Remarks
- Unlike Pytorch-Encoding, you don't need custom nn.DataParallel
- Unlike Inplace-ABN, you can just replace your nn.BatchNorm2d to this module implementation, since it will not mark for inplace operation
- You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
- Backward computation is rewritten and tested against behavior of nn.BatchNorm2d
Requirements
For PyTorch, please refer to https://pytorch.org/
NOTE : The code is tested only with PyTorch v0.4.0, CUDA9.1.85/CuDNN7.1.4 on ubuntu16.04
(It can also be compiled and run on the JetsonTX2, but won't work as multi-gpu synchronnized BN.)
To install all dependencies using pip, run:
pip install -U -r requirements.txt
Build
use make_ext.sh
to build the extension. for example:
PYTHON_CMD=python3 ./make_ext.sh
Usage
Please refer to test.py
for testing the difference between nn.BatchNorm2d
and modules.nn.BatchNorm2d
import torch
from modules import nn as NN
num_gpu = torch.cuda.device_count()
model = nn.Sequential(
nn.Conv2d(3, 3, 1, 1, bias=False),
NN.BatchNorm2d(3),
nn.ReLU(inplace=True),
nn.Conv2d(3, 3, 1, 1, bias=False),
NN.BatchNorm2d(3),
).cuda()
model = nn.DataParallel(model, device_ids=range(num_gpu))
x = torch.rand(num_gpu, 3, 2, 2).cuda()
z = model(x)
Math
Forward
-
compute in each gpu
-
gather all from workers to master and compute where
and
and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
-
forward batchnorm using global stats by
and then
where is weight parameter and is bias parameter.
-
save for backward
Backward
-
Restore saved
-
Compute below sums on each gpu
and
where
then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
-
compute gradients using global stats
where
and
and finally,
Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.