summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 17:46:09 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 17:46:09 +0100
commita7340d378909cadbfd72dbd1f5b756f907c2a3e0 (patch)
tree60a34dd084d746f0c6dad50b0d5cc7f20bc0c409 /training
parentf6b3ea72a03850654b69986bc82bb249aaabe2e3 (diff)
CNN v2 training: Add --grayscale-loss option for luminance-based loss computation
Add option to compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of full RGBA channels. Useful for training models that prioritize luminance accuracy over color accuracy. Changes: - training/train_cnn_v2.py: Add --grayscale-loss flag and grayscale conversion in loss computation - scripts/train_cnn_v2_full.sh: Add --grayscale-loss parameter support - doc/CNN_V2.md: Document grayscale loss in training configuration and checkpoint format - doc/HOWTO.md: Add usage examples for --grayscale-loss flag Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index a9a311a..abe07bc 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -350,7 +350,16 @@ def train(args):
optimizer.zero_grad()
output = model(input_rgbd, static_feat)
- loss = criterion(output, target)
+
+ # Compute loss (grayscale or RGBA)
+ if args.grayscale_loss:
+ # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B
+ output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3]
+ target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3]
+ loss = criterion(output_gray, target_gray)
+ else:
+ loss = criterion(output, target)
+
loss.backward()
optimizer.step()
@@ -376,6 +385,7 @@ def train(args):
'kernel_sizes': kernel_sizes,
'num_layers': args.num_layers,
'mip_level': args.mip_level,
+ 'grayscale_loss': args.grayscale_loss,
'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
}
}, checkpoint_path)
@@ -419,6 +429,8 @@ def main():
parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
+ parser.add_argument('--grayscale-loss', action='store_true',
+ help='Compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of RGBA')
parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
help='Checkpoint directory')
parser.add_argument('--checkpoint-every', type=int, default=1000,