GPT and more from scratch with comments

NLP
transformers
deep learning
code
analysis
Author

Andrej Muhic

Published

May 28, 2024

All original credits to companion notebook to the Zero To Hero video on GPT by Andrej Karpathy. Code was modified and heavily commented.

Transformer architecture

This blog post is meant as a collection of resources to facilitate deeper and easier understanding of the transformer architecture and also the main deep learning concepts. Let us start with the graph of multi head causal transformer with learnable positional embedding with the following setup in the full blown annotated image of the multi head attention with learnable positional encoding:

# hyperparameters
batch_size = 1024 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
n_embd = 64
n_head = 2
n_layer = 2
max_iters = 10000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
Creating a graph of the model
from torchview import draw_graph
import torchlens as tl
from pathlib import Path

#device='meta' -> no memory is consumed for visualization
model = GPT2LanguageModel()
m = model.to(device)
#the number of parameters in the model
print(sum(p.numel() for p in m.parameters()) / 1e6, 'M parameters')

xb, yb = get_batch('train')
print(xb.shape)
print(yb.shape)
model_graph = draw_graph(model, input_data=(xb, yb), depth=5, expand_nested=True, save_graph=False,
                         directory=str(Path('~').expanduser()), device='cuda')
print(model_graph)
graph = model_graph.visual_graph
format='svg'
graph_svg = graph.pipe(format='svg').decode('utf-8') # convert to binary data
file = str(Path('~').expanduser()) + '/multihead_transformer_with_learnable_positional_encoding.' + format
print(file)
with open(file, 'wt') as f:
    f.write(graph_svg)

Multihead transformer with learnable positional encoding

If you can reason about the model only looking at the picture then you can skip most of the details and just dive in. The code tries to be well documented with plenty of the links to the resources that help to look at things from different perspective. If you are beginner in deep learning but have solid mathematical background I also provide collection of resources that can be helpful to get up to the speed quickly. I was using this mainly as scratchpad to note my ideas and questions over years. The view that I liked the most is transformers as GNN with additional positional encoding as I got attracted by nice symmetry, group and even categorical abstract motivation of the deep learning in general. I would like to also to point to extrakt.AI for one nice use case of information extraction and specifically Jan Rupnik and AiLab@JSI for several stimulating discussions. I also recommend checking Soniox for interesting applications in the deeper understanding of the audio.

Cleaned up code resources

For cleaned up code see Karpathy:

Key techiques so deep learning is easier to use

  • Adam/AdamW optimizer
  • Transformers
  • Residual connections
  • Dropout
  • Layer/batch normalization
  • Automatic differentiation improvements
  • Polished libraries that are easy to use, PyTorch
  • Hardware acceleration optimized for the architectures

Strict run time named dimension checking

Tracking progress of your runs

Optimization, AdamW state of the art

Troubles with regular techniques

Visualizing the network

Activations

Useful tricks

Einsum formulas for humans

Broadcasting

Sanity checks and data

  • Inspect your data before you start training
  • Overift on small piece of data using simplest possible model before trying to do something fancy
  • Getting better and more data will bring larger improvements than using superior model

NLP specific info

Tokenization

Alternative architectures and improvements

Embeddings

Various sources

Stricter typing

Code
# Let us do some stricter typing
%config Completer.use_jedi = False
# from sklearn.model_selection import train_test_split
import typing
from typing import TYPE_CHECKING, Any, Optional
import torch

# Strict run time dimension checking
# https://kidger.site/thoughts/jaxtyping/
# https://docs.kidger.site/jaxtyping/api/array/
# https://github.com/patrick-kidger/jaxtyping
# https://github.com/agronholm/typeguard

# Tracking the progress: 
#                        https://github.com/aimhubio/aim
#                        https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html
#                        https://wandb.ai/site

# Something on order of distributions in cross entropy and KL
# https://agustinus.kristia.de/techblog/2016/12/21/forward-reverse-kl/

# Visualizing the network
# https://github.com/szagoruyko/pytorchviz
# https://github.com/mert-kurttutan/torchview nicer graph

# Optimization, AdamW state of the art
# https://www.ruder.io/optimizing-gradient-descent/#adam
# https://iclr-blogposts.github.io/2023/blog/2023/adamw/ as proximal operator
# https://arxiv.org/abs/2404.04454 Implicit Bias of AdamW: ℓ∞ Norm Constrained Optimization
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# exponential weighted average of gradient and "variance of gradients", corrected for bias for starting at 0 
# +  a trick to do better than L2 for decay rate
# Back stepping
# Not feasible for deep learning, too costly to store harder to parallelize!
# https://en.wikipedia.org/wiki/Limited-memory_BFGS approximates inverse of Hessian implicitly and even this is too costly
# https://en.wikipedia.org/wiki/Wolfe_conditions
# https://en.wikipedia.org/wiki/Backtracking_line_search more advanced step size

# Activations
# https://www.ai-contentlab.com/2023/03/swishglu-activation-function.html
# https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html

# Einsum formulas for human, thanks for the hint guys!
# https://einops.rocks/api/einsum/

# Broadcasting
# https://numpy.org/doc/stable/user/basics.broadcasting.html My advice would be if you are not sure do manual broadcasting and test.
# Strictly test dimenstions with something like jaxtyping

# key enabling techiques that made deep learning more easy to work with
# - adam/adamw optimizer
# - residual connections
# - dropout
# - layer/batch normalization
# - automatic differentiation

from jaxtyping import Float, Int64

Array: typing.TypeAlias = torch.Tensor
Long: typing.TypeAlias = Int64

Loading the data

Code
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
from pathlib import Path
save_path = Path('~').expanduser() / f'.cache/input.txt'
if not save_path.exists():
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O {save_path}

Inspecting the data

Code
# read it in to inspect it
with save_path.open('r', encoding='utf-8') as f:
    text = f.read()
Code
print("length of dataset in characters: ", len(text))
length of dataset in characters:  1115394
Code
# let's look at the first 1000 characters
print(text[:1000])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.

Unique characters as tokens, 65 vocabulary size

Code
# here are all the unique characters that occur in this text
chars: list[str] = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65

Mapping from characters to integers

Code
# create a mapping from characters to integers
stoi: dict[str, int] = {ch: i for i, ch in enumerate(chars)}
itos: dict[int, str] = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]  # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # decoder: take a list of integers, output a string

# In practice one would use something like [Byte Pair Encoding](https://github.com/karpathy/minbpe) or [WordPiece, SentencePiece](https://huggingface.co/docs/transformers/en/tokenizer_summary)

print(encode("hii there"))
print(decode(encode("hii there")))
[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there
Code
# let's now encode the entire text dataset and store it into a torch.Tensor
import torch  # we use PyTorch: https://pytorch.org

data = torch.tensor(encode(text), dtype=torch.int64)
print(data.shape, data.dtype)
print(data[:1000])  # the 1000 characters we looked at earier will to the GPT look like this
torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 46, 47, 43, 44,  1, 43, 52, 43, 51, 63,
         1, 58, 53,  1, 58, 46, 43,  1, 54, 43, 53, 54, 50, 43,  8,  0,  0, 13,
        50, 50, 10,  0, 35, 43,  1, 49, 52, 53, 61,  5, 58,  6,  1, 61, 43,  1,
        49, 52, 53, 61,  5, 58,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47, 58,
        47, 64, 43, 52, 10,  0, 24, 43, 58,  1, 59, 57,  1, 49, 47, 50, 50,  1,
        46, 47, 51,  6,  1, 39, 52, 42,  1, 61, 43,  5, 50, 50,  1, 46, 39, 60,
        43,  1, 41, 53, 56, 52,  1, 39, 58,  1, 53, 59, 56,  1, 53, 61, 52,  1,
        54, 56, 47, 41, 43,  8,  0, 21, 57,  5, 58,  1, 39,  1, 60, 43, 56, 42,
        47, 41, 58, 12,  0,  0, 13, 50, 50, 10,  0, 26, 53,  1, 51, 53, 56, 43,
         1, 58, 39, 50, 49, 47, 52, 45,  1, 53, 52,  5, 58, 11,  1, 50, 43, 58,
         1, 47, 58,  1, 40, 43,  1, 42, 53, 52, 43, 10,  1, 39, 61, 39, 63,  6,
         1, 39, 61, 39, 63,  2,  0,  0, 31, 43, 41, 53, 52, 42,  1, 15, 47, 58,
        47, 64, 43, 52, 10,  0, 27, 52, 43,  1, 61, 53, 56, 42,  6,  1, 45, 53,
        53, 42,  1, 41, 47, 58, 47, 64, 43, 52, 57,  8,  0,  0, 18, 47, 56, 57,
        58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 35, 43,  1, 39, 56, 43,  1,
        39, 41, 41, 53, 59, 52, 58, 43, 42,  1, 54, 53, 53, 56,  1, 41, 47, 58,
        47, 64, 43, 52, 57,  6,  1, 58, 46, 43,  1, 54, 39, 58, 56, 47, 41, 47,
        39, 52, 57,  1, 45, 53, 53, 42,  8,  0, 35, 46, 39, 58,  1, 39, 59, 58,
        46, 53, 56, 47, 58, 63,  1, 57, 59, 56, 44, 43, 47, 58, 57,  1, 53, 52,
         1, 61, 53, 59, 50, 42,  1, 56, 43, 50, 47, 43, 60, 43,  1, 59, 57, 10,
         1, 47, 44,  1, 58, 46, 43, 63,  0, 61, 53, 59, 50, 42,  1, 63, 47, 43,
        50, 42,  1, 59, 57,  1, 40, 59, 58,  1, 58, 46, 43,  1, 57, 59, 54, 43,
        56, 44, 50, 59, 47, 58, 63,  6,  1, 61, 46, 47, 50, 43,  1, 47, 58,  1,
        61, 43, 56, 43,  0, 61, 46, 53, 50, 43, 57, 53, 51, 43,  6,  1, 61, 43,
         1, 51, 47, 45, 46, 58,  1, 45, 59, 43, 57, 57,  1, 58, 46, 43, 63,  1,
        56, 43, 50, 47, 43, 60, 43, 42,  1, 59, 57,  1, 46, 59, 51, 39, 52, 43,
        50, 63, 11,  0, 40, 59, 58,  1, 58, 46, 43, 63,  1, 58, 46, 47, 52, 49,
         1, 61, 43,  1, 39, 56, 43,  1, 58, 53, 53,  1, 42, 43, 39, 56, 10,  1,
        58, 46, 43,  1, 50, 43, 39, 52, 52, 43, 57, 57,  1, 58, 46, 39, 58,  0,
        39, 44, 44, 50, 47, 41, 58, 57,  1, 59, 57,  6,  1, 58, 46, 43,  1, 53,
        40, 48, 43, 41, 58,  1, 53, 44,  1, 53, 59, 56,  1, 51, 47, 57, 43, 56,
        63,  6,  1, 47, 57,  1, 39, 57,  1, 39, 52,  0, 47, 52, 60, 43, 52, 58,
        53, 56, 63,  1, 58, 53,  1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,
        57, 43,  1, 58, 46, 43, 47, 56,  1, 39, 40, 59, 52, 42, 39, 52, 41, 43,
        11,  1, 53, 59, 56,  0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43,  1, 47,
        57,  1, 39,  1, 45, 39, 47, 52,  1, 58, 53,  1, 58, 46, 43, 51,  1, 24,
        43, 58,  1, 59, 57,  1, 56, 43, 60, 43, 52, 45, 43,  1, 58, 46, 47, 57,
         1, 61, 47, 58, 46,  0, 53, 59, 56,  1, 54, 47, 49, 43, 57,  6,  1, 43,
        56, 43,  1, 61, 43,  1, 40, 43, 41, 53, 51, 43,  1, 56, 39, 49, 43, 57,
        10,  1, 44, 53, 56,  1, 58, 46, 43,  1, 45, 53, 42, 57,  1, 49, 52, 53,
        61,  1, 21,  0, 57, 54, 43, 39, 49,  1, 58, 46, 47, 57,  1, 47, 52,  1,
        46, 59, 52, 45, 43, 56,  1, 44, 53, 56,  1, 40, 56, 43, 39, 42,  6,  1,
        52, 53, 58,  1, 47, 52,  1, 58, 46, 47, 56, 57, 58,  1, 44, 53, 56,  1,
        56, 43, 60, 43, 52, 45, 43,  8,  0,  0])

Train/validation split

