Created
April 28, 2026 12:50
-
-
Save sieste/ab753fc7ad0e95c258fcf8709d34fc34 to your computer and use it in GitHub Desktop.
torch conv2d for wind prediction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| import xarray as xr | |
| import glob | |
| from tqdm import tqdm | |
| # ----------------------- | |
| # Dataset | |
| # ----------------------- | |
| class UV850Dataset(Dataset): | |
| def __init__(self, path_pattern): | |
| self.files = sorted(glob.glob(path_pattern)) | |
| self.samples = [] | |
| for f in self.files: | |
| ds = xr.open_dataset(f, decode_timedelta=False) | |
| u = ds['u'].values # shape: time, lat, lon | |
| v = ds['v'].values | |
| for i in range(len(u)-1): | |
| self.samples.append((u[i], v[i], u[i+1], v[i+1])) | |
| ds.close() | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| u0, v0, u1, v1 = self.samples[idx] | |
| x = torch.tensor([u0, v0], dtype=torch.float32) | |
| y = torch.tensor([u1, v1], dtype=torch.float32) # target Δu, Δv | |
| return x, y | |
| # ----------------------- | |
| # Model | |
| # ----------------------- | |
| class Block(nn.Module): | |
| def __init__(self, ic, oc): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(ic, ic, 3, padding=1) | |
| self.conv2 = nn.Conv2d(ic, oc, 3, padding=1) | |
| self.act = nn.ReLU() | |
| def forward(self, x): | |
| y = self.act(self.conv1(x)) | |
| y = self.conv2(y) | |
| return self.act(y) | |
| class Model(nn.Module): | |
| def __init__(self, seed=0, nh=16, coarsen=1): | |
| super().__init__() | |
| self.coarsen = coarsen | |
| old_rng_state = torch.get_rng_state() | |
| g = torch.Generator().manual_seed(seed) | |
| torch.set_rng_state(g.get_state()) | |
| self.down = nn.AvgPool2d(kernel_size=coarsen) | |
| self.up = nn.Upsample(scale_factor=coarsen) | |
| self.inp = Block(2,nh) | |
| self.b1 = Block(nh,nh) | |
| self.b2 = Block(nh,nh) | |
| self.out = Block(nh,2) | |
| torch.set_rng_state(old_rng_state) | |
| def forward(self, x): | |
| x = self.down(x) | |
| h = self.inp(x) | |
| h = self.b1(h) | |
| h = self.b2(h) | |
| # h = self.b3(h) | |
| h = self.up(h) | |
| return self.out(h) | |
| # ----------------------- | |
| # Training loop | |
| # ----------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load only 2005 ... 2009 data | |
| nc_files_train = "/home/stefan/folders/era5-wind-data-download/era5_uv850_6hourly_global1deg_200[0-9]*.nc" | |
| train_ds = UV850Dataset(nc_files_train) | |
| train_loader = DataLoader(train_ds, batch_size=8, shuffle=True) | |
| x_, y_ = next(iter(train_loader)) # one input/output batch, for testing | |
| # FIXME: integrate data coarsening into data loader | |
| nc_files_valid = "/home/stefan/folders/era5-wind-data-download/era5_uv850_6hourly_global1deg_20[10-15]*.nc" | |
| valid_ds = UV850Dataset(nc_files_valid) | |
| valid_loader = DataLoader(valid_ds, batch_size=8, shuffle=True) | |
| coarse_level = 10 | |
| model = Model(coarsen=coarse_level).to(device) | |
| opt = torch.optim.Adam(model.parameters(), lr=3e-4) | |
| loss_fn = nn.MSELoss() | |
| for epoch in range(10): | |
| ctr = 0 | |
| for x, y in tqdm(train_loader, desc=f"epoch {epoch}"): | |
| x = x.to(device) | |
| y = y.to(device) | |
| opt.zero_grad() | |
| pred = model(x) | |
| loss = loss_fn(pred, y) | |
| loss.backward() | |
| opt.step() | |
| ctr += 1 | |
| if ctr % 10 == 0: | |
| with torch.no_grad(): | |
| loss_val = [] | |
| loss_ref = [] | |
| for x, y in tqdm(valid_loader, desc=f"epoch {epoch}"): | |
| pred = model(x) | |
| xc = nn.Upsample(scale_factor=coarse_level)(nn.AvgPool2d(kernel_size=coarse_level)(x)) | |
| loss_val.append(loss_fn(pred, y)) | |
| loss_ref.append(loss_fn(xc, y)) | |
| if len(loss_val) > 5: | |
| break | |
| print(f"validation skill: {(1 - sum(loss_val) / sum(loss_ref)):.2f}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment