Memory Consumption During LLM Training


Introduction

When we train a Large Language Model (LLM), GPU memory plays a very important role in the training process. Many people think that GPU memory is only used to store the model weights, but actually memory is consumed by several components during training.

Modern LLMs such as ChatGPT, LLaMA, Gemini and Claude contains billions of parameters. During training, the model not only stores the weights but also stores intermediate calculations, gradients and optimizer states. Due to this reason training a Large Language Model requires a very large amount of GPU memory.

The memory usage of an LLM can be divided into four major components:

GPU Memory
│
├── Model Weights
├── Activation Memory
├── Gradient Memory
└── Optimizer Memory

Each component have a different purpose during training and together they consume most of the available GPU memory.

Now lets understand each component in detail.

1. Model Weights Memory

Model Weights are the trainable parameters of a neural network. The actual knowledge learned by the model is stored inside these weights.  Initially all weights are randomly initialized.

Example:

Weight1 = 0.34
Weight2 = -0.12
Weight3 = 0.98

During training these weights are continuously updated through backpropagation and optimization. When training starts, all model weights are loaded into GPU memory because every layer requires them during the forward pass.

Example:

Input
↓
Weight Matrix
↓
Output

Every calculation inside Embedding Layers, Attention Layers and Feed Forward Networks uses model weights.  Larger models contain more weights.

Example:

7 Billion Parameters
↓
7 Billion Weights

As the number of parameters increases, weight memory also increases.

2. Activation Memory

Activation Memory is one of the largest consumers of GPU memory during training.  During the forward pass every layer generates an output.

Example:

Input
↓
Layer 1 Output
↓
Layer 2 Output
↓
Layer 3 Output
↓
Prediction

These outputs are called activations.

A common misunderstanding is that after a layer produces an output, it can be removed from memory. However this is not possible because backpropagation requires these activations later for gradient calculation.  For this reason activations are stored in GPU memory until the backward pass is completed.

Example:

Layer 1 → Activation 1
Layer 2 → Activation 2
Layer 3 → Activation 3

The larger the batch size and context length, the larger the activation memory becomes.  In many cases activation memory consume more memory than model weights.

3. Gradient Memory

After the model generates predictions, the prediction is compared with the actual answer using a loss function.

Example:

Prediction:
Paris = 20%

Actual:
Paris

After calculating the loss, backpropagation computes gradients for every trainable weight.

Example:

Weight A Gradient = +0.03
Weight B Gradient = -0.07
Weight C Gradient = +0.01

Gradients tells the optimizer:

  • Which direction the weight should move
  • How much the weight should change

These gradients must be stored in memory until the optimizer updates the weights.

Since every trainable parameter requires a gradient, gradient memory can become very large for billion-parameter models.

4. Optimizer Memory

Modern LLMs commonly use AdamW optimizer.  The optimizer does not only store the current gradients. It also stores additional information about previous updates.  For every parameter AdamW typically stores:

Weight
Gradient
Momentum
Variance

Momentum helps the optimizer remember previous learning direction. Variance helps the optimizer adjust the update size.  Because of these additional values, optimizer memory can consume even more memory than the model weights themselves.  This is one of the major reasons why training requires much more memory than inference.

Why Long Context Length Increases Memory?

Modern LLMs support long context windows such as:

4K Tokens
8K Tokens
32K Tokens
128K Tokens

As the number of tokens increases, attention calculations increase significantly.

Attention Complexity:

This means that increasing context length increases activation memory and attention memory dramatically.  Because of this reason long-context training is very expensive.

Why Batch Size Increases Memory?

Batch Size represents how many training examples are processed at the same time.

Example:

Batch Size = 1

Only one example is stored.

Batch Size = 64

Sixty-four examples are stored simultaneously.  Larger batch sizes create more activations and therefore consume more GPU memory.

Memory Breakdown During Training

A simplified memory distribution during LLM training can be represented as:

GPU Memory
│
├── Model Weights
├── Activations
├── Gradients
└── Optimizer States

All four components are required for successful training.  If GPU memory becomes insufficient, training may fail due to Out Of Memory (OOM) errors.

0 Comments Report