Created
April 24, 2026 19:36
-
-
Save lastforkbender/430095d11c27c0ad3ad25d69b8150b41 to your computer and use it in GitHub Desktop.
HPTB NN using torch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # hierarchical_bspline.py | |
| # Requiere: torch, numpy | |
| from typing import List, Optional, Tuple, Dict | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| @dataclass | |
| class BSplineConfig: | |
| degree: int = 3 | |
| num_basis: int = 8 | |
| knot_distribution: str = "clamped" # "clamped" o "uniforme" | |
| trainable_knots: bool = False | |
| eps: float = 1e-8 | |
| class BSplineBasis(nn.Module): | |
| """ | |
| Evaluador vectorizado de bases B-spline usando Cox-de Boor iterativo. | |
| Entrada x en [0,1]. Devuelve base con forma (batch, dims, num_basis). | |
| Comentarios y código en español. | |
| """ | |
| def __init__(self, cfg: BSplineConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.p = int(cfg.degree) | |
| self.num_basis = int(cfg.num_basis) | |
| assert self.num_basis > self.p, "num_basis debe ser > degree" | |
| self.K = self.num_basis + self.p + 1 | |
| if cfg.knot_distribution == "clamped": | |
| interior_count = self.K - 2 * (self.p + 1) | |
| if interior_count > 0: | |
| interior = np.linspace(0.0, 1.0, interior_count + 2)[1:-1].astype(np.float32) | |
| else: | |
| interior = np.array([], dtype=np.float32) | |
| knots = np.concatenate([ | |
| np.zeros(self.p + 1, dtype=np.float32), | |
| interior, | |
| np.ones(self.p + 1, dtype=np.float32) | |
| ]) | |
| assert len(knots) == self.K | |
| knots_t = torch.as_tensor(knots, dtype=torch.float32) | |
| # Para esta versión mantenemos trainable_knots=False por defecto. | |
| self.register_buffer('knots', knots_t) | |
| self.interior_params = None | |
| else: | |
| knots = np.linspace(0.0, 1.0, self.K).astype(np.float32) | |
| self.register_buffer('knots', torch.as_tensor(knots)) | |
| self.interior_params = None | |
| def get_knots(self, device: torch.device): | |
| # De momento no usamos knots entrenables; devolvemos buffer al dispositivo. | |
| return self.knots.to(device) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Acepta x con forma (batch, dims) o (batch*?,) -> normaliza formato | |
| if x.dim() == 1: | |
| x = x.unsqueeze(1) | |
| batch, dims = x.shape | |
| device = x.device | |
| x = x.clamp(0.0, 1.0) | |
| knots = self.get_knots(device) # (K,) | |
| left_knots = knots[:-1] | |
| right_knots = knots[1:] | |
| x_exp = x.unsqueeze(-1) # (batch, dims, 1) | |
| left = left_knots.view(1, 1, -1) | |
| right = right_knots.view(1, 1, -1) | |
| # Inicialización: indicador de intervalo [left, right) | |
| N = ((x_exp >= left) & (x_exp < right)).float() # (batch, dims, K-1) | |
| # Manejar x == 1.0: poner en la última base | |
| is_one = (x_exp >= 1.0 - self.cfg.eps) | |
| if is_one.any(): | |
| N = N.clone() | |
| idx = is_one.expand_as(N) | |
| N[idx] = 0.0 | |
| # establecer la última celda de cada entrada que sea ~1 | |
| last_idx = is_one.squeeze(-1) # (batch, dims) | |
| N[last_idx, -1] = 1.0 | |
| # Cox-de Boor iterativo, con índices cuidados y denominadores estables | |
| for r in range(1, self.p + 1): | |
| L = N.shape[-1] # longitud actual de particiones | |
| k = L - 1 # longitud después de combinación | |
| if k <= 0: | |
| break | |
| N_left = N[..., :k] | |
| N_right = N[..., 1:L] | |
| denom_left = knots[r: r + k] - knots[:k] # (k,) | |
| denom_right = knots[r + 1: r + 1 + k] - knots[1: 1 + k] # (k,) | |
| denom_left = denom_left.view(1, 1, -1).to(device) | |
| denom_right = denom_right.view(1, 1, -1).to(device) | |
| a_num = (x_exp - left[..., :k]) | |
| b_num = (right[..., 1:L] - x_exp) | |
| mask_a = denom_left.abs() > self.cfg.eps | |
| mask_b = denom_right.abs() > self.cfg.eps | |
| a = torch.where(mask_a, (a_num / (denom_left + self.cfg.eps)) * N_left, torch.zeros_like(a_num)) | |
| b = torch.where(mask_b, (b_num / (denom_right + self.cfg.eps)) * N_right, torch.zeros_like(b_num)) | |
| N = a + b | |
| # Validación de tamaño de base y recorte si es necesario | |
| expected_len = self.K - 1 - self.p | |
| if N.shape[-1] != expected_len: | |
| # seguridad: recomputar o alertar; aquí recortamos/expandimos de forma segura | |
| if N.shape[-1] > self.num_basis: | |
| basis = N[..., :self.num_basis] | |
| else: | |
| # rellenar con ceros si es más corto (caso borde) | |
| pad = self.num_basis - N.shape[-1] | |
| basis = F.pad(N, (0, pad)) | |
| else: | |
| basis = N[..., :self.num_basis] | |
| return basis | |
| class ResidualTracker(nn.Module): | |
| """ | |
| Seguimiento EMA de normas residuales por camino/sección. | |
| Comentarios en español y manejo estable. | |
| """ | |
| def __init__(self, num_pathways: int, ema_decay: float = 0.99, eps: float = 1e-8): | |
| super().__init__() | |
| self.ema_decay = float(ema_decay) | |
| self.eps = float(eps) | |
| self.num_pathways = int(num_pathways) | |
| self.register_buffer('residual_ema', torch.ones(self.num_pathways) * 1e-2) | |
| self.register_buffer('update_count', torch.tensor(0, dtype=torch.long)) | |
| def update(self, residuals: torch.Tensor): | |
| # residuals con forma (batch, num_sectors, feat) | |
| with torch.no_grad(): | |
| # calcular norma por pathway sumando batch y features | |
| flat = residuals.view(residuals.shape[0], residuals.shape[1], -1) | |
| norms = torch.norm(flat, dim=(0, 2)) # (num_sectors,) | |
| self.residual_ema = self.ema_decay * self.residual_ema + (1.0 - self.ema_decay) * norms | |
| self.update_count += 1 | |
| def get_scores(self, invert: bool = True) -> torch.Tensor: | |
| if invert: | |
| s = 1.0 / (self.residual_ema + self.eps) | |
| else: | |
| s = self.residual_ema.clone() | |
| s = s / (s.sum() + self.eps) | |
| return s | |
| class SectoredBSplineLayer(nn.Module): | |
| """ | |
| Capa por sectores que implementa: | |
| - rutas polimórficas por sector (varias sub-rutas lineal->cubo->lineal), | |
| - temporización (phase, gain, timing logits), | |
| - mezcla residual entre input remapeado y salida de ruta. | |
| Comentarios en español. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| num_sectors: int = 4, | |
| bspline_cfg: Optional[BSplineConfig] = None, | |
| parent_layer: Optional['SectoredBSplineLayer'] = None, | |
| ema_decay: float = 0.99, | |
| polymorph_ways: int = 2 # número de rutas polimórficas por sector | |
| ): | |
| super().__init__() | |
| assert input_dim >= num_sectors and output_dim >= num_sectors | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.num_sectors = num_sectors | |
| self.bspline_cfg = bspline_cfg or BSplineConfig() | |
| self.num_basis = self.bspline_cfg.num_basis | |
| self.polymorph_ways = polymorph_ways | |
| base_in = input_dim // num_sectors | |
| in_rem = input_dim % num_sectors | |
| self.sector_input_sizes = [base_in + (1 if i < in_rem else 0) for i in range(num_sectors)] | |
| base_out = output_dim // num_sectors | |
| out_rem = output_dim % num_sectors | |
| self.sector_output_sizes = [base_out + (1 if i < out_rem else 0) for i in range(num_sectors)] | |
| # pre y post arranges (remap de entradas y salidas por sector) | |
| self.pre_arrange = nn.ModuleList([ | |
| nn.Linear(self.sector_input_sizes[i], self.sector_input_sizes[i]) for i in range(num_sectors) | |
| ]) | |
| self.post_arrange = nn.ModuleList([ | |
| nn.Linear(self.sector_output_sizes[i], self.sector_output_sizes[i]) for i in range(num_sectors) | |
| ]) | |
| # puntos de control por sector (num_basis x sector_out) | |
| self.control_points = nn.ParameterList([ | |
| nn.Parameter(torch.randn(self.num_basis, self.sector_output_sizes[i]) * 0.01) | |
| for i in range(num_sectors) | |
| ]) | |
| # logits para seleccionar entre rutas polimórficas por sector (entrenables) | |
| self.poly_logits = nn.Parameter(torch.zeros(num_sectors, polymorph_ways)) | |
| # parámetros de temporización y fase por sector | |
| self.timing_logits = nn.Parameter(torch.zeros(num_sectors)) | |
| self.phase = nn.Parameter(torch.zeros(num_sectors)) | |
| self.gain = nn.Parameter(torch.ones(num_sectors)) | |
| # puerta residual por sector (mezcla input->output) | |
| self.residual_gate = nn.Parameter(torch.ones(num_sectors) * 0.5) | |
| # usar la misma BSpline si existe en padre, sino crear nueva | |
| if parent_layer is not None and isinstance(parent_layer.bspline_basis, BSplineBasis): | |
| self.bspline_basis = parent_layer.bspline_basis | |
| else: | |
| self.bspline_basis = BSplineBasis(self.bspline_cfg) | |
| self.residual_tracker = ResidualTracker(num_sectors, ema_decay=ema_decay) | |
| self.register_buffer('_last_residuals', torch.zeros(1)) | |
| # transformaciones polimórficas por sector: lista de ModuleList de rutas | |
| self.polymorph_routes = nn.ModuleList() | |
| for i in range(num_sectors): | |
| ways = nn.ModuleList() | |
| for w in range(polymorph_ways): | |
| # cada ruta: Linear -> (posible combinación) -> Linear | |
| ways.append(nn.Sequential( | |
| nn.Linear(self.sector_output_sizes[i], self.sector_output_sizes[i]), | |
| nn.Linear(self.sector_output_sizes[i], self.sector_output_sizes[i]) | |
| )) | |
| self.polymorph_routes.append(ways) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch = x.shape[0] | |
| device = x.device | |
| x_sectors = torch.split(x, self.sector_input_sizes, dim=1) | |
| timing_w = F.softmax(self.timing_logits, dim=0) # pesos de temporización por sector | |
| phase = self.phase | |
| gain = self.gain | |
| residual_gate = torch.sigmoid(self.residual_gate) # a [0,1] | |
| outputs = [] | |
| residuals_list = [] | |
| for i, xsec in enumerate(x_sectors): | |
| # remapeo de entrada | |
| x_pre = self.pre_arrange[i](xsec) # (batch, sin) | |
| # normalizar por canal para evaluar B-spline en [0,1] | |
| x_min = x_pre.amin(dim=1, keepdim=True) | |
| x_max = x_pre.amax(dim=1, keepdim=True) | |
| denom = (x_max - x_min).clamp(min=self.bspline_cfg.eps) | |
| x_norm = (x_pre - x_min) / denom # (batch, sin) | |
| # evaluar base B-spline; reduce canales promediando (comportamiento original) | |
| basis = self.bspline_basis(x_norm) | |
| basis_reduced = basis.mean(dim=1) # (batch, num_basis) | |
| # proyección por puntos de control | |
| cp = self.control_points[i] # (num_basis, sout) | |
| sector_main = torch.matmul(basis_reduced, cp) # (batch, sout) | |
| # ruta polimórfica: mezclar entre varias sub-rutas usando poly_logits | |
| poly_w = F.softmax(self.poly_logits[i], dim=0) # (ways,) | |
| route_outputs = [] | |
| for w, route in enumerate(self.polymorph_routes[i]): | |
| r_out = route[0](sector_main) # primera linear | |
| # timed cubed: combinar con partner, aplicar rotación simple por phase y cubicar | |
| W = self.post_arrange[i].weight # (sout, sout) | |
| partner = torch.matmul(r_out, W.t()) | |
| theta = phase[i] + float(w) * 0.1 # pequeña variación por ruta para diversidad | |
| cos_t = torch.cos(theta) | |
| sin_t = torch.sin(theta) | |
| mixed = cos_t * r_out + sin_t * partner | |
| # cubicidad (timed cubed): elevar al cubo para intensidad no lineal | |
| cubic = mixed * mixed * mixed # r^3 | |
| r_out2 = route[1](cubic) # segunda linear | |
| route_outputs.append(r_out2 * poly_w[w]) | |
| # sumar rutas polimórficas | |
| poly_sum = torch.stack(route_outputs, dim=0).sum(dim=0) # (batch, sout) | |
| # post-arrange final | |
| sector_out = self.post_arrange[i](poly_sum) | |
| # residual: mezclar con una proyección corta de la entrada (skip) | |
| skip = F.linear(x_pre, torch.eye(x_pre.shape[1], device=device)[:x_pre.shape[1], :x_pre.shape[1]]) | |
| # proyectar skip a tamaño de salida si es necesario | |
| if skip.shape[1] != sector_out.shape[1]: | |
| # usar una proyección simple (lineal) rápida | |
| skip = F.pad(skip, (0, sector_out.shape[1] - skip.shape[1])) | |
| # mezcla final con puerta residual y temporización/gain | |
| modulated = gain[i] * sector_out * timing_w[i] | |
| out_i = residual_gate[i] * modulated + (1.0 - residual_gate[i]) * skip | |
| outputs.append(out_i) | |
| residuals_list.append(out_i) # aquí consideramos el output como residual para tracking | |
| out = torch.cat(outputs, dim=1) | |
| residuals = torch.stack(residuals_list, dim=1) # (batch, num_sectors, sout) | |
| self.residual_tracker.update(residuals) | |
| self._last_residuals = residuals.detach() | |
| return out, residuals | |
| def get_gradient_scores(self) -> torch.Tensor: | |
| return self.residual_tracker.get_scores() | |
| class HierarchicalBSplineNN(nn.Module): | |
| """ | |
| Red jerárquica usando capas SectoredBSplineLayer y una capa de salida lineal. | |
| Comentarios en español. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dims: List[int], | |
| output_dim: int, | |
| num_sectors: int = 4, | |
| bspline_cfg: Optional[BSplineConfig] = None, | |
| use_adaptive_backprop: bool = True, | |
| ): | |
| super().__init__() | |
| self.use_adaptive_backprop = use_adaptive_backprop | |
| self.layers = nn.ModuleList() | |
| bs_cfg = bspline_cfg or BSplineConfig() | |
| in_dim = input_dim | |
| for h in hidden_dims: | |
| layer = SectoredBSplineLayer(in_dim, h, num_sectors=num_sectors, bspline_cfg=bs_cfg) | |
| self.layers.append(layer) | |
| in_dim = h | |
| self.output_layer = nn.Linear(in_dim, output_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| self._layer_residuals = [] | |
| out = x | |
| for layer in self.layers: | |
| out, residuals = layer(out) | |
| self._layer_residuals.append(residuals) | |
| out = F.relu(out) | |
| out = self.output_layer(out) | |
| return out | |
| def apply_adaptive_reweighting(self, optimizer: torch.optim.Optimizer, temperature: float = 1.0, reweight_all: bool = True): | |
| """ | |
| Escalado de gradientes por score de pathways (comportamiento por defecto solicitado). | |
| Esto modifica gradientes in-place antes de optimizer.step(). | |
| """ | |
| for layer in self.layers: | |
| if not isinstance(layer, SectoredBSplineLayer): | |
| continue | |
| scores = layer.get_gradient_scores() | |
| scores = scores.pow(1.0 / max(temperature, 1e-6)) | |
| scores = scores / (scores.sum() + 1e-8) | |
| for i in range(layer.num_sectors): | |
| cp = layer.control_points[i] | |
| if cp.grad is not None: | |
| cp.grad = cp.grad * scores[i] | |
| if reweight_all: | |
| for pname, p in layer.pre_arrange[i].named_parameters(): | |
| if p.grad is not None: | |
| p.grad = p.grad * scores[i] | |
| for pname, p in layer.post_arrange[i].named_parameters(): | |
| if p.grad is not None: | |
| p.grad = p.grad * scores[i] | |
| # reweighted scalars | |
| for scalar_p in (layer.timing_logits, layer.phase, layer.gain, layer.poly_logits, layer.residual_gate): | |
| if scalar_p.grad is not None: | |
| scalar_p.grad = scalar_p.grad * scores[i] | |
| def get_pathway_statistics(self) -> Dict[int, Dict[str, torch.Tensor]]: | |
| stats = {} | |
| for idx, layer in enumerate(self.layers): | |
| if isinstance(layer, SectoredBSplineLayer): | |
| stats[idx] = { | |
| 'residual_ema': layer.residual_tracker.residual_ema.clone().detach(), | |
| 'cluster_scores': layer.get_gradient_scores().clone().detach() | |
| } | |
| return stats | |
| if __name__ == "__main__": | |
| # Prueba mínima: ver que forward, backward y reweight funcionen sin errores. | |
| torch.manual_seed(0) | |
| cfg = BSplineConfig(degree=3, num_basis=8, knot_distribution="clamped", trainable_knots=False) | |
| model = HierarchicalBSplineNN(input_dim=16, hidden_dims=[32, 32, 16], output_dim=4, num_sectors=4, bspline_cfg=cfg) | |
| x = torch.randn(8, 16) | |
| y = torch.randn(8, 4) | |
| opt = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| out = model(x) | |
| loss = F.mse_loss(out, y) | |
| opt.zero_grad() | |
| loss.backward() | |
| model.apply_adaptive_reweighting(opt, temperature=1.0) | |
| opt.step() | |
| print("Loss:", loss.item()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment