Xavier

2025-10-19

数值稳定性和模型初始化

到目前为止,我们实现的每个模型都是根据某个预先指定的分布来初始化模型的参数。 有人会认为初始化方案是理所当然的,忽略了如何做出这些选择的细节。甚至有人可能会觉得,初始化方案的选择并不是特别重要。 相反,初始化方案的选择在神经网络学习中起着举足轻重的作用, 它对保持数值稳定性至关重要。 此外,这些初始化方案的选择可以与非线性激活函数的选择有趣的结合在一起。 我们选择哪个函数以及如何初始化参数可以决定优化算法收敛的速度有多快。 糟糕选择可能会导致我们在训练时遇到梯度爆炸或梯度消失。

梯度消失和梯度爆炸

考虑到一个具有L层,输入x输出o的深层网络。每一层的l 由变换$f_l$ 定义, 该变换的参数为权重$W^{(l)}$, 其隐藏变量是$h^{(l)}$(令$h^{(0)} = x$)。 我们的网络可以表示为:

\[h^{(l)} = f_l(h^{(l-1)}) \,\,, 因此o = f_l ....f_l(x).\]

如果所有隐藏变量和输入都是向量,我们可以将o关于任何一组参数$W^{(l)}$ 的梯度写为下式:

\[\partial_{w^{(l)}}o = \underbrace{\partial_{h^{(L-1)}}h^{(L)}}_{M^{(L)}\overset{\text{def}}{=}} .... \underbrace{\partial_{h^{(l)}}h^{(l+1)}}_{M^{(l+1)}\overset{\text{def}}{=}} \underbrace{\partial_{W^{(l)}}h^{(l)}}_{V^{(l)}\overset{\text{def}}{=}}\]

换言之,该梯度是$L-l$ 个矩阵$M^{(L)}…..M^{(l+1)}$ 与梯度向量$v^{(l)}$ 的乘积, 因此,我们容易受到数值下溢问题的影响. 当将太多的概率乘在一起时,这些问题经常会出现。 在处理概率时,一个常见的技巧是切换到对数空间, 即将数值表示的压力从尾数转移到指数。 不幸的是,上面的问题更为严重: 最初,$M^{(L)}$矩阵 可能具有各种各样的特征值。 他们可能很小,也可能很大; 他们的乘积可能非常大,也可能非常小。

不稳定梯度带来的风险不止在于数值表示; 不稳定梯度也威胁到我们优化算法的稳定性。 我们可能面临一些问题。 要么是梯度爆炸(gradient exploding)问题: 参数更新过大,破坏了模型的稳定收敛; 要么是梯度消失(gradient vanishing)问题: 参数更新过小,在每次更新时几乎不会移动,导致模型无法学习。

梯度消失

曾经的sigmoid 函数$1/(1+exp(-x))$ 很流行, 因为它类似于阈值函数,然而 它却是导致梯度消失问题的一个常见原因,让我们看看sigmoid函数为什么会导致梯度消失:

%matplotlib inline
import torch
from d2l import torch as d2l

x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
y = torch.sigmoid(x)
y.backward(torch.ones_like(x))

d2l.plot(x.detach().numpy(), [y.detach().numpy(), x.grad.numpy()],
         legend=['sigmoid', 'gradient'], figsize=(4.5, 2.5))

正如上图,如果sigmoid 函数的输入很大或者很小时,它的梯度都是消失。此外,当反向传播通过许多层时,除非我们在刚刚好的地方, 这些地方sigmoid函数的输入接近于零,否则整个乘积的梯度可能会消失。 当我们的网络有很多层时,除非我们很小心,否则在某一层可能会切断梯度。 事实上,这个问题曾经困扰着深度网络的训练。 因此,更稳定的ReLU系列函数已经成为从业者的默认选择。

梯度爆炸

我们生成100个高斯随机矩阵,并将他们与初始矩阵相乘。对于我们选择的尺度(方差 $\sigma^2 = 1 $), 矩阵乘积发生爆炸。当这种情况是由于深度网络的初始化所导致时,我们没有机会让梯度下降优化器收敛。

M = torch.normal(0, 1, size=(4,4))
print('一个矩阵 \n',M)
for i in range(100):
    M = torch.mm(M,torch.normal(0, 1, size=(4, 4)))

print('乘以100个矩阵后\n', M)

输出

一个矩阵 
 tensor([[ 0.3798,  2.0855, -0.6250,  0.2529],
        [-1.5336, -1.6261, -0.5801, -0.3945],
        [-0.1637,  0.8572,  0.7904, -1.5200],
        [-1.6464,  1.4212,  0.4079, -1.2139]])
乘以100个矩阵后
 tensor([[ 1.2040e+22, -7.8007e+21,  1.5275e+22,  1.6664e+22],
        [-9.6011e+21,  6.2204e+21, -1.2180e+22, -1.3288e+22],
        [ 1.2644e+22, -8.1919e+21,  1.6041e+22,  1.7499e+22],
        [ 1.4404e+22, -9.3321e+21,  1.8274e+22,  1.9935e+22]])

打破对称性

神经网络设计中的另一个问题是其参数化所固有的对称性。 假设我们有一个简单的多层感知机,它有一个隐藏层和两个隐藏单元。 在这种情况下,我们可以对第一层的权重W进行重排列, 并且同样对输出层的权重进行重排列,可以获得相同的函数。 第一个隐藏单元与第二个隐藏单元没有什么特别的区别。 换句话说,我们在每一层的隐藏单元之间具有排列对称性。

假设输出层将上述两个隐藏单元的多层感知机转换为仅一个输出单元。 想象一下,如果我们将隐藏层的所有参数初始化为W=c, c为常量,会发生什么? 在这种情况下,在前向传播期间,两个隐藏单元采用相同的输入和参数, 产生相同的激活,该激活被送到输出单元。 在反向传播期间,根据参数W对输出单元进行微分, 得到一个梯度,其元素都取相同的值。 因此,在基于梯度的迭代(例如,小批量随机梯度下降)之后, W的所有元素仍然采用相同的值。 这样的迭代永远不会打破对称性,我们可能永远也无法实现网络的表达能力。 隐藏层的行为就好像只有一个单元。 请注意,虽然小批量随机梯度下降不会打破这种对称性,但暂退法正则化可以。

解决(或至少减轻)上述问题的一种方法是进行参数初始化, 优化期间的注意和适当的正则化也可以进一步提高稳定性。

我们使用正态分布来初始化权重值。如果我们不指定初始化方法, 框架将使用默认的随机初始化方法,对于中等难度的问题,这种方法通常很有效。

Xavier 初始化

让我们看看某些没有非线性的全连接层的输出(例如,隐藏变量)$o_i$的尺度分布,对于该层$n_{in}$ 输入$x_j$ 及其相关的权重$w_{ij}$, 输出由下式给出

\[o_i = \sum_{j=1}^{n_{ij}}w_{ij}x_{j}\]

权重$w_{ij}$ 都是同一分布中独立抽取的。此外,假设该分布具有零均值和方差$\sigma ^2$, 这不并意味着分布必须是高斯的,只是均值和方差需要存在;同时我们假设层$x_j$的输入也具有零均值和方差$\gamma^2 $, 并且他们独立于$w_{ij}$ ;此情况下,我们按如下方式计算$o_i$的平均值和方差:

\[E[O_i] = \sum_{j=1}^{n_{in}}E[w_{ij}x_{j}] = \sum_{j=1}^{n_{in}}E[w_{ij}]E[x_{j}] = 0\] \[Var[o_i] = E[o_i^2] - (E[o_i])^2 =\sum_{j=1}^{n_{in}}E[w_{ij}^2x_{j}^2] - 0 = \sum_{j=1}^{n_{in}}E[w_{ij}^2]E[x_{j}^2] = n_{in}\sigma^2\gamma^2\]

保持方差不变的一种方法是设置$n_{in}\sigma^2 = 1$。 现在考虑反向传播过程,我们面临着类似的问题,尽管梯度是从更靠近输出的层传播的。 使用与前向传播相同的推断,我们可以看到,除非$n_{out}\sigma^2 = 1$, 否则梯度的方差可能会增大,其中$n_{out}$是该层的输出的数量。 这使得我们进退两难:我们不可能同时满足这两个条件。 相反,我们只需满足:

\[\frac{1}{2}(n_{in} + n_{out})\sigma^2 = 1 或等价于\sigma = \sqrt{\frac{2}{n_{in}+n_{out}}}\]

这就是现在标准且实用的Xavier初始化的基础。 通常,Xavier初始化从均值为零,方差 $\sigma^2 = \frac{2}{n_{in}+n_{out}}$的高斯分布中采样权重, 我们也可以将其改为选择从均匀分布中抽取权重时的方差。 注意均匀分布U(-a, a)的方差为$\frac{a^2}{3}$。 将$\frac{a^2}{3}$代入到$\sigma^2 $的条件中,将得到初始化值域:

\[U(-\sqrt{\frac{6}{n_{in}+n_{out}}}, \sqrt{\frac{6}{n_{in}+n_{out}}})\]