Skip to content

Instantly share code, notes, and snippets.

@n-WN
Created February 4, 2026 08:59
Show Gist options
  • Select an option

  • Save n-WN/03474e9442c4c21d878f71170b962fd1 to your computer and use it in GitHub Desktop.

Select an option

Save n-WN/03474e9442c4c21d878f71170b962fd1 to your computer and use it in GitHub Desktop.
alictf 2026 crypto-Griffin
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import ast
import base64
import dataclasses
import gzip
import hashlib
import json
import logging
import os
import random
import re
import shutil
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable
import numpy as np
try:
from fpylll import BKZ, IntegerMatrix, LLL
except Exception as exc: # pragma: no cover
raise SystemExit(f"fpylll import failed: {exc}")
try:
from Crypto.Util.number import bytes_to_long
except Exception as exc: # pragma: no cover
raise SystemExit(f"pycryptodome import failed: {exc}")
try:
from sage.all import ( # type: ignore
GF,
ZZ,
EllipticCurve,
Integer,
PolynomialRing,
Zmod,
crt,
factor,
matrix,
)
except Exception as exc: # pragma: no cover
raise SystemExit(
"Sage import failed. Run under a Sage-enabled Python env on the remote host.\n"
f"Original error: {exc}"
)
# Avoid accidental oversubscription inside multiprocess BKZ scanning.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
UUID4_RE = re.compile(
r"^alictf\{[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}\}$"
)
# === BEGIN EMBEDDED OUTPUT (base64+gzip of task/attachment/output.txt) ===
# Optional: fill this constant with the compressed output.txt content so the
# solver can run without external input files.
EMBEDDED_OUTPUT_GZ_B64 = ""
# === END EMBEDDED OUTPUT ===
def _utc_now_id() -> str:
return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
def _sha256(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
def _configure_logger(log_path: Path) -> logging.Logger:
log_path.parent.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger("griffin_onefile_gap")
logger.setLevel(logging.INFO)
logger.handlers.clear()
fmt = logging.Formatter("%(asctime)s.%(msecs)03dZ %(levelname)s %(message)s", "%Y-%m-%dT%H:%M:%S")
fmt.converter = time.gmtime # make the trailing 'Z' true (UTC)
h_file = logging.FileHandler(log_path, mode="a", encoding="utf-8")
h_file.setFormatter(fmt)
h_file.setLevel(logging.INFO)
logger.addHandler(h_file)
h_console = logging.StreamHandler(sys.stdout)
h_console.setFormatter(fmt)
h_console.setLevel(logging.INFO)
logger.addHandler(h_console)
logger.propagate = False
return logger
def _save_json(path: Path, obj, logger: logging.Logger) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
tmp.write_text(json.dumps(obj, indent=2, sort_keys=True), encoding="utf-8")
tmp.replace(path)
logger.info("[checkpoint] wrote %s", path)
def _read_embedded_output_bytes() -> bytes | None:
blob = EMBEDDED_OUTPUT_GZ_B64.strip()
if not blob:
return None
return gzip.decompress(base64.b64decode(blob.encode("ascii")))
@dataclasses.dataclass(frozen=True)
class ChallengeData:
rows: list[list[tuple[int, int]]]
flagct: int
output_sha256: str
def load_challenge_data(*, output_path: Path | None, logger: logging.Logger) -> ChallengeData:
raw = _read_embedded_output_bytes()
if raw is not None:
out_sha = _sha256(raw)
logger.info("[input] using embedded output block (%d bytes, sha256=%s)", len(raw), out_sha)
text = raw.decode("utf-8", errors="strict")
else:
if output_path is None:
raise SystemExit("No embedded output and --output not provided.")
raw = output_path.read_bytes()
out_sha = _sha256(raw)
logger.info("[input] reading %s (%d bytes, sha256=%s)", output_path, len(raw), out_sha)
text = raw.decode("utf-8", errors="strict")
lines = text.strip().splitlines()
if len(lines) != 2:
raise ValueError("output.txt line count mismatch")
if not lines[0].startswith("Griffin = "):
raise ValueError("output.txt missing 'Griffin = ' prefix")
if not lines[1].startswith("flagct = "):
raise ValueError("output.txt missing 'flagct = ' prefix")
rows = ast.literal_eval(lines[0].split("=", 1)[1].strip())
flagct = int(lines[1].split("=", 1)[1].strip())
if not isinstance(rows, list):
raise TypeError("parsed Griffin is not a list")
return ChallengeData(rows=rows, flagct=int(flagct), output_sha256=out_sha)
# ------------------------
# Phase0: curve recovery
# ------------------------
def recover_prime_from_points(points: list[tuple[int, int]], logger: logging.Logger) -> int:
# Deduplicate to stabilize gcd.
pts = list(dict.fromkeys(points))
if len(pts) < 6:
raise ValueError("need more points to recover p")
# For curve y^2 = x^3 + a*x + b (mod p):
# r_i = y_i^2 - x_i^3 ≡ a*x_i + b (mod p), so r is linear in x.
# Eliminate a,b using 3 points -> expression is 0 mod p; gcd reveals p.
rs = [y * y - x * x * x for x, y in pts]
xs = [x for x, _y in pts]
g = 0
for i in range(0, min(len(pts) - 2, 2000)):
x1, x2, x3 = xs[i], xs[i + 1], xs[i + 2]
r1, r2, r3 = rs[i], rs[i + 1], rs[i + 2]
t = (r1 - r2) * (x2 - x3) - (r2 - r3) * (x1 - x2)
t = abs(int(t))
if t == 0:
continue
g = int(Integer(g).gcd(Integer(t)))
if g != 0 and Integer(g).is_prime():
logger.info("[curve] gcd converged to prime at i=%d (bits=%d)", i, int(Integer(g).nbits()))
return int(g)
if g == 0:
raise ValueError("gcd stayed 0; points may be degenerate")
if Integer(g).is_prime():
return int(g)
# If gcd is a multiple of p, factor it lightly (p is expected to be a prime factor).
fac = list(factor(Integer(g)))
primes = sorted(int(p) for p, _e in fac)
if not primes:
raise ValueError("failed to factor gcd to prime")
logger.info("[curve] gcd is composite; candidate primes=%s", primes[:8])
return int(primes[-1])
def derive_curve_params(*, p: int, sample_points: list[tuple[int, int]], logger: logging.Logger) -> tuple[int, int]:
if len(sample_points) < 2:
raise ValueError("need at least 2 points to derive curve params")
# Find two points with distinct x.
for i in range(len(sample_points)):
x1, y1 = sample_points[i]
for j in range(i + 1, len(sample_points)):
x2, y2 = sample_points[j]
if x1 == x2:
continue
r1 = (y1 * y1 - x1 * x1 * x1) % p
r2 = (y2 * y2 - x2 * x2 * x2) % p
denom = (x1 - x2) % p
a = ((r1 - r2) * pow(int(denom), -1, int(p))) % p
b = (r1 - a * x1) % p
logger.info("[curve] y^2 = x^3 + a*x + b (mod p), a=%d, b=%d", int(a), int(b))
return int(a), int(b)
raise ValueError("failed to find sample points with distinct x")
# ------------------------
# Small modular linear algebra (mod prime)
# ------------------------
def inv_mod_matrix(M: np.ndarray, q: int) -> np.ndarray | None:
n = int(M.shape[0])
A = np.concatenate([M.copy() % q, np.eye(n, dtype=np.int64)], axis=1) % q
for col in range(n):
pivot = None
for r in range(col, n):
if int(A[r, col]) % q != 0:
pivot = r
break
if pivot is None:
return None
if pivot != col:
A[[col, pivot], :] = A[[pivot, col], :]
inv_p = pow(int(A[col, col]), q - 2, q)
A[col, :] = (A[col, :] * inv_p) % q
factors = A[:, col].copy()
factors[col] = 0
if np.any(factors):
A = (A - (factors[:, None] * A[col, :][None, :])) % q
return A[:, n:]
def row_space_basis(vectors: list[np.ndarray], q: int) -> tuple[list[np.ndarray], list[int]]:
basis: list[np.ndarray] = []
pivots: list[int] = []
for vec in vectors:
v = (np.array(vec, dtype=np.int64).copy()) % q
for b, p in zip(basis, pivots):
if int(v[p]) != 0:
v = (v - v[p] * b) % q
nz = np.nonzero(v)[0]
if nz.size == 0:
continue
p = int(nz[0])
inv_p = pow(int(v[p]), q - 2, q)
v = (v * inv_p) % q
for i, (b, _pb) in enumerate(list(zip(basis, pivots))):
if int(b[p]) != 0:
basis[i] = (b - b[p] * v) % q
basis.append(v)
pivots.append(p)
order = np.argsort(pivots)
basis = [basis[i] for i in order]
pivots = [pivots[i] for i in order]
return basis, pivots
def in_span(vec: np.ndarray, basis: list[np.ndarray], pivots: list[int], q: int) -> bool:
v = (np.array(vec, dtype=np.int64).copy()) % q
for b, p in zip(basis, pivots):
if int(v[p]) != 0:
v = (v - v[p] * b) % q
return np.count_nonzero(v) == 0
def bkz_reduce_int_basis(rows_as_basis: list[list[int]], block_size: int) -> list[list[int]]:
A = IntegerMatrix.from_matrix(rows_as_basis)
LLL.reduction(A)
BKZ.reduction(A, BKZ.Param(block_size=int(block_size)))
return [[int(A[i, j]) for j in range(A.ncols)] for i in range(A.nrows)]
# ------------------------
# Phase1: find structured rows + recover xs
# ------------------------
@dataclasses.dataclass(frozen=True)
class Phase1Result:
q: int
inlier_rows: list[int]
xs: list[int]
def phase1_find_inliers_and_xs(
*,
rows_points: list[list[tuple[int, int]]],
curve,
gen,
group_order: int,
logger: logging.Logger,
seed: int,
max_trials: int,
) -> Phase1Result:
fac = list(factor(Integer(group_order)))
primes = [int(p) for p, _e in fac]
q = max(primes)
if q <= 256:
raise ValueError("largest prime factor too small for xs recovery")
sub_mul = int(group_order) // int(q)
gen_sub = sub_mul * gen
if int(gen_sub.order()) != int(q):
raise ValueError("subgroup order mismatch")
# Precompute discrete log table in the subgroup of order q.
t0 = time.perf_counter()
logger.info("[phase1] build dlog table mod q=%d (size=%d)", q, q)
table: dict[tuple[int, int] | None, int] = {}
P = curve(0)
for k in range(int(q)):
key = None if P.is_zero() else (int(P[0]), int(P[1]))
table[key] = int(k)
P += gen_sub
logger.info("[phase1] table built in %.3fs", time.perf_counter() - t0)
# Build matrix A (rows=290, cols=80) of dlogs mod q.
t1 = time.perf_counter()
rows = len(rows_points)
cols = len(rows_points[0]) if rows else 0
logger.info("[phase1] compute dlog mod q for %d x %d points", rows, cols)
A = np.empty((rows, cols), dtype=np.int64)
for i, row in enumerate(rows_points):
for j, (x, y) in enumerate(row):
Pxy = curve(int(x), int(y))
Psub = sub_mul * Pxy
key = None if Psub.is_zero() else (int(Psub[0]), int(Psub[1]))
A[i, j] = int(table[key])
if (i + 1) % 20 == 0:
logger.info("[phase1] dlog progress %d/%d (elapsed %.1fs)", i + 1, rows, time.perf_counter() - t1)
logger.info("[phase1] dlog matrix ready in %.3fs", time.perf_counter() - t1)
# Randomized fundamental-circuit search to recover the 40 structured rows.
rng = random.Random(int(seed))
all_idx = list(range(rows))
found: list[int] | None = None
t2 = time.perf_counter()
logger.info("[phase1] search small circuit (max_trials=%d, seed=%d)", max_trials, seed)
for trial in range(1, int(max_trials) + 1):
basis_idx = rng.sample(all_idx, cols) # 80 rows -> square matrix
B = A[basis_idx, :]
invBt = inv_mod_matrix(B.T, q)
if invBt is None:
continue
outside = [j for j in all_idx if j not in set(basis_idx)]
rng.shuffle(outside)
outside = outside[:50]
for j in outside:
v = A[j, :]
coeff = (invBt @ v) % q
weight = int(np.count_nonzero(coeff))
if weight <= 24:
supp = [basis_idx[t] for t in np.nonzero(coeff)[0].tolist()] + [j]
found = sorted(set(supp))
logger.info("[phase1] hit: trial=%d circuit_size=%d", trial, len(found))
break
if found is not None:
break
if trial % 200 == 0:
logger.info("[phase1] trials=%d elapsed=%.1fs", trial, time.perf_counter() - t2)
if found is None:
raise ValueError("failed to find circuit; increase --phase1-trials or change --seed")
basis, pivots = row_space_basis([A[i, :] for i in found], q)
inliers = [i for i in range(rows) if in_span(A[i, :], basis, pivots, q)]
if len(inliers) != 40:
raise ValueError(f"expected 40 structured rows, got {len(inliers)}")
logger.info("[phase1] structured rows recovered: %d", len(inliers))
# Recover xs using a lattice in Z^40 from the column space.
H = A[inliers, :] % q # 40x80
col_basis, _piv = row_space_basis([H[:, j] for j in range(H.shape[1])], q)
if len(col_basis) != 20:
raise ValueError(f"expected column-space dimension 20, got {len(col_basis)}")
Bc = np.stack(col_basis, axis=1) % q # 40x20
logger.info("[phase1] lattice reduce to recover xs (BKZ-25)")
t3 = time.perf_counter()
gens: list[list[int]] = []
for i in range(40):
v = [0] * 40
v[i] = int(q)
gens.append(v)
for t in range(20):
gens.append([int(Bc[i, t]) for i in range(40)])
M = matrix(ZZ, gens).transpose() # 40 x 60, columns generate the lattice
basis_mat = M.column_module().basis_matrix()
basis_rows = [list(map(int, row)) for row in basis_mat.rows()]
reduced = bkz_reduce_int_basis(basis_rows, block_size=25)
logger.info("[phase1] BKZ done in %.3fs", time.perf_counter() - t3)
# Score short vectors; try to normalize them into xs ∈ [1,256].
scored: list[tuple[int, int, list[int]]] = []
for r in reduced:
v = [int(x) % int(q) for x in r[:40]]
vc = [x - int(q) if x > int(q) // 2 else x for x in v]
maxabs = max(abs(int(x)) for x in vc)
l1 = sum(abs(int(x)) for x in vc)
scored.append((int(maxabs), int(l1), v))
scored.sort(key=lambda t: (t[0], t[1]))
candidates = [v for _ma, _l1, v in scored[: min(80, len(scored))]]
if not candidates:
raise ValueError("no short vectors collected from BKZ output")
inv_pos = {d: pow(int(d), int(q) - 2, int(q)) for d in range(1, 256)}
def inv_small(d: int) -> int:
d %= int(q)
if d < 0:
d += int(q)
if 1 <= d <= 255:
return int(inv_pos[d])
if int(q) - 255 <= d <= int(q) - 1:
return int(q) - int(inv_pos[int(q) - d])
return pow(int(d), int(q) - 2, int(q))
def try_normalize(v: list[int]) -> list[int] | None:
i0, i1 = 0, 1
v0, v1 = int(v[i0]), int(v[i1])
for X0 in range(1, 257):
for X1 in range(1, 257):
if X0 == X1:
continue
denom = (X0 - X1) % int(q)
if denom == 0:
continue
a = ((v0 - v1) * inv_small(int(denom))) % int(q)
if a == 0:
continue
b = (v0 - a * X0) % int(q)
inv_a = pow(int(a), int(q) - 2, int(q))
xs = [int(((int(vk) - int(b)) * inv_a) % int(q)) for vk in v]
if all(1 <= x <= 256 for x in xs) and len(set(xs)) == 40:
return xs
return None
for idx, v in enumerate(candidates, 1):
xs = try_normalize(v)
if xs is not None:
logger.info("[phase1] xs recovered from candidate #%d", idx)
return Phase1Result(q=int(q), inlier_rows=inliers, xs=xs)
xs = try_normalize([(int(q) - int(x)) % int(q) for x in v])
if xs is not None:
logger.info("[phase1] xs recovered from negated candidate #%d", idx)
return Phase1Result(q=int(q), inlier_rows=inliers, xs=xs)
raise ValueError("failed to normalize xs from BKZ short vectors")
# ------------------------
# Phase2: full dlog + polynomial interpolation + roots mod n
# ------------------------
@dataclasses.dataclass(frozen=True)
class Phase2Result:
poly_coeffs: list[int]
x_candidates: list[int]
y_values: list[int]
def full_dlog_points_mod_n(*, curve, gen, group_order: int, points, logger: logging.Logger) -> list[int]:
fac = list(factor(Integer(group_order)))
primes = sorted(int(p) for p, _e in fac)
num = len(points)
residues = [0] * num
mod = 1
t0 = time.perf_counter()
logger.info("[phase2] full dlog for %d points via CRT over %d primes", num, len(primes))
for idx_prime, prime in enumerate(primes, 1):
sub_mul = int(group_order) // int(prime)
gen_sub = sub_mul * gen
if int(gen_sub.order()) != int(prime):
raise ValueError("subgroup order mismatch (prime factor)")
t_tab = time.perf_counter()
table: dict[tuple[int, int] | None, int] = {}
P = curve(0)
for k in range(int(prime)):
key = None if P.is_zero() else (int(P[0]), int(P[1]))
table[key] = int(k)
P += gen_sub
logger.info(
"[phase2] prime %d/%d=%d table built (%d entries) in %.3fs",
idx_prime,
len(primes),
int(prime),
len(table),
time.perf_counter() - t_tab,
)
for i, Pxy in enumerate(points):
Psub = sub_mul * Pxy
key = None if Psub.is_zero() else (int(Psub[0]), int(Psub[1]))
r = int(table[key])
residues[i] = int(crt(Integer(residues[i]), Integer(r), Integer(mod), Integer(prime)))
mod *= int(prime)
if idx_prime % 2 == 0:
logger.info("[phase2] CRT progress %d/%d (elapsed %.1fs)", idx_prime, len(primes), time.perf_counter() - t0)
del table
logger.info("[phase2] full dlog done in %.3fs", time.perf_counter() - t0)
return [int(x % int(group_order)) for x in residues]
def interpolate_poly_mod_n(*, n: int, xs: list[int], ys: list[int], degree: int, logger: logging.Logger):
Zn = Zmod(n)
R = PolynomialRing(Zn, names=("X",))
(X,) = R.gens()
pts = list(zip(xs[:degree], ys[:degree]))
f = R(0)
for i, (xi, yi) in enumerate(pts):
xi = Zn(int(xi))
yi = Zn(int(yi))
num = R(1)
den = Zn(1)
for j, (xj, _yj) in enumerate(pts):
if i == j:
continue
xj = Zn(int(xj))
num *= X - xj
den *= xi - xj
f += yi * num * den ** (-1)
# Verify on all provided points.
for x, y in zip(xs, ys):
if int(f(Zn(int(x)))) != int(int(y) % int(n)):
raise ValueError("interpolation verification failed")
logger.info("[phase2] interpolation verified on %d points", len(xs))
return f
def roots_mod_composite(*, poly, n: int, logger: logging.Logger, limit: int = 8000) -> list[int]:
fac = list(factor(Integer(n)))
pairs: list[tuple[int, list[int]]] = []
for prime, _e in fac:
prime = int(prime)
Rp = PolynomialRing(GF(prime), names=("X",))
poly_p = Rp([int(c) % prime for c in poly.list()])
rts = [int(r[0]) for r in poly_p.roots()]
if not rts:
raise ValueError(f"no roots mod prime {prime} (unexpected)")
pairs.append((prime, rts))
logger.info("[phase2] roots mod %d: %d", prime, len(rts))
pairs.sort(key=lambda t: t[0])
primes = [p for p, _ in pairs]
roots_by_prime = [rts for _p, rts in pairs]
candidates = [0]
mod = 1
for prime, rts in zip(primes, roots_by_prime):
new: list[int] = []
for a0 in candidates:
for b0 in rts:
new.append(int(crt(Integer(a0), Integer(b0), Integer(mod), Integer(prime))))
mod *= int(prime)
candidates = sorted(set(int(x) % int(mod) for x in new))
logger.info("[phase2] CRT combine mod=%d candidates=%d", mod, len(candidates))
if len(candidates) > int(limit):
raise ValueError(f"root explosion: {len(candidates)} > {limit}")
out = [int(x % int(n)) for x in candidates]
logger.info("[phase2] total x candidates mod n: %d", len(out))
return out
# ------------------------
# Phase3: gap-model lattice embedding (BKZ scan)
# ------------------------
def centered_mod(x: int, mod: int) -> int:
x %= int(mod)
if x > int(mod) // 2:
x -= int(mod)
return int(x)
@dataclasses.dataclass(frozen=True)
class _GapLayout:
L: int
fixed_bytes: dict[int, int]
var_pos: list[int]
coeffs_centered: list[int]
fixed_mod: int
@dataclasses.dataclass(frozen=True)
class _GapParams:
weight_k: int = 1 << 40
delta: int = 4
bkz_blocks: tuple[int, ...] = (30, 36, 40)
bkz_max_loops: int = 4
def _gap_build_layout(n: int) -> _GapLayout:
L = 44
fixed: dict[int, int] = {i: b for i, b in enumerate(b"alictf{")}
fixed[43] = ord("}")
for pos in (15, 20, 25, 30):
fixed[pos] = ord("-")
fixed[21] = ord("4")
var_pos = [pos for pos in range(7, 43) if pos not in fixed]
if len(var_pos) != 31:
raise ValueError(f"unexpected var_pos size: {len(var_pos)}")
coeffs_centered = [centered_mod(pow(256, L - 1 - pos, n), n) for pos in var_pos]
fixed_mod = 0
for pos, byte in fixed.items():
fixed_mod = (fixed_mod + int(byte) * pow(256, L - 1 - pos, n)) % n
fixed_mod = int(fixed_mod)
return _GapLayout(
L=int(L),
fixed_bytes=fixed,
var_pos=var_pos,
coeffs_centered=coeffs_centered,
fixed_mod=fixed_mod,
)
def _gap_decode_if_valid(*, n: int, r: int, layout: _GapLayout, params: _GapParams, vec: list[int]) -> str | None:
k = len(layout.var_pos)
dim = 2 * k + 2
if len(vec) != dim:
return None
if int(vec[0]) != 0:
return None
last = dim - 1
if abs(int(vec[last])) != int(params.delta):
return None
if int(vec[last]) < 0:
vec = [-int(x) for x in vec]
p1 = ord("5")
p2 = ord("c")
el = vec[1 : 1 + k]
sl = vec[1 + k : 1 + 2 * k]
out = bytearray(layout.L)
for pos, byte in layout.fixed_bytes.items():
out[int(pos)] = int(byte)
for i, pos in enumerate(layout.var_pos):
e = int(el[i])
s = int(sl[i])
if e == int(params.delta):
if not (-2 <= s <= 3):
return None
ch = int(p2 + s)
elif e == -int(params.delta):
if not (-5 <= s <= 4):
return None
ch = int(p1 + s)
else:
return None
if int(pos) == 26 and ch not in b"89ab":
return None
out[int(pos)] = int(ch)
try:
flag = out.decode("ascii")
except Exception:
return None
if not UUID4_RE.fullmatch(flag):
return None
if bytes_to_long(flag.encode("ascii")) % int(n) != int(r) % int(n):
return None
return flag
def _gap_build_basis_for_rhs(*, n: int, rhs: int, coeffs_centered: list[int], params: _GapParams) -> "IntegerMatrix":
k = len(coeffs_centered)
dim = 2 * k + 2
p1 = ord("5")
p2 = ord("c")
B = IntegerMatrix(dim, dim)
# e rows: a*(p2-p1)*K + (2*delta)*e_i
for i, a in enumerate(coeffs_centered):
B[i, 0] = int(a) * int(p2 - p1) * int(params.weight_k)
B[i, 1 + i] = 2 * int(params.delta)
# s rows: a*K + 1*s_i
for i, a in enumerate(coeffs_centered):
row = k + i
B[row, 0] = int(a) * int(params.weight_k)
B[row, 1 + k + i] = 1
# embedding row: rhs*K + (-delta,...,-delta) + last=delta
row = 2 * k
B[row, 0] = int(rhs) * int(params.weight_k)
for i in range(k):
B[row, 1 + i] = -int(params.delta)
B[row, dim - 1] = int(params.delta)
# modulus row: n*K
row = 2 * k + 1
B[row, 0] = int(n) * int(params.weight_k)
return B
def _gap_try_solve_r(*, n: int, r: int, layout: _GapLayout, params: _GapParams) -> str | None:
cp = (int(r) - int(layout.fixed_mod)) % int(n)
c0 = 0
for a in layout.coeffs_centered:
c0 += int(a) * ord("5")
rhs = centered_mod(int(c0) - int(cp), int(n))
B = _gap_build_basis_for_rhs(n=int(n), rhs=int(rhs), coeffs_centered=layout.coeffs_centered, params=params)
LLL.reduction(B)
bkz_blocks = tuple(int(x) for x in params.bkz_blocks)
for stage_idx, bs in enumerate(bkz_blocks, 1):
BKZ.reduction(B, BKZ.Param(int(bs), max_loops=int(params.bkz_max_loops)))
rows = [[int(B[i, j]) for j in range(B.ncols)] for i in range(B.nrows)]
for row in rows:
flag = _gap_decode_if_valid(n=int(n), r=int(r), layout=layout, params=params, vec=row)
if flag is not None:
return str(flag)
# On the last stage only: try small linear combinations to recover the hidden vector.
if stage_idx != len(bkz_blocks):
continue
dim = len(rows[0])
last = dim - 1
delta = int(params.delta)
bases = [row for row in rows if int(row[0]) == 0 and abs(int(row[last])) == delta]
adjust = [row for row in rows if int(row[0]) == 0 and int(row[last]) == 0]
if not bases or not adjust:
continue
adjust.sort(key=lambda v: sum(int(x) * int(x) for x in v))
adjust = adjust[:20]
for base in bases:
for v in adjust:
cand = [int(base[j]) + int(v[j]) for j in range(dim)]
flag = _gap_decode_if_valid(n=int(n), r=int(r), layout=layout, params=params, vec=cand)
if flag is not None:
return str(flag)
cand = [int(base[j]) - int(v[j]) for j in range(dim)]
flag = _gap_decode_if_valid(n=int(n), r=int(r), layout=layout, params=params, vec=cand)
if flag is not None:
return str(flag)
top2 = adjust[:12]
for i1 in range(len(top2)):
v1 = top2[i1]
for i2 in range(i1 + 1, len(top2)):
v2 = top2[i2]
for s1 in (-1, 1):
for s2 in (-1, 1):
cand = [int(base[j]) + s1 * int(v1[j]) + s2 * int(v2[j]) for j in range(dim)]
flag = _gap_decode_if_valid(n=int(n), r=int(r), layout=layout, params=params, vec=cand)
if flag is not None:
return str(flag)
return None
_GAP_GLOBAL_N: int | None = None
_GAP_GLOBAL_LAYOUT: _GapLayout | None = None
_GAP_GLOBAL_PARAMS: _GapParams | None = None
def _gap_worker_init(n: int, params: _GapParams) -> None:
global _GAP_GLOBAL_N, _GAP_GLOBAL_LAYOUT, _GAP_GLOBAL_PARAMS
_GAP_GLOBAL_N = int(n)
_GAP_GLOBAL_PARAMS = params
_GAP_GLOBAL_LAYOUT = _gap_build_layout(int(n))
def _gap_worker_solve(r: int) -> tuple[int, str] | None:
n = _GAP_GLOBAL_N
layout = _GAP_GLOBAL_LAYOUT
params = _GAP_GLOBAL_PARAMS
if n is None or layout is None or params is None:
raise RuntimeError("gap worker not initialized")
flag = _gap_try_solve_r(n=int(n), r=int(r), layout=layout, params=params)
if flag is None:
return None
return int(r), str(flag)
def phase3_recover_flag_gap(
*,
n: int,
m_candidates: list[int],
logger: logging.Logger,
jobs: int,
maxtasksperchild: int,
params: _GapParams,
) -> tuple[int, str]:
import multiprocessing as mp
total = len(m_candidates)
if total == 0:
raise ValueError("empty candidate list")
if int(jobs) <= 1:
logger.info("[phase3] gap scan sequential: candidates=%d bkz=%s", total, params.bkz_blocks)
layout = _gap_build_layout(int(n))
t0 = time.perf_counter()
for idx, r in enumerate(m_candidates, 1):
t_r = time.perf_counter()
flag = _gap_try_solve_r(n=int(n), r=int(r), layout=layout, params=params)
dt = time.perf_counter() - t_r
if flag is not None:
logger.info("[phase3] gap hit idx=%d/%d dt=%.3fs total=%.1fs", idx, total, dt, time.perf_counter() - t0)
return int(r), str(flag)
if idx % 20 == 0:
logger.info("[phase3] gap progress %d/%d elapsed=%.1fs (last_dt=%.3fs)", idx, total, time.perf_counter() - t0, dt)
raise ValueError("gap scan failed: no candidate produced a valid flag")
ctx = mp.get_context("fork") if "fork" in mp.get_all_start_methods() else mp.get_context()
logger.info(
"[phase3] gap scan multiprocess: candidates=%d jobs=%d maxtasksperchild=%d bkz=%s",
total,
int(jobs),
int(maxtasksperchild),
params.bkz_blocks,
)
t0 = time.perf_counter()
tried = 0
with ctx.Pool(
processes=int(jobs),
initializer=_gap_worker_init,
initargs=(int(n), params),
maxtasksperchild=int(maxtasksperchild) if int(maxtasksperchild) > 0 else None,
) as pool:
for res in pool.imap_unordered(_gap_worker_solve, (int(r) for r in m_candidates), chunksize=1):
tried += 1
if tried % 20 == 0:
logger.info("[phase3] gap tried %d/%d elapsed=%.1fs", tried, total, time.perf_counter() - t0)
if res is None:
continue
pool.terminate()
matched_r, flag = res
logger.info("[phase3] gap found after %d/%d (elapsed=%.1fs)", tried, total, time.perf_counter() - t0)
return int(matched_r), str(flag)
raise ValueError("gap scan failed: pool exhausted without a solution")
# ------------------------
# Affine ambiguity handling
# ------------------------
def _infer_x_affine_transforms_from_xs(xs: list[int]) -> list[tuple[int, int]]:
if not xs:
return [(1, 0)]
mn = int(min(xs))
mx = int(max(xs))
out: list[tuple[int, int]] = []
# a=+1: require 1 <= x_rec - b <= 256 for all x_rec in xs
lo = mx - 256
hi = mn - 1
for b in range(int(lo), int(hi) + 1):
out.append((1, int(b)))
# a=-1: require 1 <= b - x_rec <= 256 for all x_rec in xs
lo = mx + 1
hi = mn + 256
for b in range(int(lo), int(hi) + 1):
out.append((-1, int(b)))
out = sorted(set(out), key=lambda t: (0 if t[0] == 1 else 1, abs(int(t[1])), int(t[1])))
if (1, 0) not in out:
out.insert(0, (1, 0))
return out
def _affine_mod(x: int, mod: int, a: int, b: int) -> int:
return int((int(a) * int(x) + int(b)) % int(mod))
def _affine_inv_mod(x: int, mod: int, a: int, b: int) -> int:
a_mod = int(a) % int(mod)
inv_a = pow(int(a_mod), -1, int(mod))
return int((int(x) - (int(b) % int(mod))) * int(inv_a) % int(mod))
# ------------------------
# Main
# ------------------------
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--output", type=str, default=None, help="path to task/attachment/output.txt (used when no embedded block)")
ap.add_argument("--workdir", type=str, default="runs", help="base directory to store run artifacts")
ap.add_argument("--run-dir", type=str, default="", help="reuse/resume an existing run directory")
ap.add_argument("--seed", type=int, default=20260131)
ap.add_argument("--phase1-trials", type=int, default=6000)
ap.add_argument("--x-affine", type=str, default="auto", choices=("auto", "none", "manual"))
ap.add_argument("--x-affine-a", type=int, default=1)
ap.add_argument("--x-affine-b", type=int, default=0)
ap.add_argument(
"--gap-jobs",
type=int,
default=0,
help="gap scan parallelism: 0=auto, 1=sequential, >1=multiprocess pool",
)
ap.add_argument("--gap-maxtasksperchild", type=int, default=1)
ap.add_argument("--gap-bkz", type=str, default="30,36,40")
ap.add_argument("--gap-loops", type=int, default=4)
ap.add_argument("--gap-K", type=int, default=1 << 40)
ap.add_argument("--gap-delta", type=int, default=4)
args = ap.parse_args()
gap_jobs = int(args.gap_jobs)
if gap_jobs < 0:
raise SystemExit("--gap-jobs must be >= 0")
if gap_jobs == 0:
gap_jobs = int(os.cpu_count() or 1)
if args.run_dir:
run_dir = Path(str(args.run_dir)).resolve()
run_dir.mkdir(parents=True, exist_ok=True)
run_id = run_dir.name
else:
run_id = _utc_now_id()
base_dir = Path(args.workdir).resolve()
run_dir = base_dir / f"griffin_onefile_{run_id}"
run_dir.mkdir(parents=True, exist_ok=True)
logger = _configure_logger(run_dir / "run.log")
logger.info("[run] id=%s run_dir=%s", run_id, run_dir)
logger.info("[run] python=%s", sys.version.replace("\n", " "))
# Save a copy of this script for reproducibility.
try:
shutil.copy2(Path(__file__).resolve(), run_dir / "solver_snapshot.py")
except Exception:
pass
ck0 = run_dir / "checkpoint_phase0.json"
ck1 = run_dir / "checkpoint_phase1.json"
ck2 = run_dir / "checkpoint_phase2.json"
ck3 = run_dir / "checkpoint_phase3.json"
# ---- Load input ----
t0 = time.perf_counter()
data = load_challenge_data(output_path=Path(args.output) if args.output else None, logger=logger)
logger.info(
"[input] parsed rows=%d cols=%d flagct=%d in %.3fs",
len(data.rows),
len(data.rows[0]),
data.flagct,
time.perf_counter() - t0,
)
# ---- Phase0: recover curve + group order n ----
p = a = b = n = 0
if ck0.exists():
meta = json.loads(ck0.read_text(encoding="utf-8"))
if meta.get("output_sha256") == data.output_sha256:
p = int(meta["p"])
a = int(meta["a"])
b = int(meta["b"])
n = int(meta["n"])
logger.info("[phase0] cache hit: p=%d a=%d b=%d n=%d", p, a, b, n)
if n == 0:
t0 = time.perf_counter()
some_pts = [pt for row in data.rows[:10] for pt in row[:10]]
p = recover_prime_from_points(some_pts, logger)
a, b = derive_curve_params(p=p, sample_points=some_pts, logger=logger)
curve = EllipticCurve(GF(p), [a, b])
gen = curve.lift_x(curve.base_field()(3137))
n = int(gen.order())
logger.info("[phase0] recovered group order n=%d (bits=%d) in %.3fs", n, int(Integer(n).nbits()), time.perf_counter() - t0)
_save_json(ck0, {"output_sha256": data.output_sha256, "p": p, "a": a, "b": b, "n": n}, logger)
curve = EllipticCurve(GF(p), [a, b])
gen = curve.lift_x(curve.base_field()(3137))
# ---- Phase1 ----
phase1: Phase1Result | None = None
if ck1.exists():
obj = json.loads(ck1.read_text(encoding="utf-8"))
if obj.get("output_sha256") == data.output_sha256 and int(obj.get("n", 0)) == int(n):
phase1 = Phase1Result(
q=int(obj["q"]),
inlier_rows=list(map(int, obj["inlier_rows"])),
xs=list(map(int, obj["xs"])),
)
logger.info("[phase1] cache hit: q=%d inliers=%d", phase1.q, len(phase1.inlier_rows))
if phase1 is None:
t0 = time.perf_counter()
phase1 = phase1_find_inliers_and_xs(
rows_points=data.rows,
curve=curve,
gen=gen,
group_order=n,
logger=logger,
seed=int(args.seed),
max_trials=int(args.phase1_trials),
)
logger.info("[phase1] done in %.3fs", time.perf_counter() - t0)
_save_json(
ck1,
{
"output_sha256": data.output_sha256,
"n": int(n),
"q": int(phase1.q),
"inlier_rows": list(map(int, phase1.inlier_rows)),
"xs": list(map(int, phase1.xs)),
},
logger,
)
# ---- Phase2 ----
phase2: Phase2Result | None = None
if ck2.exists():
obj = json.loads(ck2.read_text(encoding="utf-8"))
if obj.get("output_sha256") == data.output_sha256 and int(obj.get("n", 0)) == int(n):
phase2 = Phase2Result(
poly_coeffs=list(map(int, obj["poly_coeffs"])),
x_candidates=list(map(int, obj["x_candidates"])),
y_values=list(map(int, obj["y_values"])),
)
logger.info("[phase2] cache hit: candidates=%d", len(phase2.x_candidates))
if phase2 is None:
t0 = time.perf_counter()
col0_points = [curve(*data.rows[idx][0]) for idx in phase1.inlier_rows]
y_values = full_dlog_points_mod_n(curve=curve, gen=gen, group_order=n, points=col0_points, logger=logger)
f0 = interpolate_poly_mod_n(n=n, xs=phase1.xs, ys=y_values, degree=20, logger=logger)
target = int(data.flagct) % int(n)
poly = f0 - Zmod(n)(target)
x_candidates = roots_mod_composite(poly=poly, n=n, logger=logger)
phase2 = Phase2Result(
poly_coeffs=[int(c) for c in f0.list()],
x_candidates=list(map(int, x_candidates)),
y_values=list(map(int, y_values)),
)
logger.info("[phase2] done in %.3fs", time.perf_counter() - t0)
_save_json(
ck2,
{
"output_sha256": data.output_sha256,
"n": int(n),
"poly_coeffs": list(map(int, phase2.poly_coeffs)),
"x_candidates": list(map(int, phase2.x_candidates)),
"y_values": list(map(int, phase2.y_values)),
},
logger,
)
# ---- Phase3 ----
if ck3.exists():
obj = json.loads(ck3.read_text(encoding="utf-8"))
if obj.get("output_sha256") == data.output_sha256 and obj.get("flag"):
flag = str(obj["flag"])
logger.info("[phase3] cache hit: %s", flag)
print(flag)
return 0
mode = str(args.x_affine)
if mode == "none":
affine_list = [(1, 0)]
elif mode == "manual":
affine_list = [(int(args.x_affine_a), int(args.x_affine_b))]
else:
affine_list = _infer_x_affine_transforms_from_xs(list(map(int, phase1.xs)))
xs_min = int(min(phase1.xs)) if phase1.xs else -1
xs_max = int(max(phase1.xs)) if phase1.xs else -1
logger.info("[affine] mode=%s transforms=%d xs_min=%d xs_max=%d", mode, len(affine_list), xs_min, xs_max)
x_root_set = set(map(int, phase2.x_candidates))
params = _GapParams(
weight_k=int(args.gap_K),
delta=int(args.gap_delta),
bkz_blocks=tuple(int(x.strip()) for x in str(args.gap_bkz).split(",") if x.strip()),
bkz_max_loops=int(args.gap_loops),
)
# Rebuild f0 for final validation (avoid trusting checkpointed coeffs blindly).
Zn = Zmod(n)
R = PolynomialRing(Zn, names=("X",))
(X,) = R.gens()
f0 = R([Zn(int(c)) for c in phase2.poly_coeffs])
t_phase3_all = time.perf_counter()
last_exc: Exception | None = None
chosen_a = 1
chosen_b = 0
matched_r = 0
flag = ""
for t_idx, (a_try, b_try) in enumerate(affine_list, 1):
a_try = int(a_try)
b_try = int(b_try)
logger.info("[affine] try %d/%d: a=%d b=%d", t_idx, len(affine_list), a_try, b_try)
# Roots are in "x-space"; candidates for m are preimages under x = a*m + b.
m_candidates = sorted({_affine_inv_mod(int(x), int(n), a_try, b_try) for x in phase2.x_candidates})
try:
t0 = time.perf_counter()
matched_r, flag = phase3_recover_flag_gap(
n=int(n),
m_candidates=m_candidates,
logger=logger,
jobs=int(gap_jobs),
maxtasksperchild=int(args.gap_maxtasksperchild),
params=params,
)
logger.info("[affine] solved with a=%d b=%d dt=%.3fs", a_try, b_try, time.perf_counter() - t0)
chosen_a, chosen_b = a_try, b_try
break
except Exception as exc:
last_exc = exc
logger.info("[affine] no hit for a=%d b=%d (%s): %s", a_try, b_try, type(exc).__name__, exc)
else:
raise last_exc if last_exc is not None else ValueError("phase3 failed: no affine transform produced a flag")
logger.info("[phase3] done in %.3fs (a=%d b=%d)", time.perf_counter() - t_phase3_all, chosen_a, chosen_b)
# Final validation against polynomial and flagct.
if not UUID4_RE.fullmatch(flag):
raise ValueError("final flag format mismatch")
m_val = bytes_to_long(flag.encode("ascii"))
m_res = int(m_val % int(n))
if int(matched_r) != int(m_res):
logger.info("[phase3] note: matched_r adjusted from %d to %d", int(matched_r), int(m_res))
matched_r = int(m_res)
x_res = _affine_mod(int(m_res), int(n), int(chosen_a), int(chosen_b))
if x_res not in x_root_set:
raise ValueError("final flag residue does not map into root candidate set under affine")
if int(f0(Zn(int(x_res)))) != int(data.flagct) % int(n):
raise ValueError("final flag does not satisfy f0(a*m+b)=flagct mod n under affine")
_save_json(
ck3,
{"output_sha256": data.output_sha256, "flag": flag, "affine": {"a": int(chosen_a), "b": int(chosen_b)}},
logger,
)
(run_dir / "final_flag.txt").write_text(flag + "\n", encoding="utf-8")
logger.info("[ok] %s", flag)
print(flag)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment