#!/usr/bin/env python3 """ CNN Training Script for Image-to-Image Transformation Trains a convolutional neural network on multiple input/target image pairs. Usage: python3 train_cnn.py --input input_dir/ --target target_dir/ [options] Example: python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import sys import argparse import glob class ImagePairDataset(Dataset): """Dataset for loading matching input/target image pairs""" def __init__(self, input_dir, target_dir, transform=None): self.input_dir = input_dir self.target_dir = target_dir self.transform = transform # Find all images in input directory input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] self.image_pairs = [] for pattern in input_patterns: input_files = glob.glob(os.path.join(input_dir, pattern)) for input_path in input_files: filename = os.path.basename(input_path) # Try to find matching target with same name but any supported extension target_path = None for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: base_name = os.path.splitext(filename)[0] candidate = os.path.join(target_dir, f"{base_name}.{ext}") if os.path.exists(candidate): target_path = candidate break if target_path: self.image_pairs.append((input_path, target_path)) if not self.image_pairs: raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") print(f"Found {len(self.image_pairs)} matching image pairs") def __len__(self): return len(self.image_pairs) def __getitem__(self, idx): input_path, target_path = self.image_pairs[idx] input_img = Image.open(input_path).convert('RGB') target_img = Image.open(target_path).convert('RGB') if self.transform: input_img = self.transform(input_img) target_img = self.transform(target_img) return input_img, target_img class CoordConv2d(nn.Module): """Conv2d that accepts coordinate input separate from spatial patches""" def __init__(self, in_channels, out_channels, kernel_size, padding=0): super().__init__() self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False) self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01) self.bias = nn.Parameter(torch.zeros(out_channels)) def forward(self, x, coords): # x: [B, C, H, W] image # coords: [B, 2, H, W] coordinate grid out = self.conv_rgba(x) B, C, H, W = out.shape coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights) out = out + coord_contrib + self.bias.view(1, -1, 1, 1) return out class SimpleCNN(nn.Module): """Simple CNN for image-to-image transformation""" def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() if kernel_sizes is None: kernel_sizes = [3] * num_layers assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" self.kernel_sizes = kernel_sizes self.layers = nn.ModuleList() for i, kernel_size in enumerate(kernel_sizes): padding = kernel_size // 2 if i == 0: self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding)) else: self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True)) self.use_residual = True def forward(self, x): B, C, H, W = x.shape y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) coords = torch.cat([x_coords, y_coords], dim=1) out = self.layers[0](x, coords) out = torch.tanh(out) for i in range(1, len(self.layers)): out = self.layers[i](out) if i < len(self.layers) - 1: out = torch.tanh(out) if self.use_residual: out = x + out * 0.3 return out def export_weights_to_wgsl(model, output_path, kernel_sizes): """Export trained weights to WGSL format""" with open(output_path, 'w') as f: f.write("// Auto-generated CNN weights\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") layer_idx = 0 for i, layer in enumerate(model.layers): if isinstance(layer, CoordConv2d): # Export RGBA weights weights = layer.conv_rgba.weight.data.cpu().numpy() kernel_size = kernel_sizes[layer_idx] out_ch, in_ch, kh, kw = weights.shape num_positions = kh * kw f.write(f"const rgba_weights_layer{layer_idx}: array, {num_positions}> = array(\n") for pos in range(num_positions): row = pos // kw col = pos % kw f.write(" mat4x4(\n") for out_c in range(min(4, out_ch)): vals = [] for in_c in range(min(4, in_ch)): vals.append(f"{weights[out_c, in_c, row, col]:.6f}") f.write(f" {', '.join(vals)},\n") f.write(" )") if pos < num_positions - 1: f.write(",\n") else: f.write("\n") f.write(");\n\n") # Export coordinate weights coord_w = layer.coord_weights.data.cpu().numpy() f.write(f"const coord_weights_layer{layer_idx} = mat2x4(\n") for c in range(2): vals = [f"{coord_w[out_c, c]:.6f}" for out_c in range(min(4, coord_w.shape[0]))] f.write(f" {', '.join(vals)}") if c < 1: f.write(",\n") else: f.write("\n") f.write(");\n\n") # Export bias bias = layer.bias.data.cpu().numpy() f.write(f"const bias_layer{layer_idx} = vec4(") f.write(", ".join([f"{b:.6f}" for b in bias[:4]])) f.write(");\n\n") layer_idx += 1 elif isinstance(layer, nn.Conv2d): # Standard conv layer weights = layer.weight.data.cpu().numpy() kernel_size = kernel_sizes[layer_idx] out_ch, in_ch, kh, kw = weights.shape num_positions = kh * kw f.write(f"const weights_layer{layer_idx}: array, {num_positions}> = array(\n") for pos in range(num_positions): row = pos // kw col = pos % kw f.write(" mat4x4(\n") for out_c in range(min(4, out_ch)): vals = [] for in_c in range(min(4, in_ch)): vals.append(f"{weights[out_c, in_c, row, col]:.6f}") f.write(f" {', '.join(vals)},\n") f.write(" )") if pos < num_positions - 1: f.write(",\n") else: f.write("\n") f.write(");\n\n") # Export bias bias = layer.bias.data.cpu().numpy() f.write(f"const bias_layer{layer_idx} = vec4(") f.write(", ".join([f"{b:.6f}" for b in bias[:4]])) f.write(");\n\n") layer_idx += 1 def train(args): """Main training loop""" # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Prepare dataset transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) dataset = ImagePairDataset(args.input, args.target, transform=transform) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] if len(kernel_sizes) == 1 and args.layers > 1: kernel_sizes = kernel_sizes * args.layers # Create model model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device) # Loss and optimizer criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) # Resume from checkpoint start_epoch = 0 if args.resume: if os.path.exists(args.resume): print(f"Loading checkpoint from {args.resume}...") checkpoint = torch.load(args.resume, map_location=device) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) start_epoch = checkpoint['epoch'] + 1 print(f"Resumed from epoch {start_epoch}") else: print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") # Training loop print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") # Save checkpoint if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') torch.save({ 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'loss': avg_loss, 'kernel_sizes': kernel_sizes, 'num_layers': args.layers }, checkpoint_path) print(f"Saved checkpoint to {checkpoint_path}") # Export weights output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"\nExporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) print("Training complete!") def main(): parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') parser.add_argument('--input', required=True, help='Input image directory') parser.add_argument('--target', required=True, help='Target image directory') parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)') parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') parser.add_argument('--resume', help='Resume from checkpoint file') args = parser.parse_args() # Validate directories if not os.path.isdir(args.input): print(f"Error: Input directory '{args.input}' does not exist") sys.exit(1) if not os.path.isdir(args.target): print(f"Error: Target directory '{args.target}' does not exist") sys.exit(1) train(args) if __name__ == "__main__": main()