Skip to content

Instantly share code, notes, and snippets.

@ruvnet
Created May 8, 2026 18:49
Show Gist options
  • Select an option

  • Save ruvnet/af1638d7db2961f60d732467b4282ad5 to your computer and use it in GitHub Desktop.

Select an option

Save ruvnet/af1638d7db2961f60d732467b4282ad5 to your computer and use it in GitHub Desktop.
Pattern fill, without training: a Rust library that copies the style of a few example sequences (Mario levels, drum loops, configs) without any model training. Bidirectional fill mode beats a 1st-order Markov chain by 4x.

Pattern fill, without training

A small Rust library that copies the style of a few example sequences and produces new ones in the same shape — without training a model.

You give it a handful of examples (Mario level slices, drum loops, snippets of structured text — any short tokens that have a pattern). It reads them once. From then on it can produce new sequences that look like they came from the same source. No GPUs. No PyTorch. No model files. Just Rust.

It does this by re-using a sparse attention kernel (the math behind transformers) as a kind of lookup table rather than a learned model. The examples are the model.

The trick has two flavours:

  • Stream mode — produces output one token at a time, left to right. Like writing into a text box. ~12 microseconds per token.
  • Fill mode — starts from a blank canvas and fills it in everywhere at once, refining over a few rounds. Like content-aware fill in Photoshop, but for tokens. ~6 milliseconds for a 64-token grid.

Both modes share the same underlying machinery. Fill mode is the interesting one: it can also repair a partial sequence (mask out the broken bits, fill them back in) which stream mode can't do.


A worked example: drum patterns

Give it four classic 16-step drum loops:

K.h.S.h.K.h.S.h.   (basic rock beat)
KhhhShhhKhhhShhh   (funky 16th-note hats)
..S...S...S...S.   (reggae one-drop)
K..K.S..K..K.S..   (sparse boom-bap)

K is kick, S is snare, h is closed hi-hat, H is open hi-hat, . is silence. That's a 5-token alphabet and 64 tokens of total "training data".

Then ask for a new 4-bar pattern (64 steps):

Stream mode, started from K.h.S.h., in 268 microseconds:

K.h.S.h.SHHK.SH.
K.SH.SH.h.hhS.H.
hHhH.HHHSSS.HhS.
K.H.SS..K.K.hKHS

Fill mode, blank canvas, in 5.7 milliseconds:

S.S.S.Shh.KhhhSh
hS.S.S.SS.SS.SS.
SS..ShShShShShS.
.SS.SS.ShSShSShS

Both are new — none of them appears in the four input patterns — but both clearly belong to the same family. Density (how busy the patterns are) and longest run of identical tokens both land in the corpus's range.


How to use it on your own examples

The whole API is four lines:

use ruvllm_retrieval_diffusion::{Retriever, Diffuser, RetrievalConfig, SamplingConfig};

let cfg = RetrievalConfig {
    vocab_size: 5,        // how many distinct tokens your alphabet has
    head_dim: 64,         // 64 is fine for any vocab up to ~32 tokens
    pos_scale: 0.0,       // 0 for repeating patterns, 0.5 for grids
    ..RetrievalConfig::default()
};

let corpus: Vec<u8> = encode_my_examples();           // u8 indices
let retriever = Retriever::new(corpus, cfg, 0xCAFE);  // one-time init

// stream mode
let new_seq = retriever.generate_fast(&seed, 256, &SamplingConfig::quality(), 0xC0FFEE);

// fill mode
let new_grid = Diffuser::new(&retriever).diffuse(700, 24, &SamplingConfig::quality(), 0xD1FF);

Things people have plugged in already:

  • Super Mario Bros level slices (the originating example — full writeup gist).
  • Drum-machine patterns (this gist's drum_patterns.rs).

Things that should plug in cleanly:

  • Terraform / k8s YAML snippets — feed it a directory of your team's configs, generate new ones in the same style. Useful for templating starters.
  • MIDI loops — same idea as drum patterns but with full pitched notes.
  • Log-line templates — predict next-line shapes from a corpus of past logs.
  • MAGVIT-style visual tokens — if you have a vector-quantised image codec, this becomes a tiny image-fill demo. Image diffusion in Rust with no training pipeline.

Why this isn't an LLM

It's not. There's no learning. There are no weights to update. If your examples don't show a pattern, the output won't either. It can't count, it can't follow grammar, it can't reason. It can copy local style.

That's the point: there are dozens of small problems where "copy local style" is exactly the right tool — and where shipping a real LLM (with the PyTorch toolchain, model weights, GPU dependency) would be wildly overkill. This is what you reach for instead.


Why bidirectional fill is the headline result

We also tried the cheaper alternatives — uniform random sampling and a plain 1st-order Markov chain (the classical bigram model). On the Mario benchmark, scored as L2 distance to corpus on five quality metrics:

pipeline L2 distance to corpus
Bidirectional fill 0.72
1st-order Markov bigram 2.75
Uniform random 3.35
Stream mode 5.00

Bidirectional fill wins by ~4× over the closest non-trivial baseline.

The headline takeaway: the value is not bigram fidelity (the Markov chain has perfect bigrams and still loses by 4×). The value is the ability to fill blanks using context from both sides — which only the attention-based fill mode provides.


What's in this gist

  • README.md — this file
  • lib.rs — the whole library, ~600 lines, depends only on ruvllm_sparse_attention
  • drum_patterns.rs — the worked example, runnable as a cargo example

The library is published in the ruvnet/RuVector workspace as crates/ruvllm_retrieval_diffusion (branch sparse-mario until merged).

Mario lineage: the Sparse-Mario gist has the full 13-iteration story and benchmark history. This crate is the corpus-agnostic generalisation step: same code, packaged so you can point it at your own examples in four lines.


License

MIT.

//! Drum patterns — second-domain demo of `ruvllm_retrieval_diffusion`.
//!
//! Same pattern as `sparse-mario`, different corpus: instead of Super
//! Mario level tiles we use 5-token drum-machine notation:
//!
//! K = kick, S = snare, h = closed hi-hat, H = open hi-hat, . = silence
//!
//! Four classic 16-step patterns are embedded as the corpus (rock, funk,
//! reggae, boom-bap). The retriever learns the bigram statistics; the
//! diffuser fills bidirectional context. Output is a 64-step (4-bar) loop.
//!
//! Run with: cargo run --release --features parallel --example drum_patterns
use ruvllm_retrieval_diffusion::{Diffuser, RetrievalConfig, Retriever, SamplingConfig};
const VOCAB: &[char] = &['.', 'K', 'S', 'h', 'H']; // index = token id
fn encode_char(c: char) -> Option<u8> {
VOCAB.iter().position(|&v| v == c).map(|i| i as u8)
}
fn decode_token(t: u8) -> char {
VOCAB.get(t as usize).copied().unwrap_or('?')
}
fn encode(s: &str) -> Vec<u8> {
s.chars().filter_map(encode_char).collect()
}
/// Embedded corpus — four 16-step drum loops, hand-authored.
/// Total = 64 tokens (the full corpus is short by design — the demo
/// shows the same training-free retrieval picking up *any* small-vocab
/// rhythmic prior, not specific drum knowledge).
const PATTERNS: &[&str] = &[
// Basic rock beat — 4 on the floor
"K.h.S.h.K.h.S.h.",
// Funk — sixteenth-note hi-hats with snare on 5/13
"KhhhShhhKhhhShhh",
// Reggae one-drop — snare on 3, kick on 4-and
"..S...S...S...S.",
// Boom-bap — sparse kick + ghost snares, open hi-hat lift on 7
"K..K.S..K..K.S..",
];
fn render_bars(tokens: &[u8], steps_per_bar: usize) -> String {
let mut out = String::new();
let mut col = 0;
for &t in tokens {
out.push(decode_token(t));
col += 1;
if col == steps_per_bar {
out.push('\n');
col = 0;
}
}
out
}
fn drum_config() -> RetrievalConfig {
RetrievalConfig {
vocab_size: VOCAB.len(),
head_dim: 64,
// Drum patterns repeat every 16 steps; positional bias would push
// queries late in the prefix toward late corpus positions, which is
// the wrong inductive bias for a strictly-cyclic domain.
pos_scale: 0.0,
mask_sentinel: 255,
diffusion_context_weights: vec![0.5, 0.10],
sparse: ruvllm_retrieval_diffusion::SparseConfig {
window: 32,
block_size: 16,
global_tokens: vec![0],
causal: false,
use_log_stride: true,
use_landmarks: true,
sort_candidates: false,
},
}
}
fn build_corpus() -> Vec<u8> {
// Concatenate all patterns, no separator (the shape is fixed at 16
// steps per bar, so absolute index modulo 16 is the bar position).
let mut c = Vec::new();
for p in PATTERNS {
c.extend(encode(p));
}
c
}
fn main() {
let corpus = build_corpus();
println!("== Drum-pattern retrieval-diffusion demo ==");
println!(
"corpus : {} tokens ({} patterns × 16 steps)",
corpus.len(),
PATTERNS.len()
);
println!("vocab : {:?}", VOCAB);
// Tile distribution
let mut dist = std::collections::HashMap::new();
for &t in &corpus {
*dist.entry(decode_token(t)).or_insert(0usize) += 1;
}
print!("tile mix : ");
for &c in VOCAB {
let n = dist.get(&c).copied().unwrap_or(0);
print!("{}={:.1}% ", c, n as f32 / corpus.len() as f32 * 100.0);
}
println!();
let cfg = drum_config();
let retriever = Retriever::new(corpus.clone(), cfg, 0xD7_5_BABE);
// Seed with the first half of a familiar pattern, ask the model to
// continue. AR walks bigram statistics; should mostly stay in groove.
let seed = encode("K.h.S.h.");
let sampling = SamplingConfig::quality();
println!();
println!("--- AR (KvCache + decode_step) ---");
let t0 = std::time::Instant::now();
let ar = retriever.generate_fast(&seed, 64, &sampling, 0xC0_FFEE_42);
let dt_ar = t0.elapsed();
println!("seed : \"{}\"", String::from_utf8_lossy(&seed.iter().map(|&t| decode_token(t) as u8).collect::<Vec<_>>()));
println!("generated : {} tokens in {:.2?}", 64, dt_ar);
println!();
println!("{}", render_bars(&ar, 16));
// Diffusion — start fully masked, denoise to 4 bars (64 steps) with
// bidirectional context. Boot slice is taken from the corpus.
println!("--- Diffusion (D3PM-style, cosine schedule) ---");
let diffuser = Diffuser::new(&retriever);
let t0 = std::time::Instant::now();
let diff = diffuser.diffuse(64, 24, &sampling, 0xD1_FF_BEEF);
let dt_diff = t0.elapsed();
println!(
"diffused : {} tokens × 24 denoising steps in {:.2?}",
64, dt_diff
);
println!();
println!("{}", render_bars(&diff, 16));
// Compute simple "groove sanity" stats: density (non-silence rate)
// and longest-streak. A real corpus has density ≈ 0.5–0.75.
let density = |toks: &[u8]| -> f32 {
let nonsilence = toks.iter().filter(|&&t| t != 0).count();
nonsilence as f32 / toks.len().max(1) as f32
};
let max_streak = |toks: &[u8]| -> usize {
let mut best = 0;
let mut cur = 0;
let mut prev: Option<u8> = None;
for &t in toks {
if Some(t) == prev {
cur += 1;
} else {
cur = 1;
}
if cur > best {
best = cur;
}
prev = Some(t);
}
best
};
println!("--- groove sanity ---");
println!(
"corpus density={:.2} max_streak={}",
density(&corpus),
max_streak(&corpus)
);
println!(
"AR density={:.2} max_streak={}",
density(&ar),
max_streak(&ar)
);
println!(
"diffusion density={:.2} max_streak={}",
density(&diff),
max_streak(&diff)
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vocab_roundtrip() {
for (i, &c) in VOCAB.iter().enumerate() {
assert_eq!(encode_char(c), Some(i as u8));
assert_eq!(decode_token(i as u8), c);
}
}
#[test]
fn corpus_is_64_tokens() {
let c = build_corpus();
assert_eq!(c.len(), 64, "4 patterns × 16 steps = 64");
for &t in &c {
assert!((t as usize) < VOCAB.len(), "out-of-vocab token {}", t);
}
}
#[test]
fn ar_generation_in_vocab() {
let r = Retriever::new(build_corpus(), drum_config(), 0x1111);
let out = r.generate_fast(&[1u8], 64, &SamplingConfig::quality(), 0x2222);
for &t in &out {
assert!((t as usize) < VOCAB.len());
}
}
#[test]
fn diffusion_clears_all_masks() {
let r = Retriever::new(build_corpus(), drum_config(), 0x1111);
let d = Diffuser::new(&r);
let out = d.diffuse(64, 16, &SamplingConfig::quality(), 0x3333);
let mask = drum_config().mask_sentinel;
for &t in &out {
assert_ne!(t, mask);
assert!((t as usize) < VOCAB.len());
}
}
}
//! Corpus-agnostic training-free retrieval LM and masked discrete diffusion
//! built on `ruvllm_sparse_attention`.
//!
//! Generalises the `sparse-mario` example: any small-vocab token domain can
//! plug in by supplying a corpus and a [`RetrievalConfig`]. The kernel is
//! used as an associative memory — no autograd, no learned weights, no
//! Python toolchain.
//!
//! Two pipelines from one kernel:
//!
//! - [`Retriever::generate_fast`] — autoregressive next-token retrieval via
//! `KvCache` + `decode_step`, O(log T) per generated token.
//! - [`Diffuser::diffuse`] — bidirectional masked discrete diffusion with a
//! MaskGIT cosine schedule. Beats the AR path on aggregate by 6.9× on
//! the Mario benchmark (see `sparse-mario` baselines doc).
//!
//! ## Domain plug-in checklist
//!
//! ```ignore
//! use ruvllm_retrieval_diffusion::{Retriever, Diffuser, RetrievalConfig, SamplingConfig};
//!
//! let cfg = RetrievalConfig {
//! vocab_size: 5, // your domain's token count
//! head_dim: 64, // 64 works well for vocab ≤ 32
//! pos_scale: 0.5, // try 0 to make AR pos-invariant
//! mask_sentinel: 255,
//! ..RetrievalConfig::default()
//! };
//! let corpus: Vec<u8> = encode_my_corpus(); // your encoder, vocab-bounded
//! let retriever = Retriever::new(corpus, cfg, 0xMARI_BEEF);
//! let level = retriever.generate_fast(&seed, 256, &SamplingConfig::quality(), 0xC0FFEE);
//! ```
use ruvllm_sparse_attention::{
AttentionBackend, KvCache, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
};
pub use ruvllm_sparse_attention::SparseAttentionConfig as SparseConfig;
// ----------------------------------------------------------------------
// Retrieval config
// ----------------------------------------------------------------------
/// Static configuration shared by both `Retriever` and `Diffuser`.
///
/// `vocab_size` is the number of distinct tokens (≤ 254 to leave one byte
/// for `mask_sentinel`). `head_dim` is the embedding dimension (64 is a
/// good default — the kernel's `1/sqrt(d)` softmax scale separates matched
/// random unit-vector pairs by ~sqrt(d) which is comfortable at d=64).
#[derive(Clone, Debug)]
pub struct RetrievalConfig {
pub vocab_size: usize,
pub head_dim: usize,
/// Positional encoding weight in K/V row construction (AR path). 0
/// disables — AR becomes purely content-based, useful when the
/// corpus has no per-position structure to exploit. The Mario
/// example uses 0.5; the iter-13 finding was that 0 would halve
/// AR's L2 distance for grid-shaped corpora.
pub pos_scale: f32,
/// Out-of-vocab byte used by the diffuser to mark not-yet-denoised
/// positions. Must be ≥ vocab_size.
pub mask_sentinel: u8,
/// Bidirectional context weights for the diffuser, indexed by
/// `offset - 1` (radius = len()). [0.5, 0.10] is the iter-10 pick.
pub diffusion_context_weights: Vec<f32>,
/// Sparse attention config passed to the underlying kernel. Defaults
/// to non-causal window=256 + log-stride + landmarks.
pub sparse: SparseAttentionConfig,
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
vocab_size: 16,
head_dim: 64,
pos_scale: 0.5,
mask_sentinel: 255,
diffusion_context_weights: vec![0.5, 0.10],
sparse: SparseAttentionConfig {
window: 256,
block_size: 64,
global_tokens: vec![0],
causal: false,
use_log_stride: true,
use_landmarks: true,
sort_candidates: false,
},
}
}
}
// ----------------------------------------------------------------------
// Sampling config
// ----------------------------------------------------------------------
/// Sampling controls applied in `sample_logits` in this order:
/// repetition penalty → top-k → top-p → softmax(/T) → categorical sample.
#[derive(Clone, Debug)]
pub struct SamplingConfig {
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub repetition_penalty: f32,
pub no_repeat_window: usize,
}
impl Default for SamplingConfig {
fn default() -> Self {
Self {
temperature: 1.0,
top_k: 0,
top_p: 0.0,
repetition_penalty: 1.0,
no_repeat_window: 0,
}
}
}
impl SamplingConfig {
/// The Mario-validated quality recipe. Reasonable starting point for any
/// small-vocab domain; tune `no_repeat_window` to your meaningful local
/// span (e.g. one row, one bar of music, one indented config block).
pub fn quality() -> Self {
Self {
temperature: 1.0,
top_k: 5,
top_p: 0.90,
repetition_penalty: 1.7,
no_repeat_window: 24,
}
}
}
// ----------------------------------------------------------------------
// Deterministic PRNG (xorshift32 + Box-Muller normal)
// ----------------------------------------------------------------------
#[inline]
pub fn xorshift32(state: &mut u32) -> u32 {
let mut x = *state;
if x == 0 {
x = 0x9E37_79B9;
}
x ^= x.wrapping_shl(13);
x ^= x.wrapping_shr(17);
x ^= x.wrapping_shl(5);
*state = x;
x
}
#[inline]
pub fn next_uniform(state: &mut u32) -> f32 {
(xorshift32(state) as f32) / (u32::MAX as f32 + 1.0)
}
pub fn next_normal(state: &mut u32) -> f32 {
loop {
let u1 = next_uniform(state);
let u2 = next_uniform(state);
if u1 > 1e-9 {
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f32::consts::PI * u2;
return r * theta.cos();
}
}
}
// ----------------------------------------------------------------------
// Embedding helpers
// ----------------------------------------------------------------------
fn make_embedding_matrix(vocab_size: usize, head_dim: usize, seed: u32) -> Vec<f32> {
let mut state = seed.max(1);
let mut w = vec![0.0f32; vocab_size * head_dim];
for v in w.iter_mut() {
*v = next_normal(&mut state);
}
w
}
#[inline]
fn token_embedding<'a>(t: u8, w: &'a [f32], head_dim: usize) -> &'a [f32] {
let i = (t as usize) * head_dim;
&w[i..i + head_dim]
}
fn pos_encoding_into(i: usize, dim: usize, out: &mut [f32]) {
for d in 0..dim {
let half = d / 2;
let theta = (i as f32) / 10000_f32.powf((2 * half) as f32 / dim as f32);
out[d] = if d % 2 == 0 { theta.sin() } else { theta.cos() };
}
}
// ----------------------------------------------------------------------
// Sample logits helper (rep penalty → top-k → top-p → softmax)
// ----------------------------------------------------------------------
pub fn sample_logits(
logits: &mut [f32],
cfg: &SamplingConfig,
recent: &[u8],
state: &mut u32,
) -> u8 {
let v = logits.len();
if v == 0 {
return 0;
}
if cfg.repetition_penalty > 1.0 + f32::EPSILON && !recent.is_empty() {
let pen = cfg.repetition_penalty;
for &t in recent {
let i = t as usize;
if i < v {
logits[i] = if logits[i] > 0.0 {
logits[i] / pen
} else {
logits[i] * pen
};
}
}
}
if cfg.top_k > 0 && cfg.top_k < v {
let mut idx: Vec<usize> = (0..v).collect();
idx.sort_unstable_by(|&a, &b| {
logits[b]
.partial_cmp(&logits[a])
.unwrap_or(core::cmp::Ordering::Equal)
});
let kth = logits[idx[cfg.top_k - 1]];
for li in logits.iter_mut() {
if *li < kth {
*li = f32::NEG_INFINITY;
}
}
}
if cfg.top_p > 0.0 && cfg.top_p < 1.0 {
let temp_p = cfg.temperature.max(1e-3);
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut pairs: Vec<(usize, f32)> = (0..v)
.map(|i| {
let p = if logits[i].is_finite() {
((logits[i] - max_l) / temp_p).exp()
} else {
0.0
};
(i, p)
})
.collect();
let total: f32 = pairs.iter().map(|p| p.1).sum();
if total > 0.0 {
for pr in pairs.iter_mut() {
pr.1 /= total;
}
pairs.sort_unstable_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal)
});
let mut keep = vec![false; v];
let mut cum = 0.0f32;
for &(idx, p) in pairs.iter() {
keep[idx] = true;
cum += p;
if cum >= cfg.top_p {
break;
}
}
for i in 0..v {
if !keep[i] {
logits[i] = f32::NEG_INFINITY;
}
}
}
}
let temp = cfg.temperature.max(1e-3);
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs = vec![0.0f32; v];
let mut sum = 0.0f32;
for i in 0..v {
if logits[i].is_finite() {
probs[i] = ((logits[i] - max_l) / temp).exp();
sum += probs[i];
}
}
if sum <= 0.0 {
return 0;
}
for p in probs.iter_mut() {
*p /= sum;
}
let r = next_uniform(state);
let mut acc = 0.0f32;
for i in 0..v {
acc += probs[i];
if r < acc {
return i as u8;
}
}
(v - 1) as u8
}
// ----------------------------------------------------------------------
// Retriever — autoregressive retrieval LM
// ----------------------------------------------------------------------
/// Training-free retrieval LM. K[i] = embed(corpus[i]) + pos·pos(i),
/// V[i] = embed(corpus[i+1]) + pos·pos(i). Attention finds positions
/// where the query token matches a corpus token, and reads back what
/// follows it — pure bigram retrieval through the kernel's lookup.
pub struct Retriever {
pub corpus: Vec<u8>,
pub w: Vec<f32>,
pub cfg: RetrievalConfig,
}
impl Retriever {
pub fn new(corpus: Vec<u8>, cfg: RetrievalConfig, embedding_seed: u32) -> Self {
let w = make_embedding_matrix(cfg.vocab_size, cfg.head_dim, embedding_seed);
Self { corpus, w, cfg }
}
fn build_kv_row(&self, tok: u8, abs_pos: usize) -> Tensor3 {
let d = self.cfg.head_dim;
let mut data = vec![0.0f32; d];
let emb = token_embedding(tok, &self.w, d);
let mut pos = vec![0.0f32; d];
pos_encoding_into(abs_pos, d, &mut pos);
for di in 0..d {
data[di] = emb[di] + self.cfg.pos_scale * pos[di];
}
Tensor3::from_vec(data, 1, 1, d).unwrap()
}
fn make_row_tensor(&self, tokens: &[u8], shift_for_value: bool) -> Tensor3 {
let d = self.cfg.head_dim;
let seq = tokens.len();
let mut t = Tensor3::zeros(seq, 1, d);
let mut pos = vec![0.0f32; d];
for i in 0..seq {
let tok = if shift_for_value {
if i + 1 < seq {
tokens[i + 1]
} else {
tokens[i]
}
} else {
tokens[i]
};
let emb = token_embedding(tok, &self.w, d);
pos_encoding_into(i, d, &mut pos);
let row = t.row_mut(i, 0);
for di in 0..d {
row[di] = emb[di] + self.cfg.pos_scale * pos[di];
}
}
t
}
/// Reference path — full forward over corpus + prefix every step.
/// Slow (~O(N log N) per token); use `generate_fast` in production.
pub fn next_token_logits(&self, prefix: &[u8]) -> Vec<f32> {
let mut combined = self.corpus.clone();
combined.extend_from_slice(prefix);
let q = self.make_row_tensor(&combined, false);
let v = self.make_row_tensor(&combined, true);
let attn = SubquadraticSparseAttention::new(self.cfg.sparse.clone()).expect("config");
let out = attn.forward(&q, &q, &v).expect("attention");
let last = combined.len() - 1;
let d = self.cfg.head_dim;
let mut logits = vec![0.0f32; self.cfg.vocab_size];
for v_idx in 0..self.cfg.vocab_size {
let emb = token_embedding(v_idx as u8, &self.w, d);
let mut dot = 0.0f32;
for di in 0..d {
dot += out.get(last, 0, di) * emb[di];
}
logits[v_idx] = dot;
}
logits
}
/// Fast path — pre-fill `KvCache` once, then one O(log T) `decode_step`
/// per generated token. Targets ~3000× speedup vs `next_token_logits`
/// at 700-token generations on the Mario benchmark.
pub fn generate_fast(
&self,
prefix: &[u8],
n: usize,
sampling: &SamplingConfig,
sampler_seed: u32,
) -> Vec<u8> {
let mut state = sampler_seed.max(1);
let d = self.cfg.head_dim;
let cap = self.corpus.len() + prefix.len() + n + 16;
let mut cache = KvCache::new(cap, 1, d, self.cfg.sparse.block_size);
let attn = SubquadraticSparseAttention::new(self.cfg.sparse.clone()).expect("config");
let zero_v = Tensor3::zeros(1, 1, d);
for i in 0..self.corpus.len() {
let next = if i + 1 < self.corpus.len() {
self.corpus[i + 1]
} else {
prefix.first().copied().unwrap_or(self.corpus[i])
};
let k = self.build_kv_row(self.corpus[i], i);
let v = self.build_kv_row(next, i);
cache.try_append(&k, &v).expect("capacity");
}
for j in 0..prefix.len() {
let abs = self.corpus.len() + j;
let k = self.build_kv_row(prefix[j], abs);
let v = if j + 1 < prefix.len() {
self.build_kv_row(prefix[j + 1], abs)
} else {
zero_v.clone()
};
cache.try_append(&k, &v).expect("capacity");
}
let mut sequence = prefix.to_vec();
for _ in 0..n {
let last_idx = cache.len - 1;
let last_tok = sequence.last().copied().unwrap_or(0);
let q = self.build_kv_row(last_tok, last_idx);
let out = attn.decode_step(&q, &cache).expect("decode");
let mut logits = vec![0.0f32; self.cfg.vocab_size];
for v_idx in 0..self.cfg.vocab_size {
let emb = token_embedding(v_idx as u8, &self.w, d);
let mut dot = 0.0f32;
for di in 0..d {
dot += out.get(0, 0, di) * emb[di];
}
logits[v_idx] = dot;
}
let win = sampling.no_repeat_window.min(sequence.len());
let recent = &sequence[sequence.len() - win..];
let next = sample_logits(&mut logits, sampling, recent, &mut state);
let new_idx = cache.len;
let k_new = self.build_kv_row(next, new_idx);
if cache.try_append(&k_new, &zero_v).is_err() {
break;
}
sequence.push(next);
}
sequence
}
}
// ----------------------------------------------------------------------
// Diffuser — bidirectional masked discrete diffusion
// ----------------------------------------------------------------------
pub struct Diffuser<'a> {
pub retriever: &'a Retriever,
}
impl<'a> Diffuser<'a> {
pub fn new(retriever: &'a Retriever) -> Self {
Self { retriever }
}
/// Build bidirectional K and V tensors. K[i] sums weighted neighbour
/// embeddings within radius = `cfg.diffusion_context_weights.len()`.
/// No positional encoding — pure content match.
pub fn make_bidir_kv(&self, seq: &[u8]) -> (Tensor3, Tensor3) {
let d = self.retriever.cfg.head_dim;
let n = seq.len();
let mask = self.retriever.cfg.mask_sentinel;
let weights = &self.retriever.cfg.diffusion_context_weights;
let mut k = Tensor3::zeros(n, 1, d);
let mut v = Tensor3::zeros(n, 1, d);
let zero = vec![0.0f32; d];
for i in 0..n {
let krow = k.row_mut(i, 0);
for slot in 0..weights.len() {
let weight = weights[slot];
let off = slot + 1;
if i >= off && seq[i - off] != mask {
let emb = token_embedding(seq[i - off], &self.retriever.w, d);
for di in 0..d {
krow[di] += weight * emb[di];
}
}
if i + off < n && seq[i + off] != mask {
let emb = token_embedding(seq[i + off], &self.retriever.w, d);
for di in 0..d {
krow[di] += weight * emb[di];
}
}
}
let vrow = v.row_mut(i, 0);
if seq[i] != mask {
let emb = token_embedding(seq[i], &self.retriever.w, d);
vrow.copy_from_slice(emb);
} else {
vrow.copy_from_slice(&zero);
}
}
(k, v)
}
fn diffusion_logits(&self, working: &[u8]) -> Vec<Vec<f32>> {
let d = self.retriever.cfg.head_dim;
let mut combined = self.retriever.corpus.clone();
combined.extend_from_slice(working);
let (k, v) = self.make_bidir_kv(&combined);
let q = k.clone();
let attn =
SubquadraticSparseAttention::new(self.retriever.cfg.sparse.clone()).expect("config");
let out = attn.forward(&q, &q, &v).expect("attention");
let prefix_start = self.retriever.corpus.len();
let vsize = self.retriever.cfg.vocab_size;
let mut all = Vec::with_capacity(working.len());
for i in 0..working.len() {
let idx = prefix_start + i;
let mut logits = vec![0.0f32; vsize];
for v_idx in 0..vsize {
let emb = token_embedding(v_idx as u8, &self.retriever.w, d);
let mut dot = 0.0f32;
for di in 0..d {
dot += out.get(idx, 0, di) * emb[di];
}
logits[v_idx] = dot;
}
all.push(logits);
}
all
}
pub fn denoise_step(
&self,
working: &mut [u8],
keep_count: usize,
sampling: &SamplingConfig,
state: &mut u32,
) {
let mask = self.retriever.cfg.mask_sentinel;
let masked: Vec<usize> = working
.iter()
.enumerate()
.filter(|(_, &t)| t == mask)
.map(|(i, _)| i)
.collect();
if masked.is_empty() || keep_count == 0 {
return;
}
let logits = self.diffusion_logits(working);
let confidence = |row: &[f32]| -> f32 {
let max_l = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
let mut top = 0.0f32;
for &l in row.iter() {
let e = (l - max_l).exp();
sum += e;
if e > top {
top = e;
}
}
if sum > 0.0 {
top / sum
} else {
0.0
}
};
let mut ranked: Vec<(usize, f32)> = masked
.iter()
.map(|&j| (j, confidence(&logits[j])))
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
let n = keep_count.min(ranked.len());
for ki in 0..n {
let (j, _) = ranked[ki];
let mut row = logits[j].clone();
let mut next = sample_logits(&mut row, sampling, &[], state);
if (next as usize) >= self.retriever.cfg.vocab_size {
next = 0;
}
working[j] = next;
}
}
/// Full pipeline: all-mask init → context boot (random contiguous corpus
/// slice) → cosine-scheduled denoising → final sweep. Returns a fully
/// denoised sequence of length `n`.
pub fn diffuse(
&self,
n: usize,
n_steps: usize,
sampling: &SamplingConfig,
seed: u32,
) -> Vec<u8> {
let mut state = seed.max(1);
let mask = self.retriever.cfg.mask_sentinel;
let mut working = vec![mask; n];
let corpus_len = self.retriever.corpus.len();
let boot_len = (n / 8).clamp(8, 64).min(corpus_len.saturating_sub(1));
if boot_len > 0 && corpus_len > boot_len {
let corpus_off = (xorshift32(&mut state) as usize) % (corpus_len - boot_len);
let work_off = (xorshift32(&mut state) as usize) % (n - boot_len);
working[work_off..work_off + boot_len].copy_from_slice(
&self.retriever.corpus[corpus_off..corpus_off + boot_len],
);
}
for t in 0..n_steps {
let frac = ((t + 1) as f32) / (n_steps as f32);
let target_masked =
(n as f32 * (core::f32::consts::FRAC_PI_2 * frac).cos()) as usize;
let current_masked = working.iter().filter(|&&x| x == mask).count();
let to_unmask = current_masked.saturating_sub(target_masked).max(1);
self.denoise_step(&mut working, to_unmask, sampling, &mut state);
}
let remaining = working.iter().filter(|&&x| x == mask).count();
if remaining > 0 {
self.denoise_step(&mut working, remaining, sampling, &mut state);
}
working
}
}
#[cfg(test)]
mod tests {
use super::*;
fn small_corpus() -> Vec<u8> {
// 4-token vocab: [0, 1, 2, 3]. A repeating pattern with structure.
let mut c = Vec::new();
for _ in 0..50 {
c.push(0);
c.push(1);
c.push(2);
c.push(3);
}
c
}
fn small_cfg() -> RetrievalConfig {
RetrievalConfig {
vocab_size: 4,
head_dim: 32,
pos_scale: 0.5,
mask_sentinel: 255,
diffusion_context_weights: vec![0.5, 0.10],
sparse: SparseAttentionConfig {
window: 64,
block_size: 16,
global_tokens: vec![0],
causal: false,
use_log_stride: true,
use_landmarks: true,
sort_candidates: false,
},
}
}
#[test]
fn retriever_runs_end_to_end() {
let r = Retriever::new(small_corpus(), small_cfg(), 0xABCD);
let out = r.generate_fast(&[0u8, 1u8], 32, &SamplingConfig::quality(), 0xBEEF);
assert_eq!(out.len(), 34);
for &t in &out {
assert!((t as usize) < 4, "out-of-vocab token {}", t);
}
}
#[test]
fn retriever_is_deterministic() {
let r = Retriever::new(small_corpus(), small_cfg(), 0xABCD);
let a = r.generate_fast(&[0u8], 64, &SamplingConfig::quality(), 0xCAFE);
let b = r.generate_fast(&[0u8], 64, &SamplingConfig::quality(), 0xCAFE);
assert_eq!(a, b);
}
#[test]
fn diffuser_runs_end_to_end_and_clears_masks() {
let r = Retriever::new(small_corpus(), small_cfg(), 0xABCD);
let d = Diffuser::new(&r);
let out = d.diffuse(80, 8, &SamplingConfig::quality(), 0xDEAD);
assert_eq!(out.len(), 80);
let mask = small_cfg().mask_sentinel;
for &t in &out {
assert!(t != mask, "leftover mask in output");
assert!((t as usize) < 4, "out-of-vocab token {}", t);
}
}
#[test]
fn diffuser_is_deterministic() {
let r = Retriever::new(small_corpus(), small_cfg(), 0xABCD);
let d = Diffuser::new(&r);
let a = d.diffuse(80, 8, &SamplingConfig::quality(), 0x1234);
let b = d.diffuse(80, 8, &SamplingConfig::quality(), 0x1234);
assert_eq!(a, b);
}
#[test]
fn sample_logits_top_k_one_is_greedy() {
let cfg = SamplingConfig {
temperature: 1.0,
top_k: 1,
..SamplingConfig::default()
};
let mut logits = vec![1.0, 2.0, 0.5, 3.0];
let mut state = 0xABCDu32;
let next = sample_logits(&mut logits, &cfg, &[], &mut state);
assert_eq!(next, 3, "top_k=1 should pick the argmax (index 3)");
}
#[test]
fn pos_scale_zero_makes_retrieval_position_invariant() {
// With pos_scale=0 the AR retriever depends only on token identity.
// The same prefix should produce the same prediction regardless of
// its absolute position — i.e. shifting the prefix index doesn't
// change next-token logits *modulo what positions are in the sparse
// window*. We just check that the path runs and produces in-vocab
// tokens; full position-invariance is corpus-dependent.
let mut cfg = small_cfg();
cfg.pos_scale = 0.0;
let r = Retriever::new(small_corpus(), cfg, 0xABCD);
let out = r.generate_fast(&[2u8], 32, &SamplingConfig::default(), 0xBEEF);
for &t in &out {
assert!((t as usize) < 4);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment