Activation Checkpointing¶
Introduction to Activation Checkpointing¶
Activation Checkpointing is a sub-linear memory optimization technique proposed in 2016, by Chen Tianqi's team in their paper Training Deep Nets with Sublinear Memory Cost, aiming to reduce the memory usage during training. The basic principle of Activation Checkpointing is exchange time for space: After the analysis of the computational graph, some intermediate activation features that are not used temporarily in the forward process will be deleted to reduce the memory usage, and they will be restored with additional forward computation when needed in the backward process.
OneFlow's static graph module nn.Graph
already supports Activation Checkpointing. This article will introduce how to turn on it during training.
Example of using Activation Checkpointing¶
First, we define a simple model consist of loss function and optimizer in exactly the same way as before.
import oneflow as flow
import oneflow.nn as nn
DEVICE = "cuda" if flow.cuda.is_available() else "cpu"
print("Using {} device".format(DEVICE))
model_part1 = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU()
)
model_part1 = model_part1.to(DEVICE)
model_part1.train()
model_part2 = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10)
)
model_part2 = model_part2.to(DEVICE)
model_part2.train()
loss_fn = nn.CrossEntropyLoss().to(DEVICE)
optimizer = flow.optim.SGD([{'params': model_part1.parameters()},
{'params': model_part2.parameters()}],
lr=1e-3)
To turn on activation checkpointing, you only need to use method .to(nn.graph.GraphModule)
on the Eager model member (i.e. the nn.Module object) to get nn.graph.GraphModule
object. And then modify corresponding attribute as .activation_checkpointing = True
on the nn.graph.GraphModule
. For more details of this API, please refer to: activation_checkpointing. For each nn.Module with "activation checkpointing" turned on, its input activations will be preserved, while other intermediate activations will be recomputed when used during backpropagation.
class CustomGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model_part1 = model_part1
self.model_part2 = model_part2
# Turn on activation checkpointing on two consecutive nn.Module
self.model_part1.to(nn.graph.GraphModule).activation_checkpointing = True
self.model_part2.to(nn.graph.GraphModule).activation_checkpointing = True
self.loss_fn = loss_fn
self.add_optimizer(optimizer)
def build(self, x, y):
y_pred = self.model_part2(self.model_part1(x))
loss = self.loss_fn(y_pred, y)
loss.backward()
return y_pred, loss
Then, you can start training and other operations as usual.
graph_model = CustomGraph()
for _ in range(100):
x = flow.randn(128, 256).to(DEVICE)
y = flow.ones(128, 1, dtype=flow.int64).to(DEVICE)
graph_model(x, y)
# Other codes...
Comparative Experiment on BERT Model¶
In order to verify the actual effect of Activation Checkpointing, we can conduct comparative experiments on the model BERT. We can directly use the BERT model provided by libai. To turn on Activation Checkpointing, we just need to set train.activation_checkpoint.enabled
to True
in the configuration file.
First, get data ready according to Prepare the Data and the Vocab. For simplicity, we use a single device for training (the GPU used in the experimental environment is NVIDIA GeForce RTX 3090, and the memory size is 24268 MB):
time python tools/train_net.py --config-file configs/bert_large_pretrain.py
Add the time
command at the beginning of the whole command to measure the time spent in the training process.
The experimental results are as follows:
Whether to Turn on Activation Checkpointing | Average Memory Usage | Time Spent |
---|---|---|
No | 9141 MB | 25 minutes 16 seconds |
Yes | 5978 MB | 33 minutes 36 seconds |
We can see from the above table that Activation Checkpointin significantly reduces the memory usage during training. At the same time, the time spent increases due to the additional forward computation required. Overall, Activation Checkpointing is a very effective solution when there is a lack of video memory.