一、论文概述
2015
年Google
提出的Batch Normalization
- 训练深层的神经网络很复杂,因为训练时每一层输入的分布在变化,导致训练过程中的饱和,称这种现象为:
internal covariate shift
- 需要降低学习率Learning Rate和注意参数的初始化
- 论文中提出的方法是对于每一个小的训练
batch
都进行标准化(正态化)- 允许使用较大的学习率
- 不必太关心初始化的问题
- 同时一些例子中不需要使用
Dropout
方法避免过拟合 - 此方法在
ImageNet classification
比赛中获得4.82% top-5
的测试错误率
二、BN
思路
1、问题
如果输入数据是白化的(whitened),网络会更快的收敛
- 白化目的是降低数据的冗余性和特征的相关性,例如通过线性变换使数据为0均值和单位方差
并非直接标准化每一层那么简单,如果不考虑归一化的影响,可能会降低梯度下降的影响
- 标准化与某个样本和所有样本都有关系
- 解决上面的问题,我们希望对于任何参数值,都要满足想要的分布;$$\widehat x Norm(x,\chi )$$
- 对于反向传播,需要计算:${\partial Norm(x,\chi )} \over {\partial x}$和${\partial Norm({x},\chi )} \over {\partial \chi }$
- 这样做的计算代价是非常大的,因为需要计算x的协方差矩阵
- 然后白化操作:$${x - E[x]} \over {\sqrt {Cov[x]} }$$
- 上面两种都不行或是不好,进而得到了BN的方法
- 既然白化每一层的输入代价非常大,我们可以进行简化
2、简化1
- 标准化特征的每一个维度而不是去标准化所有的特征,这样就不用求协方差矩阵了
- 例如
d
维的输入:$$x = ({x^{(1)}},{x^{(2)}}, \cdots ,{x^{(d)}})$$ - 标准化操作:
$${\widehat x^k} = {x^{(k) - E[x^{(k)}]} \over {\sqrt {Var[x^{(k)}]} }}$$
- 例如
需要注意的是标准化操作可能会降低数据的表达能力,例如我们之前提到的Sigmoid函数,标准化之后均值为0,方差为1,数据就会落在近似线性的函数区域内,这样激活函数的意义就不明显
所以对于每个标准化之后的$\widehat x^{(k)}$,对应一对参数:${\gamma ^{(k)}},{\beta ^{(k)}}$ ,然后令:${y^{(k)}} = {\gamma ^{(k)}}{\widehat x^{(k)}} + {\beta ^{(k)}}$
- 从式子来看就是对标准化的数据进行缩放和平移,不至于使数据落在线性区域内,增加数据的表达能力(式子中如果:${\gamma ^{(k)}} = \sqrt {Var[x^{(k)}]}, {\beta ^{(k)}} = E[x^{(k)}]$ ,就会使恢复到原来的值了)
- 但是这里还是使用的全部的数据集,但是如果使用随机梯度下降,可以选取一个batch进行训练
3、简化2
- 第二种简化就是使用
mini-batch
进行随机梯度下降
- 注意这里使用
mini-batch
也是标准化每一个维度上的特征,而不是所有的特征一起,因为若果mini-batch
中的数据量小于特征的维度时,会产生奇异协方差矩阵, 对应的行列式的值为0,非满秩 - 假设
mini-batch
大小为m
的B
- $B = \{ {x_{1 \ldots m}}\}$对应的变换操作为:$$B{N_{\gamma ,\beta }}:{x_{1 \ldots m}} \to {y_{1 \ldots m}}$$
- 作者给出的批标准化的算法如下:
- 算法中的
ε
是一个很小的常量,为了保证数值的稳定性(就是防止除数为0)
4、反向传播求梯度:
- 因为:$$y^{(k)} = \gamma ^{(k)}\widehat x^{(k)} + \beta ^{(k)}$$
- 所以:$${\partial l \over \partial \widehat x_i} = {\partial l \over \partial y_i}\gamma $$
因为:$$\widehat x_i = {x_i - \mu _B \over {\sqrt {\sigma _B^2 + \varepsilon } }}$$
- 所以:$${\partial l \over \partial \sigma _B^2} = \sum\limits_{i=1}^m {\partial l \over \partial \widehat x_i} (x_i- \mu_B) {-1 \over 2}(\sigma_B^2 + \varepsilon)^{-{3\over2}}$$
$${\partial l \over \partial u_B} = \sum\limits_{i = 1}^m {\partial l \over \partial \widehat x_i} { - 1 \over \sqrt {\sigma _B^2 + \varepsilon }}$$
- 所以:$${\partial l \over \partial \sigma _B^2} = \sum\limits_{i=1}^m {\partial l \over \partial \widehat x_i} (x_i- \mu_B) {-1 \over 2}(\sigma_B^2 + \varepsilon)^{-{3\over2}}$$
因为:${\mu _B} = {1 \over m}\sum\limits_{i = 1}^m $和$\sigma _B^2 = {1 \over m}\sum\limits_{i = 1}^m {({x_i}} - {\mu _B}{)^2}$
- 所以:$${\partial l \over \partial x_i} = {\partial l \over \partial \widehat x_i}{1 \over \sqrt {\sigma _B^2 + \varepsilon } } + {\partial l \over \partial \sigma _B^2}{2(x_i - \mu _B) \over m} + {\partial l \over \partial u_B}{1 \over m}$$
- 所以:$${\partial l \over \partial \gamma } = \sum\limits_{i = 1}^m {\partial l \over \partial y_i} {\widehat x_i}$$
$${\partial l \over \partial \beta } = \sum\limits_{i = 1}^m {\partial l \over \partial y_i} $$
- 所以:$${\partial l \over \partial \gamma } = \sum\limits_{i = 1}^m {\partial l \over \partial y_i} {\widehat x_i}$$
- 所以:$${\partial l \over \partial x_i} = {\partial l \over \partial \widehat x_i}{1 \over \sqrt {\sigma _B^2 + \varepsilon } } + {\partial l \over \partial \sigma _B^2}{2(x_i - \mu _B) \over m} + {\partial l \over \partial u_B}{1 \over m}$$
- 对于BN变换是可微分的,随着网络的训练,网络层可以持续学到输入的分布。
三、BN
网络的训练和推断(预测)
1、预测的问题
- 按照
BN
方法,输入数据x
会经过变化得到BN(x)
,然后可以通过随机梯度下降进行训练,标准化是在mini-batch
上所以是非常高效的。 - 但是对于推断我们希望输出只取决于输入,而对于输入只有一个实例数据,无法得到
mini-batch
的其他实例,就无法求对应的均值和方差了。
2、解决方法
- 可以通过从所有训练实例中获得的统计量来代替
mini-batch
中m
个训练实例获得统计量均值和方差- 比如我们机器学习算法,在训练集上进行了标准化,在测试集上的标准化操作时利用的训练集上的数据(
Standarscaler
中的mean
和variance
)
- 比如我们机器学习算法,在训练集上进行了标准化,在测试集上的标准化操作时利用的训练集上的数据(
- 我们对每个
mini-batch
做标准化,可以对记住每个mini-batch
的B,然后得到全局统计量 - $$E[x] \leftarrow E_B[{\mu _B}]$$
- $$Var[x] \leftarrow {m \over {m - 1}}E_B[\sigma _B^2]$$(这里方差采用的是无偏方差估计,所以是
m-1
) - 所以推断采用
BN
的方式为:
$$\eqalign{
& y = \gamma {x - E(x) \over \sqrt {Var[x] + \varepsilon }} + \beta \cr
& \quad= {\gamma \over \sqrt {Var[x] + \varepsilon }}x + (\beta - {\gamma E[x] \over \sqrt {Var[x] + \varepsilon }})} $$3、完整算法
- 作者给出的完整算法:
四、实验
- 最后给出的实验可以看出使用BN的方式训练精准度很高而且很稳定。
- 本文链接: http://lawlite.me/2017/01/09/论文记录-Batch-Normalization/
- 版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 3.0 许可协议 。转载请注明出处!