Created
May 8, 2026 09:11
-
-
Save danielfone/a1c63e8878c580dce21864293ebd895d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| tiny_llm.py — a minimal character-level transformer trained live. | |
| Demonstrates zero-shot completion from scratch. | |
| Requirements: pip install torch | |
| Run: python tiny_llm.py | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import time | |
| import argparse | |
| # ── Corpus ──────────────────────────────────────────────────────────────────── | |
| # A handful of Shakespeare sonnets (public domain). Swap for anything you like. | |
| TEXT = """ | |
| Shall I compare thee to a summer's day? | |
| Thou art more lovely and more temperate. | |
| Rough winds do shake the darling buds of May, | |
| And summer's lease hath all too short a date. | |
| Sometime too hot the eye of heaven shines, | |
| And often is his gold complexion dimmed; | |
| And every fair from fair sometime declines, | |
| By chance, or nature's changing course, untrimmed; | |
| But thy eternal summer shall not fade, | |
| Nor lose possession of that fair thou ow'st, | |
| Nor shall death brag thou wand'rest in his shade, | |
| When in eternal lines to Time thou grow'st. | |
| So long as men can breathe, or eyes can see, | |
| So long lives this, and this gives life to thee. | |
| Let me not to the marriage of true minds | |
| Admit impediments. Love is not love | |
| Which alters when it alteration finds, | |
| Or bends with the remover to remove. | |
| O no! it is an ever-fixed mark | |
| That looks on tempests and is never shaken; | |
| It is the star to every wand'ring bark, | |
| Whose worth's unknown, although his height be taken. | |
| Love's not Time's fool, though rosy lips and cheeks | |
| Within his bending sickle's compass come; | |
| Love alters not with his brief hours and weeks, | |
| But bears it out even to the edge of doom. | |
| If this be error and upon me proved, | |
| I never writ, nor no man ever loved. | |
| When I do count the clock that tells the time, | |
| And see the brave day sunk in hideous night; | |
| When I behold the violet past prime, | |
| And sable curls all silvered o'er with white; | |
| When lofty trees I see barren of leaves, | |
| Which erst from heat did canopy the herd, | |
| And summer's green all girded up in sheaves, | |
| Borne on the bier with white and bristly beard; | |
| Then of thy beauty do I question make | |
| That thou among the wastes of time must go, | |
| Since sweets and beauties do themselves forsake, | |
| And die as fast as they see others grow. | |
| And nothing 'gainst Time's scythe can make defence | |
| Save breed, to brave him when he takes thee hence. | |
| """.strip() | |
| # ── Tokenisation ────────────────────────────────────────────────────────────── | |
| chars = sorted(set(TEXT)) | |
| VOCAB = len(chars) | |
| s2i = {c: i for i, c in enumerate(chars)} | |
| i2s = {i: c for i, c in enumerate(chars)} | |
| encode = lambda s: [s2i[c] for c in s] | |
| decode = lambda t: ''.join(i2s[i] for i in t) | |
| data = torch.tensor(encode(TEXT), dtype=torch.long) | |
| # ── Hyperparameters ─────────────────────────────────────────────────────────── | |
| CTX = 64 # context window (characters) | |
| EMBED = 64 # embedding dimension | |
| HEADS = 4 # attention heads | |
| LAYERS = 2 # transformer blocks | |
| BATCH = 32 | |
| STEPS = 800 | |
| LR = 3e-3 | |
| DEVICE = 'cpu' | |
| # ── Data loader ─────────────────────────────────────────────────────────────── | |
| def get_batch(): | |
| ix = torch.randint(len(data) - CTX, (BATCH,)) | |
| x = torch.stack([data[i:i+CTX] for i in ix]) | |
| y = torch.stack([data[i+1:i+CTX+1] for i in ix]) | |
| return x, y | |
| # ── Model ───────────────────────────────────────────────────────────────────── | |
| class Head(nn.Module): | |
| def __init__(self, head_size): | |
| super().__init__() | |
| self.q = nn.Linear(EMBED, head_size, bias=False) | |
| self.k = nn.Linear(EMBED, head_size, bias=False) | |
| self.v = nn.Linear(EMBED, head_size, bias=False) | |
| self.register_buffer('mask', torch.tril(torch.ones(CTX, CTX))) | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| q, k, v = self.q(x), self.k(x), self.v(x) | |
| w = q @ k.transpose(-2, -1) * C**-0.5 | |
| w = w.masked_fill(self.mask[:T, :T] == 0, float('-inf')) | |
| w = F.softmax(w, dim=-1) | |
| return w @ v | |
| class MultiHead(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| hs = EMBED // HEADS | |
| self.heads = nn.ModuleList([Head(hs) for _ in range(HEADS)]) | |
| self.proj = nn.Linear(EMBED, EMBED) | |
| def forward(self, x): | |
| return self.proj(torch.cat([h(x) for h in self.heads], dim=-1)) | |
| class Block(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.attn = MultiHead() | |
| self.ff = nn.Sequential( | |
| nn.Linear(EMBED, 4 * EMBED), nn.ReLU(), nn.Linear(4 * EMBED, EMBED) | |
| ) | |
| self.ln1 = nn.LayerNorm(EMBED) | |
| self.ln2 = nn.LayerNorm(EMBED) | |
| def forward(self, x): | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.ff(self.ln2(x)) | |
| return x | |
| class TinyLLM(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.tok_emb = nn.Embedding(VOCAB, EMBED) | |
| self.pos_emb = nn.Embedding(CTX, EMBED) | |
| self.blocks = nn.Sequential(*[Block() for _ in range(LAYERS)]) | |
| self.ln = nn.LayerNorm(EMBED) | |
| self.head = nn.Linear(EMBED, VOCAB) | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| x = self.tok_emb(idx) + self.pos_emb(torch.arange(T)) | |
| logits = self.head(self.ln(self.blocks(x))) | |
| loss = None if targets is None else \ | |
| F.cross_entropy(logits.view(-1, VOCAB), targets.view(-1)) | |
| return logits, loss | |
| @torch.no_grad() | |
| def complete(self, prompt, n=200, temperature=0.8): | |
| idx = torch.tensor([encode(prompt)], dtype=torch.long) | |
| print(prompt, end='', flush=True) | |
| for _ in range(n): | |
| logits, _ = self(idx[:, -CTX:]) | |
| probs = F.softmax(logits[:, -1, :] / temperature, dim=-1) | |
| next_tok = torch.multinomial(probs, 1) | |
| idx = torch.cat([idx, next_tok], dim=1) | |
| print(decode([next_tok.item()]), end='', flush=True) | |
| print() | |
| # ── Training ────────────────────────────────────────────────────────────────── | |
| model = TinyLLM().to(DEVICE) | |
| total = sum(p.numel() for p in model.parameters()) | |
| print(f"\n vocab size : {VOCAB} characters") | |
| print(f" parameters : {total:,}") | |
| print(f" training on: {len(TEXT):,} characters\n") | |
| print(" step loss") | |
| print(" ────────────") | |
| def train(steps): | |
| t0 = time.time() | |
| for step in range(1, steps + 1): | |
| x, y = get_batch() | |
| _, loss = model(x, y) | |
| opt.zero_grad() | |
| loss.backward() | |
| opt.step() | |
| if step % 100 == 0: | |
| print(f" {step:4d} {loss.item():.4f}") | |
| print(f"\n trained in {time.time()-t0:.1f}s\n") | |
| def prompt_loop(): | |
| print("=" * 60) | |
| print("INTERACTIVE — type a prompt, enter to complete, q to quit") | |
| print("=" * 60 + "\n") | |
| while True: | |
| try: | |
| line = input("Prompt> ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| break | |
| if not line or line.lower() == 'q': | |
| break | |
| prompt = ''.join(c if c in s2i else ' ' for c in line) | |
| print("Output: ", end='') | |
| model.complete(prompt, n=200) | |
| print() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--steps', type=int, default=100) | |
| args = parser.parse_args() | |
| opt = torch.optim.AdamW(model.parameters(), lr=LR) | |
| train(args.steps) | |
| prompt_loop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment