Understanding the LLM Training Flow: From Batch Loading to Weight Updates


Introduction

when we talk about Large Language Models (LLMs) like ChatGPT, Llama, Gemini, we suppose the model already have all the information but Initially model don't have any knowledge or information.  At the beginning of LLM training process all the weights are random and through millions or billions iterations LLM are trained through language, facts, reasoning and patterns.

The actual Learning process of Large Language Model (LLM) is based on training loop which repeat the same process multiple times.

Batch Load
↓
Embedding
↓
Attention
↓
FFN
↓
Prediction
↓
Cross Entropy Loss
↓
Backpropagation
↓
Gradients
↓
Gradient Clipping
↓
AdamW
↓
Weight Update
↓
Next Batch

Now lets understand about the processes

1. Batch Loading:

During the training process of Large Language Model (LLMs) we do not provide the whole dataset at a time but instead of this we divide the dataset into small groups which is called batches.  The main purpose of batch loading is Using memory efficiently, continuous use of GPU and Fast and stable training.

Example:

The capital of France is Paris.
AI is transforming the world.
Machine Learning is a subset of AI.

During training DataLoader will load the examples in GPU.

2. Embedding:

Model is not able to understand the token IDs directly.

Example:

[464, 3139, 286, 4881]

These are just numbers.  Embedding layer converts these tokens into vectors. In simple words Embedding layer convert the text into machine language.

464
↓
[0.12, -0.54, 1.28, ...]

Now the model have meaningful numerical representation that the model can process.

3. Self Attention: 

Self Attention is the one of the most important component of Transformers.

example:

The capital of France is Paris.

Now the attention mechanism will identify:

France ↔ Paris
Capital ↔ Paris

Means what are the relation between the different word in any sentence.  This helps model to understand the context and capture the long-range relationships.

4. Feed Forward Network (FFN):

Self attention mechanism finds the relationship between the tokens.  But only understanding the relationship is not sufficient, Now we also have to process these information, this work is done by FFN.  In simple words Attention mechanism collects the information but FFN process the information and transform.

5. Prediction:

Now the model predicts the next token. Lets understand it by an example

Input:

The capital of France is

Prediction:

Paris = 85%
London = 7%
Rome = 4%
Delhi = 2%

Model assigns the probability to each vocabulary and the most likely token is choosed.

6. Cross Entropy Loss:

now in cross entropy loss the prediction of model is compared with the actual answer.

Actual:

Paris

Prediction:

Paris = 85%

here the loss is low 

If model predicts the probability as:

Paris = 2%

Then the loss will be high.

Cross Entropy Loss model help the model how inaccurate prediction is.

7. Backpropogation:

After the calculation of Loss model investigate which weight caused wrong prediction. Backpropogation use the chain rule to calculate the contribution of each parameter.

Flow:

Loss
↓
Backward Pass
↓
Gradient Calculation

This stage is the foundation of learning.

8. Gradient:

Gradient is used to communicate with the optimizer for:

Decide in which direction the weight should move?
And where to move?

Example:

Weight A Gradient = +0.3
Weight B Gradient = -0.1

Gradients are the learning signals for the model.

9. Gradient Clipping:

Sometime the size of gradients become large which can make gradients training unstable.

Example:

Gradient = 10000

so to solve this problem we use gradient clipping.

clip_grad_norm_(manorm=1.0)

10. AdamW:

AdamW  is a most popular optimizer in modern LLMs.  The work of AdamW is:

  • To analyze the gradient.
  • Consider the past learning history.
  • Perform smart weight perform

AdamW are very advanced and efficient than simple gradient decent.  that's why most modern models like GPT, Lama use AdamW.  It helps to keep gradients in safe range.

11. weight Update:

here AdamW updates the weight.

Before:

Weight = 0.25

After:

Weight = 0.247

this change might be look small but the billions of parameters and billions of updates happen, then the model become gradually intelligent. Here the knowledge is stored in the weights.

0 Comments Report