Whale's Blog

记录一下阶段所学(Machine Learning, Optimization, Mathematics, etc),吐槽一下生活

0%

2024阿里巴巴数学竞赛决赛(应用与计算赛道第2题)

今天写一下第二题我的个人解答,这道题源自NIPS2022的一篇论文《The alignment property of SGD noise and how it helps select flat minima: A stability analysis》,问题的背景是与梯度下降(GD)相比,随机梯度下降(SGD)在过参数化的神经网络训练过程当中可以收敛到一个更加平缓(flat)的极值点,而且平缓程度与神经网络的参数量无关。

T2. 假设 $ F(x; w) $ 是一个输出标量的深度神经网络,其中 $ x $ 是输入,$ w $ 表示权重。假设 $ F $ 关于 $ w $ 连续可微,并且对于训练数据 $\{ x_j, y_j \}_{j=1}^n $ 过参数化,即存在 $ w^\ast $ 使得对所有 $ j $ 满足 $ F(x_j; w^\ast) = y_j $。为了研究训练神经网络时 $ w^* $ 的局部优化行为,我们考虑线性化神经网络 $ \widetilde{F}(x; w) = F(x; w^\ast) + (w - w^\ast)^T \nabla F(x; w^\ast) $,其损失函数为

令 $ s $ 表示学习率,梯度下降法为

而随机梯度下降法为

其中噪声项满足 $ \mathbb{E} \epsilon_i = 0 $ 和 $ \mathbb{E} \epsilon_i \epsilon_i^\top = M(w_i)/b $,$ b $ 是 mini-batch 的大小。假设协方差矩阵 $ \Sigma $ 为

在以下意义上对齐:

对于 $ s > 0 $ 和所有 $ w $ 成立。这里 $| \cdot |_F$ 表示 Frobenius 范数。

(1) 对于梯度下降,证明如果 $ \Sigma $ 的谱范数满足

则梯度下降是局部稳定的(即对于所有 $ i $,$\text{Loss}(w_i)$ 是有界的)。(注意,这蕴含了一个依赖维度的 $\Vert \Sigma \Vert_F \leq \frac{2\sqrt{d}}{s} $,其中 $ d $ 是 $w$ 的维度)。

(2) 对于随机梯度下降,如果 $\mathbb{E}\text{Loss}(w_i)$ 对于所有 $ i $ 都有界,则以下独立于维度的不等式必须成立

(1) $Proof.$ 由过参数化的条件有, $F(x_j; w^\ast) = y_j$,那么

对 $Loss(w_i)$ 求梯度可得

由梯度下降法的更新格式,有

对上式两边同时取 $2$ 范数,因为 $I-s\Sigma$ 为对称矩阵,因此一定存在矩阵 $P$ 使得 $I-s\Sigma$ 可以正交对角化,故

由 $\Sigma$ 为半正定矩阵,以及 $\rho (\Sigma) \leq \frac{2}{s}$,可得 $I-s\Lambda$ 对角线上元素范围为 $[-1,1]$,即 $\rho (I-s\Lambda) \leq 1$

所以

可得

故梯度下降法是局部稳定的.

第二小问的证明过程参考论文 《The alignment property of SGD noise and how it helps select flat minima: A stability analysis》中的 Theorem 3.3 的证明
(2) $Proof.$ 记 $\theta_i = w_i - w^\ast$,那么有

计算$\mathbb{E}Loss(w_{i+1})$,并由 $\mathbb{E}\epsilon_i = 0$ 和 $\mathbb{E} \epsilon_i \epsilon_i^\top = M(w_i)/b$,有

考虑期望 $ \mathbb{E}[x^T A x] $,其中$\mathbb{E}x = 0$ 和 $\mathbb{E} x x^\top = M(w_i)/b$

对于 $ i \neq j $,因为 $ x $ 的均值为 0, $ E[x_i x_j] = 0 $(这里假设各个分量是相互独立的),故

记 $r(\theta) = 1 - 2s \frac{\theta^T \Sigma^2 \theta}{\theta^T \Sigma \theta} + s^2 \frac{\theta^T \Sigma^3 \theta}{\theta^T \Sigma \theta}$,那么

下证 $r(\theta) \geq 0$,令

并对 $\Sigma$ 进行谱分解

假设 $u$ 在 ${e_i}$ 这组标准正交基下的表示如下

且由 $||u||_2 = 1$ 得 $\sum_j a_j^2 = 1$,那么

由 $r(\theta_i)Loss(w_i) \geq 0$ 和 $\frac{\text{Tr}(M(w) \Sigma)}{2 \text{Loss}(w) | \Sigma |_F^2} \geq \delta$,有

因为 $\mathbb{E}Loss(w_{i+1})$ 是有界的,所以

否则,令 $q = \frac{s^2}{b} \delta ||\Sigma||_F^2$

$\mathbb{E}Loss(w_{i+1})$ 呈指数增长,故 $| \Sigma |_F \leq \frac{\sqrt{b \delta}}{s}$,证毕.