summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/README.md233
-rw-r--r--training/ground_truth.pngbin0 -> 127405 bytes
-rwxr-xr-xtraining/train_cnn.py444
-rw-r--r--training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl158
4 files changed, 461 insertions, 374 deletions
diff --git a/training/README.md b/training/README.md
index 08379ee..0a46718 100644
--- a/training/README.md
+++ b/training/README.md
@@ -1,167 +1,182 @@
-# Image Style Processor
+# CNN Training Tools
-A comprehensive Python script that applies artistic hand-drawn and futuristic effects to images.
+Tools for training and preparing data for the CNN post-processing effect.
-## Requirements
+---
-- Python 3
-- OpenCV (cv2)
-- NumPy
+## train_cnn.py
+
+PyTorch-based training script for image-to-image stylization.
+
+### Basic Usage
-Install dependencies:
```bash
-pip install opencv-python numpy
+python3 train_cnn.py --input <input_dir> --target <target_dir> [options]
```
-## Usage
+### Examples
+**Single layer, 3×3 kernel:**
```bash
-python3 image_style_processor.py <input_directory> <output_directory> <style>
+python3 train_cnn.py --input training/input --target training/output \
+ --layers 1 --kernel-sizes 3 --epochs 500
```
-### Arguments
+**Multi-layer, mixed kernels:**
+```bash
+python3 train_cnn.py --input training/input --target training/output \
+ --layers 3 --kernel-sizes 3,5,3 --epochs 1000
+```
-- `input_directory`: Directory containing your input images (PNG, JPG, JPEG)
-- `output_directory`: Directory where processed images will be saved (created if doesn't exist)
-- `style`: The artistic style to apply (see below)
+**With checkpointing:**
+```bash
+python3 train_cnn.py --input training/input --target training/output \
+ --epochs 500 --checkpoint-every 50
+```
-## Available Styles
+**Resume from checkpoint:**
+```bash
+python3 train_cnn.py --input training/input --target training/output \
+ --resume training/checkpoints/checkpoint_epoch_200.pth
+```
+
+### Options
-### Sketch Styles
+| Option | Default | Description |
+|--------|---------|-------------|
+| `--input` | *required* | Input image directory |
+| `--target` | *required* | Target image directory |
+| `--layers` | 1 | Number of CNN layers |
+| `--kernel-sizes` | 3 | Comma-separated kernel sizes (auto-repeats if single value) |
+| `--epochs` | 100 | Training epochs |
+| `--batch-size` | 4 | Batch size |
+| `--learning-rate` | 0.001 | Learning rate |
+| `--output` | `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl` | Output WGSL file |
+| `--checkpoint-every` | 0 | Save checkpoint every N epochs (0=disabled) |
+| `--checkpoint-dir` | `training/checkpoints` | Checkpoint directory |
+| `--resume` | None | Resume from checkpoint file |
-1. **pencil_sketch** - Dense cross-hatching with progressive layers in shadows
- - Best for: Detailed technical drawings, architectural scenes
- - Features: Clean line art, 5 layers of cross-hatching, strong shadow definition
+### Architecture
-2. **ink_drawing** - Bold black outlines with comic book aesthetic
- - Best for: Graphic novel style, high contrast scenes
- - Features: Bold outlines, posterized tones, minimal shading
+- **Layer 0:** `CoordConv2d` - accepts (x,y) patch center + 3×3 RGBA samples
+- **Layers 1+:** Standard `Conv2d` - 3×3 RGBA samples only
+- **Activation:** Tanh between layers
+- **Output:** Residual connection (30% stylization blend)
-3. **charcoal_pastel** - Dramatic contrasts with soft, smudged textures
- - Best for: Portraits, dramatic landscapes
- - Features: Soft blending, grainy texture, highlighted areas
+### Requirements
-4. **conte_crayon** - Directional strokes following image contours
- - Best for: Figure studies, natural forms
- - Features: Stroke direction follows gradients, cross-hatching in dark areas
+```bash
+pip install torch torchvision pillow
+```
-5. **gesture_sketch** - Loose, quick observational sketch style
- - Best for: Quick studies, energetic compositions
- - Features: Randomized line wobble, sparse suggestion lines
+---
-### Futuristic Styles
+## image_style_processor.py
-6. **circuit_board** - Tech blueprint with circuit paths and geometric patterns
- - Best for: Sci-fi imagery, technological themes
- - Features: Multi-layer circuit paths, connection nodes, technical grid overlay
+Generates stylized target images from raw renders.
-7. **glitch_art** - Digital corruption with scan line shifts and pixel sorting
- - Best for: Cyberpunk aesthetics, digital art
- - Features: Horizontal scan artifacts, block displacement, pixel sorting, noise strips
+### Usage
-8. **wireframe_topo** - Topographic contour lines with holographic grid
- - Best for: Landscape, abstract patterns, sci-fi hologram effect
- - Features: 20 contour levels, scan lines, measurement markers, grid overlay
+```bash
+python3 image_style_processor.py <input_dir> <output_dir> <style>
+```
-9. **data_mosaic** - Voronoi geometric fragmentation with angular cells
- - Best for: Abstract art, geometric compositions
- - Features: 200 Voronoi cells, posterized tones, embedded geometric patterns
+### Available Styles
-10. **holographic_scan** - CRT/hologram display with scanlines and HUD elements
- - Best for: Retro-futuristic, heads-up display aesthetic
- - Features: Scanlines, interference patterns, glitch effects, corner brackets, crosshair
+**Sketch:**
+- `pencil_sketch` - Dense cross-hatching
+- `ink_drawing` - Bold outlines, comic style
+- `charcoal_pastel` - Soft, dramatic contrasts
+- `conte_crayon` - Directional strokes
+- `gesture_sketch` - Loose, energetic lines
-## Examples
+**Futuristic:**
+- `circuit_board` - Tech blueprint
+- `glitch_art` - Digital corruption
+- `wireframe_topo` - Topographic contours
+- `data_mosaic` - Voronoi fragmentation
+- `holographic_scan` - CRT/HUD aesthetic
-### Sketch Effects
+### Examples
-Process images with pencil sketch:
```bash
-python3 image_style_processor.py ./photos ./output pencil_sketch
-```
+# Generate pencil sketch targets
+python3 image_style_processor.py input/ output/ pencil_sketch
-Apply ink drawing style:
-```bash
-python3 image_style_processor.py ./input ./sketches ink_drawing
+# Generate glitch art targets
+python3 image_style_processor.py input/ output/ glitch_art
```
-Create charcoal effect:
+### Requirements
+
```bash
-python3 image_style_processor.py ./images ./results charcoal_pastel
+pip install opencv-python numpy
```
-### Futuristic Effects
+---
+
+## Workflow
+
+### 1. Render Raw Frames
-Apply circuit board style:
+Generate raw 3D renders as input:
```bash
-python3 image_style_processor.py ./photos ./output circuit_board
+./build/demo64k --headless --duration 5 --output training/input/
```
-Create glitch art:
+### 2. Generate Stylized Targets
+
+Apply artistic style:
```bash
-python3 image_style_processor.py ./input ./glitched glitch_art
+python3 training/image_style_processor.py training/input/ training/output/ pencil_sketch
```
-Apply holographic effect:
+### 3. Train CNN
+
+Train network to reproduce the style:
```bash
-python3 image_style_processor.py ./images ./holo holographic_scan
+python3 training/train_cnn.py \
+ --input training/input \
+ --target training/output \
+ --epochs 500 \
+ --checkpoint-every 50
```
-## Output
+### 4. Rebuild Demo
-- Processed images are saved to the output directory with **the same filename** as the input
-- Supported input formats: PNG, JPG, JPEG (case-insensitive)
-- Output format: PNG (preserves quality)
-- Original images are never modified
+Weights auto-exported to `cnn_weights_generated.wgsl`:
+```bash
+cmake --build build -j4
+./build/demo64k
+```
-## Style Comparison
+---
-### Sketch Styles
-- **pencil_sketch**: Most detailed, traditional drawing look
-- **ink_drawing**: Boldest, most graphic/comic-like
-- **charcoal_pastel**: Softest, most artistic/painterly
-- **conte_crayon**: Most directional, follows contours
-- **gesture_sketch**: Loosest, most expressive
+## Tips
-### Futuristic Styles
-- **circuit_board**: Cleanest, most technical/blueprint-like
-- **glitch_art**: Most chaotic, digital corruption aesthetic
-- **wireframe_topo**: Most structured, topographic/hologram feel
-- **data_mosaic**: Most geometric, fragmented cells
-- **holographic_scan**: Most retro-futuristic, HUD/CRT display
+- **Training data:** 10-50 image pairs recommended
+- **Resolution:** 256×256 (auto-resized during training)
+- **Checkpoints:** Save every 50-100 epochs for long runs
+- **Loss plateaus:** Try lower learning rate (0.0001) or more layers
+- **Residual connection:** Prevents catastrophic divergence (input always blended in)
-## Tips
+---
-- Images are automatically converted to grayscale before processing
-- All styles work best with high-resolution images (300+ DPI recommended)
-- Processing time varies by style:
- - Fast: ink_drawing, glitch_art, holographic_scan
- - Medium: charcoal_pastel, gesture_sketch, circuit_board, wireframe_topo
- - Slow: pencil_sketch, conte_crayon, data_mosaic (due to intensive computation)
-- For batch processing large collections, consider processing in smaller batches
-- Randomized styles (glitch_art, gesture_sketch, data_mosaic) will produce slightly different results each run
+## Coordinate-Aware Layer 0
-## Technical Notes
+Layer 0 receives normalized (x,y) patch center coordinates, enabling position-dependent effects:
-### Randomization
-Some styles use randomization for natural variation:
-- **glitch_art**: Random scan line shifts, block positions
-- **gesture_sketch**: Random line wobble, stroke placement
-- **data_mosaic**: Random Voronoi cell centers
-- **circuit_board**: Random pattern placement in dark regions
-- **holographic_scan**: Random glitch line positions
+- **Vignetting:** Darker edges
+- **Radial gradients:** Center-focused stylization
+- **Corner effects:** Edge-specific treatments
-### Processing Details
-- **pencil_sketch**: Uses 5-level progressive cross-hatching algorithm
-- **conte_crayon**: Follows Sobel gradients for directional strokes
-- **wireframe_topo**: Generates 20 brightness-based contour levels
-- **data_mosaic**: Creates 200 Voronoi cells via nearest-neighbor algorithm
-- **holographic_scan**: Applies scanline patterns and interference waves
+Training coordinate grid is auto-generated during forward pass. No manual intervention needed.
-## License
+Size impact: +32B coord weights (kernel-agnostic).
-Free to use and modify for any purpose.
+---
-## Version
+## References
-Version 1.0 - Complete collection of 10 artistic styles (5 sketch + 5 futuristic)
+- **CNN Effect Documentation:** `doc/CNN_EFFECT.md`
+- **Training Architecture:** See `train_cnn.py` (CoordConv2d class)
diff --git a/training/ground_truth.png b/training/ground_truth.png
new file mode 100644
index 0000000..6e1f2aa
--- /dev/null
+++ b/training/ground_truth.png
Binary files differ
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 82f0b48..16f8e7a 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -5,10 +5,15 @@ CNN Training Script for Image-to-Image Transformation
Trains a convolutional neural network on multiple input/target image pairs.
Usage:
+ # Training
python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
+ # Inference (generate ground truth)
+ python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png
+
Example:
python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
+ python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth
"""
import torch
@@ -62,7 +67,8 @@ class ImagePairDataset(Dataset):
def __getitem__(self, idx):
input_path, target_path = self.image_pairs[idx]
- input_img = Image.open(input_path).convert('RGB')
+ # Load RGBD input (4 channels: RGB + Depth)
+ input_img = Image.open(input_path).convert('RGBA')
target_img = Image.open(target_path).convert('RGB')
if self.transform:
@@ -72,27 +78,8 @@ class ImagePairDataset(Dataset):
return input_img, target_img
-class CoordConv2d(nn.Module):
- """Conv2d that accepts coordinate input separate from spatial patches"""
-
- def __init__(self, in_channels, out_channels, kernel_size, padding=0):
- super().__init__()
- self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
- self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01)
- self.bias = nn.Parameter(torch.zeros(out_channels))
-
- def forward(self, x, coords):
- # x: [B, C, H, W] image
- # coords: [B, 2, H, W] coordinate grid
- out = self.conv_rgba(x)
- B, C, H, W = out.shape
- coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights)
- out = out + coord_contrib + self.bias.view(1, -1, 1, 1)
- return out
-
-
class SimpleCNN(nn.Module):
- """Simple CNN for image-to-image transformation"""
+ """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)"""
def __init__(self, num_layers=1, kernel_sizes=None):
super(SimpleCNN, self).__init__()
@@ -107,30 +94,126 @@ class SimpleCNN(nn.Module):
for i, kernel_size in enumerate(kernel_sizes):
padding = kernel_size // 2
- if i == 0:
- self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding))
+ if i < num_layers - 1:
+ # Inner layers: 7→4 (RGBD output)
+ self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True))
else:
- self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True))
-
- self.use_residual = True
+ # Final layer: 7→1 (grayscale output)
+ self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
def forward(self, x):
+ # x: [B,4,H,W] - RGBD input (D = 1/z)
B, C, H, W = x.shape
+
+ # Normalize RGBD to [-1,1]
+ x_norm = (x - 0.5) * 2.0
+
+ # Compute coordinates [0,1] then normalize to [-1,1]
y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
- coords = torch.cat([x_coords, y_coords], dim=1)
+ y_coords = (y_coords - 0.5) * 2.0 # [-1,1]
+ x_coords = (x_coords - 0.5) * 2.0 # [-1,1]
+
+ # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
+ gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]
+ gray = (gray - 0.5) * 2.0 # [-1,1]
+
+ # Layer 0
+ layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
+ out = self.layers[0](layer0_input) # [B,4,H,W]
+ out = torch.tanh(out) # [-1,1]
+
+ # Inner layers
+ for i in range(1, len(self.layers)-1):
+ layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[i](layer_input)
+ out = torch.tanh(out)
+
+ # Final layer (grayscale output)
+ final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[-1](final_input) # [B,1,H,W]
+ out = torch.clamp(out, 0.0, 1.0) # Clip to [0,1]
+ return out.expand(-1, 3, -1, -1)
+
+
+def generate_layer_shader(output_path, num_layers, kernel_sizes):
+ """Generate cnn_layer.wgsl with proper layer switches"""
+
+ with open(output_path, 'w') as f:
+ f.write("// CNN layer shader - uses modular convolution snippets\n")
+ f.write("// Supports multi-pass rendering with residual connections\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+ f.write("@group(0) @binding(0) var smplr: sampler;\n")
+ f.write("@group(0) @binding(1) var txt: texture_2d<f32>;\n\n")
+ f.write("#include \"common_uniforms\"\n")
+ f.write("#include \"cnn_activation\"\n")
+
+ # Include necessary conv functions
+ conv_sizes = set(kernel_sizes)
+ for ks in sorted(conv_sizes):
+ f.write(f"#include \"cnn_conv{ks}x{ks}\"\n")
+ f.write("#include \"cnn_weights_generated\"\n\n")
- out = self.layers[0](x, coords)
- out = torch.tanh(out)
+ f.write("struct CNNLayerParams {\n")
+ f.write(" layer_index: i32,\n")
+ f.write(" blend_amount: f32,\n")
+ f.write(" _pad: vec2<f32>,\n")
+ f.write("};\n\n")
+ f.write("@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;\n")
+ f.write("@group(0) @binding(3) var<uniform> params: CNNLayerParams;\n")
+ f.write("@group(0) @binding(4) var original_input: texture_2d<f32>;\n\n")
+ f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {\n")
+ f.write(" var pos = array<vec2<f32>, 3>(\n")
+ f.write(" vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)\n")
+ f.write(" );\n")
+ f.write(" return vec4<f32>(pos[i], 0.0, 1.0);\n")
+ f.write("}\n\n")
+ f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n")
+ f.write(" let uv = p.xy / uniforms.resolution;\n")
+ f.write(" let original_raw = textureSample(original_input, smplr, uv);\n")
+ f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
+ f.write(" var result = vec4<f32>(0.0);\n\n")
- for i in range(1, len(self.layers)):
- out = self.layers[i](out)
- if i < len(self.layers) - 1:
- out = torch.tanh(out)
+ # Generate layer switches
+ for layer_idx in range(num_layers):
+ is_final = layer_idx == num_layers - 1
+ ks = kernel_sizes[layer_idx]
+ conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1"
- if self.use_residual:
- out = x + out * 0.3
- return out
+ if layer_idx == 0:
+ conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src"
+ f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n")
+ f.write(f" if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" }}\n")
+ elif not is_final:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n")
+ f.write(f" }}\n")
+ else:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" // gray_out already in [0,1] from clipped training\n")
+ f.write(f" let original_denorm = (original + 1.0) * 0.5;\n")
+ f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
+ f.write(f" let blended = mix(original_denorm, result, params.blend_amount);\n")
+ f.write(f" return blended; // [0,1]\n")
+ f.write(f" }}\n")
+
+ # Add else clause for invalid layer index
+ if num_layers > 0:
+ f.write(f" else {{\n")
+ f.write(f" return textureSample(txt, smplr, uv);\n")
+ f.write(f" }}\n")
+
+ f.write("\n // Non-final layers: denormalize for display\n")
+ f.write(" return (result + 1.0) * 0.5; // [-1,1] → [0,1]\n")
+ f.write("}\n")
def export_weights_to_wgsl(model, output_path, kernel_sizes):
@@ -140,82 +223,95 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
f.write("// Auto-generated CNN weights\n")
f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
- layer_idx = 0
for i, layer in enumerate(model.layers):
- if isinstance(layer, CoordConv2d):
- # Export RGBA weights
- weights = layer.conv_rgba.weight.data.cpu().numpy()
- kernel_size = kernel_sizes[layer_idx]
- out_ch, in_ch, kh, kw = weights.shape
- num_positions = kh * kw
+ weights = layer.weight.data.cpu().numpy()
+ bias = layer.bias.data.cpu().numpy()
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
- f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
+ is_final = (i == len(model.layers) - 1)
+
+ if is_final:
+ # Final layer: 7→1, structure: array<array<f32, 8>, 9>
+ # [w0, w1, w2, w3, w4, w5, w6, bias]
+ f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_positions}> = array(\n")
for pos in range(num_positions):
- row = pos // kw
- col = pos % kw
- f.write(" mat4x4<f32>(\n")
- for out_c in range(min(4, out_ch)):
- vals = []
- for in_c in range(min(4, in_ch)):
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
- f.write(f" {', '.join(vals)},\n")
- f.write(" )")
- if pos < num_positions - 1:
- f.write(",\n")
- else:
- f.write("\n")
+ row, col = pos // kw, pos % kw
+ vals = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(7)]
+ vals.append(f"{bias[0]:.6f}") # Append bias as 8th element
+ f.write(f" array<f32, 8>({', '.join(vals)})")
+ f.write(",\n" if pos < num_positions-1 else "\n")
f.write(");\n\n")
-
- # Export coordinate weights
- coord_w = layer.coord_weights.data.cpu().numpy()
- f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n")
- for c in range(2):
- vals = [f"{coord_w[out_c, c]:.6f}" for out_c in range(min(4, coord_w.shape[0]))]
- f.write(f" {', '.join(vals)}")
- if c < 1:
- f.write(",\n")
- else:
- f.write("\n")
+ else:
+ # Inner layers: 7→4, structure: array<array<f32, 8>, 36>
+ # Flattened: [pos0_ch0[7w+bias], pos0_ch1[7w+bias], ..., pos8_ch3[7w+bias]]
+ num_entries = num_positions * 4
+ f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_entries}> = array(\n")
+ for pos in range(num_positions):
+ row, col = pos // kw, pos % kw
+ for out_c in range(4):
+ vals = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(7)]
+ vals.append(f"{bias[out_c]:.6f}") # Append bias
+ idx = pos * 4 + out_c
+ f.write(f" array<f32, 8>({', '.join(vals)})")
+ f.write(",\n" if idx < num_entries-1 else "\n")
f.write(");\n\n")
- # Export bias
- bias = layer.bias.data.cpu().numpy()
- f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
- f.write(");\n\n")
- layer_idx += 1
- elif isinstance(layer, nn.Conv2d):
- # Standard conv layer
- weights = layer.weight.data.cpu().numpy()
- kernel_size = kernel_sizes[layer_idx]
- out_ch, in_ch, kh, kw = weights.shape
- num_positions = kh * kw
+def generate_conv_src_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0"""
- f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
- for pos in range(num_positions):
- row = pos // kw
- col = pos % kw
- f.write(" mat4x4<f32>(\n")
- for out_c in range(min(4, out_ch)):
- vals = []
- for in_c in range(min(4, in_ch)):
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
- f.write(f" {', '.join(vals)},\n")
- f.write(" )")
- if pos < num_positions - 1:
- f.write(",\n")
- else:
- f.write("\n")
- f.write(");\n\n")
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
- # Export bias
- bias = layer.bias.data.cpu().numpy()
- f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
- f.write(");\n\n")
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Source layer: 7→4 channels (RGBD output)\n")
+ f.write(f"// Normalizes [0,1] input to [-1,1] internally\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" weights: array<array<f32, 8>, {num_positions * 4}>\n")
+ f.write(f") -> vec4<f32> {{\n")
+ f.write(f" let step = 1.0 / resolution;\n\n")
+
+ # Normalize center pixel for gray channel
+ f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n")
+ f.write(f" let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+
+ f.write(f" var sum = vec4<f32>(0.0);\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n")
- layer_idx += 1
+ # 7-channel input
+ f.write(f" let inputs = array<f32, 7>(\n")
+ f.write(f" rgbd.r, rgbd.g, rgbd.b, rgbd.a,\n")
+ f.write(f" uv_norm.x, uv_norm.y, gray\n")
+ f.write(f" );\n\n")
+
+ # Accumulate
+ f.write(f" for (var out_c = 0; out_c < 4; out_c++) {{\n")
+ f.write(f" let idx = pos * 4 + out_c;\n")
+ f.write(f" var channel_sum = weights[idx][7];\n")
+ f.write(f" for (var in_c = 0; in_c < 7; in_c++) {{\n")
+ f.write(f" channel_sum += weights[idx][in_c] * inputs[in_c];\n")
+ f.write(f" }}\n")
+ f.write(f" sum[out_c] += channel_sum;\n")
+ f.write(f" }}\n")
+ f.write(f" pos++;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
def train(args):
@@ -293,32 +389,166 @@ def train(args):
}, checkpoint_path)
print(f"Saved checkpoint to {checkpoint_path}")
- # Export weights
+ # Export weights and shader
output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
print(f"\nExporting weights to {output_path}...")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
export_weights_to_wgsl(model, output_path, kernel_sizes)
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, args.layers, kernel_sizes)
+
+ # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ for ks in set(kernel_sizes):
+ if ks == 3:
+ continue
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+ if not os.path.exists(conv_path):
+ print(f"Warning: {conv_path} not found, skipping _src generation")
+ continue
+
+ # Check if _src already exists
+ with open(conv_path, 'r') as f:
+ content = f.read()
+ if f"cnn_conv{ks}x{ks}_7to4_src" in content:
+ continue
+
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+
print("Training complete!")
+def export_from_checkpoint(checkpoint_path, output_path=None):
+ """Export WGSL files from checkpoint without training"""
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint file '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ kernel_sizes = checkpoint['kernel_sizes']
+ num_layers = checkpoint['num_layers']
+
+ # Recreate model
+ model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes)
+ model.load_state_dict(checkpoint['model_state'])
+
+ # Export weights
+ output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"Exporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, num_layers, kernel_sizes)
+
+ # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ for ks in set(kernel_sizes):
+ if ks == 3:
+ continue
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+ if not os.path.exists(conv_path):
+ print(f"Warning: {conv_path} not found, skipping _src generation")
+ continue
+
+ # Check if _src already exists
+ with open(conv_path, 'r') as f:
+ content = f.read()
+ if f"cnn_conv{ks}x{ks}_7to4_src" in content:
+ continue
+
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+
+ print("Export complete!")
+
+
+def infer_from_checkpoint(checkpoint_path, input_path, output_path):
+ """Run inference on single image to generate ground truth"""
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ if not os.path.exists(input_path):
+ print(f"Error: Input image '{input_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ # Reconstruct model
+ model = SimpleCNN(
+ num_layers=checkpoint['num_layers'],
+ kernel_sizes=checkpoint['kernel_sizes']
+ )
+ model.load_state_dict(checkpoint['model_state'])
+ model.eval()
+
+ # Load image [0,1]
+ print(f"Loading input image: {input_path}")
+ img = Image.open(input_path).convert('RGBA')
+ img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
+
+ # Inference
+ print("Running inference...")
+ with torch.no_grad():
+ out = model(img_tensor) # [1,3,H,W] in [0,1]
+
+ # Save
+ print(f"Saving output to: {output_path}")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ transforms.ToPILImage()(out.squeeze(0)).save(output_path)
+ print("Done!")
+
+
def main():
parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
- parser.add_argument('--input', required=True, help='Input image directory')
- parser.add_argument('--target', required=True, help='Target image directory')
+ parser.add_argument('--input', help='Input image directory (training) or single image (inference)')
+ parser.add_argument('--target', help='Target image directory')
parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
- parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)')
+ parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)')
parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
parser.add_argument('--resume', help='Resume from checkpoint file')
+ parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
+ parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)')
args = parser.parse_args()
- # Validate directories
+ # Inference mode
+ if args.infer:
+ checkpoint = args.export_only
+ if not checkpoint:
+ print("Error: --infer requires --export-only <checkpoint>")
+ sys.exit(1)
+ output_path = args.output or 'inference_output.png'
+ infer_from_checkpoint(checkpoint, args.infer, output_path)
+ return
+
+ # Export-only mode
+ if args.export_only:
+ export_from_checkpoint(args.export_only, args.output)
+ return
+
+ # Validate directories for training
+ if not args.input or not args.target:
+ print("Error: --input and --target required for training (or use --export-only)")
+ sys.exit(1)
+
if not os.path.isdir(args.input):
print(f"Error: Input directory '{args.input}' does not exist")
sys.exit(1)
diff --git a/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
deleted file mode 100644
index dae81df..0000000
--- a/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
+++ /dev/null
@@ -1,158 +0,0 @@
-// Auto-generated CNN weights
-// DO NOT EDIT - Generated by train_cnn.py
-
-const rgba_weights_layer0: array<mat4x4<f32>, 9> = array(
- mat4x4<f32>(
- -0.019696, -0.045138, -0.059372,
- 0.113637, -0.026176, -0.204699,
- -0.147723, -0.124720, -0.133641,
- ),
- mat4x4<f32>(
- -0.011820, -0.110039, -0.019111,
- 0.102596, -0.053469, -0.090972,
- -0.106286, 0.062616, -0.211309,
- ),
- mat4x4<f32>(
- 0.169672, -0.188668, 0.097992,
- -0.048049, 0.035012, -0.028287,
- 0.041841, 0.113846, 0.092006,
- ),
- mat4x4<f32>(
- 0.084688, -0.173117, -0.130135,
- 0.125052, 0.070060, -0.072493,
- -0.081996, -0.041021, -0.200688,
- ),
- mat4x4<f32>(
- 0.180550, 0.018555, -0.092889,
- 0.105823, 0.109215, 0.042989,
- -0.116116, 0.115354, 0.044726,
- ),
- mat4x4<f32>(
- 0.069597, -0.156086, -0.116919,
- 0.003641, -0.033090, 0.077686,
- -0.090117, 0.047527, 0.093449,
- ),
- mat4x4<f32>(
- -0.007961, -0.201232, -0.094087,
- 0.041521, -0.001265, -0.164458,
- -0.063295, -0.177367, 0.120887,
- ),
- mat4x4<f32>(
- 0.005358, -0.153663, 0.234817,
- 0.094452, -0.030598, -0.159715,
- -0.025096, 0.010606, -0.151786,
- ),
- mat4x4<f32>(
- 0.035922, 0.039006, -0.073426,
- 0.234309, 0.042990, -0.074330,
- 0.129497, -0.084083, -0.165691,
- )
-);
-
-const coord_weights_layer0 = mat2x4<f32>(
- 0.156995, -0.026005, 0.159550,
- 0.112678, -0.021301, 0.106653
-);
-
-const bias_layer0 = vec4<f32>(0.149566, -0.002723, 0.142744);
-
-const weights_layer1: array<mat4x4<f32>, 9> = array(
- mat4x4<f32>(
- 0.198730, -0.060590, -0.126001,
- 0.018094, 0.099855, 0.043531,
- -0.048028, 0.024975, -0.055560,
- ),
- mat4x4<f32>(
- 0.093012, -0.056168, 0.075685,
- -0.104572, 0.202161, 0.093453,
- 0.008470, 0.190414, -0.121853,
- ),
- mat4x4<f32>(
- 0.157523, -0.278521, 0.267972,
- 0.226318, 0.108021, -0.020615,
- 0.116906, 0.094663, 0.103058,
- ),
- mat4x4<f32>(
- 0.184815, -0.167385, -0.081513,
- 0.167595, 0.147724, -0.034069,
- 0.109272, 0.149283, 0.022741,
- ),
- mat4x4<f32>(
- -0.133319, 0.069405, 0.028862,
- -0.044914, -0.121720, 0.074758,
- 0.150973, 0.086887, 0.193997,
- ),
- mat4x4<f32>(
- 0.123384, -0.157817, -0.053264,
- 0.216874, 0.024062, 0.227470,
- 0.092232, 0.156942, 0.098989,
- ),
- mat4x4<f32>(
- -0.074328, -0.265180, 0.065633,
- 0.033679, 0.175748, 0.178567,
- 0.168913, 0.192317, -0.015507,
- ),
- mat4x4<f32>(
- -0.103567, -0.081663, 0.239707,
- 0.020591, 0.031346, 0.089577,
- -0.040636, 0.061481, 0.215428,
- ),
- mat4x4<f32>(
- 0.103399, -0.291323, 0.220388,
- 0.163876, 0.106383, 0.175615,
- 0.050511, 0.210950, -0.143280,
- )
-);
-
-const bias_layer1 = vec4<f32>(0.273340, 0.183151, 0.057200);
-
-const weights_layer2: array<mat4x4<f32>, 9> = array(
- mat4x4<f32>(
- 0.170688, -0.158379, -0.073057,
- -0.213429, -0.075772, -0.117451,
- -0.265536, -0.066896, 0.185188,
- ),
- mat4x4<f32>(
- 0.061069, -0.267237, -0.057030,
- -0.112682, -0.001723, 0.020779,
- -0.158726, -0.027319, -0.133134,
- ),
- mat4x4<f32>(
- -0.036597, 0.000282, -0.286058,
- -0.056992, 0.129227, 0.037650,
- -0.305341, -0.082011, 0.155333,
- ),
- mat4x4<f32>(
- 0.146811, 0.086471, -0.092652,
- -0.083987, -0.164501, 0.005801,
- -0.108568, 0.079618, 0.011061,
- ),
- mat4x4<f32>(
- 0.008716, -0.174373, 0.038516,
- -0.263207, -0.201249, -0.106428,
- -0.321199, 0.139540, -0.069047,
- ),
- mat4x4<f32>(
- -0.099231, -0.037154, -0.189117,
- 0.014380, 0.102996, 0.068944,
- -0.011073, 0.175106, 0.019059,
- ),
- mat4x4<f32>(
- -0.170030, -0.077528, -0.038504,
- 0.042379, -0.198288, 0.008895,
- -0.144090, -0.129658, 0.215823,
- ),
- mat4x4<f32>(
- -0.082481, -0.160808, -0.279220,
- -0.029358, 0.021159, -0.037080,
- -0.194849, -0.013461, 0.057026,
- ),
- mat4x4<f32>(
- -0.063711, -0.198759, -0.037847,
- -0.049292, -0.222896, -0.067384,
- -0.167766, -0.090320, 0.106986,
- )
-);
-
-const bias_layer2 = vec4<f32>(0.021260, -0.056985, 0.000823);
-