summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/gen_test_vectors.py8
-rw-r--r--cnn_v3/training/train_cnn_v3.py7
2 files changed, 7 insertions, 8 deletions
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index 3f81247..96b175a 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -193,8 +193,8 @@ def dec1_forward(bn, enc1, w, gamma, beta):
def dec0_forward(dec1, enc0, w, gamma, beta):
"""
- NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid
- → rgba16float (full-res, final output).
+ NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + sigmoid
+ → rgba16float (full-res, final output). No ReLU before sigmoid.
dec1: (hH, hW, 8) f32 — half-res
enc0: (H, W, 8) f32 — full-res enc0 skip
"""
@@ -219,8 +219,8 @@ def dec0_forward(dec1, enc0, w, gamma, beta):
for kx in range(3):
wv = get_w(w, wo, o * DEC0_IN * 9 + i * 9 + ky * 3 + kx)
s += wv * fp[ky:ky+H, kx:kx+W, i]
- # FiLM + ReLU + sigmoid (matches WGSL dec0 shader)
- v = np.maximum(0.0, gamma[o] * s + beta[o])
+ # FiLM + sigmoid (matches WGSL dec0 shader — no ReLU before sigmoid)
+ v = gamma[o] * s + beta[o]
out[:, :, o] = 1.0 / (1.0 + np.exp(-v.astype(np.float64))).astype(np.float32)
return np.float16(out).astype(np.float32) # rgba16float boundary
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index e48f684..fa0d2e2 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -10,8 +10,7 @@ Architecture (enc_channels=[8,16]):
enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split)
bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split)
dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch)
- dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch)
- output sigmoid → RGBA
+ dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM + sigmoid H×W rgba16float (4ch)
FiLM MLP: Linear(5→16) → ReLU → Linear(16→72)
72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4)
@@ -93,9 +92,9 @@ class CNNv3(nn.Module):
torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
), gd1, bd1))
- x = F.relu(film_apply(self.dec0(
+ x = film_apply(self.dec0(
torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
- ), gd0, bd0))
+ ), gd0, bd0)
return torch.sigmoid(x)