Code
# Let's now split up the data into train and validation sets
n = int(0.9 * len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
Code
block_size = 8
# First block and the target
train_data[:block_size + 1]
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
Code
# This is what we want to learn for max_context_size = block_size
# It is important that we train also for shorter sequences than block_size
# When creating batches in practice we could need padding token and if pad left or right!
# Also how to shuffle if we need to if data does not fit in memory: https://blog.janestreet.com/how-to-shuffle-a-big-dataset/
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t + 1]
    target = y[t]
    print(f"when input is {context} the target: {target}")
when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58

Batch generation

Code
torch.manual_seed(1337)
batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # what is the maximum context length for predictions?


def get_batch(split_kind: str) -> tuple[Int64[Array, "n_batches block_size"], Int64[Array, "n_batches block_size"]]:
    # generate a small batch of data of inputs x and targets y
    data = train_data if split_kind == 'train' else val_data
    # Random starting indices of blocks, notice that blocks can overlap
    # To do something like this for real: 
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y


xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size):  # batch dimension
    for t in range(block_size):  # time dimension
        context = xb[b, :t + 1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")
inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 53, 56, 1, 58, 46] the target: 39
when input is [44, 53, 56, 1, 58, 46, 39] the target: 58
when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1
when input is [52] the target: 58
when input is [52, 58] the target: 1
when input is [52, 58, 1] the target: 58
when input is [52, 58, 1, 58] the target: 46
when input is [52, 58, 1, 58, 46] the target: 39
when input is [52, 58, 1, 58, 46, 39] the target: 58
when input is [52, 58, 1, 58, 46, 39, 58] the target: 1
when input is [52, 58, 1, 58, 46, 39, 58, 1] the target: 46
when input is [25] the target: 17
when input is [25, 17] the target: 27
when input is [25, 17, 27] the target: 10
when input is [25, 17, 27, 10] the target: 0
when input is [25, 17, 27, 10, 0] the target: 21
when input is [25, 17, 27, 10, 0, 21] the target: 1
when input is [25, 17, 27, 10, 0, 21, 1] the target: 54
when input is [25, 17, 27, 10, 0, 21, 1, 54] the target: 39
Code
print(xb)  # our input to the transformer
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

Bigram language model

What is the most likely next token given the current token? We are targeting the probability or logit of the next token, given the current token.

Code
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)


class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size: int):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx: Long[torch.Tensor, "batch_dim context_dim"],
                targets: Optional[Long[torch.Tensor, "batch_dim context_dim"]] = None):

        # idx and targets are both (B,T) tensor of integers
        logits: Long[torch.Tensor, "batch_dim context_dim latent_dim"] = self.token_embedding_table(
            idx)  # (B,T,C=vocab_size)

        if targets is None:
            return logits, None
        else:
            # Note that here strictly speaking this does not fix batch size explicitly to B
            B, T, C = logits.shape  # (B,T,C=vocab_size)
            # Just a hack to avoid transposing, cross_entropy expects B x C x T in batched mode
            # This converts into non batched mode
            logits: Long[torch.Tensor, "batch_dim*context_dim latent_dim"] = logits.view(B * T, C)
            targets: Long[torch.Tensor, "batch_dim*context_dim"] = targets.view(B * T)
            # https://agustinus.kristia.de/techblog/2016/12/21/forward-reverse-kl/
            loss: Float[torch.Tensor, ""] = F.cross_entropy(logits, targets)
            return logits, loss

    def generate(self, idx: Long[torch.Tensor, "batch_dim context_dim"], max_new_tokens: int):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, vocab_size)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, vocab_size)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
            # What can go wrong here? and it is not handled at all
        return idx


m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
print(loss.shape)
print(decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))
torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)
torch.Size([])

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3
Code
### Optimization using AdamW and cross entropy
Code
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
# weighted average of mean and "variance" of gradients +  a trick
# Not feasible for deep learning, too costly to store harder to parallelize!
# https://en.wikipedia.org/wiki/Limited-memory_BFGS approximates inverse of Hessian implicitly and even this is too costly
# https://en.wikipedia.org/wiki/Wolfe_conditions
# https://en.wikipedia.org/wiki/Backtracking_line_search more advanced step size
Code
batch_size = 32
n_steps = 1_000
for steps in range(n_steps):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())
Code
print(decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))

The mathematical formulation of (causal) self-attention

Code
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
# Causal attention does not take into account future information
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)
a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
Code
# consider the following toy example:

torch.manual_seed(1337)
B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
x.shape
torch.Size([4, 8, 2])

