summaryrefslogtreecommitdiff
path: root/cnn_v3
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-27 08:41:05 +0100
committerskal <pascal.massimino@gmail.com>2026-03-27 08:41:05 +0100
commit37df61d1a0dbd5e253f9db778c17c4187e453b8d (patch)
tree2879f69f35ddeff1ead30cb5099c1073c5ab376a /cnn_v3
parentfb13e67acbc7d7dd2974a456fcb134966c47cee0 (diff)
fix(cnn_v3): L1 loss + depth-grad tanh normalization to reduce flat convergenceHEADmain
- Switch MSELoss → L1Loss in train_cnn_v3.py (median-seeking, avoids gray-blob) - Normalize depth_grad channels with tanh(10x) in cnn_v3_utils.py (bounds ±∞ signed values) - Match normalization in gbuf_pack.wgsl: tanh((right-left)*5.0) == tanh(10*central_diff) handoff(Gemini): training pipeline only; no C++ or test changes needed.
Diffstat (limited to 'cnn_v3')
-rw-r--r--cnn_v3/shaders/gbuf_pack.wgsl5
-rw-r--r--cnn_v3/training/cnn_v3_utils.py2
-rw-r--r--cnn_v3/training/train_cnn_v3.py2
3 files changed, 5 insertions, 4 deletions
diff --git a/cnn_v3/shaders/gbuf_pack.wgsl b/cnn_v3/shaders/gbuf_pack.wgsl
index 777b4e5..5870938 100644
--- a/cnn_v3/shaders/gbuf_pack.wgsl
+++ b/cnn_v3/shaders/gbuf_pack.wgsl
@@ -60,8 +60,9 @@ fn pack_features(@builtin(global_invocation_id) id: vec3u) {
let depth_raw = load_depth(coord);
// Finite-difference depth gradient (central difference, clamped coords)
- let dzdx = (load_depth(coord + vec2i(1, 0)) - load_depth(coord - vec2i(1, 0))) * 0.5;
- let dzdy = (load_depth(coord + vec2i(0, 1)) - load_depth(coord - vec2i(0, 1))) * 0.5;
+ // tanh(10x) keeps typical gradients (±0.05–0.1) in [-1,1]; matches training normalization.
+ let dzdx = tanh((load_depth(coord + vec2i(1, 0)) - load_depth(coord - vec2i(1, 0))) * 5.0);
+ let dzdy = tanh((load_depth(coord + vec2i(0, 1)) - load_depth(coord - vec2i(0, 1))) * 5.0);
// Normal: stored as oct-encoded [0,1] in RG; extract just the encoded xy for feat_tex0
let normal_enc = nm.rg; // already in [0,1] — decode to get the xy for CNN input
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index 68c0798..bd0ca2f 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -140,7 +140,7 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray,
mip1 = _upsample_nearest(pyrdown(albedo), h, w)
mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
- dgrad = depth_gradient(depth)
+ dgrad = np.tanh(depth_gradient(depth) * 10.0)
if prev is None:
prev = np.zeros((h, w, 3), dtype=np.float32)
nor3 = oct_decode(normal)
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index fa0d2e2..5065f22 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -160,7 +160,7 @@ def train(args):
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
- criterion = nn.MSELoss()
+ criterion = nn.L1Loss()
ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
start_epoch = 1