是一种在训练深度神经网络时通过增加计算代价来换取显存优化的技术。它的核心思想是:在反向传播过程中动态重新计算中间激活值(activations),而不是保存所有中间结果。这对于显存受限的场景(如训练大型模型)非常有用。
直接上代码:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint# 1. 定义一个简单的 FFN 模型
class SimpleFFN(nn.Module):def __init__(self, input_dim=128, hidden_dim=256, output_dim=10):super().__init__()self.linear1 = nn.Linear(input_dim, hidden_dim)self.linear2 = nn.Linear(hidden_dim, hidden_dim)self.linear3 = nn.Linear(hidden_dim, output_dim)self.relu = nn.ReLU()def forward(self, x):# 2. 定义一个自定义的前向传播函数(用于 checkpoint)def custom_forward(x):x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.relu(x)x = self.linear3(x)return x# 3. 使用 checkpoint 包装前向传播return checkpoint(custom_forward, x)# 4. 初始化模型和数据
model = SimpleFFN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()# 模拟输入数据
input_data = torch.randn(64, 128) # batch_size=64, input_dim=128
target = torch.randn(64, 10) # 模拟目标输出# 5. 前向传播、损失计算和反向传播
output = model(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
- 在反向传播时,custom_forward 会被重新调用,从输入 x 重新计算中间激活值,从而节省显存。
- 显存占用:仅保存 linear3 的输出和 x,中间激活值在反向传播时动态计算。
- 需要多次前向计算激活值,训练速度可能变慢