diff options
Diffstat (limited to 'cnn_v3/training/gen_test_vectors.py')
| -rw-r--r-- | cnn_v3/training/gen_test_vectors.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index 3f81247..96b175a 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -193,8 +193,8 @@ def dec1_forward(bn, enc1, w, gamma, beta): def dec0_forward(dec1, enc0, w, gamma, beta): """ - NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid - → rgba16float (full-res, final output). + NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + sigmoid + → rgba16float (full-res, final output). No ReLU before sigmoid. dec1: (hH, hW, 8) f32 — half-res enc0: (H, W, 8) f32 — full-res enc0 skip """ @@ -219,8 +219,8 @@ def dec0_forward(dec1, enc0, w, gamma, beta): for kx in range(3): wv = get_w(w, wo, o * DEC0_IN * 9 + i * 9 + ky * 3 + kx) s += wv * fp[ky:ky+H, kx:kx+W, i] - # FiLM + ReLU + sigmoid (matches WGSL dec0 shader) - v = np.maximum(0.0, gamma[o] * s + beta[o]) + # FiLM + sigmoid (matches WGSL dec0 shader — no ReLU before sigmoid) + v = gamma[o] * s + beta[o] out[:, :, o] = 1.0 / (1.0 + np.exp(-v.astype(np.float64))).astype(np.float32) return np.float16(out).astype(np.float32) # rgba16float boundary |