Manual aggregation

Code
# We want $x[b,t] = mean_{i<=t} x[b,i]$
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1]  # (t,C)
        xbow[b, t] = torch.mean(xprev, 0)

Matrix multiply for weighted aggregation

Code
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x  # (B, T, T) @ (B, T, C) ----> (B, T, C)
# Numerical instability and float32 only, need to set lower relative accuracy
torch.allclose(xbow, xbow2, 1e-4)
True
Code
### Softmax for weighted aggregation
Code
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
# This seems strange at first glance but exp(-inf) = 0 and it is well defined, otherwise training would get broken constantly
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
# Numerical instability and float32 only, need to set lower relative accuracy
torch.allclose(xbow, xbow3, 1e-4)
True

Self attention

\[\texttt{Attention}(Q, K, V) = \texttt{softmax}(\frac{QK^T}{\sqrt{d_k}})V,\] where \(\sqrt{d_k}\) is the dimension of the key vector \(k\) and query vector \(q\). The causal version can only take into account current and past tokens se we need to mask all future token inputs so that they not influence the prediction of the next token. This can be accomplished by triangular mask.

Multi head attention

\[\texttt{MultiHead}(Q, K, V) = \texttt{Concat}(\texttt{head}_1, ..., \texttt{head}_h)W^O,\] where \[\texttt{head}_i = \texttt{Attention}(Q W^Q_i, K W^K_i, V W^V_i).\]

Modifed from ML equations in latex. As a note, the vectors are rows and not columns like in Matlab.

Code
# version 4: self-attention!
torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch, time, channels
x = torch.randn(B, T, C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T,T))
# Full matrix multiplication is faster, just block what we do not need
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
# out = wei @ x

out.shape
torch.Size([4, 8, 16])
Code
wei[0]
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Notes:

  • Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
  • There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
  • Each example across batch dimension is of course processed completely independently and never “talk” to each other
  • In an “encoder” attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a “decoder” attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
  • “self-attention” just means that the keys and values are produced from the same source as queries. In “cross-attention”, the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
  • “Scaled” attention additionally divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below’’
Code
# head_size is the dimension of "latent space"
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1) * head_size ** -0.5
Code
k.var()
tensor(1.0449)
Code
q.var()
tensor(1.0700)
Code
wei.var()
tensor(1.0918)
Code
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)
tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
Code
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 8, dim=-1)  # gets too peaky, converges to one-hot
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])
Code
class LayerNorm1d:  # (used to be BatchNorm1d)

    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):
        # calculate the forward pass
        xmean = x.mean(1, keepdim=True)  # batch mean
        xvar = x.var(1, keepdim=True)  # batch variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)  # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]


torch.manual_seed(1337)
module = LayerNorm1d(100)
# batch_norm_1d = nn.BatchNorm1d(100)
# x_normalized_torch = batch_norm_1d(x)

x = torch.randn(32, 100)  # batch size 32 of 100-dimensional vectors
x_normalized = module(x)
x_normalized.shape
torch.Size([32, 100])
Code
x[:, 0].mean(), x[:, 0].std()  # mean,std of one feature across all batch inputs
(tensor(0.1392), tensor(0.8899))
Code
x[0, :].mean(), x[0, :].std()  # mean,std of a single input from the batch, of its features
(tensor(0.0409), tensor(1.0476))
# French to English translation example:

# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>

Full finished code, for reference

You may want to refer directly to the Karpathy’s git repo or use NanoGPT instead though.

Code
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 1024 # 16  # how many independent sequences will we process in parallel?
block_size = 128 # 32  # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.2
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
save_path = Path('~').expanduser() / f'.cache/input.txt'
if not save_path.exists():
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O {save_path}

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]  # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]


# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


class Head(nn.Module):
    """ one head of self-attention
        https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention
    """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # This limits us to the maximal context block_size
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B,T,C=n_embd) -> (B,T,C=head_size)
        q = self.query(x)  # (B,T,C=n_embd) -> (B,T,C=head_size)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C ** -0.5  # (B, T, C=head_size) @ (B, C=head_size), T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B, T, T)
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        # The drop out is over full matrix, alternatively it would be better to just drop on mask, this is biased
        # Also it seems maybe conceptually we should just do symmetric dropout
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,C=head_size)
        # The matrix multiplication is batched and applied on last two dimensions!
        out = wei @ v  # (B, T, T) @ (B, T, C=head_size) -> (B, T, C=head_size)
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        # Projection back to residual pathway, align the basis
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # Projection back to residual pathway
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd: int, n_head: int):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        # To guarantee that final concatenated embedding is of size n_embd 
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # Skip connections added to improve the flow of gradient
        # Need to project back to residual pathway in sa and ffwd to "align bases" 
        
        # Modern way is do to layer norm before and not after, in original paper it was done after, we do before!
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


