diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 10:27:44 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 10:27:44 +0100 |
| commit | 96a349b9874c6cdaac525ba062a0f4f90c9bc3ed (patch) | |
| tree | a4eb24fdb417393cbe5a0dc84bf5063cffc94daf /training/train_cnn.py | |
| parent | 75af266889b61b5722d842a1a1eb23f79bc06a85 (diff) | |
feat: Add coordinate-aware CNN layer 0 for position-dependent stylization
- Implement CoordConv2d custom layer accepting (x,y) patch center
- Split layer 0 weights: rgba_weights (9x mat4x4) + coord_weights (mat2x4)
- Add *_with_coord() functions to 3x3/5x5/7x7 convolution shaders
- Update training script to generate coordinate grid and export split weights
- Regenerate placeholder weights with new format
Size impact: +32B coord weights + ~100B shader code = +132B total
All 36 tests passing (100%)
handoff(Claude): CNN coordinate awareness implemented, ready for training
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 301 |
1 files changed, 301 insertions, 0 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py new file mode 100755 index 0000000..4fc3a6c --- /dev/null +++ b/training/train_cnn.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +""" +CNN Training Script for Image-to-Image Transformation + +Trains a convolutional neural network on multiple input/target image pairs. + +Usage: + python3 train_cnn.py --input input_dir/ --target target_dir/ [options] + +Example: + python3 train_cnn.py --input ./training/input --target ./training/output --layers 3 --epochs 100 +""" + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image +import os +import sys +import argparse +import glob + + +class ImagePairDataset(Dataset): + """Dataset for loading matching input/target image pairs""" + + def __init__(self, input_dir, target_dir, transform=None): + self.input_dir = input_dir + self.target_dir = target_dir + self.transform = transform + + # Find all images in input directory + input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] + self.image_pairs = [] + + for pattern in input_patterns: + input_files = glob.glob(os.path.join(input_dir, pattern)) + for input_path in input_files: + filename = os.path.basename(input_path) + # Try to find matching target with same name but any supported extension + target_path = None + for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: + base_name = os.path.splitext(filename)[0] + candidate = os.path.join(target_dir, f"{base_name}.{ext}") + if os.path.exists(candidate): + target_path = candidate + break + + if target_path: + self.image_pairs.append((input_path, target_path)) + + if not self.image_pairs: + raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") + + print(f"Found {len(self.image_pairs)} matching image pairs") + + def __len__(self): + return len(self.image_pairs) + + def __getitem__(self, idx): + input_path, target_path = self.image_pairs[idx] + + input_img = Image.open(input_path).convert('RGB') + target_img = Image.open(target_path).convert('RGB') + + if self.transform: + input_img = self.transform(input_img) + target_img = self.transform(target_img) + + return input_img, target_img + + +class CoordConv2d(nn.Module): + """Conv2d that accepts coordinate input separate from spatial patches""" + + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super().__init__() + self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False) + self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01) + self.bias = nn.Parameter(torch.zeros(out_channels)) + + def forward(self, x, coords): + # x: [B, C, H, W] image + # coords: [B, 2, H, W] coordinate grid + out = self.conv_rgba(x) + B, C, H, W = out.shape + coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights) + out = out + coord_contrib + self.bias.view(1, -1, 1, 1) + return out + + +class SimpleCNN(nn.Module): + """Simple CNN for image-to-image transformation""" + + def __init__(self, num_layers=1, kernel_sizes=None): + super(SimpleCNN, self).__init__() + + if kernel_sizes is None: + kernel_sizes = [3] * num_layers + + assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" + + self.kernel_sizes = kernel_sizes + self.layers = nn.ModuleList() + + for i, kernel_size in enumerate(kernel_sizes): + padding = kernel_size // 2 + if i == 0: + self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding)) + else: + self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True)) + + self.use_residual = True + + def forward(self, x): + B, C, H, W = x.shape + y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) + x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) + coords = torch.cat([x_coords, y_coords], dim=1) + + out = self.layers[0](x, coords) + out = torch.tanh(out) + + for i in range(1, len(self.layers)): + out = self.layers[i](out) + if i < len(self.layers) - 1: + out = torch.tanh(out) + + if self.use_residual: + out = x + out * 0.3 + return out + + +def export_weights_to_wgsl(model, output_path, kernel_sizes): + """Export trained weights to WGSL format""" + + with open(output_path, 'w') as f: + f.write("// Auto-generated CNN weights\n") + f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") + + layer_idx = 0 + for i, layer in enumerate(model.layers): + if isinstance(layer, CoordConv2d): + # Export RGBA weights + weights = layer.conv_rgba.weight.data.cpu().numpy() + kernel_size = kernel_sizes[layer_idx] + out_ch, in_ch, kh, kw = weights.shape + num_positions = kh * kw + + f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n") + for pos in range(num_positions): + row = pos // kw + col = pos % kw + f.write(" mat4x4<f32>(\n") + for out_c in range(min(4, out_ch)): + vals = [] + for in_c in range(min(4, in_ch)): + vals.append(f"{weights[out_c, in_c, row, col]:.6f}") + f.write(f" {', '.join(vals)},\n") + f.write(" )") + if pos < num_positions - 1: + f.write(",\n") + else: + f.write("\n") + f.write(");\n\n") + + # Export coordinate weights + coord_w = layer.coord_weights.data.cpu().numpy() + f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n") + for c in range(2): + vals = [f"{coord_w[out_c, c]:.6f}" for out_c in range(min(4, coord_w.shape[0]))] + f.write(f" {', '.join(vals)}") + if c < 1: + f.write(",\n") + else: + f.write("\n") + f.write(");\n\n") + + # Export bias + bias = layer.bias.data.cpu().numpy() + f.write(f"const bias_layer{layer_idx} = vec4<f32>(") + f.write(", ".join([f"{b:.6f}" for b in bias[:4]])) + f.write(");\n\n") + + layer_idx += 1 + elif isinstance(layer, nn.Conv2d): + # Standard conv layer + weights = layer.weight.data.cpu().numpy() + kernel_size = kernel_sizes[layer_idx] + out_ch, in_ch, kh, kw = weights.shape + num_positions = kh * kw + + f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n") + for pos in range(num_positions): + row = pos // kw + col = pos % kw + f.write(" mat4x4<f32>(\n") + for out_c in range(min(4, out_ch)): + vals = [] + for in_c in range(min(4, in_ch)): + vals.append(f"{weights[out_c, in_c, row, col]:.6f}") + f.write(f" {', '.join(vals)},\n") + f.write(" )") + if pos < num_positions - 1: + f.write(",\n") + else: + f.write("\n") + f.write(");\n\n") + + # Export bias + bias = layer.bias.data.cpu().numpy() + f.write(f"const bias_layer{layer_idx} = vec4<f32>(") + f.write(", ".join([f"{b:.6f}" for b in bias[:4]])) + f.write(");\n\n") + + layer_idx += 1 + + +def train(args): + """Main training loop""" + + # Setup device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Prepare dataset + transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + ]) + + dataset = ImagePairDataset(args.input, args.target, transform=transform) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + # Parse kernel sizes + kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] + + # Create model + model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device) + + # Loss and optimizer + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) + + # Training loop + print(f"\nTraining for {args.epochs} epochs...") + for epoch in range(args.epochs): + epoch_loss = 0.0 + for batch_idx, (inputs, targets) in enumerate(dataloader): + inputs, targets = inputs.to(device), targets.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(dataloader) + if (epoch + 1) % 10 == 0: + print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") + + # Export weights + output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' + print(f"\nExporting weights to {output_path}...") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + export_weights_to_wgsl(model, output_path, kernel_sizes) + + print("Training complete!") + + +def main(): + parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') + parser.add_argument('--input', required=True, help='Input image directory') + parser.add_argument('--target', required=True, help='Target image directory') + parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') + parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') + parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') + parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') + parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') + parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)') + + args = parser.parse_args() + + # Validate directories + if not os.path.isdir(args.input): + print(f"Error: Input directory '{args.input}' does not exist") + sys.exit(1) + + if not os.path.isdir(args.target): + print(f"Error: Target directory '{args.target}' does not exist") + sys.exit(1) + + train(args) + + +if __name__ == "__main__": + main() |
