summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 12:48:43 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 12:48:43 +0100
commit6944733a6a2f05c18e7e0b73f847a4c9144801fd (patch)
tree10713cd41a0e038a016a2e6b357471690f232834 /training/train_cnn.py
parentcc9cbeb75353181193e3afb880dc890aa8bf8985 (diff)
feat: Add multi-layer CNN support with framebuffer capture and blend control
Implements automatic layer chaining and generic framebuffer capture API for multi-layer neural network effects with proper original input preservation. Key changes: - Effect::needs_framebuffer_capture() - generic API for pre-render capture - MainSequence: auto-capture to "captured_frame" auxiliary texture - CNNEffect: multi-layer support via layer_index/total_layers params - seq_compiler: expands "layers=N" to N chained effect instances - Shader: @binding(4) original_input available to all layers - Training: generates layer switches and original input binding - Blend: mix(original, result, blend_amount) uses layer 0 input Timeline syntax: CNNEffect layers=3 blend=0.7 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py159
1 files changed, 142 insertions, 17 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 82f0b48..1cd6579 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -112,8 +112,6 @@ class SimpleCNN(nn.Module):
else:
self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True))
- self.use_residual = True
-
def forward(self, x):
B, C, H, W = x.shape
y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
@@ -128,11 +126,77 @@ class SimpleCNN(nn.Module):
if i < len(self.layers) - 1:
out = torch.tanh(out)
- if self.use_residual:
- out = x + out * 0.3
return out
+def generate_layer_shader(output_path, num_layers, kernel_sizes):
+ """Generate cnn_layer.wgsl with proper layer switches"""
+
+ with open(output_path, 'w') as f:
+ f.write("// CNN layer shader - uses modular convolution snippets\n")
+ f.write("// Supports multi-pass rendering with residual connections\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+ f.write("@group(0) @binding(0) var smplr: sampler;\n")
+ f.write("@group(0) @binding(1) var txt: texture_2d<f32>;\n\n")
+ f.write("#include \"common_uniforms\"\n")
+ f.write("#include \"cnn_activation\"\n")
+
+ # Include necessary conv functions
+ conv_sizes = set(kernel_sizes)
+ for ks in sorted(conv_sizes):
+ f.write(f"#include \"cnn_conv{ks}x{ks}\"\n")
+ f.write("#include \"cnn_weights_generated\"\n\n")
+
+ f.write("struct CNNLayerParams {\n")
+ f.write(" layer_index: i32,\n")
+ f.write(" blend_amount: f32,\n")
+ f.write(" _pad: vec2<f32>,\n")
+ f.write("};\n\n")
+ f.write("@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;\n")
+ f.write("@group(0) @binding(3) var<uniform> params: CNNLayerParams;\n")
+ f.write("@group(0) @binding(4) var original_input: texture_2d<f32>;\n\n")
+ f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {\n")
+ f.write(" var pos = array<vec2<f32>, 3>(\n")
+ f.write(" vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)\n")
+ f.write(" );\n")
+ f.write(" return vec4<f32>(pos[i], 0.0, 1.0);\n")
+ f.write("}\n\n")
+ f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n")
+ f.write(" let uv = p.xy / uniforms.resolution;\n")
+ f.write(" let input = textureSample(txt, smplr, uv);\n")
+ f.write(" let original = textureSample(original_input, smplr, uv);\n")
+ f.write(" var result = vec4<f32>(0.0);\n\n")
+
+ # Generate layer switches
+ for layer_idx in range(num_layers):
+ ks = kernel_sizes[layer_idx]
+ if layer_idx == 0:
+ f.write(f" // Layer 0 uses coordinate-aware convolution\n")
+ f.write(f" if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = cnn_conv{ks}x{ks}_with_coord(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" rgba_weights_layer{layer_idx}, coord_weights_layer{layer_idx}, bias_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" }}\n")
+ else:
+ is_last = layer_idx == num_layers - 1
+ f.write(f" {'else ' if layer_idx > 0 else ''}if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = cnn_conv{ks}x{ks}(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" weights_layer{layer_idx}, bias_layer{layer_idx});\n")
+ if not is_last:
+ f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" }}\n")
+
+ # Add else clause for invalid layer index
+ if num_layers > 1:
+ f.write(f" else {{\n")
+ f.write(f" result = input;\n")
+ f.write(f" }}\n")
+
+ f.write("\n // Blend with ORIGINAL input from layer 0\n")
+ f.write(" return mix(original, result, params.blend_amount);\n")
+ f.write("}\n")
+
+
def export_weights_to_wgsl(model, output_path, kernel_sizes):
"""Export trained weights to WGSL format"""
@@ -154,10 +218,13 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
row = pos // kw
col = pos % kw
f.write(" mat4x4<f32>(\n")
- for out_c in range(min(4, out_ch)):
+ for out_c in range(4):
vals = []
- for in_c in range(min(4, in_ch)):
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ for in_c in range(4):
+ if out_c < out_ch and in_c < in_ch:
+ vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ else:
+ vals.append("0.0")
f.write(f" {', '.join(vals)},\n")
f.write(" )")
if pos < num_positions - 1:
@@ -170,7 +237,12 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
coord_w = layer.coord_weights.data.cpu().numpy()
f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n")
for c in range(2):
- vals = [f"{coord_w[out_c, c]:.6f}" for out_c in range(min(4, coord_w.shape[0]))]
+ vals = []
+ for out_c in range(4):
+ if out_c < coord_w.shape[0]:
+ vals.append(f"{coord_w[out_c, c]:.6f}")
+ else:
+ vals.append("0.0")
f.write(f" {', '.join(vals)}")
if c < 1:
f.write(",\n")
@@ -180,8 +252,9 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
# Export bias
bias = layer.bias.data.cpu().numpy()
+ bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)]
f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
+ f.write(", ".join(bias_vals))
f.write(");\n\n")
layer_idx += 1
@@ -197,10 +270,13 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
row = pos // kw
col = pos % kw
f.write(" mat4x4<f32>(\n")
- for out_c in range(min(4, out_ch)):
+ for out_c in range(4):
vals = []
- for in_c in range(min(4, in_ch)):
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ for in_c in range(4):
+ if out_c < out_ch and in_c < in_ch:
+ vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ else:
+ vals.append("0.0")
f.write(f" {', '.join(vals)},\n")
f.write(" )")
if pos < num_positions - 1:
@@ -211,8 +287,9 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
# Export bias
bias = layer.bias.data.cpu().numpy()
+ bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)]
f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
+ f.write(", ".join(bias_vals))
f.write(");\n\n")
layer_idx += 1
@@ -293,19 +370,57 @@ def train(args):
}, checkpoint_path)
print(f"Saved checkpoint to {checkpoint_path}")
- # Export weights
+ # Export weights and shader
output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
print(f"\nExporting weights to {output_path}...")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
export_weights_to_wgsl(model, output_path, kernel_sizes)
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, args.layers, kernel_sizes)
+
print("Training complete!")
+def export_from_checkpoint(checkpoint_path, output_path=None):
+ """Export WGSL files from checkpoint without training"""
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint file '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ kernel_sizes = checkpoint['kernel_sizes']
+ num_layers = checkpoint['num_layers']
+
+ # Recreate model
+ model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes)
+ model.load_state_dict(checkpoint['model_state'])
+
+ # Export weights
+ output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"Exporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, num_layers, kernel_sizes)
+
+ print("Export complete!")
+
+
def main():
parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
- parser.add_argument('--input', required=True, help='Input image directory')
- parser.add_argument('--target', required=True, help='Target image directory')
+ parser.add_argument('--input', help='Input image directory')
+ parser.add_argument('--target', help='Target image directory')
parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
@@ -315,10 +430,20 @@ def main():
parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
parser.add_argument('--resume', help='Resume from checkpoint file')
+ parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
args = parser.parse_args()
- # Validate directories
+ # Export-only mode
+ if args.export_only:
+ export_from_checkpoint(args.export_only, args.output)
+ return
+
+ # Validate directories for training
+ if not args.input or not args.target:
+ print("Error: --input and --target required for training (or use --export-only)")
+ sys.exit(1)
+
if not os.path.isdir(args.input):
print(f"Error: Input directory '{args.input}' does not exist")
sys.exit(1)