diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-13 17:46:09 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-13 17:46:09 +0100 |
| commit | a7340d378909cadbfd72dbd1f5b756f907c2a3e0 (patch) | |
| tree | 60a34dd084d746f0c6dad50b0d5cc7f20bc0c409 /training/train_cnn_v2.py | |
| parent | f6b3ea72a03850654b69986bc82bb249aaabe2e3 (diff) | |
CNN v2 training: Add --grayscale-loss option for luminance-based loss computation
Add option to compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of full RGBA channels. Useful for training models that prioritize luminance accuracy over color accuracy.
Changes:
- training/train_cnn_v2.py: Add --grayscale-loss flag and grayscale conversion in loss computation
- scripts/train_cnn_v2_full.sh: Add --grayscale-loss parameter support
- doc/CNN_V2.md: Document grayscale loss in training configuration and checkpoint format
- doc/HOWTO.md: Add usage examples for --grayscale-loss flag
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index a9a311a..abe07bc 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -350,7 +350,16 @@ def train(args): optimizer.zero_grad() output = model(input_rgbd, static_feat) - loss = criterion(output, target) + + # Compute loss (grayscale or RGBA) + if args.grayscale_loss: + # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B + output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3] + target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3] + loss = criterion(output_gray, target_gray) + else: + loss = criterion(output, target) + loss.backward() optimizer.step() @@ -376,6 +385,7 @@ def train(args): 'kernel_sizes': kernel_sizes, 'num_layers': args.num_layers, 'mip_level': args.mip_level, + 'grayscale_loss': args.grayscale_loss, 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias'] } }, checkpoint_path) @@ -419,6 +429,8 @@ def main(): parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') parser.add_argument('--batch-size', type=int, default=16, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') + parser.add_argument('--grayscale-loss', action='store_true', + help='Compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of RGBA') parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', help='Checkpoint directory') parser.add_argument('--checkpoint-every', type=int, default=1000, |