# Not so super simple and not bigram model anymore
class GPT2LanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token is mapped to latent space of size n_embd
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # learnable position embedding for positions 0, ..., block_size - 1
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        # Modern way of handling this is:
        # https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
        # https://afterhoursresearch.hashnode.dev/rope-rotary-positional-embedding
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        
        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * n_layer))
                
    def _init_weights(self, module):
        # Not really needed but with this in the convergence is faster
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C=n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))  # (T,C=n_embd)
        x = tok_emb + pos_emb  # (B,T,C=n_embd)
        x = self.blocks(x)  # (B,T,C=n_embd)
        # Modern way is do to layer norm before and not after, in original paper it was done after, we do before!
        x = self.ln_f(x)  # (B,T,C=n_embd)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens, otherwise things will explode
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # (B, T, C=vocab_size) becomes (B, C=vocab_size)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C=vocab_size)
            # sample from the multinomial distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


model = GPT2LanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters()) / 1e6, 'M parameters')

# create a PyTorch optimizer
# https://www.ruder.io/optimizing-gradient-descent/#adam
# https://iclr-blogposts.github.io/2023/blog/2023/adamw/
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))
# Temperature motivation https://en.wikipedia.org/wiki/LogSumExp
# What happens when you divide logits by posite Temp and then do softmax
# Temp close to 0 is max in limit, one hot, Temp >> 1 in limit random
0.215808 M parameters
step 0: train loss 4.1890, val loss 4.1885
step 100: train loss 2.5203, val loss 2.5128
step 200: train loss 2.2087, val loss 2.2280
step 300: train loss 1.9885, val loss 2.0614
step 400: train loss 1.7996, val loss 1.9230
step 500: train loss 1.7208, val loss 1.8648
step 600: train loss 1.6610, val loss 1.8272
step 700: train loss 1.6218, val loss 1.7936
step 800: train loss 1.5859, val loss 1.7645
step 900: train loss 1.5593, val loss 1.7374
step 1000: train loss 1.5727, val loss 1.7391
step 1100: train loss 1.5257, val loss 1.7019
step 1200: train loss 1.5101, val loss 1.6936
step 1300: train loss 1.5033, val loss 1.6867
step 1400: train loss 1.4825, val loss 1.6667
step 1500: train loss 1.4743, val loss 1.6577
step 1600: train loss 1.4717, val loss 1.6572
step 1700: train loss 1.4584, val loss 1.6505
step 1800: train loss 1.4466, val loss 1.6353
step 1900: train loss 1.4423, val loss 1.6301
step 2000: train loss 1.4483, val loss 1.6326
step 2100: train loss 1.4322, val loss 1.6291
step 2200: train loss 1.4267, val loss 1.6201
step 2300: train loss 1.4237, val loss 1.6142
step 2400: train loss 1.4157, val loss 1.6069
step 2500: train loss 1.4104, val loss 1.6050
step 2600: train loss 1.4105, val loss 1.6065
step 2700: train loss 1.4058, val loss 1.6060
step 2800: train loss 1.4023, val loss 1.5983
step 2900: train loss 1.3992, val loss 1.5967
step 3000: train loss 1.3964, val loss 1.5929
step 3100: train loss 1.3934, val loss 1.5959
step 3200: train loss 1.3905, val loss 1.5892
step 3300: train loss 1.3861, val loss 1.5817
step 3400: train loss 1.3873, val loss 1.5865
step 3500: train loss 1.3850, val loss 1.5846
step 3600: train loss 1.3816, val loss 1.5820
step 3700: train loss 1.3786, val loss 1.5816
step 3800: train loss 1.3778, val loss 1.5762
step 3900: train loss 1.3747, val loss 1.5735
step 4000: train loss 1.3722, val loss 1.5791
step 4100: train loss 1.3727, val loss 1.5765
step 4200: train loss 1.3725, val loss 1.5795
step 4300: train loss 1.3736, val loss 1.5785
step 4400: train loss 1.3675, val loss 1.5717
step 4500: train loss 1.3649, val loss 1.5684
step 4600: train loss 1.3663, val loss 1.5669
step 4700: train loss 1.3611, val loss 1.5657
step 4800: train loss 1.3614, val loss 1.5636
step 4900: train loss 1.3591, val loss 1.5697
step 5000: train loss 1.3640, val loss 1.5737
step 5100: train loss 1.8337, val loss 1.9548
step 5200: train loss 1.3989, val loss 1.6005
step 5300: train loss 1.3727, val loss 1.5784
step 5400: train loss 1.3656, val loss 1.5752
step 5500: train loss 1.3609, val loss 1.5700
step 5600: train loss 1.3588, val loss 1.5663
step 5700: train loss 1.3558, val loss 1.5678
step 5800: train loss 1.3536, val loss 1.5654
step 5900: train loss 1.3526, val loss 1.5655
step 6000: train loss 1.3500, val loss 1.5611
step 6100: train loss 1.3543, val loss 1.5676
step 6200: train loss 1.3485, val loss 1.5605
step 6300: train loss 1.3472, val loss 1.5632
step 6400: train loss 1.3463, val loss 1.5652
step 6500: train loss 1.3445, val loss 1.5602
step 6600: train loss 1.3436, val loss 1.5604
step 6700: train loss 1.3427, val loss 1.5672
step 6800: train loss 1.3439, val loss 1.5627
step 6900: train loss 1.3393, val loss 1.5595
step 7000: train loss 1.3419, val loss 1.5651
step 7100: train loss 1.3409, val loss 1.5611
step 7200: train loss 1.3427, val loss 1.5649
step 7300: train loss 1.3393, val loss 1.5593
step 7400: train loss 1.3360, val loss 1.5573
step 7500: train loss 1.3405, val loss 1.5673
step 7600: train loss 1.3387, val loss 1.5630
step 7700: train loss 1.3365, val loss 1.5624
step 7800: train loss 1.3336, val loss 1.5640
step 7900: train loss 1.3351, val loss 1.5637
step 8000: train loss 1.3350, val loss 1.5658
step 8100: train loss 1.3320, val loss 1.5566
step 8200: train loss 1.3315, val loss 1.5572
step 8300: train loss 1.3321, val loss 1.5573
step 8400: train loss 1.3321, val loss 1.5650
step 8500: train loss 1.3326, val loss 1.5600
step 8600: train loss 1.3325, val loss 1.5667
step 8700: train loss 1.3300, val loss 1.5573
step 8800: train loss 1.3289, val loss 1.5581
step 8900: train loss 1.3289, val loss 1.5621
step 9000: train loss 1.3276, val loss 1.5536
step 9100: train loss 1.3263, val loss 1.5544
step 9200: train loss 1.3284, val loss 1.5636
step 9300: train loss 1.3270, val loss 1.5552
step 9400: train loss 1.3262, val loss 1.5569
step 9500: train loss 1.3264, val loss 1.5593
step 9600: train loss 1.3250, val loss 1.5538
step 9700: train loss 1.3251, val loss 1.5588
step 9800: train loss 1.3245, val loss 1.5568
step 9900: train loss 1.3228, val loss 1.5628
step 9999: train loss 1.3235, val loss 1.5614


