An implementation of Shuffle BatchNorm technique mentioned in He et al., Momentum Contrast for Unsupervised Visual Representation Learning, 2019, in Section 3.3 "Shuffling BN".
Implemented with torch 1.3.1. It works with pytorch DistrbutedDataParallel with 1 process per GPU. So in order to use this
ShuffleBatchNorm layer you need at least 2 GPUs.
The formula above is the BatchNorm algorithm. The
ShuffleBatchNorm layer shuffles the mini-batch statistics (mean and variance) across multiple GPUs to avoid information leak. This operation eliminates model "cheating" when training contrastive loss and the contrast is obtained within the mini batch.
How to use?
The implementation mimics the design of SyncBatchNorm. To use
ShuffleBatchNorm, just create your model first and then convert all
torch.nn.BatchNormND layers into
ShuffleBatchNorm by the function:
from shuffle_batchnorm import ShuffleBatchNorm # ... model = Model() # with BN layers model = ShuffleBatchNorm.convert_shuffle_batchnorm(model)
main.py for a completed example.
$ python main.py --gpu 0,1 --shuffle --epochs 10 => Spawning 2 distributed workers ... mean before shuffle: tensor([-0.2478, 0.1704, 0.0640, -0.2732], device='cuda:0') mean before shuffle: tensor([-0.4012, -0.1913, -0.0553, -0.1917], device='cuda:1') mean after shuffle: tensor([-0.4012, -0.1913, -0.0553, -0.1917], device='cuda:0') mean after shuffle: tensor([-0.2478, 0.1704, 0.0640, -0.2732], device='cuda:1') [9/10] Loss 0.6868 ================================================ [9/10] Loss 0.7908 ================================================
If you find bugs, please create an issue. Very welcome!
- Doesn't work when training with multiple nodes, will fix soon.