From 0d5775a2295ed9323330ecf577c7927c8d4b65e8 Mon Sep 17 00:00:00 2001 From: skal Date: Tue, 10 Feb 2026 20:04:50 +0100 Subject: feat: Add inference mode to train_cnn.py for ground truth generation - Added --infer flag for single-image inference - Loads checkpoint, runs forward pass, saves PNG output - Useful for verifying shader matches trained model Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) (limited to 'training/train_cnn.py') diff --git a/training/train_cnn.py b/training/train_cnn.py index 3312768..16f8e7a 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -5,10 +5,15 @@ CNN Training Script for Image-to-Image Transformation Trains a convolutional neural network on multiple input/target image pairs. Usage: + # Training python3 train_cnn.py --input input_dir/ --target target_dir/ [options] + # Inference (generate ground truth) + python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png + Example: python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 + python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth """ import torch @@ -467,23 +472,73 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print("Export complete!") +def infer_from_checkpoint(checkpoint_path, input_path, output_path): + """Run inference on single image to generate ground truth""" + + if not os.path.exists(checkpoint_path): + print(f"Error: Checkpoint '{checkpoint_path}' not found") + sys.exit(1) + + if not os.path.exists(input_path): + print(f"Error: Input image '{input_path}' not found") + sys.exit(1) + + print(f"Loading checkpoint from {checkpoint_path}...") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Reconstruct model + model = SimpleCNN( + num_layers=checkpoint['num_layers'], + kernel_sizes=checkpoint['kernel_sizes'] + ) + model.load_state_dict(checkpoint['model_state']) + model.eval() + + # Load image [0,1] + print(f"Loading input image: {input_path}") + img = Image.open(input_path).convert('RGBA') + img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W] + + # Inference + print("Running inference...") + with torch.no_grad(): + out = model(img_tensor) # [1,3,H,W] in [0,1] + + # Save + print(f"Saving output to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + transforms.ToPILImage()(out.squeeze(0)).save(output_path) + print("Done!") + + def main(): parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') - parser.add_argument('--input', help='Input image directory') + parser.add_argument('--input', help='Input image directory (training) or single image (inference)') parser.add_argument('--target', help='Target image directory') parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') - parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)') + parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)') parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') parser.add_argument('--resume', help='Resume from checkpoint file') parser.add_argument('--export-only', help='Export WGSL from checkpoint without training') + parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)') args = parser.parse_args() + # Inference mode + if args.infer: + checkpoint = args.export_only + if not checkpoint: + print("Error: --infer requires --export-only ") + sys.exit(1) + output_path = args.output or 'inference_output.png' + infer_from_checkpoint(checkpoint, args.infer, output_path) + return + # Export-only mode if args.export_only: export_from_checkpoint(args.export_only, args.output) -- cgit v1.2.3