Skip to content

Instantly share code, notes, and snippets.

@lastforkbender
Created April 24, 2026 19:36
Show Gist options
  • Select an option

  • Save lastforkbender/430095d11c27c0ad3ad25d69b8150b41 to your computer and use it in GitHub Desktop.

Select an option

Save lastforkbender/430095d11c27c0ad3ad25d69b8150b41 to your computer and use it in GitHub Desktop.
HPTB NN using torch
# hierarchical_bspline.py
# Requiere: torch, numpy
from typing import List, Optional, Tuple, Dict
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
@dataclass
class BSplineConfig:
degree: int = 3
num_basis: int = 8
knot_distribution: str = "clamped" # "clamped" o "uniforme"
trainable_knots: bool = False
eps: float = 1e-8
class BSplineBasis(nn.Module):
"""
Evaluador vectorizado de bases B-spline usando Cox-de Boor iterativo.
Entrada x en [0,1]. Devuelve base con forma (batch, dims, num_basis).
Comentarios y código en español.
"""
def __init__(self, cfg: BSplineConfig):
super().__init__()
self.cfg = cfg
self.p = int(cfg.degree)
self.num_basis = int(cfg.num_basis)
assert self.num_basis > self.p, "num_basis debe ser > degree"
self.K = self.num_basis + self.p + 1
if cfg.knot_distribution == "clamped":
interior_count = self.K - 2 * (self.p + 1)
if interior_count > 0:
interior = np.linspace(0.0, 1.0, interior_count + 2)[1:-1].astype(np.float32)
else:
interior = np.array([], dtype=np.float32)
knots = np.concatenate([
np.zeros(self.p + 1, dtype=np.float32),
interior,
np.ones(self.p + 1, dtype=np.float32)
])
assert len(knots) == self.K
knots_t = torch.as_tensor(knots, dtype=torch.float32)
# Para esta versión mantenemos trainable_knots=False por defecto.
self.register_buffer('knots', knots_t)
self.interior_params = None
else:
knots = np.linspace(0.0, 1.0, self.K).astype(np.float32)
self.register_buffer('knots', torch.as_tensor(knots))
self.interior_params = None
def get_knots(self, device: torch.device):
# De momento no usamos knots entrenables; devolvemos buffer al dispositivo.
return self.knots.to(device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Acepta x con forma (batch, dims) o (batch*?,) -> normaliza formato
if x.dim() == 1:
x = x.unsqueeze(1)
batch, dims = x.shape
device = x.device
x = x.clamp(0.0, 1.0)
knots = self.get_knots(device) # (K,)
left_knots = knots[:-1]
right_knots = knots[1:]
x_exp = x.unsqueeze(-1) # (batch, dims, 1)
left = left_knots.view(1, 1, -1)
right = right_knots.view(1, 1, -1)
# Inicialización: indicador de intervalo [left, right)
N = ((x_exp >= left) & (x_exp < right)).float() # (batch, dims, K-1)
# Manejar x == 1.0: poner en la última base
is_one = (x_exp >= 1.0 - self.cfg.eps)
if is_one.any():
N = N.clone()
idx = is_one.expand_as(N)
N[idx] = 0.0
# establecer la última celda de cada entrada que sea ~1
last_idx = is_one.squeeze(-1) # (batch, dims)
N[last_idx, -1] = 1.0
# Cox-de Boor iterativo, con índices cuidados y denominadores estables
for r in range(1, self.p + 1):
L = N.shape[-1] # longitud actual de particiones
k = L - 1 # longitud después de combinación
if k <= 0:
break
N_left = N[..., :k]
N_right = N[..., 1:L]
denom_left = knots[r: r + k] - knots[:k] # (k,)
denom_right = knots[r + 1: r + 1 + k] - knots[1: 1 + k] # (k,)
denom_left = denom_left.view(1, 1, -1).to(device)
denom_right = denom_right.view(1, 1, -1).to(device)
a_num = (x_exp - left[..., :k])
b_num = (right[..., 1:L] - x_exp)
mask_a = denom_left.abs() > self.cfg.eps
mask_b = denom_right.abs() > self.cfg.eps
a = torch.where(mask_a, (a_num / (denom_left + self.cfg.eps)) * N_left, torch.zeros_like(a_num))
b = torch.where(mask_b, (b_num / (denom_right + self.cfg.eps)) * N_right, torch.zeros_like(b_num))
N = a + b
# Validación de tamaño de base y recorte si es necesario
expected_len = self.K - 1 - self.p
if N.shape[-1] != expected_len:
# seguridad: recomputar o alertar; aquí recortamos/expandimos de forma segura
if N.shape[-1] > self.num_basis:
basis = N[..., :self.num_basis]
else:
# rellenar con ceros si es más corto (caso borde)
pad = self.num_basis - N.shape[-1]
basis = F.pad(N, (0, pad))
else:
basis = N[..., :self.num_basis]
return basis
class ResidualTracker(nn.Module):
"""
Seguimiento EMA de normas residuales por camino/sección.
Comentarios en español y manejo estable.
"""
def __init__(self, num_pathways: int, ema_decay: float = 0.99, eps: float = 1e-8):
super().__init__()
self.ema_decay = float(ema_decay)
self.eps = float(eps)
self.num_pathways = int(num_pathways)
self.register_buffer('residual_ema', torch.ones(self.num_pathways) * 1e-2)
self.register_buffer('update_count', torch.tensor(0, dtype=torch.long))
def update(self, residuals: torch.Tensor):
# residuals con forma (batch, num_sectors, feat)
with torch.no_grad():
# calcular norma por pathway sumando batch y features
flat = residuals.view(residuals.shape[0], residuals.shape[1], -1)
norms = torch.norm(flat, dim=(0, 2)) # (num_sectors,)
self.residual_ema = self.ema_decay * self.residual_ema + (1.0 - self.ema_decay) * norms
self.update_count += 1
def get_scores(self, invert: bool = True) -> torch.Tensor:
if invert:
s = 1.0 / (self.residual_ema + self.eps)
else:
s = self.residual_ema.clone()
s = s / (s.sum() + self.eps)
return s
class SectoredBSplineLayer(nn.Module):
"""
Capa por sectores que implementa:
- rutas polimórficas por sector (varias sub-rutas lineal->cubo->lineal),
- temporización (phase, gain, timing logits),
- mezcla residual entre input remapeado y salida de ruta.
Comentarios en español.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
num_sectors: int = 4,
bspline_cfg: Optional[BSplineConfig] = None,
parent_layer: Optional['SectoredBSplineLayer'] = None,
ema_decay: float = 0.99,
polymorph_ways: int = 2 # número de rutas polimórficas por sector
):
super().__init__()
assert input_dim >= num_sectors and output_dim >= num_sectors
self.input_dim = input_dim
self.output_dim = output_dim
self.num_sectors = num_sectors
self.bspline_cfg = bspline_cfg or BSplineConfig()
self.num_basis = self.bspline_cfg.num_basis
self.polymorph_ways = polymorph_ways
base_in = input_dim // num_sectors
in_rem = input_dim % num_sectors
self.sector_input_sizes = [base_in + (1 if i < in_rem else 0) for i in range(num_sectors)]
base_out = output_dim // num_sectors
out_rem = output_dim % num_sectors
self.sector_output_sizes = [base_out + (1 if i < out_rem else 0) for i in range(num_sectors)]
# pre y post arranges (remap de entradas y salidas por sector)
self.pre_arrange = nn.ModuleList([
nn.Linear(self.sector_input_sizes[i], self.sector_input_sizes[i]) for i in range(num_sectors)
])
self.post_arrange = nn.ModuleList([
nn.Linear(self.sector_output_sizes[i], self.sector_output_sizes[i]) for i in range(num_sectors)
])
# puntos de control por sector (num_basis x sector_out)
self.control_points = nn.ParameterList([
nn.Parameter(torch.randn(self.num_basis, self.sector_output_sizes[i]) * 0.01)
for i in range(num_sectors)
])
# logits para seleccionar entre rutas polimórficas por sector (entrenables)
self.poly_logits = nn.Parameter(torch.zeros(num_sectors, polymorph_ways))
# parámetros de temporización y fase por sector
self.timing_logits = nn.Parameter(torch.zeros(num_sectors))
self.phase = nn.Parameter(torch.zeros(num_sectors))
self.gain = nn.Parameter(torch.ones(num_sectors))
# puerta residual por sector (mezcla input->output)
self.residual_gate = nn.Parameter(torch.ones(num_sectors) * 0.5)
# usar la misma BSpline si existe en padre, sino crear nueva
if parent_layer is not None and isinstance(parent_layer.bspline_basis, BSplineBasis):
self.bspline_basis = parent_layer.bspline_basis
else:
self.bspline_basis = BSplineBasis(self.bspline_cfg)
self.residual_tracker = ResidualTracker(num_sectors, ema_decay=ema_decay)
self.register_buffer('_last_residuals', torch.zeros(1))
# transformaciones polimórficas por sector: lista de ModuleList de rutas
self.polymorph_routes = nn.ModuleList()
for i in range(num_sectors):
ways = nn.ModuleList()
for w in range(polymorph_ways):
# cada ruta: Linear -> (posible combinación) -> Linear
ways.append(nn.Sequential(
nn.Linear(self.sector_output_sizes[i], self.sector_output_sizes[i]),
nn.Linear(self.sector_output_sizes[i], self.sector_output_sizes[i])
))
self.polymorph_routes.append(ways)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch = x.shape[0]
device = x.device
x_sectors = torch.split(x, self.sector_input_sizes, dim=1)
timing_w = F.softmax(self.timing_logits, dim=0) # pesos de temporización por sector
phase = self.phase
gain = self.gain
residual_gate = torch.sigmoid(self.residual_gate) # a [0,1]
outputs = []
residuals_list = []
for i, xsec in enumerate(x_sectors):
# remapeo de entrada
x_pre = self.pre_arrange[i](xsec) # (batch, sin)
# normalizar por canal para evaluar B-spline en [0,1]
x_min = x_pre.amin(dim=1, keepdim=True)
x_max = x_pre.amax(dim=1, keepdim=True)
denom = (x_max - x_min).clamp(min=self.bspline_cfg.eps)
x_norm = (x_pre - x_min) / denom # (batch, sin)
# evaluar base B-spline; reduce canales promediando (comportamiento original)
basis = self.bspline_basis(x_norm)
basis_reduced = basis.mean(dim=1) # (batch, num_basis)
# proyección por puntos de control
cp = self.control_points[i] # (num_basis, sout)
sector_main = torch.matmul(basis_reduced, cp) # (batch, sout)
# ruta polimórfica: mezclar entre varias sub-rutas usando poly_logits
poly_w = F.softmax(self.poly_logits[i], dim=0) # (ways,)
route_outputs = []
for w, route in enumerate(self.polymorph_routes[i]):
r_out = route[0](sector_main) # primera linear
# timed cubed: combinar con partner, aplicar rotación simple por phase y cubicar
W = self.post_arrange[i].weight # (sout, sout)
partner = torch.matmul(r_out, W.t())
theta = phase[i] + float(w) * 0.1 # pequeña variación por ruta para diversidad
cos_t = torch.cos(theta)
sin_t = torch.sin(theta)
mixed = cos_t * r_out + sin_t * partner
# cubicidad (timed cubed): elevar al cubo para intensidad no lineal
cubic = mixed * mixed * mixed # r^3
r_out2 = route[1](cubic) # segunda linear
route_outputs.append(r_out2 * poly_w[w])
# sumar rutas polimórficas
poly_sum = torch.stack(route_outputs, dim=0).sum(dim=0) # (batch, sout)
# post-arrange final
sector_out = self.post_arrange[i](poly_sum)
# residual: mezclar con una proyección corta de la entrada (skip)
skip = F.linear(x_pre, torch.eye(x_pre.shape[1], device=device)[:x_pre.shape[1], :x_pre.shape[1]])
# proyectar skip a tamaño de salida si es necesario
if skip.shape[1] != sector_out.shape[1]:
# usar una proyección simple (lineal) rápida
skip = F.pad(skip, (0, sector_out.shape[1] - skip.shape[1]))
# mezcla final con puerta residual y temporización/gain
modulated = gain[i] * sector_out * timing_w[i]
out_i = residual_gate[i] * modulated + (1.0 - residual_gate[i]) * skip
outputs.append(out_i)
residuals_list.append(out_i) # aquí consideramos el output como residual para tracking
out = torch.cat(outputs, dim=1)
residuals = torch.stack(residuals_list, dim=1) # (batch, num_sectors, sout)
self.residual_tracker.update(residuals)
self._last_residuals = residuals.detach()
return out, residuals
def get_gradient_scores(self) -> torch.Tensor:
return self.residual_tracker.get_scores()
class HierarchicalBSplineNN(nn.Module):
"""
Red jerárquica usando capas SectoredBSplineLayer y una capa de salida lineal.
Comentarios en español.
"""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
output_dim: int,
num_sectors: int = 4,
bspline_cfg: Optional[BSplineConfig] = None,
use_adaptive_backprop: bool = True,
):
super().__init__()
self.use_adaptive_backprop = use_adaptive_backprop
self.layers = nn.ModuleList()
bs_cfg = bspline_cfg or BSplineConfig()
in_dim = input_dim
for h in hidden_dims:
layer = SectoredBSplineLayer(in_dim, h, num_sectors=num_sectors, bspline_cfg=bs_cfg)
self.layers.append(layer)
in_dim = h
self.output_layer = nn.Linear(in_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
self._layer_residuals = []
out = x
for layer in self.layers:
out, residuals = layer(out)
self._layer_residuals.append(residuals)
out = F.relu(out)
out = self.output_layer(out)
return out
def apply_adaptive_reweighting(self, optimizer: torch.optim.Optimizer, temperature: float = 1.0, reweight_all: bool = True):
"""
Escalado de gradientes por score de pathways (comportamiento por defecto solicitado).
Esto modifica gradientes in-place antes de optimizer.step().
"""
for layer in self.layers:
if not isinstance(layer, SectoredBSplineLayer):
continue
scores = layer.get_gradient_scores()
scores = scores.pow(1.0 / max(temperature, 1e-6))
scores = scores / (scores.sum() + 1e-8)
for i in range(layer.num_sectors):
cp = layer.control_points[i]
if cp.grad is not None:
cp.grad = cp.grad * scores[i]
if reweight_all:
for pname, p in layer.pre_arrange[i].named_parameters():
if p.grad is not None:
p.grad = p.grad * scores[i]
for pname, p in layer.post_arrange[i].named_parameters():
if p.grad is not None:
p.grad = p.grad * scores[i]
# reweighted scalars
for scalar_p in (layer.timing_logits, layer.phase, layer.gain, layer.poly_logits, layer.residual_gate):
if scalar_p.grad is not None:
scalar_p.grad = scalar_p.grad * scores[i]
def get_pathway_statistics(self) -> Dict[int, Dict[str, torch.Tensor]]:
stats = {}
for idx, layer in enumerate(self.layers):
if isinstance(layer, SectoredBSplineLayer):
stats[idx] = {
'residual_ema': layer.residual_tracker.residual_ema.clone().detach(),
'cluster_scores': layer.get_gradient_scores().clone().detach()
}
return stats
if __name__ == "__main__":
# Prueba mínima: ver que forward, backward y reweight funcionen sin errores.
torch.manual_seed(0)
cfg = BSplineConfig(degree=3, num_basis=8, knot_distribution="clamped", trainable_knots=False)
model = HierarchicalBSplineNN(input_dim=16, hidden_dims=[32, 32, 16], output_dim=4, num_sectors=4, bspline_cfg=cfg)
x = torch.randn(8, 16)
y = torch.randn(8, 4)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
out = model(x)
loss = F.mse_loss(out, y)
opt.zero_grad()
loss.backward()
model.apply_adaptive_reweighting(opt, temperature=1.0)
opt.step()
print("Loss:", loss.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment