This is transformer training: teaching our model to predict tokens by minimizing cross-entropy loss. In the previous lesson, we assembled all the pieces. Now we will watch it learn to generate fairy tales using PyTorch.
Transformer Training Loop
Training a language model means minimizing the cross-entropy loss between predicted tokens and actual tokens. For each position in the input sequence, the model predicts a probability distribution over the vocabulary; we compare this to the actual next token and compute the loss. This is called “next token prediction” or causal language modeling.
Run train.py to see the training loop in action. On each epoch, the model processes batches of fairy tale text, computes the loss, and updates its parameters using backpropagation. The loss should drop steadily from around 4.0 (random guessing for a 24-token vocabulary) to below 1.0 as the model learns patterns.
Optimizer and Learning Rate
Modern LLMs use the Adam optimizer (Kingma & Ba, 2014) or its improved variant AdamW (Loshchilov & Hutter, 2017), which decouples weight decay from the adaptive learning rate. A cosine learning rate schedule with warmup is standard: the LR ramps up linearly from 0 to the maximum over the first few thousand steps, then decays following a cosine curve. This warmup phase prevents the model from making destructive updates early in training when the gradients are noisy.
What Transformer Training Teaches
Early in training, the model learns token frequencies and common bigrams. After a few epochs, it discovers simple grammatical patterns: articles precede nouns, verbs follow subjects. Given enough data and capacity, it learns longer-range dependencies – character names, story structure, and even quotation patterns. The next lesson will show how to sample from the trained model to generate text.
Training on a single CPU takes about 30 seconds for our tiny fairy tale dataset. Scaling to GPT-3 scale (175B parameters trained on 300B tokens) required thousands of GPUs for weeks – a difference of roughly 10 orders of magnitude in compute.
In the next lesson, we will look at how to convert the model’s raw logits into actual text through various sampling strategies.
Batching and Gradient Accumulation
In practice, we process multiple sequences simultaneously in batches. PyTorch’s DataLoader handles this efficiently. For large models, even a single batch may not fit in GPU memory. Gradient accumulation solves this by computing gradients on micro-batches and summing them before taking an optimizer step. This effectively simulates a larger batch size without requiring proportional memory.
Monitoring Transformer Training
We track the loss on a held-out validation set to detect overfitting. If the training loss continues to decrease but the validation loss starts increasing, the model is memorizing rather than learning. Techniques like dropout (Srivastava et al., 2014) and weight decay help prevent this. For our tiny fairy tale dataset, the model typically reaches its best validation perplexity after 50-100 epochs, then begins to overfit.

Leave a Reply