summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-27 08:41:05 +0100
committerskal <pascal.massimino@gmail.com>2026-03-27 08:41:05 +0100
commit37df61d1a0dbd5e253f9db778c17c4187e453b8d (patch)
tree2879f69f35ddeff1ead30cb5099c1073c5ab376a /cnn_v3/training/train_cnn_v3.py
parentfb13e67acbc7d7dd2974a456fcb134966c47cee0 (diff)
fix(cnn_v3): L1 loss + depth-grad tanh normalization to reduce flat convergenceHEADmain
- Switch MSELoss → L1Loss in train_cnn_v3.py (median-seeking, avoids gray-blob) - Normalize depth_grad channels with tanh(10x) in cnn_v3_utils.py (bounds ±∞ signed values) - Match normalization in gbuf_pack.wgsl: tanh((right-left)*5.0) == tanh(10*central_diff) handoff(Gemini): training pipeline only; no C++ or test changes needed.
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index fa0d2e2..5065f22 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -160,7 +160,7 @@ def train(args):
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
- criterion = nn.MSELoss()
+ criterion = nn.L1Loss()
ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
start_epoch = 1