# Activation Checkpointing¶

## Introduction to Activation Checkpointing¶

Activation Checkpointing is a sub-linear memory optimization technique proposed in 2016, by Chen Tianqi's team in their paper Training Deep Nets with Sublinear Memory Cost, aiming to reduce the memory usage during training. The basic principle of Activation Checkpointing is exchange time for space: After the analysis of the computational graph, some intermediate activation features that are not used temporarily in the forward process will be deleted to reduce the memory usage, and they will be restored with additional forward computation when needed in the backward process.

OneFlow's static graph module nn.Graph already supports Activation Checkpointing. This article will introduce how to turn on it during training.

## Example of using Activation Checkpointing¶

First, we define a simple model consist of loss function and optimizer in exactly the same way as before.

import oneflow as flow
import oneflow.nn as nn

DEVICE = "cuda" if flow.cuda.is_available() else "cpu"
print("Using {} device".format(DEVICE))

model_part1 = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU()
)
model_part1 = model_part1.to(DEVICE)
model_part1.train()

model_part2 = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10)
)
model_part2 = model_part2.to(DEVICE)
model_part2.train()

loss_fn = nn.CrossEntropyLoss().to(DEVICE)
optimizer = flow.optim.SGD([{'params': model_part1.parameters()},
{'params': model_part2.parameters()}],
lr=1e-3)


To turn on activation checkpointing, you only need to use method .to(nn.graph.GraphModule) on the Eager model member (i.e. the nn.Module object) to get nn.graph.GraphModule object. And then modify corresponding attribute as .activation_checkpointing = True on the nn.graph.GraphModule. For more details of this API, please refer to: activation_checkpointing. For each nn.Module with "activation checkpointing" turned on, its input activations will be preserved, while other intermediate activations will be recomputed when used during backpropagation.

class CustomGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model_part1 = model_part1
self.model_part2 = model_part2
# Turn on activation checkpointing on two consecutive nn.Module
self.model_part1.to(nn.graph.GraphModule).activation_checkpointing = True
self.model_part2.to(nn.graph.GraphModule).activation_checkpointing = True
self.loss_fn = loss_fn

def build(self, x, y):
y_pred = self.model_part2(self.model_part1(x))
loss = self.loss_fn(y_pred, y)
loss.backward()
return y_pred, loss


Then, you can start training and other operations as usual.

graph_model = CustomGraph()

for _ in range(100):
x = flow.randn(128, 256).to(DEVICE)
y = flow.ones(128, 1, dtype=flow.int64).to(DEVICE)
graph_model(x, y)
# Other codes...


## Comparative Experiment on BERT Model¶

In order to verify the actual effect of Activation Checkpointing, we can conduct comparative experiments on the model BERT. We can directly use the BERT model provided by libai. To turn on Activation Checkpointing, we just need to set train.activation_checkpoint.enabled to True in the configuration file.

First, get data ready according to Prepare the Data and the Vocab. For simplicity, we use a single device for training (the GPU used in the experimental environment is NVIDIA GeForce RTX 3090, and the memory size is 24268 MB):

time python tools/train_net.py --config-file configs/bert_large_pretrain.py


Add the time command at the beginning of the whole command to measure the time spent in the training process.

The experimental results are as follows:

Whether to Turn on Activation Checkpointing Average Memory Usage Time Spent
No 9141 MB 25 minutes 16 seconds
Yes 5978 MB 33 minutes 36 seconds

We can see from the above table that Activation Checkpointin significantly reduces the memory usage during training. At the same time, the time spent increases due to the additional forward computation required. Overall, Activation Checkpointing is a very effective solution when there is a lack of video memory.