Skip to content

Instantly share code, notes, and snippets.

@samwho
Last active February 21, 2026 05:28
Show Gist options
  • Select an option

  • Save samwho/6e35176f32296ddc7c6579be606973fc to your computer and use it in GitHub Desktop.

Select an option

Save samwho/6e35176f32296ddc7c6579be606973fc to your computer and use it in GitHub Desktop.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy>=1.26",
# "torch>=2.4",
# "transformers>=4.45",
# "accelerate>=1.12.0",
# "gguf>=0.17.1",
# ]
# ///
#
# Example usage:
# uv run attention_typing.py \
# --model-id Qwen/Qwen3-4B-Instruct-2507 \
# --prompt-file prompt.txt \
# --layers all \
# --heads all \
# --contrast 2.0 \
# --output typing.html
#
# Example prompt.txt:
# [
# {"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": "Write a short haiku about attention maps."}
# ]
from __future__ import annotations
import argparse
import base64
import json
from pathlib import Path
import numpy as np
def parse_comma_list(raw: str, max_items: int, name: str) -> list[int]:
raw = raw.strip().lower()
if raw == "all":
return list(range(max_items))
out: list[int] = []
for part in raw.split(","):
part = part.strip()
if not part:
continue
idx = int(part)
if idx < 0 or idx >= max_items:
raise ValueError(f"{name} index {idx} out of range [0, {max_items - 1}]")
out.append(idx)
if not out:
raise ValueError(f"No valid {name} indices provided")
return sorted(set(out))
def parse_chat_messages(raw: object, source: str) -> list[dict[str, str]]:
if isinstance(raw, dict):
raw = raw.get("messages")
if not isinstance(raw, list) or not raw:
raise ValueError(
f"{source} must be a non-empty JSON array of chat messages or an object with a 'messages' array."
)
out: list[dict[str, str]] = []
for idx, item in enumerate(raw):
if not isinstance(item, dict):
raise ValueError(f"{source} message at index {idx} must be an object.")
role = item.get("role")
content = item.get("content")
if not isinstance(role, str) or not role.strip():
raise ValueError(f"{source} message at index {idx} has invalid 'role'.")
if not isinstance(content, str) or not content:
raise ValueError(f"{source} message at index {idx} has invalid 'content'.")
out.append({"role": role.strip(), "content": content})
return out
def resolve_prompt(prompt: str | None, prompt_file: str | None) -> list[dict[str, str]]:
if prompt_file:
text = Path(prompt_file).read_text(encoding="utf-8").strip()
if not text:
raise ValueError(f"Prompt file is empty: {prompt_file}")
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
raise ValueError(
f"Prompt file must be JSON chat messages: {prompt_file}"
) from exc
return parse_chat_messages(payload, f"Prompt file '{prompt_file}'")
if prompt is None or not prompt.strip():
raise ValueError("Provide --prompt or --prompt-file")
return [{"role": "user", "content": prompt}]
def resolve_model_source(
model_id: str, gguf_file: str | None
) -> tuple[str, str | None]:
model_path = Path(model_id).expanduser()
if model_path.is_file():
if model_path.suffix.lower() == ".gguf":
if gguf_file is not None:
raise ValueError(
"Pass either a GGUF file as --model-id or pass --gguf-file, not both."
)
return str(model_path.parent), model_path.name
raise ValueError(
"Expected --model-id to be a Hugging Face model id or local model directory, "
f"but got a file path: '{model_id}'."
)
if gguf_file is None:
return model_id, None
gguf_path = Path(gguf_file).expanduser()
if gguf_path.is_absolute():
if not gguf_path.is_file():
raise ValueError(f"GGUF file not found: {gguf_file}")
return model_id, str(gguf_path)
if model_path.is_dir():
candidate = model_path / gguf_path
if candidate.is_file():
return model_id, candidate.name
return model_id, gguf_file
def validate_gguf_dependencies(gguf_file: str | None) -> None:
if gguf_file is None:
return
missing: list[str] = []
try:
import accelerate # noqa: F401
except ImportError:
missing.append("accelerate")
try:
import gguf # noqa: F401
except ImportError:
missing.append("gguf")
if missing:
joined = ", ".join(missing)
raise ValueError(
"GGUF loading requires additional dependencies missing from this environment: "
f"{joined}. Install them with: `uv add {' '.join(missing)}`."
)
def load_attention_tensor(
model_id: str,
gguf_file: str | None,
messages: list[dict[str, str]],
dtype: str,
device: str,
allow_download: bool,
layers_arg: str,
heads_arg: str,
) -> tuple[np.ndarray, list[tuple[int, int]], list[str], list[int], list[int], str]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype_map = {
"auto": "auto",
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
if dtype not in dtype_map:
raise ValueError(f"Unsupported dtype: {dtype}")
local_only = not allow_download
tokenizer_kwargs: dict[str, object] = {
"local_files_only": local_only,
"trust_remote_code": False,
}
if gguf_file is not None:
tokenizer_kwargs["gguf_file"] = gguf_file
tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs)
if tokenizer.chat_template is None:
raise ValueError("Tokenizer has no chat template; this script requires one.")
model_input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
model_kwargs: dict[str, object] = {
"local_files_only": local_only,
"trust_remote_code": False,
"output_attentions": True,
}
if gguf_file is not None:
model_kwargs["gguf_file"] = gguf_file
selected_dtype = dtype_map[dtype]
if selected_dtype != "auto":
model_kwargs["torch_dtype"] = selected_dtype
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
if device != "auto":
model = model.to(device)
elif torch.backends.mps.is_available():
model = model.to("mps")
elif torch.cuda.is_available():
model = model.to("cuda")
model.eval()
encoded = tokenizer(
model_input_text, return_tensors="pt", return_offsets_mapping=True
)
offsets_raw = encoded.pop("offset_mapping")
model_device = next(model.parameters()).device
encoded = {k: v.to(model_device) for k, v in encoded.items()}
with torch.no_grad():
outputs = model(**encoded, output_attentions=True, use_cache=False)
attentions = outputs.attentions
if attentions is None:
raise RuntimeError("Model did not return attentions")
n_layers = len(attentions)
n_heads = attentions[0].shape[1]
layers = parse_comma_list(layers_arg, n_layers, "layer")
heads = parse_comma_list(heads_arg, n_heads, "head")
tensor_layers: list[np.ndarray] = []
for layer_idx in layers:
layer = (
attentions[layer_idx][0].detach().float().cpu().numpy()
) # [heads, seq, seq]
selected_heads: list[np.ndarray] = []
for head_idx in heads:
selected_heads.append(layer[head_idx])
tensor_layers.append(np.stack(selected_heads, axis=0))
attention = np.stack(tensor_layers, axis=0)
token_ids = encoded["input_ids"][0].detach().cpu().tolist()
tokens = tokenizer.convert_ids_to_tokens(token_ids)
offsets = [(int(s), int(e)) for s, e in offsets_raw[0].detach().cpu().tolist()]
return attention, offsets, tokens, layers, heads, model_input_text
def build_segments(
prompt: str, offsets: list[tuple[int, int]]
) -> tuple[list[dict], list[dict]]:
non_special: list[tuple[int, int, int]] = []
for full_idx, (start, end) in enumerate(offsets):
if end > start:
non_special.append((full_idx, start, end))
non_special.sort(key=lambda x: (x[1], x[2]))
segments: list[dict] = []
tokens: list[dict] = []
cursor = 0
token_id = 0
for full_idx, start, end in non_special:
if start > cursor:
segments.append(
{
"text": prompt[cursor:start],
"start": cursor,
"end": start,
"token_id": None,
}
)
token_text = prompt[start:end]
if token_text:
tokens.append(
{
"id": token_id,
"full_index": full_idx,
"start": start,
"end": end,
"text": token_text,
}
)
segments.append(
{
"text": token_text,
"start": start,
"end": end,
"token_id": token_id,
}
)
token_id += 1
cursor = max(cursor, end)
if cursor < len(prompt):
segments.append(
{
"text": prompt[cursor:],
"start": cursor,
"end": len(prompt),
"token_id": None,
}
)
return segments, tokens
def validate_opacity_params(contrast: float, opacity_floor: float) -> None:
if contrast <= 0:
raise ValueError("--contrast must be > 0")
if not (0 <= opacity_floor < 1):
raise ValueError("--opacity-floor must be in [0, 1)")
def detect_causal_attention(attention: np.ndarray, atol: float = 1e-8) -> bool:
token_count = attention.shape[-1]
for row_idx in range(token_count - 1):
if np.any(np.abs(attention[..., row_idx, row_idx + 1 :]) > atol):
return False
return True
def encode_opacity_matrix(
matrix: np.ndarray,
out: np.ndarray,
contrast: float,
opacity_floor: float,
causal: bool,
) -> None:
token_count = matrix.shape[0]
expected = (
token_count * (token_count + 1) // 2 if causal else token_count * token_count
)
if out.size != expected:
raise ValueError(
f"Output buffer size {out.size} does not match expected {expected}."
)
p = 1.0 / contrast
floor_byte = int(round(opacity_floor * 255.0))
cursor = 0
for row_idx in range(token_count):
row_full = matrix[row_idx, :token_count]
row_len = row_idx + 1 if causal else token_count
row = row_full[:row_len]
lo = float(np.min(row_full))
hi = float(np.max(row_full))
dst = out[cursor : cursor + row_len]
if abs(hi - lo) < 1e-12:
dst.fill(floor_byte)
else:
norm = (row - lo) / (hi - lo)
boosted = np.power(norm, p)
opacity = opacity_floor + (1.0 - opacity_floor) * boosted
dst[:] = np.clip(np.rint(opacity * 255.0), 0, 255).astype(np.uint8)
cursor += row_len
def build_compact_opacity_payload(
attention: np.ndarray,
contrast: float,
opacity_floor: float,
) -> tuple[str, dict[str, int | bool]]:
layer_count, head_count, token_count, _ = attention.shape
causal = detect_causal_attention(attention)
bytes_per_combo = (
token_count * (token_count + 1) // 2 if causal else token_count * token_count
)
combo_count = 1 + layer_count + layer_count * head_count
packed = np.empty(combo_count * bytes_per_combo, dtype=np.uint8)
layer_means = attention.mean(axis=1)
all_mean = layer_means.mean(axis=0)
combo_idx = 0
encode_opacity_matrix(
all_mean,
packed[combo_idx * bytes_per_combo : (combo_idx + 1) * bytes_per_combo],
contrast=contrast,
opacity_floor=opacity_floor,
causal=causal,
)
combo_idx += 1
for layer_idx in range(layer_count):
encode_opacity_matrix(
layer_means[layer_idx],
packed[combo_idx * bytes_per_combo : (combo_idx + 1) * bytes_per_combo],
contrast=contrast,
opacity_floor=opacity_floor,
causal=causal,
)
combo_idx += 1
for layer_idx in range(layer_count):
for head_idx in range(head_count):
encode_opacity_matrix(
attention[layer_idx, head_idx],
packed[combo_idx * bytes_per_combo : (combo_idx + 1) * bytes_per_combo],
contrast=contrast,
opacity_floor=opacity_floor,
causal=causal,
)
combo_idx += 1
payload_b64 = base64.b64encode(packed.tobytes()).decode("ascii")
meta = {
"causal": causal,
"bytes_per_combo": int(bytes_per_combo),
"combo_count": int(combo_count),
"floor_byte": int(round(opacity_floor * 255.0)),
}
return payload_b64, meta
def generate_html(data: dict) -> str:
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
return f"""<!doctype html>
<html lang=\"en\">
<head>
<meta charset=\"utf-8\" />
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" />
<title>Typing Attention Timeline</title>
<style>
body {{
margin: 0;
background: #fff;
color: #000;
font-family: "Times New Roman", Times, serif;
}}
.wrap {{
max-width: 900px;
margin: 22px auto;
padding: 0 16px 28px;
}}
.controls {{
display: flex;
gap: 10px;
flex-wrap: wrap;
align-items: center;
justify-content: center;
padding: 2px 0 10px;
margin-bottom: 12px;
}}
button {{
border: 1px solid #111;
background: #111;
color: #fff;
border-radius: 999px;
padding: 6px 14px;
font-family: "Times New Roman", Times, serif;
font-size: 15px;
line-height: 1;
cursor: pointer;
transition: background-color 120ms ease, transform 120ms ease;
}}
button:hover {{ background: #2a2a2a; }}
button:active {{ transform: translateY(1px); }}
.attn-panel {{
display: flex;
flex-direction: column;
align-items: center;
gap: 6px;
margin: 4px 0 12px;
}}
.attn-top-row {{
display: flex;
align-items: center;
gap: 12px;
}}
.attn-heads-title {{
font-size: 24px;
line-height: 1;
font-weight: 600;
}}
.attn-grid-wrap {{
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
max-width: 100%;
overflow-x: auto;
padding-bottom: 4px;
}}
.attn-layers-label {{
font-size: 24px;
line-height: 1;
font-weight: 600;
}}
.attention-grid {{
display: flex;
flex-direction: column;
gap: 8px;
align-items: flex-start;
}}
.attn-row {{
display: flex;
align-items: center;
gap: 6px;
}}
.attn-head-row {{
display: flex;
flex-wrap: nowrap;
gap: 4px;
}}
.attn-btn {{
border: 2px solid #222;
background: #f6f6f6;
width: 16px;
height: 16px;
border-radius: 2px;
padding: 0;
cursor: pointer;
transition: background-color 120ms ease, transform 120ms ease, border-color 120ms ease;
}}
.attn-btn:hover {{
background: #e8e8e8;
border-color: #111;
}}
.attn-btn:active {{
transform: translateY(1px);
}}
.attn-btn.layer-square {{
background: #ececec;
}}
.attn-btn.layer-square:hover {{
background: #dfdfdf;
}}
.attn-btn.attn-all-btn {{
width: 24px;
height: 24px;
font-family: "Times New Roman", Times, serif;
font-size: 8px;
line-height: 1;
color: #222;
text-transform: lowercase;
}}
.attn-btn.attn-all-btn:hover {{
background: #e8e8e8;
border-color: #111;
color: #111;
}}
.attn-btn.is-active {{
background: #111;
border-color: #111;
}}
.attn-btn.is-active.attn-all-btn {{
color: #fff;
}}
@media (max-width: 760px) {{
.attn-top-row {{
gap: 8px;
}}
.attn-heads-title {{
font-size: 18px;
}}
.attn-layers-label {{
font-size: 18px;
}}
.attn-grid-wrap {{
gap: 12px;
}}
.attn-btn {{
width: 12px;
height: 12px;
}}
.attn-btn.attn-all-btn {{
width: 18px;
height: 18px;
font-size: 6px;
}}
.attn-head-row {{
gap: 3px;
}}
}}
.stage {{
background: #fff;
padding: 18px 16px;
min-height: 300px;
line-height: 1.7;
font-size: 22px;
white-space: pre-wrap;
word-wrap: break-word;
}}
.token {{
color: #000;
transition: opacity 580ms ease;
opacity: 0.2;
}}
.cursor {{
display: inline-block;
width: 0.07em;
height: 1.05em;
vertical-align: text-bottom;
background: #000;
color: #000;
animation: blink 1s steps(1) infinite;
user-select: none;
margin-left: 1px;
}}
@keyframes blink {{
50% {{ opacity: 0; }}
}}
</style>
</head>
<body>
<div class=\"wrap\">
<div class=\"controls\">
<button id=\"play\">Play</button>
<button id=\"pause\">Pause</button>
<button id=\"reset\">Reset</button>
</div>
<div class=\"attn-panel\" aria-label=\"Attention selector\">
<div class=\"attn-top-row\">
<div class=\"attn-heads-title\">heads</div>
<button id=\"attn-all-btn\" class=\"attn-btn attn-all-btn\" type=\"button\">all</button>
</div>
<div class=\"attn-grid-wrap\">
<div class=\"attn-layers-label\">layers</div>
<div id=\"attention-grid\" class=\"attention-grid\"></div>
</div>
</div>
<div id=\"stage\" class=\"stage\" aria-live=\"polite\"></div>
</div>
<script>
const DATA = {payload};
const stage = document.getElementById('stage');
const playBtn = document.getElementById('play');
const pauseBtn = document.getElementById('pause');
const resetBtn = document.getElementById('reset');
const allSelectBtn = document.getElementById('attn-all-btn');
const attentionGrid = document.getElementById('attention-grid');
const CHARS_PER_SEC = 30;
const FLOOR_OPACITY = String(DATA.opacity_floor);
const tokenCount = DATA.tokens.length;
const layerCount = DATA.meta.layers.length;
const headCount = DATA.meta.heads.length;
const bytesPerCombo = DATA.meta.bytes_per_combo;
const usesCausalRows = Boolean(DATA.meta.causal);
function decodeBase64ToUint8(base64Text) {{
const binary = atob(base64Text);
const out = new Uint8Array(binary.length);
for (let i = 0; i < binary.length; i++) {{
out[i] = binary.charCodeAt(i);
}}
return out;
}}
let opacityRowsB64 = DATA.opacity_rows_b64;
const opacityRowsBytes = decodeBase64ToUint8(opacityRowsB64);
opacityRowsB64 = '';
const expectedBytes = DATA.meta.combo_count * bytesPerCombo;
if (opacityRowsBytes.length !== expectedBytes) {{
throw new Error(`Decoded opacity payload has ${{opacityRowsBytes.length}} bytes, expected ${{expectedBytes}}.`);
}}
delete DATA.opacity_rows_b64;
const segmentEls = [];
const tokenEls = [];
for (const seg of DATA.segments) {{
const span = document.createElement('span');
if (seg.token_id !== null) {{
span.className = 'token';
span.style.opacity = FLOOR_OPACITY;
tokenEls[seg.token_id] = span;
}}
segmentEls.push({{ span, seg }});
stage.appendChild(span);
}}
const cursor = document.createElement('span');
cursor.className = 'cursor';
cursor.textContent = '';
stage.appendChild(cursor);
const tokenEndToId = new Map();
for (const tok of DATA.tokens) {{
tokenEndToId.set(tok.end, tok.id);
}}
let typedLen = 0;
let timer = null;
let lastAppliedToken = null;
let activeComboIndex = 0;
let selectionLayerRaw = 'all';
let selectionHeadRaw = 'all';
const layerKeys = DATA.meta.layers.map((layer) => String(layer));
const headKeys = DATA.meta.heads.map((head) => String(head));
const layerPosByRaw = new Map(layerKeys.map((key, i) => [key, i]));
const headPosByRaw = new Map(headKeys.map((key, i) => [key, i]));
const layerButtons = new Map();
const headButtonsByLayer = new Map();
function comboIndexForSelection(layerRaw, headRaw) {{
if (layerRaw === 'all' && headRaw === 'all') {{
return 0;
}}
const layerPos = layerPosByRaw.get(layerRaw);
if (layerPos === undefined) {{
return null;
}}
if (headRaw === 'all') {{
return 1 + layerPos;
}}
const headPos = headPosByRaw.get(headRaw);
if (headPos === undefined) {{
return null;
}}
return 1 + layerCount + layerPos * headCount + headPos;
}}
function applyCurrentSelection() {{
const comboIdx = comboIndexForSelection(selectionLayerRaw, selectionHeadRaw);
if (comboIdx === null) return;
activeComboIndex = comboIdx;
if (lastAppliedToken !== null) {{
applyRow(lastAppliedToken);
}} else {{
for (const el of tokenEls) {{
if (!el) continue;
el.style.opacity = FLOOR_OPACITY;
}}
}}
}}
function clearSelectionHighlights() {{
allSelectBtn.classList.remove('is-active');
for (const btn of layerButtons.values()) {{
btn.classList.remove('is-active');
}}
for (const btns of headButtonsByLayer.values()) {{
for (const btn of btns) {{
btn.classList.remove('is-active');
}}
}}
}}
function onSelectionClick(layerRaw, headRaw, btn) {{
selectionLayerRaw = layerRaw;
selectionHeadRaw = headRaw;
clearSelectionHighlights();
if (layerRaw === 'all' && headRaw === 'all') {{
btn.classList.add('is-active');
}} else if (layerRaw !== 'all' && headRaw === 'all') {{
const layerBtn = layerButtons.get(layerRaw);
if (layerBtn) {{
layerBtn.classList.add('is-active');
}}
const rowHeadBtns = headButtonsByLayer.get(layerRaw) || [];
for (const rowBtn of rowHeadBtns) {{
rowBtn.classList.add('is-active');
}}
}} else if (layerRaw !== 'all') {{
btn.classList.add('is-active');
}}
applyCurrentSelection();
}}
function moveSelectionByArrow(layerRaw, headRaw, key) {{
if (key !== 'ArrowUp' && key !== 'ArrowDown' && key !== 'ArrowLeft' && key !== 'ArrowRight') {{
return null;
}}
if (layerRaw === 'all' && headRaw === 'all') {{
if ((key === 'ArrowDown' || key === 'ArrowRight') && layerKeys.length > 0) {{
const firstLayerKey = layerKeys[0];
const firstHeadKey = headKeys[0] || 'all';
if (firstHeadKey === 'all') {{
return {{ layerRaw: firstLayerKey, headRaw: 'all', btn: layerButtons.get(firstLayerKey) || null }};
}}
const rowHeadButtons = headButtonsByLayer.get(firstLayerKey) || [];
return {{ layerRaw: firstLayerKey, headRaw: firstHeadKey, btn: rowHeadButtons[0] || null }};
}}
return null;
}}
const layerPos = layerKeys.indexOf(layerRaw);
if (layerPos < 0) return null;
const isLayerSquare = headRaw === 'all';
const headPos = isLayerSquare ? -1 : headKeys.indexOf(headRaw);
if (!isLayerSquare && headPos < 0) return null;
let nextLayerPos = layerPos;
let nextHeadPos = headPos;
if (key === 'ArrowLeft') {{
if (headPos > 0) {{
nextHeadPos = headPos - 1;
}} else if (headPos === 0) {{
nextHeadPos = -1;
}} else {{
return null;
}}
}} else if (key === 'ArrowRight') {{
if (headPos < headKeys.length - 1) {{
nextHeadPos = headPos + 1;
}} else if (headPos === -1 && headKeys.length > 0) {{
nextHeadPos = 0;
}} else {{
return null;
}}
}} else if (key === 'ArrowUp') {{
if (layerPos === 0) return null;
nextLayerPos = layerPos - 1;
}} else if (key === 'ArrowDown') {{
if (layerPos >= layerKeys.length - 1) return null;
nextLayerPos = layerPos + 1;
}}
const nextLayerRaw = layerKeys[nextLayerPos];
if (nextHeadPos === -1) {{
return {{ layerRaw: nextLayerRaw, headRaw: 'all', btn: layerButtons.get(nextLayerRaw) || null }};
}}
const rowHeadButtons = headButtonsByLayer.get(nextLayerRaw) || [];
return {{
layerRaw: nextLayerRaw,
headRaw: headKeys[nextHeadPos],
btn: rowHeadButtons[nextHeadPos] || null,
}};
}}
function buildSelectionButton(layerRaw, headRaw, extraClass, tooltip) {{
const btn = document.createElement('button');
btn.type = 'button';
btn.className = extraClass ? `attn-btn ${{extraClass}}` : 'attn-btn';
btn.setAttribute('aria-label', tooltip);
btn.title = tooltip;
btn.addEventListener('click', () => onSelectionClick(layerRaw, headRaw, btn));
btn.addEventListener('keydown', (event) => {{
const target = moveSelectionByArrow(layerRaw, headRaw, event.key);
if (!target || !target.btn) return;
event.preventDefault();
target.btn.focus();
onSelectionClick(target.layerRaw, target.headRaw, target.btn);
}});
return btn;
}}
function buildAttentionGrid() {{
for (const layer of DATA.meta.layers) {{
const layerKey = String(layer);
const row = document.createElement('div');
row.className = 'attn-row';
const layerBtn = buildSelectionButton(
layerKey,
'all',
'layer-square',
`Layer ${{layer}} (all heads)`
);
row.appendChild(layerBtn);
layerButtons.set(layerKey, layerBtn);
const headRow = document.createElement('div');
headRow.className = 'attn-head-row';
const rowHeadButtons = [];
for (const head of DATA.meta.heads) {{
const headBtn = buildSelectionButton(
layerKey,
String(head),
'',
`Layer ${{layer}}, head ${{head}}`
);
headRow.appendChild(headBtn);
rowHeadButtons.push(headBtn);
}}
headButtonsByLayer.set(layerKey, rowHeadButtons);
row.appendChild(headRow);
attentionGrid.appendChild(row);
}}
}}
function applyRow(tokenId) {{
if (tokenId < 0 || tokenId >= tokenCount) return;
const comboOffset = activeComboIndex * bytesPerCombo;
if (usesCausalRows) {{
const rowOffset = (tokenId * (tokenId + 1)) / 2;
const rowStart = comboOffset + rowOffset;
for (let i = 0; i <= tokenId; i++) {{
const el = tokenEls[i];
if (!el) continue;
el.style.opacity = String(opacityRowsBytes[rowStart + i] / 255);
}}
for (let i = tokenId + 1; i < tokenEls.length; i++) {{
const el = tokenEls[i];
if (!el) continue;
el.style.opacity = FLOOR_OPACITY;
}}
return;
}}
const rowStart = comboOffset + tokenId * tokenCount;
for (let i = 0; i < tokenEls.length; i++) {{
const el = tokenEls[i];
if (!el) continue;
el.style.opacity = String(opacityRowsBytes[rowStart + i] / 255);
}}
}}
function renderTyped() {{
for (const item of segmentEls) {{
const seg = item.seg;
const span = item.span;
let text = '';
if (typedLen <= seg.start) {{
text = '';
}} else if (typedLen >= seg.end) {{
text = seg.text;
}} else {{
text = seg.text.slice(0, typedLen - seg.start);
}}
span.textContent = text;
if (seg.token_id !== null && typedLen > seg.start && typedLen < seg.end) {{
// Keep the actively typed token fully visible until its boundary applies a row.
span.style.opacity = '1';
}}
}}
const boundaryToken = tokenEndToId.get(typedLen);
if (boundaryToken !== undefined) {{
lastAppliedToken = boundaryToken;
applyRow(boundaryToken);
}}
}}
function measureFullPromptHeight() {{
const prevTypedLen = typedLen;
const prevLastAppliedToken = lastAppliedToken;
typedLen = DATA.prompt_length;
lastAppliedToken = null;
renderTyped();
const measured = stage.scrollHeight;
typedLen = prevTypedLen;
lastAppliedToken = prevLastAppliedToken;
renderTyped();
return measured;
}}
function lockStageHeightToFullPrompt() {{
const measured = measureFullPromptHeight();
if (measured > 0) {{
stage.style.minHeight = `${{measured}}px`;
}}
}}
function tick() {{
if (typedLen >= DATA.prompt_length) {{
pause();
return;
}}
typedLen += 1;
renderTyped();
}}
function play() {{
if (timer !== null) return;
const intervalMs = Math.max(8, Math.floor(1000 / CHARS_PER_SEC));
timer = setInterval(tick, intervalMs);
}}
function pause() {{
if (timer === null) return;
clearInterval(timer);
timer = null;
}}
function reset() {{
pause();
typedLen = 0;
lastAppliedToken = null;
for (const el of tokenEls) {{
if (!el) continue;
el.style.opacity = FLOOR_OPACITY;
}}
renderTyped();
}}
function showFullPrompt() {{
pause();
typedLen = DATA.prompt_length;
lastAppliedToken = null;
for (const el of tokenEls) {{
if (!el) continue;
el.style.opacity = FLOOR_OPACITY;
}}
renderTyped();
}}
playBtn.addEventListener('click', play);
pauseBtn.addEventListener('click', pause);
resetBtn.addEventListener('click', reset);
allSelectBtn.addEventListener('click', () => onSelectionClick('all', 'all', allSelectBtn));
allSelectBtn.addEventListener('keydown', (event) => {{
const target = moveSelectionByArrow('all', 'all', event.key);
if (!target || !target.btn) return;
event.preventDefault();
target.btn.focus();
onSelectionClick(target.layerRaw, target.headRaw, target.btn);
}});
buildAttentionGrid();
onSelectionClick('all', 'all', allSelectBtn);
showFullPrompt();
lockStageHeightToFullPrompt();
reset();
window.addEventListener('resize', lockStageHeightToFullPrompt);
</script>
</body>
</html>
"""
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=(
"Generate a self-contained HTML animation that types the prompt character-by-character "
"and updates token opacity at each token boundary from selectable layer/head attention."
)
)
parser.add_argument(
"--model-id",
required=True,
help=(
"HF Transformers model id or local Transformers model directory. "
"You can also pass a .gguf file path directly."
),
)
parser.add_argument(
"--gguf-file",
default=None,
help=(
"Optional GGUF filename/path to load via Transformers GGUF support. "
"If --model-id is a .gguf file, this is inferred automatically."
),
)
parser.add_argument("--prompt", default=None, help="Prompt text")
parser.add_argument(
"--prompt-file",
default=None,
help="Path to JSON chat messages file (array of {role, content} or object with messages)",
)
parser.add_argument(
"--output",
default="outputs/typing_attention_timeline/index.html",
help="Output HTML file path",
)
parser.add_argument(
"--layers",
default="all",
help="Comma-separated layer indices or 'all' for all available in the UI",
)
parser.add_argument(
"--heads",
default="all",
help="Comma-separated head indices or 'all' for all available in the UI",
)
parser.add_argument(
"--dtype",
choices=["auto", "float16", "float32", "bfloat16"],
default="auto",
help="Torch dtype",
)
parser.add_argument("--device", default="auto", help="auto, cpu, cuda, mps")
parser.add_argument(
"--contrast",
type=float,
default=3.0,
help="Opacity contrast for each attention row (>1 boosts low scores)",
)
parser.add_argument(
"--opacity-floor",
type=float,
default=0.2,
help="Minimum token opacity in [0, 1)",
)
parser.add_argument(
"--allow-download",
action="store_true",
help="Allow download if model is not already local",
)
return parser
def main() -> None:
args = build_parser().parse_args()
messages = resolve_prompt(args.prompt, args.prompt_file)
validate_opacity_params(args.contrast, args.opacity_floor)
model_id, gguf_file = resolve_model_source(args.model_id, args.gguf_file)
validate_gguf_dependencies(gguf_file)
attention, offsets, _tokens, layers, heads, prompt_for_timeline = (
load_attention_tensor(
model_id=model_id,
gguf_file=gguf_file,
messages=messages,
dtype=args.dtype,
device=args.device,
allow_download=args.allow_download,
layers_arg=args.layers,
heads_arg=args.heads,
)
)
segments, tokens = build_segments(prompt_for_timeline, offsets)
if not tokens:
raise ValueError("No non-special tokens found in prompt")
token_full_indices = np.array(
[int(t["full_index"]) for t in tokens], dtype=np.int64
)
visible_attention = attention[:, :, token_full_indices, :][
:, :, :, token_full_indices
]
opacity_payload_b64, compact_meta = build_compact_opacity_payload(
visible_attention,
contrast=args.contrast,
opacity_floor=args.opacity_floor,
)
data = {
"prompt_length": len(prompt_for_timeline),
"segments": segments,
"tokens": tokens,
"opacity_rows_b64": opacity_payload_b64,
"contrast": args.contrast,
"opacity_floor": args.opacity_floor,
"meta": {
"layers": layers,
"heads": heads,
**compact_meta,
},
}
html = generate_html(data)
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(html, encoding="utf-8")
print(f"Saved: {out.resolve()}")
print(f"Prompt chars: {len(prompt_for_timeline)}, rendered tokens: {len(tokens)}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment