Skip to content

Instantly share code, notes, and snippets.

@sieste
Created April 28, 2026 12:50
Show Gist options
  • Select an option

  • Save sieste/ab753fc7ad0e95c258fcf8709d34fc34 to your computer and use it in GitHub Desktop.

Select an option

Save sieste/ab753fc7ad0e95c258fcf8709d34fc34 to your computer and use it in GitHub Desktop.
torch conv2d for wind prediction
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import glob
from tqdm import tqdm
# -----------------------
# Dataset
# -----------------------
class UV850Dataset(Dataset):
def __init__(self, path_pattern):
self.files = sorted(glob.glob(path_pattern))
self.samples = []
for f in self.files:
ds = xr.open_dataset(f, decode_timedelta=False)
u = ds['u'].values # shape: time, lat, lon
v = ds['v'].values
for i in range(len(u)-1):
self.samples.append((u[i], v[i], u[i+1], v[i+1]))
ds.close()
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
u0, v0, u1, v1 = self.samples[idx]
x = torch.tensor([u0, v0], dtype=torch.float32)
y = torch.tensor([u1, v1], dtype=torch.float32) # target Δu, Δv
return x, y
# -----------------------
# Model
# -----------------------
class Block(nn.Module):
def __init__(self, ic, oc):
super().__init__()
self.conv1 = nn.Conv2d(ic, ic, 3, padding=1)
self.conv2 = nn.Conv2d(ic, oc, 3, padding=1)
self.act = nn.ReLU()
def forward(self, x):
y = self.act(self.conv1(x))
y = self.conv2(y)
return self.act(y)
class Model(nn.Module):
def __init__(self, seed=0, nh=16, coarsen=1):
super().__init__()
self.coarsen = coarsen
old_rng_state = torch.get_rng_state()
g = torch.Generator().manual_seed(seed)
torch.set_rng_state(g.get_state())
self.down = nn.AvgPool2d(kernel_size=coarsen)
self.up = nn.Upsample(scale_factor=coarsen)
self.inp = Block(2,nh)
self.b1 = Block(nh,nh)
self.b2 = Block(nh,nh)
self.out = Block(nh,2)
torch.set_rng_state(old_rng_state)
def forward(self, x):
x = self.down(x)
h = self.inp(x)
h = self.b1(h)
h = self.b2(h)
# h = self.b3(h)
h = self.up(h)
return self.out(h)
# -----------------------
# Training loop
# -----------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# load only 2005 ... 2009 data
nc_files_train = "/home/stefan/folders/era5-wind-data-download/era5_uv850_6hourly_global1deg_200[0-9]*.nc"
train_ds = UV850Dataset(nc_files_train)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
x_, y_ = next(iter(train_loader)) # one input/output batch, for testing
# FIXME: integrate data coarsening into data loader
nc_files_valid = "/home/stefan/folders/era5-wind-data-download/era5_uv850_6hourly_global1deg_20[10-15]*.nc"
valid_ds = UV850Dataset(nc_files_valid)
valid_loader = DataLoader(valid_ds, batch_size=8, shuffle=True)
coarse_level = 10
model = Model(coarsen=coarse_level).to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = nn.MSELoss()
for epoch in range(10):
ctr = 0
for x, y in tqdm(train_loader, desc=f"epoch {epoch}"):
x = x.to(device)
y = y.to(device)
opt.zero_grad()
pred = model(x)
loss = loss_fn(pred, y)
loss.backward()
opt.step()
ctr += 1
if ctr % 10 == 0:
with torch.no_grad():
loss_val = []
loss_ref = []
for x, y in tqdm(valid_loader, desc=f"epoch {epoch}"):
pred = model(x)
xc = nn.Upsample(scale_factor=coarse_level)(nn.AvgPool2d(kernel_size=coarse_level)(x))
loss_val.append(loss_fn(pred, y))
loss_ref.append(loss_fn(xc, y))
if len(loss_val) > 5:
break
print(f"validation skill: {(1 - sum(loss_val) / sum(loss_ref)):.2f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment