今天写一下第二题我的个人解答,这道题源自 NIPS2022 的一篇论文《The alignment property of SGD noise and how it helps select flat minima: A stability analysis》,问题的背景是与梯度下降(GD)相比,随机梯度下降(SGD)在过参数化的神经网络训练过程中可以收敛到一个更加平缓(flat)的极值点,而且平缓程度与神经网络的参数量无关.
T2. 假设 $F(x;w)$ 是一个输出标量的深度神经网络,其中 $x$ 是输入, $w$ 表示权重.假设 $F$ 关于 $w$ 连续可微, 并且对于训练数据 $\{x_j, y_j\}_{j=1}^m$ 过参数化, 即存在 $w^\ast$ 使得对于所有 $j$ 满足 $F(x_j;w^\ast) = y_j$.为了研究训练神经网络时 $w^\ast$ 的局部优化行为,我们考虑线性化神经网络 $\tilde{F}(x;w) = F(x;w^\ast) + (w - w^\ast)^\text{T} \nabla F(x;w^\ast)$,其损失函数为
令 $s$ 表示学习率,梯度下降法为
而随机梯度下降法为
其中噪声项满足 $\mathbb{E} \epsilon_i = 0$ 和 $\mathbb{E} \epsilon_i \epsilon_i^\top = M(w_i)/b$,$b$ 是 mini-batch 的大小.假设协方差矩阵 $\Sigma$ 为
在以下意义上对齐:
对于 $s > 0$ 和所有 $w$ 成立。这里 $|\cdot|_F$ 表示 Frobenius 范数。
(1) 对于梯度下降,证明如果 $\Sigma$ 的谱范数满足
则梯度下降是局部稳定的(即对于所有 $i$, $\text{Loss}(w_i)$ 是有界的).(注意,这蕴含了一个依赖维度的 $|\Sigma|_F \leq \dfrac{2\sqrt{d}}{s}$,其中 $d$ 是 $w$ 的维度).
(2) 对于随机梯度下降,如果 $\mathbb{E}\text{Loss}(w_i)$ 对于所有 $i$ 都有界,则以下独立于维度的不等式必须成立
(1) $Proof$. 由过参数化的条件有,$F(x_j; w^*) = y_j$,那么
对 $\text{Loss}(w_i)$ 求梯度可得
由梯度下降的更新格式,有
对上述两边同时取 $2$ 范数,因为 $I - s\Sigma$ 为对称矩阵,因此一定存在矩阵 $P$ 使得 $I - s\Sigma$ 可以正交对角化,故
由 $\Sigma$ 为半正定矩阵,以及 $\rho(\Sigma) \leq \dfrac{2}{s}$,可得 $I - s\Lambda$ 对角线上元素范围为 $[-1, 1]$,即 $\rho(I - s\Lambda) \leq 1$
所以
由
可得
故梯度下降法是局部稳定的.
第二小问的证明过程参考论文 《The alignment property of SGD noise and how it helps select flat minima: A stability analysis》 中的 Theorem 3.3 的证明
(2) $Proof$. 记 $\theta_i = w_i - w^\ast$,那么有
计算 $\mathbb{E}\text{Loss}(w_{i+1})$,并由 $\mathbb{E}\epsilon_i = 0$ 和 $\mathbb{E}\epsilon_i \epsilon_i^\text{T} = M(w_i)/b$,有
考虑期望 $\mathbb{E}[x^\text{T} A x]$,其中 $\mathbb{E}x = 0$ 和 $\mathbb{E}xx^\text{T} = M(w_i)/b$
对于 $i \neq j$,因为 $x$ 的均值为 $0$,$\mathbb{E}[x_i x_j] = 0$(这里假设各个分量是相互独立的),故
即
记 $r(\theta) = 1 - 2s \dfrac{\theta^\text{T} \Sigma^2 \theta}{\theta^\text{T} \Sigma \theta} + s^2 \dfrac{\theta^\text{T} \Sigma^3 \theta}{\theta^\text{T} \Sigma \theta}$,那么
下证 $r(\theta) \geq 0$,令
并对 $\Sigma$ 进行谱分解
假设 $u$ 在 $e_i$ 这组标准正交基下的表示如下
且由 $|u|_2 = 1$ 得 $\sum_j a_j^2 = 1$,那么
由 $r(\theta_i) \text{Loss}(w_i) \geq 0$ 和 $\dfrac{\text{Tr}(M(w)\Sigma)}{2 \text{Loss}(w)|\Sigma|_F^2} \geq \delta$,有
因为 $\mathbb{E} \text{Loss}(w_{i+1})$ 是有界的,所以
否则,令 $q=\dfrac{s^2}{b} \delta |\Sigma|_F^2$
$\mathbb{E} \text{Loss}(w_{i+1})$ 呈指数增长,矛盾.故 $|\Sigma|_F \leq \dfrac{\sqrt{b\delta}}{s}$,证毕.