DUKE Venery to tiege's nague will in him.
Go, lamen, were no my none, I'll faith? Lower.

NORTHUMBY:
I am proper: and thy rain for chaired
Befice, gone
Hear years; yet sail a inferous to orguate.

JULIET:
Here your daminate to the people, and,
They have. Marry begind-because you
Sweing a wealth. Poor Clareretian seed:
But, care widownest us?

Menen, no surpose him: why, hither, venteren, he is it
sweet a fall furthy this own, therefore I not in
it.

PRINCE ERKICHARD:
A least thee loved a kise of licke worth
Begain and at so sceedd the weakning sight.

ROMEO.

Rush on vailships the and the eny-I swir,
For awil truthfus, both unflowers there to be would
sussles than at the goodlor. Al tell; besimes
To his bearet win a both toes,
And love to soldie the deckoodly for a
whence have Stay 'gawd very.
Fathely we should go break up.'

Most in his friend all.

CORIOLANUS:
Thou play there must it wilt not
To foot twell more my eear of Bolingbroke,
An all with suilench' thou may two Preasage!

Server I will so happy of our bardled
Conter mine strik's too the good rogues,
And kear pout of my heart yea bed.
That's is all, for in the bross them sead?
Did darers o stands yours, are I be for you;
Sheppale knows much ding sighs made a should bear
To shall proft-my gantlen.

