欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 会展 > pytorch checkpointing

pytorch checkpointing

2025/5/6 20:54:50 来源:https://blog.csdn.net/qq_45812220/article/details/147714208  浏览:    关键词:pytorch checkpointing

是一种在训练深度神经网络时通过增加计算代价来换取显存优化的技术。它的核心思想是:在反向传播过程中动态重新计算中间激活值(activations),而不是保存所有中间结果。这对于显存受限的场景(如训练大型模型)非常有用。

直接上代码:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint# 1. 定义一个简单的 FFN 模型
class SimpleFFN(nn.Module):def __init__(self, input_dim=128, hidden_dim=256, output_dim=10):super().__init__()self.linear1 = nn.Linear(input_dim, hidden_dim)self.linear2 = nn.Linear(hidden_dim, hidden_dim)self.linear3 = nn.Linear(hidden_dim, output_dim)self.relu = nn.ReLU()def forward(self, x):# 2. 定义一个自定义的前向传播函数(用于 checkpoint)def custom_forward(x):x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.relu(x)x = self.linear3(x)return x# 3. 使用 checkpoint 包装前向传播return checkpoint(custom_forward, x)# 4. 初始化模型和数据
model = SimpleFFN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()# 模拟输入数据
input_data = torch.randn(64, 128)  # batch_size=64, input_dim=128
target = torch.randn(64, 10)       # 模拟目标输出# 5. 前向传播、损失计算和反向传播
output = model(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
  • 在反向传播时,custom_forward 会被重新调用,从输入 x 重新计算中间激活值,从而节省显存。
  • 显存占用:仅保存 linear3 的输出和 x,中间激活值在反向传播时动态计算。
  • 需要多次前向计算激活值,训练速度可能变慢

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词