summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/ground_truth.pngbin0 -> 127405 bytes
-rwxr-xr-xtraining/train_cnn.py59
2 files changed, 57 insertions, 2 deletions
diff --git a/training/ground_truth.png b/training/ground_truth.png
new file mode 100644
index 0000000..6e1f2aa
--- /dev/null
+++ b/training/ground_truth.png
Binary files differ
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 3312768..16f8e7a 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -5,10 +5,15 @@ CNN Training Script for Image-to-Image Transformation
Trains a convolutional neural network on multiple input/target image pairs.
Usage:
+ # Training
python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
+ # Inference (generate ground truth)
+ python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png
+
Example:
python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
+ python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth
"""
import torch
@@ -467,23 +472,73 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print("Export complete!")
+def infer_from_checkpoint(checkpoint_path, input_path, output_path):
+ """Run inference on single image to generate ground truth"""
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ if not os.path.exists(input_path):
+ print(f"Error: Input image '{input_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ # Reconstruct model
+ model = SimpleCNN(
+ num_layers=checkpoint['num_layers'],
+ kernel_sizes=checkpoint['kernel_sizes']
+ )
+ model.load_state_dict(checkpoint['model_state'])
+ model.eval()
+
+ # Load image [0,1]
+ print(f"Loading input image: {input_path}")
+ img = Image.open(input_path).convert('RGBA')
+ img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
+
+ # Inference
+ print("Running inference...")
+ with torch.no_grad():
+ out = model(img_tensor) # [1,3,H,W] in [0,1]
+
+ # Save
+ print(f"Saving output to: {output_path}")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ transforms.ToPILImage()(out.squeeze(0)).save(output_path)
+ print("Done!")
+
+
def main():
parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
- parser.add_argument('--input', help='Input image directory')
+ parser.add_argument('--input', help='Input image directory (training) or single image (inference)')
parser.add_argument('--target', help='Target image directory')
parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
- parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)')
+ parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)')
parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
parser.add_argument('--resume', help='Resume from checkpoint file')
parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
+ parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)')
args = parser.parse_args()
+ # Inference mode
+ if args.infer:
+ checkpoint = args.export_only
+ if not checkpoint:
+ print("Error: --infer requires --export-only <checkpoint>")
+ sys.exit(1)
+ output_path = args.output or 'inference_output.png'
+ infer_from_checkpoint(checkpoint, args.infer, output_path)
+ return
+
# Export-only mode
if args.export_only:
export_from_checkpoint(args.export_only, args.output)