【深度学习模型训练】从链式法则到显存开销分析

3 小时前
/

【深度学习模型训练】从链式法则到显存开销分析

本文以一个典型的三层多层感知机(MLP)为例,梳理神经网络前向传播与反向传播的数学基础,并在此基础上,从系统视角剖析单次训练迭代中的 GPU 显存占用情况及相应的显存优化策略。

一、 数学基础:基于链式法则的反向传播

一个三层 MLP 模型,其参数矩阵分布为 W1,W2,W3W_1, W_2, W_3,输入为 xx,真实标签为 yy

1. 前向传播 (Forward Propagation)

前向传播的核心是计算每一层的中间激活值(Activations)以及最终的预测值 y^\hat{y}。为保持表达严谨且简洁,省略偏置项并合并激活函数,各层计算过程可表示为:

  • 第一层: h1=f1(W1,x)h_1 = f_1(W_1, x)
  • 第二层: h2=f2(W2,h1)h_2 = f_2(W_2, h_1)
  • 第三层: y^=f3(W3,h2)\hat{y} = f_3(W_3, h_2)

从全局视角来看,整个前向传播是一个深度的嵌套复合函数,其最终的损失函数(Loss)计算如下:

L=Loss(f3(W3,f2(W2,f1(W1,x))))L = \text{Loss}(f_3(W_3, f_2(W_2, f_1(W_1, x))))

2. 反向传播 (Backpropagation)

训练的核心目标是最小化误差 LL。我们需要求解损失 LL 对各层参数 WW 的偏导数(梯度),即量化参数微小变化对最终误差的边际影响。该过程严格依赖微积分中的链式法则(Chain Rule)

  • 第三层梯度: 误差直接对 W3W_3 求导,等于总误差对预测值的偏导乘以预测值对 W3W_3 的偏导。

    LW3=Ly^y^W3\frac{\partial L}{\partial W_3} = \frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial W_3}
  • 第二层梯度: 误差需先传导至第二层的输出 h2h_2,再对 W2W_2 求导。

    LW2=Ly^y^h2传导至 h2 的误差h2W2\frac{\partial L}{\partial W_2} = \underbrace{\frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial h_2}}_{\text{传导至 } h_2 \text{ 的误差}} \cdot \frac{\partial h_2}{\partial W_2}
  • 第一层梯度: 同理,误差依次反向传播至第一层。

    LW1=Ly^y^h2h2h1h1W1\frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial h_2} \cdot \frac{\partial h_2}{\partial h_1} \cdot \frac{\partial h_1}{\partial W_1}

3. 参数更新 (Weight Update)

获取各层梯度后,通过梯度下降(Gradient Descent)算法,结合学习率 η\eta 对模型参数进行迭代更新:

  • W3W3ηLW3W_3 \leftarrow W_3 - \eta \cdot \frac{\partial L}{\partial W_3}
  • W2W2ηLW2W_2 \leftarrow W_2 - \eta \cdot \frac{\partial L}{\partial W_2}
  • W1W1ηLW1W_1 \leftarrow W_1 - \eta \cdot \frac{\partial L}{\partial W_1}

二、 系统视角:单次迭代的 GPU 显存剖析

在上述数学过程转化为工程代码并在 GPU 上执行时,显存(VRAM)的占用是制约模型规模的核心瓶颈。在一次 Forward + Backward 循环中,GPU 显存主要被以下两类数据占据:

1. 静态与输入数据

  • 模型参数 (Model Parameters): 即上述的 W1,W2,W3W_1, W_2, W_3
  • 输入与标签 (Inputs & Labels): 维度通常为 Batch_Size ×\times 数据维度的大小。

2. 动态生成数据(训练时激增的开销)

  • 中间激活值 (Activations, HH): 在计算链式法则时,反向传播需要用到前向传播的中间结果(如 h1,h2h_1, h_2)。因此这些状态必须驻留在显存中,其占用大小与 Batch_Size 呈严格的线性正相关关系。
  • 参数梯度 (Gradients, W\nabla W): 大小与模型参数 WW 完全一致。
  • 优化器状态 (Optimizer States): 如果使用 Adam 等具有动量机制的优化器,需要额外记录每个参数的过去一阶动量(Momentum)和二阶方差(Variance)。并为了保证精度使用 32位,这将消耗 4×4 \times 模型参数大小的额外显存。

以目前最热门的模型llama、qwen等,最小规模 7B 至少 14GB,一次训练单次batch size ×6=84GB\times 6=84GB,已经 超出一张A100的大小,所以如今大模型训练设计出了各种显存优化策略。

三、 显存优化:存算置换与高效训练策略

针对上述显存瓶颈,工程上常采用“以计算时间换取显存空间”或“降低数值精度”的策略。

1. 核心“存算置换”策略概览

(以下提及的存算置换策略,其底层原理与工程实现机制将在后续的专栏文章中进行详细深度解析。)

  • 小批次 + 梯度累加 (Small Batch + Gradient Accumulation): 在时间维度上拆分大 Batch,通过多次小 Batch 的前向/反向传播累加梯度,绕过单次激活值过大的显存限制。(但依旧会 OOM)
  • ZeRO-Offload (算时加载): 将优化器状态或梯度等暂时卸载至 CPU 内存(RAM),在需要计算时再通过 PCIe 调度至 GPU显存。
  • 梯度检查点 (Gradient Checkpointing): 前向传播时主动丢弃部分中间激活值 HH,在反向传播经过该层时重新计算。这是一种典型的以增加计算量(约30%)换取显著显存节省的策略。

2. 其他正交优化技术

工业实践上更为常用

  • 混合精度与量化 (Mixed Precision / Quantization): 将传统的 FP32 运算降阶为 FP16/BF16,甚至 INT8/INT4,成倍削减显存占用并提升计算吞吐。
  • 参数高效微调 (PEFT, 如 LoRA): 在微调阶段冻结主干网络(不保存其激活值与优化器状态),仅注入极少量的可训练参数矩阵,将训练显存开销降低至全量微调的零头。

使用社交账号登录

  • Loading...
  • Loading...
  • Loading...
  • Loading...
  • Loading...