Fesorow, if you be a conseged a hour, bastand I
Have this sheople, nor Hanry must might of my sovet?

GLOUCESTER:
What, mercy on'tis treather it was a fortnes?
Alp you my royal, sin, in say.
Withing al toffull my virguide could, as't.

KING RICHARD:
Bring gracious loves sothers, throate i' the
That
Would parinage to secrif' the bloody sigh, and there,
Who stand, unimation their aven ever pleasure
Friend-compet As he bed: you was fold from
And that fair or was him, to to preasures,
To pranset thy love. But, the loators' tear,
Were not and that tell the depop, not cheel?
arklains overy clorge tack in mail buill
With ward her brother'd a dear ClifLo,
You shall perincation, and you are at noble.

BAHORSON:
Servant it

Draw the model graph

Code
# https://github.com/szagoruyko/pytorchviz
# https://github.com/mert-kurttutan/torchview nicer graph
from torchview import draw_graph
import torchlens as tl

# device='meta' -> no memory is consumed for visualization
print(xb.shape)
print(yb.shape)
model_graph = draw_graph(model, input_data=(xb, yb), device='cuda')
model_graph.visual_graph
save = True
save_path = Path('~/.cache').expanduser()
if save:
    torch.save(model.state_dict(), save_path / 'gpt2.pt')
    torch.save(optimizer.state_dict(), save_path / 'gpt2_opt.pt')
# model_history = tl.log_forward_pass(model, (xb, yb)
#                                      , layers_to_save='all', vis_opt='unrolled')
# print(model_history)
torch.Size([1024, 128])
torch.Size([1024, 128])
Code
# generate from the model
# Put to evaluation mode to set Dropout to evaluation mode
model.eval()
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))

LUCIO:
Nor attume to the whrewnty-shall not doubt good weak
The pluck all on thee on thy heart, much serping.

GLOUCESTER:
Come, he shall be hears that should speak of phyself
From humble and achery a heaver'd mirth,
Contice.

KING RICHARD IIV:
Intil you have 'tis their head.

JOHN OF GAUNT:
They have beart a praise and vals to see yourself;
And not to depose it, and, sweet before
The gown of the sentigal of house,
Might he shall, by where are to be yet?

First Keeper:
What is envy of our dead, for England.

Second Murderess:
A kind; he will may be conclifed thee.

Clown:
Them thrust I cannot be say, if I came; away.
Why, sir, when then stufff that's life
Reportence wings, were if he seek!

JULIET:
This pass good time love more.

Second Servant:
O, he fited upon the voice,
And which a bed-will place, by the hinnguns
hite such rumation.

Second Gentleman:
They have
But ear to such our burths is friend could rewith.

HORTENSIO:
One shall say, think yet shall thy least from my good grace!

ESCALUS:
I would go then, stand good as as to so.

DUKE VINCENTIO:
And was thry face: I'll not condry.
What proudy the great of the port, affter,
And I cired so thee soul, that farewell have;
Parged him me with quite her the accussed,
And with a montages. Thy spacress her peace
Of to their moother friends, and lucks: there,
When we dare any with English uncle Mighcaster!
What's remosted on unterty father whom
the pregned neck it that is be to be
I a joy of crown, of a fear he imposited.

PERDITA:
Foul came to Hacking so dispubble valm,
Than away?

BUCKINGHAM:
Ay, lot leave?

JULIET:
Here ever Rome, follow which even lives a sickning
Which of execut. Away they swift from you fame
And not thee, or in as a beding thee?

Second Citizen:
Here, he's sir; fooly lives' false will,
Strit, sir. Thou must to fe did death.

DUKE VINCENTIO:
I fear. Theu do, you, thou love--
Gentleman we, victory of my head did for all,
They shall call to the king noble Romeo.

Shore.

LUCIO:
You canst and a shall