Skip to content

Instantly share code, notes, and snippets.

@danielfone
Created May 8, 2026 09:11
Show Gist options
  • Select an option

  • Save danielfone/a1c63e8878c580dce21864293ebd895d to your computer and use it in GitHub Desktop.

Select an option

Save danielfone/a1c63e8878c580dce21864293ebd895d to your computer and use it in GitHub Desktop.
"""
tiny_llm.py — a minimal character-level transformer trained live.
Demonstrates zero-shot completion from scratch.
Requirements: pip install torch
Run: python tiny_llm.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import argparse
# ── Corpus ────────────────────────────────────────────────────────────────────
# A handful of Shakespeare sonnets (public domain). Swap for anything you like.
TEXT = """
Shall I compare thee to a summer's day?
Thou art more lovely and more temperate.
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date.
Sometime too hot the eye of heaven shines,
And often is his gold complexion dimmed;
And every fair from fair sometime declines,
By chance, or nature's changing course, untrimmed;
But thy eternal summer shall not fade,
Nor lose possession of that fair thou ow'st,
Nor shall death brag thou wand'rest in his shade,
When in eternal lines to Time thou grow'st.
So long as men can breathe, or eyes can see,
So long lives this, and this gives life to thee.
Let me not to the marriage of true minds
Admit impediments. Love is not love
Which alters when it alteration finds,
Or bends with the remover to remove.
O no! it is an ever-fixed mark
That looks on tempests and is never shaken;
It is the star to every wand'ring bark,
Whose worth's unknown, although his height be taken.
Love's not Time's fool, though rosy lips and cheeks
Within his bending sickle's compass come;
Love alters not with his brief hours and weeks,
But bears it out even to the edge of doom.
If this be error and upon me proved,
I never writ, nor no man ever loved.
When I do count the clock that tells the time,
And see the brave day sunk in hideous night;
When I behold the violet past prime,
And sable curls all silvered o'er with white;
When lofty trees I see barren of leaves,
Which erst from heat did canopy the herd,
And summer's green all girded up in sheaves,
Borne on the bier with white and bristly beard;
Then of thy beauty do I question make
That thou among the wastes of time must go,
Since sweets and beauties do themselves forsake,
And die as fast as they see others grow.
And nothing 'gainst Time's scythe can make defence
Save breed, to brave him when he takes thee hence.
""".strip()
# ── Tokenisation ──────────────────────────────────────────────────────────────
chars = sorted(set(TEXT))
VOCAB = len(chars)
s2i = {c: i for i, c in enumerate(chars)}
i2s = {i: c for i, c in enumerate(chars)}
encode = lambda s: [s2i[c] for c in s]
decode = lambda t: ''.join(i2s[i] for i in t)
data = torch.tensor(encode(TEXT), dtype=torch.long)
# ── Hyperparameters ───────────────────────────────────────────────────────────
CTX = 64 # context window (characters)
EMBED = 64 # embedding dimension
HEADS = 4 # attention heads
LAYERS = 2 # transformer blocks
BATCH = 32
STEPS = 800
LR = 3e-3
DEVICE = 'cpu'
# ── Data loader ───────────────────────────────────────────────────────────────
def get_batch():
ix = torch.randint(len(data) - CTX, (BATCH,))
x = torch.stack([data[i:i+CTX] for i in ix])
y = torch.stack([data[i+1:i+CTX+1] for i in ix])
return x, y
# ── Model ─────────────────────────────────────────────────────────────────────
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.q = nn.Linear(EMBED, head_size, bias=False)
self.k = nn.Linear(EMBED, head_size, bias=False)
self.v = nn.Linear(EMBED, head_size, bias=False)
self.register_buffer('mask', torch.tril(torch.ones(CTX, CTX)))
def forward(self, x):
B, T, C = x.shape
q, k, v = self.q(x), self.k(x), self.v(x)
w = q @ k.transpose(-2, -1) * C**-0.5
w = w.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
w = F.softmax(w, dim=-1)
return w @ v
class MultiHead(nn.Module):
def __init__(self):
super().__init__()
hs = EMBED // HEADS
self.heads = nn.ModuleList([Head(hs) for _ in range(HEADS)])
self.proj = nn.Linear(EMBED, EMBED)
def forward(self, x):
return self.proj(torch.cat([h(x) for h in self.heads], dim=-1))
class Block(nn.Module):
def __init__(self):
super().__init__()
self.attn = MultiHead()
self.ff = nn.Sequential(
nn.Linear(EMBED, 4 * EMBED), nn.ReLU(), nn.Linear(4 * EMBED, EMBED)
)
self.ln1 = nn.LayerNorm(EMBED)
self.ln2 = nn.LayerNorm(EMBED)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class TinyLLM(nn.Module):
def __init__(self):
super().__init__()
self.tok_emb = nn.Embedding(VOCAB, EMBED)
self.pos_emb = nn.Embedding(CTX, EMBED)
self.blocks = nn.Sequential(*[Block() for _ in range(LAYERS)])
self.ln = nn.LayerNorm(EMBED)
self.head = nn.Linear(EMBED, VOCAB)
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.tok_emb(idx) + self.pos_emb(torch.arange(T))
logits = self.head(self.ln(self.blocks(x)))
loss = None if targets is None else \
F.cross_entropy(logits.view(-1, VOCAB), targets.view(-1))
return logits, loss
@torch.no_grad()
def complete(self, prompt, n=200, temperature=0.8):
idx = torch.tensor([encode(prompt)], dtype=torch.long)
print(prompt, end='', flush=True)
for _ in range(n):
logits, _ = self(idx[:, -CTX:])
probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)
next_tok = torch.multinomial(probs, 1)
idx = torch.cat([idx, next_tok], dim=1)
print(decode([next_tok.item()]), end='', flush=True)
print()
# ── Training ──────────────────────────────────────────────────────────────────
model = TinyLLM().to(DEVICE)
total = sum(p.numel() for p in model.parameters())
print(f"\n vocab size : {VOCAB} characters")
print(f" parameters : {total:,}")
print(f" training on: {len(TEXT):,} characters\n")
print(" step loss")
print(" ────────────")
def train(steps):
t0 = time.time()
for step in range(1, steps + 1):
x, y = get_batch()
_, loss = model(x, y)
opt.zero_grad()
loss.backward()
opt.step()
if step % 100 == 0:
print(f" {step:4d} {loss.item():.4f}")
print(f"\n trained in {time.time()-t0:.1f}s\n")
def prompt_loop():
print("=" * 60)
print("INTERACTIVE — type a prompt, enter to complete, q to quit")
print("=" * 60 + "\n")
while True:
try:
line = input("Prompt> ").strip()
except (EOFError, KeyboardInterrupt):
break
if not line or line.lower() == 'q':
break
prompt = ''.join(c if c in s2i else ' ' for c in line)
print("Output: ", end='')
model.complete(prompt, n=200)
print()
parser = argparse.ArgumentParser()
parser.add_argument('--steps', type=int, default=100)
args = parser.parse_args()
opt = torch.optim.AdamW(model.parameters(), lr=LR)
train(args.steps)
prompt_loop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment