summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 10:27:44 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 10:27:44 +0100
commit96a349b9874c6cdaac525ba062a0f4f90c9bc3ed (patch)
treea4eb24fdb417393cbe5a0dc84bf5063cffc94daf /training/train_cnn.py
parent75af266889b61b5722d842a1a1eb23f79bc06a85 (diff)
feat: Add coordinate-aware CNN layer 0 for position-dependent stylization
- Implement CoordConv2d custom layer accepting (x,y) patch center - Split layer 0 weights: rgba_weights (9x mat4x4) + coord_weights (mat2x4) - Add *_with_coord() functions to 3x3/5x5/7x7 convolution shaders - Update training script to generate coordinate grid and export split weights - Regenerate placeholder weights with new format Size impact: +32B coord weights + ~100B shader code = +132B total All 36 tests passing (100%) handoff(Claude): CNN coordinate awareness implemented, ready for training Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py301
1 files changed, 301 insertions, 0 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
new file mode 100755
index 0000000..4fc3a6c
--- /dev/null
+++ b/training/train_cnn.py
@@ -0,0 +1,301 @@
+#!/usr/bin/env python3
+"""
+CNN Training Script for Image-to-Image Transformation
+
+Trains a convolutional neural network on multiple input/target image pairs.
+
+Usage:
+ python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
+
+Example:
+ python3 train_cnn.py --input ./training/input --target ./training/output --layers 3 --epochs 100
+"""
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+from PIL import Image
+import os
+import sys
+import argparse
+import glob
+
+
+class ImagePairDataset(Dataset):
+ """Dataset for loading matching input/target image pairs"""
+
+ def __init__(self, input_dir, target_dir, transform=None):
+ self.input_dir = input_dir
+ self.target_dir = target_dir
+ self.transform = transform
+
+ # Find all images in input directory
+ input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ self.image_pairs = []
+
+ for pattern in input_patterns:
+ input_files = glob.glob(os.path.join(input_dir, pattern))
+ for input_path in input_files:
+ filename = os.path.basename(input_path)
+ # Try to find matching target with same name but any supported extension
+ target_path = None
+ for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
+ base_name = os.path.splitext(filename)[0]
+ candidate = os.path.join(target_dir, f"{base_name}.{ext}")
+ if os.path.exists(candidate):
+ target_path = candidate
+ break
+
+ if target_path:
+ self.image_pairs.append((input_path, target_path))
+
+ if not self.image_pairs:
+ raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
+
+ print(f"Found {len(self.image_pairs)} matching image pairs")
+
+ def __len__(self):
+ return len(self.image_pairs)
+
+ def __getitem__(self, idx):
+ input_path, target_path = self.image_pairs[idx]
+
+ input_img = Image.open(input_path).convert('RGB')
+ target_img = Image.open(target_path).convert('RGB')
+
+ if self.transform:
+ input_img = self.transform(input_img)
+ target_img = self.transform(target_img)
+
+ return input_img, target_img
+
+
+class CoordConv2d(nn.Module):
+ """Conv2d that accepts coordinate input separate from spatial patches"""
+
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
+ super().__init__()
+ self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
+ self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01)
+ self.bias = nn.Parameter(torch.zeros(out_channels))
+
+ def forward(self, x, coords):
+ # x: [B, C, H, W] image
+ # coords: [B, 2, H, W] coordinate grid
+ out = self.conv_rgba(x)
+ B, C, H, W = out.shape
+ coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights)
+ out = out + coord_contrib + self.bias.view(1, -1, 1, 1)
+ return out
+
+
+class SimpleCNN(nn.Module):
+ """Simple CNN for image-to-image transformation"""
+
+ def __init__(self, num_layers=1, kernel_sizes=None):
+ super(SimpleCNN, self).__init__()
+
+ if kernel_sizes is None:
+ kernel_sizes = [3] * num_layers
+
+ assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
+
+ self.kernel_sizes = kernel_sizes
+ self.layers = nn.ModuleList()
+
+ for i, kernel_size in enumerate(kernel_sizes):
+ padding = kernel_size // 2
+ if i == 0:
+ self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding))
+ else:
+ self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True))
+
+ self.use_residual = True
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
+ x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
+ coords = torch.cat([x_coords, y_coords], dim=1)
+
+ out = self.layers[0](x, coords)
+ out = torch.tanh(out)
+
+ for i in range(1, len(self.layers)):
+ out = self.layers[i](out)
+ if i < len(self.layers) - 1:
+ out = torch.tanh(out)
+
+ if self.use_residual:
+ out = x + out * 0.3
+ return out
+
+
+def export_weights_to_wgsl(model, output_path, kernel_sizes):
+ """Export trained weights to WGSL format"""
+
+ with open(output_path, 'w') as f:
+ f.write("// Auto-generated CNN weights\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+
+ layer_idx = 0
+ for i, layer in enumerate(model.layers):
+ if isinstance(layer, CoordConv2d):
+ # Export RGBA weights
+ weights = layer.conv_rgba.weight.data.cpu().numpy()
+ kernel_size = kernel_sizes[layer_idx]
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
+
+ f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
+ for pos in range(num_positions):
+ row = pos // kw
+ col = pos % kw
+ f.write(" mat4x4<f32>(\n")
+ for out_c in range(min(4, out_ch)):
+ vals = []
+ for in_c in range(min(4, in_ch)):
+ vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ f.write(f" {', '.join(vals)},\n")
+ f.write(" )")
+ if pos < num_positions - 1:
+ f.write(",\n")
+ else:
+ f.write("\n")
+ f.write(");\n\n")
+
+ # Export coordinate weights
+ coord_w = layer.coord_weights.data.cpu().numpy()
+ f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n")
+ for c in range(2):
+ vals = [f"{coord_w[out_c, c]:.6f}" for out_c in range(min(4, coord_w.shape[0]))]
+ f.write(f" {', '.join(vals)}")
+ if c < 1:
+ f.write(",\n")
+ else:
+ f.write("\n")
+ f.write(");\n\n")
+
+ # Export bias
+ bias = layer.bias.data.cpu().numpy()
+ f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
+ f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
+ f.write(");\n\n")
+
+ layer_idx += 1
+ elif isinstance(layer, nn.Conv2d):
+ # Standard conv layer
+ weights = layer.weight.data.cpu().numpy()
+ kernel_size = kernel_sizes[layer_idx]
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
+
+ f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
+ for pos in range(num_positions):
+ row = pos // kw
+ col = pos % kw
+ f.write(" mat4x4<f32>(\n")
+ for out_c in range(min(4, out_ch)):
+ vals = []
+ for in_c in range(min(4, in_ch)):
+ vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ f.write(f" {', '.join(vals)},\n")
+ f.write(" )")
+ if pos < num_positions - 1:
+ f.write(",\n")
+ else:
+ f.write("\n")
+ f.write(");\n\n")
+
+ # Export bias
+ bias = layer.bias.data.cpu().numpy()
+ f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
+ f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
+ f.write(");\n\n")
+
+ layer_idx += 1
+
+
+def train(args):
+ """Main training loop"""
+
+ # Setup device
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # Prepare dataset
+ transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ ])
+
+ dataset = ImagePairDataset(args.input, args.target, transform=transform)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+
+ # Parse kernel sizes
+ kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
+
+ # Create model
+ model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device)
+
+ # Loss and optimizer
+ criterion = nn.MSELoss()
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ # Training loop
+ print(f"\nTraining for {args.epochs} epochs...")
+ for epoch in range(args.epochs):
+ epoch_loss = 0.0
+ for batch_idx, (inputs, targets) in enumerate(dataloader):
+ inputs, targets = inputs.to(device), targets.to(device)
+
+ optimizer.zero_grad()
+ outputs = model(inputs)
+ loss = criterion(outputs, targets)
+ loss.backward()
+ optimizer.step()
+
+ epoch_loss += loss.item()
+
+ avg_loss = epoch_loss / len(dataloader)
+ if (epoch + 1) % 10 == 0:
+ print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
+
+ # Export weights
+ output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"\nExporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ print("Training complete!")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
+ parser.add_argument('--input', required=True, help='Input image directory')
+ parser.add_argument('--target', required=True, help='Target image directory')
+ parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
+ parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
+ parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
+ parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
+ parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)')
+
+ args = parser.parse_args()
+
+ # Validate directories
+ if not os.path.isdir(args.input):
+ print(f"Error: Input directory '{args.input}' does not exist")
+ sys.exit(1)
+
+ if not os.path.isdir(args.target):
+ print(f"Error: Target directory '{args.target}' does not exist")
+ sys.exit(1)
+
+ train(args)
+
+
+if __name__ == "__main__":
+ main()