Last active
February 21, 2026 05:28
-
-
Save samwho/6e35176f32296ddc7c6579be606973fc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env -S uv run --script | |
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "numpy>=1.26", | |
| # "torch>=2.4", | |
| # "transformers>=4.45", | |
| # "accelerate>=1.12.0", | |
| # "gguf>=0.17.1", | |
| # ] | |
| # /// | |
| # | |
| # Example usage: | |
| # uv run attention_typing.py \ | |
| # --model-id Qwen/Qwen3-4B-Instruct-2507 \ | |
| # --prompt-file prompt.txt \ | |
| # --layers all \ | |
| # --heads all \ | |
| # --contrast 2.0 \ | |
| # --output typing.html | |
| # | |
| # Example prompt.txt: | |
| # [ | |
| # {"role": "system", "content": "You are a helpful assistant."}, | |
| # {"role": "user", "content": "Write a short haiku about attention maps."} | |
| # ] | |
| from __future__ import annotations | |
| import argparse | |
| import base64 | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| def parse_comma_list(raw: str, max_items: int, name: str) -> list[int]: | |
| raw = raw.strip().lower() | |
| if raw == "all": | |
| return list(range(max_items)) | |
| out: list[int] = [] | |
| for part in raw.split(","): | |
| part = part.strip() | |
| if not part: | |
| continue | |
| idx = int(part) | |
| if idx < 0 or idx >= max_items: | |
| raise ValueError(f"{name} index {idx} out of range [0, {max_items - 1}]") | |
| out.append(idx) | |
| if not out: | |
| raise ValueError(f"No valid {name} indices provided") | |
| return sorted(set(out)) | |
| def parse_chat_messages(raw: object, source: str) -> list[dict[str, str]]: | |
| if isinstance(raw, dict): | |
| raw = raw.get("messages") | |
| if not isinstance(raw, list) or not raw: | |
| raise ValueError( | |
| f"{source} must be a non-empty JSON array of chat messages or an object with a 'messages' array." | |
| ) | |
| out: list[dict[str, str]] = [] | |
| for idx, item in enumerate(raw): | |
| if not isinstance(item, dict): | |
| raise ValueError(f"{source} message at index {idx} must be an object.") | |
| role = item.get("role") | |
| content = item.get("content") | |
| if not isinstance(role, str) or not role.strip(): | |
| raise ValueError(f"{source} message at index {idx} has invalid 'role'.") | |
| if not isinstance(content, str) or not content: | |
| raise ValueError(f"{source} message at index {idx} has invalid 'content'.") | |
| out.append({"role": role.strip(), "content": content}) | |
| return out | |
| def resolve_prompt(prompt: str | None, prompt_file: str | None) -> list[dict[str, str]]: | |
| if prompt_file: | |
| text = Path(prompt_file).read_text(encoding="utf-8").strip() | |
| if not text: | |
| raise ValueError(f"Prompt file is empty: {prompt_file}") | |
| try: | |
| payload = json.loads(text) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError( | |
| f"Prompt file must be JSON chat messages: {prompt_file}" | |
| ) from exc | |
| return parse_chat_messages(payload, f"Prompt file '{prompt_file}'") | |
| if prompt is None or not prompt.strip(): | |
| raise ValueError("Provide --prompt or --prompt-file") | |
| return [{"role": "user", "content": prompt}] | |
| def resolve_model_source( | |
| model_id: str, gguf_file: str | None | |
| ) -> tuple[str, str | None]: | |
| model_path = Path(model_id).expanduser() | |
| if model_path.is_file(): | |
| if model_path.suffix.lower() == ".gguf": | |
| if gguf_file is not None: | |
| raise ValueError( | |
| "Pass either a GGUF file as --model-id or pass --gguf-file, not both." | |
| ) | |
| return str(model_path.parent), model_path.name | |
| raise ValueError( | |
| "Expected --model-id to be a Hugging Face model id or local model directory, " | |
| f"but got a file path: '{model_id}'." | |
| ) | |
| if gguf_file is None: | |
| return model_id, None | |
| gguf_path = Path(gguf_file).expanduser() | |
| if gguf_path.is_absolute(): | |
| if not gguf_path.is_file(): | |
| raise ValueError(f"GGUF file not found: {gguf_file}") | |
| return model_id, str(gguf_path) | |
| if model_path.is_dir(): | |
| candidate = model_path / gguf_path | |
| if candidate.is_file(): | |
| return model_id, candidate.name | |
| return model_id, gguf_file | |
| def validate_gguf_dependencies(gguf_file: str | None) -> None: | |
| if gguf_file is None: | |
| return | |
| missing: list[str] = [] | |
| try: | |
| import accelerate # noqa: F401 | |
| except ImportError: | |
| missing.append("accelerate") | |
| try: | |
| import gguf # noqa: F401 | |
| except ImportError: | |
| missing.append("gguf") | |
| if missing: | |
| joined = ", ".join(missing) | |
| raise ValueError( | |
| "GGUF loading requires additional dependencies missing from this environment: " | |
| f"{joined}. Install them with: `uv add {' '.join(missing)}`." | |
| ) | |
| def load_attention_tensor( | |
| model_id: str, | |
| gguf_file: str | None, | |
| messages: list[dict[str, str]], | |
| dtype: str, | |
| device: str, | |
| allow_download: bool, | |
| layers_arg: str, | |
| heads_arg: str, | |
| ) -> tuple[np.ndarray, list[tuple[int, int]], list[str], list[int], list[int], str]: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| dtype_map = { | |
| "auto": "auto", | |
| "float16": torch.float16, | |
| "float32": torch.float32, | |
| "bfloat16": torch.bfloat16, | |
| } | |
| if dtype not in dtype_map: | |
| raise ValueError(f"Unsupported dtype: {dtype}") | |
| local_only = not allow_download | |
| tokenizer_kwargs: dict[str, object] = { | |
| "local_files_only": local_only, | |
| "trust_remote_code": False, | |
| } | |
| if gguf_file is not None: | |
| tokenizer_kwargs["gguf_file"] = gguf_file | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs) | |
| if tokenizer.chat_template is None: | |
| raise ValueError("Tokenizer has no chat template; this script requires one.") | |
| model_input_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| model_kwargs: dict[str, object] = { | |
| "local_files_only": local_only, | |
| "trust_remote_code": False, | |
| "output_attentions": True, | |
| } | |
| if gguf_file is not None: | |
| model_kwargs["gguf_file"] = gguf_file | |
| selected_dtype = dtype_map[dtype] | |
| if selected_dtype != "auto": | |
| model_kwargs["torch_dtype"] = selected_dtype | |
| model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
| if device != "auto": | |
| model = model.to(device) | |
| elif torch.backends.mps.is_available(): | |
| model = model.to("mps") | |
| elif torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| model.eval() | |
| encoded = tokenizer( | |
| model_input_text, return_tensors="pt", return_offsets_mapping=True | |
| ) | |
| offsets_raw = encoded.pop("offset_mapping") | |
| model_device = next(model.parameters()).device | |
| encoded = {k: v.to(model_device) for k, v in encoded.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encoded, output_attentions=True, use_cache=False) | |
| attentions = outputs.attentions | |
| if attentions is None: | |
| raise RuntimeError("Model did not return attentions") | |
| n_layers = len(attentions) | |
| n_heads = attentions[0].shape[1] | |
| layers = parse_comma_list(layers_arg, n_layers, "layer") | |
| heads = parse_comma_list(heads_arg, n_heads, "head") | |
| tensor_layers: list[np.ndarray] = [] | |
| for layer_idx in layers: | |
| layer = ( | |
| attentions[layer_idx][0].detach().float().cpu().numpy() | |
| ) # [heads, seq, seq] | |
| selected_heads: list[np.ndarray] = [] | |
| for head_idx in heads: | |
| selected_heads.append(layer[head_idx]) | |
| tensor_layers.append(np.stack(selected_heads, axis=0)) | |
| attention = np.stack(tensor_layers, axis=0) | |
| token_ids = encoded["input_ids"][0].detach().cpu().tolist() | |
| tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
| offsets = [(int(s), int(e)) for s, e in offsets_raw[0].detach().cpu().tolist()] | |
| return attention, offsets, tokens, layers, heads, model_input_text | |
| def build_segments( | |
| prompt: str, offsets: list[tuple[int, int]] | |
| ) -> tuple[list[dict], list[dict]]: | |
| non_special: list[tuple[int, int, int]] = [] | |
| for full_idx, (start, end) in enumerate(offsets): | |
| if end > start: | |
| non_special.append((full_idx, start, end)) | |
| non_special.sort(key=lambda x: (x[1], x[2])) | |
| segments: list[dict] = [] | |
| tokens: list[dict] = [] | |
| cursor = 0 | |
| token_id = 0 | |
| for full_idx, start, end in non_special: | |
| if start > cursor: | |
| segments.append( | |
| { | |
| "text": prompt[cursor:start], | |
| "start": cursor, | |
| "end": start, | |
| "token_id": None, | |
| } | |
| ) | |
| token_text = prompt[start:end] | |
| if token_text: | |
| tokens.append( | |
| { | |
| "id": token_id, | |
| "full_index": full_idx, | |
| "start": start, | |
| "end": end, | |
| "text": token_text, | |
| } | |
| ) | |
| segments.append( | |
| { | |
| "text": token_text, | |
| "start": start, | |
| "end": end, | |
| "token_id": token_id, | |
| } | |
| ) | |
| token_id += 1 | |
| cursor = max(cursor, end) | |
| if cursor < len(prompt): | |
| segments.append( | |
| { | |
| "text": prompt[cursor:], | |
| "start": cursor, | |
| "end": len(prompt), | |
| "token_id": None, | |
| } | |
| ) | |
| return segments, tokens | |
| def validate_opacity_params(contrast: float, opacity_floor: float) -> None: | |
| if contrast <= 0: | |
| raise ValueError("--contrast must be > 0") | |
| if not (0 <= opacity_floor < 1): | |
| raise ValueError("--opacity-floor must be in [0, 1)") | |
| def detect_causal_attention(attention: np.ndarray, atol: float = 1e-8) -> bool: | |
| token_count = attention.shape[-1] | |
| for row_idx in range(token_count - 1): | |
| if np.any(np.abs(attention[..., row_idx, row_idx + 1 :]) > atol): | |
| return False | |
| return True | |
| def encode_opacity_matrix( | |
| matrix: np.ndarray, | |
| out: np.ndarray, | |
| contrast: float, | |
| opacity_floor: float, | |
| causal: bool, | |
| ) -> None: | |
| token_count = matrix.shape[0] | |
| expected = ( | |
| token_count * (token_count + 1) // 2 if causal else token_count * token_count | |
| ) | |
| if out.size != expected: | |
| raise ValueError( | |
| f"Output buffer size {out.size} does not match expected {expected}." | |
| ) | |
| p = 1.0 / contrast | |
| floor_byte = int(round(opacity_floor * 255.0)) | |
| cursor = 0 | |
| for row_idx in range(token_count): | |
| row_full = matrix[row_idx, :token_count] | |
| row_len = row_idx + 1 if causal else token_count | |
| row = row_full[:row_len] | |
| lo = float(np.min(row_full)) | |
| hi = float(np.max(row_full)) | |
| dst = out[cursor : cursor + row_len] | |
| if abs(hi - lo) < 1e-12: | |
| dst.fill(floor_byte) | |
| else: | |
| norm = (row - lo) / (hi - lo) | |
| boosted = np.power(norm, p) | |
| opacity = opacity_floor + (1.0 - opacity_floor) * boosted | |
| dst[:] = np.clip(np.rint(opacity * 255.0), 0, 255).astype(np.uint8) | |
| cursor += row_len | |
| def build_compact_opacity_payload( | |
| attention: np.ndarray, | |
| contrast: float, | |
| opacity_floor: float, | |
| ) -> tuple[str, dict[str, int | bool]]: | |
| layer_count, head_count, token_count, _ = attention.shape | |
| causal = detect_causal_attention(attention) | |
| bytes_per_combo = ( | |
| token_count * (token_count + 1) // 2 if causal else token_count * token_count | |
| ) | |
| combo_count = 1 + layer_count + layer_count * head_count | |
| packed = np.empty(combo_count * bytes_per_combo, dtype=np.uint8) | |
| layer_means = attention.mean(axis=1) | |
| all_mean = layer_means.mean(axis=0) | |
| combo_idx = 0 | |
| encode_opacity_matrix( | |
| all_mean, | |
| packed[combo_idx * bytes_per_combo : (combo_idx + 1) * bytes_per_combo], | |
| contrast=contrast, | |
| opacity_floor=opacity_floor, | |
| causal=causal, | |
| ) | |
| combo_idx += 1 | |
| for layer_idx in range(layer_count): | |
| encode_opacity_matrix( | |
| layer_means[layer_idx], | |
| packed[combo_idx * bytes_per_combo : (combo_idx + 1) * bytes_per_combo], | |
| contrast=contrast, | |
| opacity_floor=opacity_floor, | |
| causal=causal, | |
| ) | |
| combo_idx += 1 | |
| for layer_idx in range(layer_count): | |
| for head_idx in range(head_count): | |
| encode_opacity_matrix( | |
| attention[layer_idx, head_idx], | |
| packed[combo_idx * bytes_per_combo : (combo_idx + 1) * bytes_per_combo], | |
| contrast=contrast, | |
| opacity_floor=opacity_floor, | |
| causal=causal, | |
| ) | |
| combo_idx += 1 | |
| payload_b64 = base64.b64encode(packed.tobytes()).decode("ascii") | |
| meta = { | |
| "causal": causal, | |
| "bytes_per_combo": int(bytes_per_combo), | |
| "combo_count": int(combo_count), | |
| "floor_byte": int(round(opacity_floor * 255.0)), | |
| } | |
| return payload_b64, meta | |
| def generate_html(data: dict) -> str: | |
| payload = json.dumps(data, ensure_ascii=False, separators=(",", ":")) | |
| return f"""<!doctype html> | |
| <html lang=\"en\"> | |
| <head> | |
| <meta charset=\"utf-8\" /> | |
| <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" /> | |
| <title>Typing Attention Timeline</title> | |
| <style> | |
| body {{ | |
| margin: 0; | |
| background: #fff; | |
| color: #000; | |
| font-family: "Times New Roman", Times, serif; | |
| }} | |
| .wrap {{ | |
| max-width: 900px; | |
| margin: 22px auto; | |
| padding: 0 16px 28px; | |
| }} | |
| .controls {{ | |
| display: flex; | |
| gap: 10px; | |
| flex-wrap: wrap; | |
| align-items: center; | |
| justify-content: center; | |
| padding: 2px 0 10px; | |
| margin-bottom: 12px; | |
| }} | |
| button {{ | |
| border: 1px solid #111; | |
| background: #111; | |
| color: #fff; | |
| border-radius: 999px; | |
| padding: 6px 14px; | |
| font-family: "Times New Roman", Times, serif; | |
| font-size: 15px; | |
| line-height: 1; | |
| cursor: pointer; | |
| transition: background-color 120ms ease, transform 120ms ease; | |
| }} | |
| button:hover {{ background: #2a2a2a; }} | |
| button:active {{ transform: translateY(1px); }} | |
| .attn-panel {{ | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| gap: 6px; | |
| margin: 4px 0 12px; | |
| }} | |
| .attn-top-row {{ | |
| display: flex; | |
| align-items: center; | |
| gap: 12px; | |
| }} | |
| .attn-heads-title {{ | |
| font-size: 24px; | |
| line-height: 1; | |
| font-weight: 600; | |
| }} | |
| .attn-grid-wrap {{ | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 12px; | |
| max-width: 100%; | |
| overflow-x: auto; | |
| padding-bottom: 4px; | |
| }} | |
| .attn-layers-label {{ | |
| font-size: 24px; | |
| line-height: 1; | |
| font-weight: 600; | |
| }} | |
| .attention-grid {{ | |
| display: flex; | |
| flex-direction: column; | |
| gap: 8px; | |
| align-items: flex-start; | |
| }} | |
| .attn-row {{ | |
| display: flex; | |
| align-items: center; | |
| gap: 6px; | |
| }} | |
| .attn-head-row {{ | |
| display: flex; | |
| flex-wrap: nowrap; | |
| gap: 4px; | |
| }} | |
| .attn-btn {{ | |
| border: 2px solid #222; | |
| background: #f6f6f6; | |
| width: 16px; | |
| height: 16px; | |
| border-radius: 2px; | |
| padding: 0; | |
| cursor: pointer; | |
| transition: background-color 120ms ease, transform 120ms ease, border-color 120ms ease; | |
| }} | |
| .attn-btn:hover {{ | |
| background: #e8e8e8; | |
| border-color: #111; | |
| }} | |
| .attn-btn:active {{ | |
| transform: translateY(1px); | |
| }} | |
| .attn-btn.layer-square {{ | |
| background: #ececec; | |
| }} | |
| .attn-btn.layer-square:hover {{ | |
| background: #dfdfdf; | |
| }} | |
| .attn-btn.attn-all-btn {{ | |
| width: 24px; | |
| height: 24px; | |
| font-family: "Times New Roman", Times, serif; | |
| font-size: 8px; | |
| line-height: 1; | |
| color: #222; | |
| text-transform: lowercase; | |
| }} | |
| .attn-btn.attn-all-btn:hover {{ | |
| background: #e8e8e8; | |
| border-color: #111; | |
| color: #111; | |
| }} | |
| .attn-btn.is-active {{ | |
| background: #111; | |
| border-color: #111; | |
| }} | |
| .attn-btn.is-active.attn-all-btn {{ | |
| color: #fff; | |
| }} | |
| @media (max-width: 760px) {{ | |
| .attn-top-row {{ | |
| gap: 8px; | |
| }} | |
| .attn-heads-title {{ | |
| font-size: 18px; | |
| }} | |
| .attn-layers-label {{ | |
| font-size: 18px; | |
| }} | |
| .attn-grid-wrap {{ | |
| gap: 12px; | |
| }} | |
| .attn-btn {{ | |
| width: 12px; | |
| height: 12px; | |
| }} | |
| .attn-btn.attn-all-btn {{ | |
| width: 18px; | |
| height: 18px; | |
| font-size: 6px; | |
| }} | |
| .attn-head-row {{ | |
| gap: 3px; | |
| }} | |
| }} | |
| .stage {{ | |
| background: #fff; | |
| padding: 18px 16px; | |
| min-height: 300px; | |
| line-height: 1.7; | |
| font-size: 22px; | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| }} | |
| .token {{ | |
| color: #000; | |
| transition: opacity 580ms ease; | |
| opacity: 0.2; | |
| }} | |
| .cursor {{ | |
| display: inline-block; | |
| width: 0.07em; | |
| height: 1.05em; | |
| vertical-align: text-bottom; | |
| background: #000; | |
| color: #000; | |
| animation: blink 1s steps(1) infinite; | |
| user-select: none; | |
| margin-left: 1px; | |
| }} | |
| @keyframes blink {{ | |
| 50% {{ opacity: 0; }} | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class=\"wrap\"> | |
| <div class=\"controls\"> | |
| <button id=\"play\">Play</button> | |
| <button id=\"pause\">Pause</button> | |
| <button id=\"reset\">Reset</button> | |
| </div> | |
| <div class=\"attn-panel\" aria-label=\"Attention selector\"> | |
| <div class=\"attn-top-row\"> | |
| <div class=\"attn-heads-title\">heads</div> | |
| <button id=\"attn-all-btn\" class=\"attn-btn attn-all-btn\" type=\"button\">all</button> | |
| </div> | |
| <div class=\"attn-grid-wrap\"> | |
| <div class=\"attn-layers-label\">layers</div> | |
| <div id=\"attention-grid\" class=\"attention-grid\"></div> | |
| </div> | |
| </div> | |
| <div id=\"stage\" class=\"stage\" aria-live=\"polite\"></div> | |
| </div> | |
| <script> | |
| const DATA = {payload}; | |
| const stage = document.getElementById('stage'); | |
| const playBtn = document.getElementById('play'); | |
| const pauseBtn = document.getElementById('pause'); | |
| const resetBtn = document.getElementById('reset'); | |
| const allSelectBtn = document.getElementById('attn-all-btn'); | |
| const attentionGrid = document.getElementById('attention-grid'); | |
| const CHARS_PER_SEC = 30; | |
| const FLOOR_OPACITY = String(DATA.opacity_floor); | |
| const tokenCount = DATA.tokens.length; | |
| const layerCount = DATA.meta.layers.length; | |
| const headCount = DATA.meta.heads.length; | |
| const bytesPerCombo = DATA.meta.bytes_per_combo; | |
| const usesCausalRows = Boolean(DATA.meta.causal); | |
| function decodeBase64ToUint8(base64Text) {{ | |
| const binary = atob(base64Text); | |
| const out = new Uint8Array(binary.length); | |
| for (let i = 0; i < binary.length; i++) {{ | |
| out[i] = binary.charCodeAt(i); | |
| }} | |
| return out; | |
| }} | |
| let opacityRowsB64 = DATA.opacity_rows_b64; | |
| const opacityRowsBytes = decodeBase64ToUint8(opacityRowsB64); | |
| opacityRowsB64 = ''; | |
| const expectedBytes = DATA.meta.combo_count * bytesPerCombo; | |
| if (opacityRowsBytes.length !== expectedBytes) {{ | |
| throw new Error(`Decoded opacity payload has ${{opacityRowsBytes.length}} bytes, expected ${{expectedBytes}}.`); | |
| }} | |
| delete DATA.opacity_rows_b64; | |
| const segmentEls = []; | |
| const tokenEls = []; | |
| for (const seg of DATA.segments) {{ | |
| const span = document.createElement('span'); | |
| if (seg.token_id !== null) {{ | |
| span.className = 'token'; | |
| span.style.opacity = FLOOR_OPACITY; | |
| tokenEls[seg.token_id] = span; | |
| }} | |
| segmentEls.push({{ span, seg }}); | |
| stage.appendChild(span); | |
| }} | |
| const cursor = document.createElement('span'); | |
| cursor.className = 'cursor'; | |
| cursor.textContent = ''; | |
| stage.appendChild(cursor); | |
| const tokenEndToId = new Map(); | |
| for (const tok of DATA.tokens) {{ | |
| tokenEndToId.set(tok.end, tok.id); | |
| }} | |
| let typedLen = 0; | |
| let timer = null; | |
| let lastAppliedToken = null; | |
| let activeComboIndex = 0; | |
| let selectionLayerRaw = 'all'; | |
| let selectionHeadRaw = 'all'; | |
| const layerKeys = DATA.meta.layers.map((layer) => String(layer)); | |
| const headKeys = DATA.meta.heads.map((head) => String(head)); | |
| const layerPosByRaw = new Map(layerKeys.map((key, i) => [key, i])); | |
| const headPosByRaw = new Map(headKeys.map((key, i) => [key, i])); | |
| const layerButtons = new Map(); | |
| const headButtonsByLayer = new Map(); | |
| function comboIndexForSelection(layerRaw, headRaw) {{ | |
| if (layerRaw === 'all' && headRaw === 'all') {{ | |
| return 0; | |
| }} | |
| const layerPos = layerPosByRaw.get(layerRaw); | |
| if (layerPos === undefined) {{ | |
| return null; | |
| }} | |
| if (headRaw === 'all') {{ | |
| return 1 + layerPos; | |
| }} | |
| const headPos = headPosByRaw.get(headRaw); | |
| if (headPos === undefined) {{ | |
| return null; | |
| }} | |
| return 1 + layerCount + layerPos * headCount + headPos; | |
| }} | |
| function applyCurrentSelection() {{ | |
| const comboIdx = comboIndexForSelection(selectionLayerRaw, selectionHeadRaw); | |
| if (comboIdx === null) return; | |
| activeComboIndex = comboIdx; | |
| if (lastAppliedToken !== null) {{ | |
| applyRow(lastAppliedToken); | |
| }} else {{ | |
| for (const el of tokenEls) {{ | |
| if (!el) continue; | |
| el.style.opacity = FLOOR_OPACITY; | |
| }} | |
| }} | |
| }} | |
| function clearSelectionHighlights() {{ | |
| allSelectBtn.classList.remove('is-active'); | |
| for (const btn of layerButtons.values()) {{ | |
| btn.classList.remove('is-active'); | |
| }} | |
| for (const btns of headButtonsByLayer.values()) {{ | |
| for (const btn of btns) {{ | |
| btn.classList.remove('is-active'); | |
| }} | |
| }} | |
| }} | |
| function onSelectionClick(layerRaw, headRaw, btn) {{ | |
| selectionLayerRaw = layerRaw; | |
| selectionHeadRaw = headRaw; | |
| clearSelectionHighlights(); | |
| if (layerRaw === 'all' && headRaw === 'all') {{ | |
| btn.classList.add('is-active'); | |
| }} else if (layerRaw !== 'all' && headRaw === 'all') {{ | |
| const layerBtn = layerButtons.get(layerRaw); | |
| if (layerBtn) {{ | |
| layerBtn.classList.add('is-active'); | |
| }} | |
| const rowHeadBtns = headButtonsByLayer.get(layerRaw) || []; | |
| for (const rowBtn of rowHeadBtns) {{ | |
| rowBtn.classList.add('is-active'); | |
| }} | |
| }} else if (layerRaw !== 'all') {{ | |
| btn.classList.add('is-active'); | |
| }} | |
| applyCurrentSelection(); | |
| }} | |
| function moveSelectionByArrow(layerRaw, headRaw, key) {{ | |
| if (key !== 'ArrowUp' && key !== 'ArrowDown' && key !== 'ArrowLeft' && key !== 'ArrowRight') {{ | |
| return null; | |
| }} | |
| if (layerRaw === 'all' && headRaw === 'all') {{ | |
| if ((key === 'ArrowDown' || key === 'ArrowRight') && layerKeys.length > 0) {{ | |
| const firstLayerKey = layerKeys[0]; | |
| const firstHeadKey = headKeys[0] || 'all'; | |
| if (firstHeadKey === 'all') {{ | |
| return {{ layerRaw: firstLayerKey, headRaw: 'all', btn: layerButtons.get(firstLayerKey) || null }}; | |
| }} | |
| const rowHeadButtons = headButtonsByLayer.get(firstLayerKey) || []; | |
| return {{ layerRaw: firstLayerKey, headRaw: firstHeadKey, btn: rowHeadButtons[0] || null }}; | |
| }} | |
| return null; | |
| }} | |
| const layerPos = layerKeys.indexOf(layerRaw); | |
| if (layerPos < 0) return null; | |
| const isLayerSquare = headRaw === 'all'; | |
| const headPos = isLayerSquare ? -1 : headKeys.indexOf(headRaw); | |
| if (!isLayerSquare && headPos < 0) return null; | |
| let nextLayerPos = layerPos; | |
| let nextHeadPos = headPos; | |
| if (key === 'ArrowLeft') {{ | |
| if (headPos > 0) {{ | |
| nextHeadPos = headPos - 1; | |
| }} else if (headPos === 0) {{ | |
| nextHeadPos = -1; | |
| }} else {{ | |
| return null; | |
| }} | |
| }} else if (key === 'ArrowRight') {{ | |
| if (headPos < headKeys.length - 1) {{ | |
| nextHeadPos = headPos + 1; | |
| }} else if (headPos === -1 && headKeys.length > 0) {{ | |
| nextHeadPos = 0; | |
| }} else {{ | |
| return null; | |
| }} | |
| }} else if (key === 'ArrowUp') {{ | |
| if (layerPos === 0) return null; | |
| nextLayerPos = layerPos - 1; | |
| }} else if (key === 'ArrowDown') {{ | |
| if (layerPos >= layerKeys.length - 1) return null; | |
| nextLayerPos = layerPos + 1; | |
| }} | |
| const nextLayerRaw = layerKeys[nextLayerPos]; | |
| if (nextHeadPos === -1) {{ | |
| return {{ layerRaw: nextLayerRaw, headRaw: 'all', btn: layerButtons.get(nextLayerRaw) || null }}; | |
| }} | |
| const rowHeadButtons = headButtonsByLayer.get(nextLayerRaw) || []; | |
| return {{ | |
| layerRaw: nextLayerRaw, | |
| headRaw: headKeys[nextHeadPos], | |
| btn: rowHeadButtons[nextHeadPos] || null, | |
| }}; | |
| }} | |
| function buildSelectionButton(layerRaw, headRaw, extraClass, tooltip) {{ | |
| const btn = document.createElement('button'); | |
| btn.type = 'button'; | |
| btn.className = extraClass ? `attn-btn ${{extraClass}}` : 'attn-btn'; | |
| btn.setAttribute('aria-label', tooltip); | |
| btn.title = tooltip; | |
| btn.addEventListener('click', () => onSelectionClick(layerRaw, headRaw, btn)); | |
| btn.addEventListener('keydown', (event) => {{ | |
| const target = moveSelectionByArrow(layerRaw, headRaw, event.key); | |
| if (!target || !target.btn) return; | |
| event.preventDefault(); | |
| target.btn.focus(); | |
| onSelectionClick(target.layerRaw, target.headRaw, target.btn); | |
| }}); | |
| return btn; | |
| }} | |
| function buildAttentionGrid() {{ | |
| for (const layer of DATA.meta.layers) {{ | |
| const layerKey = String(layer); | |
| const row = document.createElement('div'); | |
| row.className = 'attn-row'; | |
| const layerBtn = buildSelectionButton( | |
| layerKey, | |
| 'all', | |
| 'layer-square', | |
| `Layer ${{layer}} (all heads)` | |
| ); | |
| row.appendChild(layerBtn); | |
| layerButtons.set(layerKey, layerBtn); | |
| const headRow = document.createElement('div'); | |
| headRow.className = 'attn-head-row'; | |
| const rowHeadButtons = []; | |
| for (const head of DATA.meta.heads) {{ | |
| const headBtn = buildSelectionButton( | |
| layerKey, | |
| String(head), | |
| '', | |
| `Layer ${{layer}}, head ${{head}}` | |
| ); | |
| headRow.appendChild(headBtn); | |
| rowHeadButtons.push(headBtn); | |
| }} | |
| headButtonsByLayer.set(layerKey, rowHeadButtons); | |
| row.appendChild(headRow); | |
| attentionGrid.appendChild(row); | |
| }} | |
| }} | |
| function applyRow(tokenId) {{ | |
| if (tokenId < 0 || tokenId >= tokenCount) return; | |
| const comboOffset = activeComboIndex * bytesPerCombo; | |
| if (usesCausalRows) {{ | |
| const rowOffset = (tokenId * (tokenId + 1)) / 2; | |
| const rowStart = comboOffset + rowOffset; | |
| for (let i = 0; i <= tokenId; i++) {{ | |
| const el = tokenEls[i]; | |
| if (!el) continue; | |
| el.style.opacity = String(opacityRowsBytes[rowStart + i] / 255); | |
| }} | |
| for (let i = tokenId + 1; i < tokenEls.length; i++) {{ | |
| const el = tokenEls[i]; | |
| if (!el) continue; | |
| el.style.opacity = FLOOR_OPACITY; | |
| }} | |
| return; | |
| }} | |
| const rowStart = comboOffset + tokenId * tokenCount; | |
| for (let i = 0; i < tokenEls.length; i++) {{ | |
| const el = tokenEls[i]; | |
| if (!el) continue; | |
| el.style.opacity = String(opacityRowsBytes[rowStart + i] / 255); | |
| }} | |
| }} | |
| function renderTyped() {{ | |
| for (const item of segmentEls) {{ | |
| const seg = item.seg; | |
| const span = item.span; | |
| let text = ''; | |
| if (typedLen <= seg.start) {{ | |
| text = ''; | |
| }} else if (typedLen >= seg.end) {{ | |
| text = seg.text; | |
| }} else {{ | |
| text = seg.text.slice(0, typedLen - seg.start); | |
| }} | |
| span.textContent = text; | |
| if (seg.token_id !== null && typedLen > seg.start && typedLen < seg.end) {{ | |
| // Keep the actively typed token fully visible until its boundary applies a row. | |
| span.style.opacity = '1'; | |
| }} | |
| }} | |
| const boundaryToken = tokenEndToId.get(typedLen); | |
| if (boundaryToken !== undefined) {{ | |
| lastAppliedToken = boundaryToken; | |
| applyRow(boundaryToken); | |
| }} | |
| }} | |
| function measureFullPromptHeight() {{ | |
| const prevTypedLen = typedLen; | |
| const prevLastAppliedToken = lastAppliedToken; | |
| typedLen = DATA.prompt_length; | |
| lastAppliedToken = null; | |
| renderTyped(); | |
| const measured = stage.scrollHeight; | |
| typedLen = prevTypedLen; | |
| lastAppliedToken = prevLastAppliedToken; | |
| renderTyped(); | |
| return measured; | |
| }} | |
| function lockStageHeightToFullPrompt() {{ | |
| const measured = measureFullPromptHeight(); | |
| if (measured > 0) {{ | |
| stage.style.minHeight = `${{measured}}px`; | |
| }} | |
| }} | |
| function tick() {{ | |
| if (typedLen >= DATA.prompt_length) {{ | |
| pause(); | |
| return; | |
| }} | |
| typedLen += 1; | |
| renderTyped(); | |
| }} | |
| function play() {{ | |
| if (timer !== null) return; | |
| const intervalMs = Math.max(8, Math.floor(1000 / CHARS_PER_SEC)); | |
| timer = setInterval(tick, intervalMs); | |
| }} | |
| function pause() {{ | |
| if (timer === null) return; | |
| clearInterval(timer); | |
| timer = null; | |
| }} | |
| function reset() {{ | |
| pause(); | |
| typedLen = 0; | |
| lastAppliedToken = null; | |
| for (const el of tokenEls) {{ | |
| if (!el) continue; | |
| el.style.opacity = FLOOR_OPACITY; | |
| }} | |
| renderTyped(); | |
| }} | |
| function showFullPrompt() {{ | |
| pause(); | |
| typedLen = DATA.prompt_length; | |
| lastAppliedToken = null; | |
| for (const el of tokenEls) {{ | |
| if (!el) continue; | |
| el.style.opacity = FLOOR_OPACITY; | |
| }} | |
| renderTyped(); | |
| }} | |
| playBtn.addEventListener('click', play); | |
| pauseBtn.addEventListener('click', pause); | |
| resetBtn.addEventListener('click', reset); | |
| allSelectBtn.addEventListener('click', () => onSelectionClick('all', 'all', allSelectBtn)); | |
| allSelectBtn.addEventListener('keydown', (event) => {{ | |
| const target = moveSelectionByArrow('all', 'all', event.key); | |
| if (!target || !target.btn) return; | |
| event.preventDefault(); | |
| target.btn.focus(); | |
| onSelectionClick(target.layerRaw, target.headRaw, target.btn); | |
| }}); | |
| buildAttentionGrid(); | |
| onSelectionClick('all', 'all', allSelectBtn); | |
| showFullPrompt(); | |
| lockStageHeightToFullPrompt(); | |
| reset(); | |
| window.addEventListener('resize', lockStageHeightToFullPrompt); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Generate a self-contained HTML animation that types the prompt character-by-character " | |
| "and updates token opacity at each token boundary from selectable layer/head attention." | |
| ) | |
| ) | |
| parser.add_argument( | |
| "--model-id", | |
| required=True, | |
| help=( | |
| "HF Transformers model id or local Transformers model directory. " | |
| "You can also pass a .gguf file path directly." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--gguf-file", | |
| default=None, | |
| help=( | |
| "Optional GGUF filename/path to load via Transformers GGUF support. " | |
| "If --model-id is a .gguf file, this is inferred automatically." | |
| ), | |
| ) | |
| parser.add_argument("--prompt", default=None, help="Prompt text") | |
| parser.add_argument( | |
| "--prompt-file", | |
| default=None, | |
| help="Path to JSON chat messages file (array of {role, content} or object with messages)", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| default="outputs/typing_attention_timeline/index.html", | |
| help="Output HTML file path", | |
| ) | |
| parser.add_argument( | |
| "--layers", | |
| default="all", | |
| help="Comma-separated layer indices or 'all' for all available in the UI", | |
| ) | |
| parser.add_argument( | |
| "--heads", | |
| default="all", | |
| help="Comma-separated head indices or 'all' for all available in the UI", | |
| ) | |
| parser.add_argument( | |
| "--dtype", | |
| choices=["auto", "float16", "float32", "bfloat16"], | |
| default="auto", | |
| help="Torch dtype", | |
| ) | |
| parser.add_argument("--device", default="auto", help="auto, cpu, cuda, mps") | |
| parser.add_argument( | |
| "--contrast", | |
| type=float, | |
| default=3.0, | |
| help="Opacity contrast for each attention row (>1 boosts low scores)", | |
| ) | |
| parser.add_argument( | |
| "--opacity-floor", | |
| type=float, | |
| default=0.2, | |
| help="Minimum token opacity in [0, 1)", | |
| ) | |
| parser.add_argument( | |
| "--allow-download", | |
| action="store_true", | |
| help="Allow download if model is not already local", | |
| ) | |
| return parser | |
| def main() -> None: | |
| args = build_parser().parse_args() | |
| messages = resolve_prompt(args.prompt, args.prompt_file) | |
| validate_opacity_params(args.contrast, args.opacity_floor) | |
| model_id, gguf_file = resolve_model_source(args.model_id, args.gguf_file) | |
| validate_gguf_dependencies(gguf_file) | |
| attention, offsets, _tokens, layers, heads, prompt_for_timeline = ( | |
| load_attention_tensor( | |
| model_id=model_id, | |
| gguf_file=gguf_file, | |
| messages=messages, | |
| dtype=args.dtype, | |
| device=args.device, | |
| allow_download=args.allow_download, | |
| layers_arg=args.layers, | |
| heads_arg=args.heads, | |
| ) | |
| ) | |
| segments, tokens = build_segments(prompt_for_timeline, offsets) | |
| if not tokens: | |
| raise ValueError("No non-special tokens found in prompt") | |
| token_full_indices = np.array( | |
| [int(t["full_index"]) for t in tokens], dtype=np.int64 | |
| ) | |
| visible_attention = attention[:, :, token_full_indices, :][ | |
| :, :, :, token_full_indices | |
| ] | |
| opacity_payload_b64, compact_meta = build_compact_opacity_payload( | |
| visible_attention, | |
| contrast=args.contrast, | |
| opacity_floor=args.opacity_floor, | |
| ) | |
| data = { | |
| "prompt_length": len(prompt_for_timeline), | |
| "segments": segments, | |
| "tokens": tokens, | |
| "opacity_rows_b64": opacity_payload_b64, | |
| "contrast": args.contrast, | |
| "opacity_floor": args.opacity_floor, | |
| "meta": { | |
| "layers": layers, | |
| "heads": heads, | |
| **compact_meta, | |
| }, | |
| } | |
| html = generate_html(data) | |
| out = Path(args.output) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| out.write_text(html, encoding="utf-8") | |
| print(f"Saved: {out.resolve()}") | |
| print(f"Prompt chars: {len(prompt_for_timeline)}, rendered tokens: {len(tokens)}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment