陌路茶色/

torch.nn.modules.batchnorm.py

BN的具体实现对应到的是torch.batch_norm,具体调用可能需要阅读C代码,后续有需求继续深入。

BN原理

屏幕快照 2020-05-04 下午4.38.37.png

函数

主要是BatchNorm1d(_BatchNorm)和BatchNorm2d(_BatchNorm),_BatchNorm(_NormBase)的forward中会调用F.batch_norm()函数,_NormBase(Module)的__init__()函数中会根据affine和track_running_stats来初始化参数。

affine为true表示$γ$和$β$会在训练中学习,否则被置为1和0。
track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值(当前的方差和均值会依赖前面batch的均值和方差,momentum参数用在此处,涉及滑动平均算法),而不只是仅仅依赖与当前输入的batch的统计特性。
num_features参数表示的是通道数,对应的是第二维度的值,即在初始化BatchNorm1d和BatchNorm2d时,传入的数值应该等于输入tensor的第二维大小,BatchNorm1d对应2D和3D的tensor,BatchNorm2d对应的是4D的tensor。

SyncBatchNorm(_BatchNorm)实现原理

先说一下BN,机器学习要求数据独立同分布,独立需要做白化操作,同分布即归一化,深度学习中借鉴使用BN,仅做了归一化,但是为了保证归一化不会影响模型的效果,会加上两个可变参数。而同步BN会在更大的数据集上进行BN,这样模型相对于在单卡上进行BN更容易收敛,目前这个实现只适合单机多卡。
参考中提到了同步BN的实现原理,如下图所示:
屏幕快照 2020-05-16 下午9.31.00.png

References

Pytorch的BatchNorm层使用中容易出现的问题
跨卡同步 Batch Normalization

留下一条评论

暂无评论