summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/README.md2
-rwxr-xr-xtraining/export_cnn_v2_shader.py214
-rwxr-xr-xtraining/export_cnn_v2_weights.py284
-rwxr-xr-xtraining/gen_identity_weights.py171
-rwxr-xr-xtraining/train_cnn.py943
-rwxr-xr-xtraining/train_cnn_v2.py472
6 files changed, 1 insertions, 2085 deletions
diff --git a/training/README.md b/training/README.md
index e78b471..bddf4d5 100644
--- a/training/README.md
+++ b/training/README.md
@@ -174,6 +174,6 @@ pip install torch torchvision pillow opencv-python numpy
## References
-- **CNN Effect:** `doc/CNN_EFFECT.md`
+- **CNN Effect:** `cnn_v1/docs/CNN_V1_EFFECT.md`
- **Timeline:** `doc/SEQUENCE.md`
- **HOWTO:** `doc/HOWTO.md`
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
deleted file mode 100755
index 1c74ad0..0000000
--- a/training/export_cnn_v2_shader.py
+++ /dev/null
@@ -1,214 +0,0 @@
-#!/usr/bin/env python3
-"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture
-
-Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
-Generates one shader per layer with embedded weight arrays.
-
-Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size.
- This script is for debugging/testing with per-layer shaders.
-"""
-
-import argparse
-import numpy as np
-import torch
-from pathlib import Path
-
-
-def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False):
- """Generate WGSL compute shader for a single CNN layer.
-
- Args:
- layer_idx: Layer index (0, 1, 2, ...)
- weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
- kernel_size: Kernel size (3, 5, etc.)
- output_dir: Output directory path
- mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.)
- is_output_layer: True if this is the final RGBA output layer
- """
- weights_flat = weights.flatten()
- weights_f16 = weights_flat.astype(np.float16)
- weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals
-
- # Format weights as WGSL array
- weights_str = ",\n ".join(
- ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8])
- for i in range(0, len(weights_f32), 8)
- )
-
- radius = kernel_size // 2
- if is_output_layer:
- activation = "output[c] = clamp(sum, 0.0, 1.0); // Output layer"
- elif layer_idx == 0:
- activation = "output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1]"
- else:
- activation = "output[c] = max(0.0, sum); // Middle layers: ReLU"
-
- shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
-// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
-// Mip level: {mip_level} (p0-p3 features)
-
-const KERNEL_SIZE: u32 = {kernel_size}u;
-const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static)
-const OUT_CHANNELS: u32 = 4u; // Uniform output
-const KERNEL_RADIUS: i32 = {radius};
-
-// Weights quantized to float16 (stored as f32 in WGSL)
-const weights: array<f32, {len(weights_f32)}> = array(
- {weights_str}
-);
-
-@group(0) @binding(0) var static_features: texture_2d<u32>;
-@group(0) @binding(1) var layer_input: texture_2d<u32>;
-@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
-
-fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
- let packed = textureLoad(static_features, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
-}}
-
-fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{
- let packed = textureLoad(layer_input, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
-}}
-
-fn pack_channels(values: vec4<f32>) -> vec4<u32> {{
- return vec4<u32>(
- pack2x16float(vec2<f32>(values.x, values.y)),
- pack2x16float(vec2<f32>(values.z, values.w)),
- 0u, // Unused
- 0u // Unused
- );
-}}
-
-@compute @workgroup_size(8, 8)
-fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
- let coord = vec2<i32>(id.xy);
- let dims = textureDimensions(static_features);
-
- if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{
- return;
- }}
-
- // Load static features (always available)
- let static_feat = unpack_static_features(coord);
-
- // Convolution: 12D input (4 prev + 8 static) → 4D output
- var output: vec4<f32> = vec4<f32>(0.0);
- for (var c: u32 = 0u; c < 4u; c++) {{
- var sum: f32 = 0.0;
-
- for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
- for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{
- let sample_coord = coord + vec2<i32>(kx, ky);
-
- // Border handling (clamp)
- let clamped = vec2<i32>(
- clamp(sample_coord.x, 0, i32(dims.x) - 1),
- clamp(sample_coord.y, 0, i32(dims.y) - 1)
- );
-
- // Load features at this spatial location
- let static_local = unpack_static_features(clamped);
- let layer_local = unpack_layer_channels(clamped); // 4D
-
- // Weight index calculation
- let ky_idx = u32(ky + KERNEL_RADIUS);
- let kx_idx = u32(kx + KERNEL_RADIUS);
- let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
-
- // Accumulate: previous/input channels (4D)
- for (var i: u32 = 0u; i < 4u; i++) {{
- let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
- i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * layer_local[i];
- }}
-
- // Accumulate: static features (8D)
- for (var i: u32 = 0u; i < 8u; i++) {{
- let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
- (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * static_local[i];
- }}
- }}
- }}
-
- {activation}
- }}
-
- // Pack and store
- textureStore(output_tex, coord, pack_channels(output));
-}}
-"""
-
- output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl"
- output_path.write_text(shader_code)
- print(f" → {output_path}")
-
-
-def export_checkpoint(checkpoint_path, output_dir):
- """Export PyTorch checkpoint to WGSL shaders.
-
- Args:
- checkpoint_path: Path to .pth checkpoint
- output_dir: Output directory for shaders
- """
- print(f"Loading checkpoint: {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
-
- state_dict = checkpoint['model_state_dict']
- config = checkpoint['config']
-
- kernel_size = config.get('kernel_size', 3)
- num_layers = config.get('num_layers', 3)
- mip_level = config.get('mip_level', 0)
-
- print(f"Configuration:")
- print(f" Kernel size: {kernel_size}×{kernel_size}")
- print(f" Layers: {num_layers}")
- print(f" Mip level: {mip_level} (p0-p3 features)")
- print(f" Architecture: uniform 12D→4D")
-
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- print(f"\nExporting shaders to {output_dir}/")
-
- # All layers uniform: 12D→4D
- for i in range(num_layers):
- layer_key = f'layers.{i}.weight'
- if layer_key not in state_dict:
- raise ValueError(f"Missing weights for layer {i}: {layer_key}")
-
- layer_weights = state_dict[layer_key].detach().numpy()
- is_output = (i == num_layers - 1)
-
- export_layer_shader(
- layer_idx=i,
- weights=layer_weights,
- kernel_size=kernel_size,
- output_dir=output_dir,
- mip_level=mip_level,
- is_output_layer=is_output
- )
-
- print(f"\nExport complete! Generated {num_layers} shader files.")
-
-
-def main():
- parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders')
- parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
- parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders',
- help='Output directory for shaders')
-
- args = parser.parse_args()
- export_checkpoint(args.checkpoint, args.output_dir)
-
-
-if __name__ == '__main__':
- main()
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
deleted file mode 100755
index f64bd8d..0000000
--- a/training/export_cnn_v2_weights.py
+++ /dev/null
@@ -1,284 +0,0 @@
-#!/usr/bin/env python3
-"""CNN v2 Weight Export Script
-
-Converts PyTorch checkpoints to binary weight format for storage buffer.
-Exports single shader template + binary weights asset.
-"""
-
-import argparse
-import numpy as np
-import torch
-import struct
-from pathlib import Path
-
-
-def export_weights_binary(checkpoint_path, output_path, quiet=False):
- """Export CNN v2 weights to binary format.
-
- Binary format:
- Header (20 bytes):
- uint32 magic ('CNN2')
- uint32 version (2)
- uint32 num_layers
- uint32 total_weights (f16 count)
- uint32 mip_level (0-3)
-
- LayerInfo × num_layers (20 bytes each):
- uint32 kernel_size
- uint32 in_channels
- uint32 out_channels
- uint32 weight_offset (f16 index)
- uint32 weight_count
-
- Weights (f16 array):
- float16[] all_weights
-
- Args:
- checkpoint_path: Path to .pth checkpoint
- output_path: Output .bin file path
-
- Returns:
- config dict for shader generation
- """
- if not quiet:
- print(f"Loading checkpoint: {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
-
- state_dict = checkpoint['model_state_dict']
- config = checkpoint['config']
-
- # Support both old (kernel_size) and new (kernel_sizes) format
- if 'kernel_sizes' in config:
- kernel_sizes = config['kernel_sizes']
- elif 'kernel_size' in config:
- kernel_size = config['kernel_size']
- num_layers = config.get('num_layers', 3)
- kernel_sizes = [kernel_size] * num_layers
- else:
- kernel_sizes = [3, 3, 3] # fallback
-
- num_layers = config.get('num_layers', len(kernel_sizes))
- mip_level = config.get('mip_level', 0)
-
- if not quiet:
- print(f"Configuration:")
- print(f" Kernel sizes: {kernel_sizes}")
- print(f" Layers: {num_layers}")
- print(f" Mip level: {mip_level} (p0-p3 features)")
- print(f" Architecture: uniform 12D→4D (bias=False)")
-
- # Collect layer info - all layers uniform 12D→4D
- layers = []
- all_weights = []
- weight_offset = 0
-
- for i in range(num_layers):
- layer_key = f'layers.{i}.weight'
- if layer_key not in state_dict:
- raise ValueError(f"Missing weights for layer {i}: {layer_key}")
-
- layer_weights = state_dict[layer_key].detach().numpy()
- layer_flat = layer_weights.flatten()
- kernel_size = kernel_sizes[i]
-
- layers.append({
- 'kernel_size': kernel_size,
- 'in_channels': 12, # 4 (input/prev) + 8 (static)
- 'out_channels': 4, # Uniform output
- 'weight_offset': weight_offset,
- 'weight_count': len(layer_flat)
- })
- all_weights.extend(layer_flat)
- weight_offset += len(layer_flat)
-
- if not quiet:
- print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights")
-
- # Convert to f16
- # TODO: Use 8-bit quantization for 2× size reduction
- # Requires quantization-aware training (QAT) to maintain accuracy
- all_weights_f16 = np.array(all_weights, dtype=np.float16)
-
- # Pack f16 pairs into u32 for storage buffer
- # Pad to even count if needed
- if len(all_weights_f16) % 2 == 1:
- all_weights_f16 = np.append(all_weights_f16, np.float16(0.0))
-
- # Pack pairs using numpy view
- weights_u32 = all_weights_f16.view(np.uint32)
-
- binary_size = 20 + len(layers) * 20 + len(weights_u32) * 4
- if not quiet:
- print(f"\nWeight statistics:")
- print(f" Total layers: {len(layers)}")
- print(f" Total weights: {len(all_weights_f16)} (f16)")
- print(f" Packed: {len(weights_u32)} u32")
- print(f" Binary size: {binary_size} bytes")
-
- # Write binary file
- output_path = Path(output_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- with open(output_path, 'wb') as f:
- # Header (20 bytes) - version 2 with mip_level
- f.write(struct.pack('<4sIIII',
- b'CNN2', # magic
- 2, # version (bumped to 2)
- len(layers), # num_layers
- len(all_weights_f16), # total_weights (f16 count)
- mip_level)) # mip_level
-
- # Layer info (20 bytes per layer)
- for layer in layers:
- f.write(struct.pack('<IIIII',
- layer['kernel_size'],
- layer['in_channels'],
- layer['out_channels'],
- layer['weight_offset'],
- layer['weight_count']))
-
- # Weights (u32 packed f16 pairs)
- f.write(weights_u32.tobytes())
-
- if quiet:
- print(f" Exported {num_layers} layers, {len(all_weights_f16)} weights, {binary_size} bytes → {output_path}")
- else:
- print(f" → {output_path}")
-
- return {
- 'num_layers': len(layers),
- 'layers': layers
- }
-
-
-def export_shader_template(config, output_dir):
- """Generate single WGSL shader template with storage buffer binding.
-
- Args:
- config: Layer configuration from export_weights_binary()
- output_dir: Output directory path
- """
- shader_code = """// CNN v2 Compute Shader - Storage Buffer Version
-// Reads weights from storage buffer, processes all layers in sequence
-
-struct CNNv2Header {
- magic: u32, // 'CNN2'
- version: u32, // 1
- num_layers: u32, // Number of layers
- total_weights: u32, // Total f16 weight count
-}
-
-struct CNNv2LayerInfo {
- kernel_size: u32,
- in_channels: u32,
- out_channels: u32,
- weight_offset: u32, // Offset in weights array
- weight_count: u32,
-}
-
-@group(0) @binding(0) var static_features: texture_2d<u32>;
-@group(0) @binding(1) var layer_input: texture_2d<u32>;
-@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
-@group(0) @binding(3) var<storage, read> weights: array<u32>; // Packed f16 pairs
-
-fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
- let packed = textureLoad(static_features, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
-}
-
-fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
- let packed = textureLoad(layer_input, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
-}
-
-fn pack_channels(values: vec4<f32>) -> vec4<u32> {
- return vec4<u32>(
- pack2x16float(vec2<f32>(values.x, values.y)),
- pack2x16float(vec2<f32>(values.z, values.w)),
- 0u, // Unused
- 0u // Unused
- );
-}
-
-fn get_weight(idx: u32) -> f32 {
- let pair_idx = idx / 2u;
- let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32)
- let unpacked = unpack2x16float(packed);
- return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
-}
-
-@compute @workgroup_size(8, 8)
-fn main(@builtin(global_invocation_id) id: vec3<u32>) {
- let coord = vec2<i32>(id.xy);
- let dims = textureDimensions(static_features);
-
- if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
- return;
- }
-
- // Read header
- let header_packed = weights[0]; // magic + version
- let counts_packed = weights[1]; // num_layers + total_weights
- let num_layers = counts_packed & 0xFFFFu;
-
- // Load static features
- let static_feat = unpack_static_features(coord);
-
- // Process each layer (hardcoded for 3 layers for now)
- // TODO: Dynamic layer loop when needed
-
- // Example for layer 0 - expand to full multi-layer when tested
- let layer_info_offset = 2u; // After header
- let layer0_info_base = layer_info_offset;
-
- // Read layer 0 info (5 u32 values = 20 bytes)
- let kernel_size = weights[layer0_info_base];
- let in_channels = weights[layer0_info_base + 1u];
- let out_channels = weights[layer0_info_base + 2u];
- let weight_offset = weights[layer0_info_base + 3u];
-
- // Convolution: 12D input (4 prev + 8 static) → 4D output
- var output: vec4<f32> = vec4<f32>(0.0);
- for (var c: u32 = 0u; c < 4u; c++) {
- output[c] = 0.0; // TODO: Actual convolution
- }
-
- textureStore(output_tex, coord, pack_channels(output));
-}
-"""
-
- output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl"
- output_path.write_text(shader_code)
- print(f" → {output_path}")
-
-
-def main():
- parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format')
- parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
- parser.add_argument('--output-weights', type=str, default='workspaces/main/weights/cnn_v2_weights.bin',
- help='Output binary weights file')
- parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders',
- help='Output directory for shader template')
- parser.add_argument('--quiet', action='store_true',
- help='Suppress detailed output')
-
- args = parser.parse_args()
-
- if not args.quiet:
- print("=== CNN v2 Weight Export ===\n")
- config = export_weights_binary(args.checkpoint, args.output_weights, quiet=args.quiet)
- if not args.quiet:
- print()
- # Shader is manually maintained in cnn_v2_compute.wgsl
- # export_shader_template(config, args.output_shader)
- print("\nExport complete!")
-
-
-if __name__ == '__main__':
- main()
diff --git a/training/gen_identity_weights.py b/training/gen_identity_weights.py
deleted file mode 100755
index 7865d68..0000000
--- a/training/gen_identity_weights.py
+++ /dev/null
@@ -1,171 +0,0 @@
-#!/usr/bin/env python3
-"""Generate Identity CNN v2 Weights
-
-Creates trivial .bin with 1 layer, 1×1 kernel, identity passthrough.
-Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features).
-
-With --mix: Output Ch{i} = 0.5*prev[i] + 0.5*static_p{4+i}
- (50-50 blend of prev layer with uv_x, uv_y, sin20_y, bias)
-
-With --p47: Output Ch{i} = static p{4+i} (uv_x, uv_y, sin20_y, bias)
- (p4/uv_x→ch0, p5/uv_y→ch1, p6/sin20_y→ch2, p7/bias→ch3)
-
-Usage:
- ./training/gen_identity_weights.py [output.bin]
- ./training/gen_identity_weights.py --mix [output.bin]
- ./training/gen_identity_weights.py --p47 [output.bin]
-"""
-
-import argparse
-import numpy as np
-import struct
-from pathlib import Path
-
-
-def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False):
- """Generate identity weights: output = input (ignores static features).
-
- If mix=True, 50-50 blend: 0.5*p0+0.5*p4, 0.5*p1+0.5*p5, etc (avoids overflow).
- If p47=True, transfers static p4-p7 (uv_x, uv_y, sin20_y, bias) to output channels.
-
- Input channel layout: [0-3: prev layer, 4-11: static (p0-p7)]
- Static features: p0-p3 (RGB+D), p4 (uv_x), p5 (uv_y), p6 (sin20_y), p7 (bias)
-
- Binary format:
- Header (20 bytes):
- uint32 magic ('CNN2')
- uint32 version (2)
- uint32 num_layers (1)
- uint32 total_weights (f16 count)
- uint32 mip_level
-
- LayerInfo (20 bytes):
- uint32 kernel_size
- uint32 in_channels (12)
- uint32 out_channels (4)
- uint32 weight_offset (0)
- uint32 weight_count
-
- Weights (u32 packed f16):
- Identity matrix for first 4 input channels
- Zeros for static features (channels 4-11) OR
- Mix matrix (p0+p4, p1+p5, p2+p6, p3+p7) if mix=True
- """
- # Identity: 4 output channels, 12 input channels
- # Weight shape: [out_ch, in_ch, kernel_h, kernel_w]
- in_channels = 12 # 4 input + 8 static
- out_channels = 4
-
- # Identity matrix: diagonal 1.0 for first 4 channels, 0.0 for rest
- weights = np.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=np.float32)
-
- # Center position for kernel
- center = kernel_size // 2
-
- if p47:
- # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only)
- # Input channels: [0-3: prev layer, 4-11: static features (p0-p7)]
- # p4-p7 are at input channels 8-11
- for i in range(out_channels):
- weights[i, i + 8, center, center] = 1.0
- elif mix:
- # Mix mode: 50-50 blend (p0+p4, p1+p5, p2+p6, p3+p7)
- # p0-p3 are at channels 0-3 (prev layer), p4-p7 at channels 8-11 (static)
- for i in range(out_channels):
- weights[i, i, center, center] = 0.5 # 0.5*p{i} (prev layer)
- weights[i, i + 8, center, center] = 0.5 # 0.5*p{i+4} (static)
- else:
- # Identity: output ch i = input ch i
- for i in range(out_channels):
- weights[i, i, center, center] = 1.0
-
- # Flatten
- weights_flat = weights.flatten()
- weight_count = len(weights_flat)
-
- mode_name = 'p47' if p47 else ('mix' if mix else 'identity')
- print(f"Generating {mode_name} weights:")
- print(f" Kernel size: {kernel_size}×{kernel_size}")
- print(f" Channels: 12D→4D")
- print(f" Weights: {weight_count}")
- print(f" Mip level: {mip_level}")
- if mix:
- print(f" Mode: 0.5*prev[i] + 0.5*static_p{{4+i}} (blend with uv/sin/bias)")
- elif p47:
- print(f" Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3")
-
- # Convert to f16
- weights_f16 = np.array(weights_flat, dtype=np.float16)
-
- # Pad to even count
- if len(weights_f16) % 2 == 1:
- weights_f16 = np.append(weights_f16, np.float16(0.0))
-
- # Pack f16 pairs into u32
- weights_u32 = weights_f16.view(np.uint32)
-
- print(f" Packed: {len(weights_u32)} u32")
- print(f" Binary size: {20 + 20 + len(weights_u32) * 4} bytes")
-
- # Write binary
- output_path = Path(output_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- with open(output_path, 'wb') as f:
- # Header (20 bytes)
- f.write(struct.pack('<4sIIII',
- b'CNN2', # magic
- 2, # version
- 1, # num_layers
- len(weights_f16), # total_weights
- mip_level)) # mip_level
-
- # Layer info (20 bytes)
- f.write(struct.pack('<IIIII',
- kernel_size, # kernel_size
- in_channels, # in_channels
- out_channels, # out_channels
- 0, # weight_offset
- weight_count)) # weight_count
-
- # Weights (u32 packed f16)
- f.write(weights_u32.tobytes())
-
- print(f" → {output_path}")
-
- # Verify
- print("\nVerification:")
- with open(output_path, 'rb') as f:
- data = f.read()
- magic, version, num_layers, total_weights, mip = struct.unpack('<4sIIII', data[:20])
- print(f" Magic: {magic}")
- print(f" Version: {version}")
- print(f" Layers: {num_layers}")
- print(f" Total weights: {total_weights}")
- print(f" Mip level: {mip}")
- print(f" File size: {len(data)} bytes")
-
-
-def main():
- parser = argparse.ArgumentParser(description='Generate identity CNN v2 weights')
- parser.add_argument('output', type=str, nargs='?',
- default='workspaces/main/weights/cnn_v2_identity.bin',
- help='Output .bin file path')
- parser.add_argument('--kernel-size', type=int, default=1,
- help='Kernel size (default: 1×1)')
- parser.add_argument('--mip-level', type=int, default=0,
- help='Mip level for p0-p3 features (default: 0)')
- parser.add_argument('--mix', action='store_true',
- help='Mix mode: 50-50 blend of p0-p3 and p4-p7')
- parser.add_argument('--p47', action='store_true',
- help='Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3')
-
- args = parser.parse_args()
-
- print("=== Identity Weight Generator ===\n")
- generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix, args.p47)
- print("\nDone!")
-
-
-if __name__ == '__main__':
- main()
diff --git a/training/train_cnn.py b/training/train_cnn.py
deleted file mode 100755
index 4171dcb..0000000
--- a/training/train_cnn.py
+++ /dev/null
@@ -1,943 +0,0 @@
-#!/usr/bin/env python3
-"""
-CNN Training Script for Image-to-Image Transformation
-
-Trains a convolutional neural network on multiple input/target image pairs.
-
-Usage:
- # Training
- python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
-
- # Inference (generate ground truth)
- python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png
-
-Example:
- python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
- python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth
-"""
-
-import torch
-import torch.nn as nn
-import torch.optim as optim
-from torch.utils.data import Dataset, DataLoader
-from torchvision import transforms
-from PIL import Image
-import numpy as np
-import cv2
-import os
-import sys
-import argparse
-import glob
-
-
-class ImagePairDataset(Dataset):
- """Dataset for loading matching input/target image pairs"""
-
- def __init__(self, input_dir, target_dir, transform=None):
- self.input_dir = input_dir
- self.target_dir = target_dir
- self.transform = transform
-
- # Find all images in input directory
- input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
- self.image_pairs = []
-
- for pattern in input_patterns:
- input_files = glob.glob(os.path.join(input_dir, pattern))
- for input_path in input_files:
- filename = os.path.basename(input_path)
- # Try to find matching target with same name but any supported extension
- target_path = None
- for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
- base_name = os.path.splitext(filename)[0]
- candidate = os.path.join(target_dir, f"{base_name}.{ext}")
- if os.path.exists(candidate):
- target_path = candidate
- break
-
- if target_path:
- self.image_pairs.append((input_path, target_path))
-
- if not self.image_pairs:
- raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
-
- print(f"Found {len(self.image_pairs)} matching image pairs")
-
- def __len__(self):
- return len(self.image_pairs)
-
- def __getitem__(self, idx):
- input_path, target_path = self.image_pairs[idx]
-
- # Load RGBD input (4 channels: RGB + Depth)
- input_img = Image.open(input_path).convert('RGBA')
- target_img = Image.open(target_path).convert('RGB')
-
- if self.transform:
- input_img = self.transform(input_img)
- target_img = self.transform(target_img)
-
- return input_img, target_img
-
-
-class PatchDataset(Dataset):
- """Dataset for extracting salient patches from image pairs"""
-
- def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
- detector='harris', transform=None):
- self.input_dir = input_dir
- self.target_dir = target_dir
- self.patch_size = patch_size
- self.patches_per_image = patches_per_image
- self.detector = detector
- self.transform = transform
-
- # Find all image pairs
- input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
- self.image_pairs = []
-
- for pattern in input_patterns:
- input_files = glob.glob(os.path.join(input_dir, pattern))
- for input_path in input_files:
- filename = os.path.basename(input_path)
- target_path = None
- for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
- base_name = os.path.splitext(filename)[0]
- candidate = os.path.join(target_dir, f"{base_name}.{ext}")
- if os.path.exists(candidate):
- target_path = candidate
- break
-
- if target_path:
- self.image_pairs.append((input_path, target_path))
-
- if not self.image_pairs:
- raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
-
- print(f"Found {len(self.image_pairs)} image pairs")
- print(f"Extracting {patches_per_image} patches per image using {detector} detector")
- print(f"Total patches: {len(self.image_pairs) * patches_per_image}")
-
- def __len__(self):
- return len(self.image_pairs) * self.patches_per_image
-
- def _detect_salient_points(self, img_array):
- """Detect salient points using specified detector"""
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
- h, w = gray.shape
- half_patch = self.patch_size // 2
-
- if self.detector == 'harris':
- # Harris corner detection
- corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
- qualityLevel=0.01, minDistance=half_patch)
- elif self.detector == 'fast':
- # FAST feature detection
- fast = cv2.FastFeatureDetector_create(threshold=20)
- keypoints = fast.detect(gray, None)
- corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
- corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
- elif self.detector == 'shi-tomasi':
- # Shi-Tomasi corner detection (goodFeaturesToTrack with different params)
- corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
- qualityLevel=0.01, minDistance=half_patch,
- useHarrisDetector=False)
- elif self.detector == 'gradient':
- # High-gradient regions
- grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
- grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
- gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
-
- # Find top gradient locations
- threshold = np.percentile(gradient_mag, 95)
- y_coords, x_coords = np.where(gradient_mag > threshold)
-
- if len(x_coords) > self.patches_per_image * 2:
- indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
- x_coords = x_coords[indices]
- y_coords = y_coords[indices]
-
- corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
- corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
- else:
- raise ValueError(f"Unknown detector: {self.detector}")
-
- # Fallback to random if no corners found
- if corners is None or len(corners) == 0:
- x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
- y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
- corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
- corners = corners.reshape(-1, 1, 2)
-
- # Filter valid corners (within bounds)
- valid_corners = []
- for corner in corners:
- x, y = int(corner[0][0]), int(corner[0][1])
- if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
- valid_corners.append((x, y))
- if len(valid_corners) >= self.patches_per_image:
- break
-
- # Fill with random if not enough
- while len(valid_corners) < self.patches_per_image:
- x = np.random.randint(half_patch, w - half_patch)
- y = np.random.randint(half_patch, h - half_patch)
- valid_corners.append((x, y))
-
- return valid_corners
-
- def __getitem__(self, idx):
- img_idx = idx // self.patches_per_image
- patch_idx = idx % self.patches_per_image
-
- input_path, target_path = self.image_pairs[img_idx]
-
- # Load images
- input_img = Image.open(input_path).convert('RGBA')
- target_img = Image.open(target_path).convert('RGB')
-
- # Detect salient points (use input image for detection)
- input_array = np.array(input_img)[:, :, :3] # Use RGB for detection
- corners = self._detect_salient_points(input_array)
-
- # Extract patch at specified index
- x, y = corners[patch_idx]
- half_patch = self.patch_size // 2
-
- # Crop patches
- input_patch = input_img.crop((x - half_patch, y - half_patch,
- x + half_patch, y + half_patch))
- target_patch = target_img.crop((x - half_patch, y - half_patch,
- x + half_patch, y + half_patch))
-
- if self.transform:
- input_patch = self.transform(input_patch)
- target_patch = self.transform(target_patch)
-
- return input_patch, target_patch
-
-
-class SimpleCNN(nn.Module):
- """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray)
-
- Internally computes grayscale, expands to 3-channel RGB output.
- """
-
- def __init__(self, num_layers=1, kernel_sizes=None):
- super(SimpleCNN, self).__init__()
-
- if kernel_sizes is None:
- kernel_sizes = [3] * num_layers
-
- assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
-
- self.kernel_sizes = kernel_sizes
- self.layers = nn.ModuleList()
-
- for i, kernel_size in enumerate(kernel_sizes):
- padding = kernel_size // 2
- if i < num_layers - 1:
- # Inner layers: 7→4 (RGBD output)
- self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True))
- else:
- # Final layer: 7→1 (grayscale output)
- self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
-
- def forward(self, x, return_intermediates=False):
- # x: [B,4,H,W] - RGBD input (D = 1/z)
- B, C, H, W = x.shape
-
- intermediates = [] if return_intermediates else None
-
- # Normalize RGBD to [-1,1]
- x_norm = (x - 0.5) * 2.0
-
- # Compute normalized coordinates [-1,1]
- y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
- x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
-
- # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
- gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]
- gray = (gray - 0.5) * 2.0 # [-1,1]
-
- # Layer 0
- layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
- out = self.layers[0](layer0_input) # [B,4,H,W]
- out = torch.tanh(out) # [-1,1]
- if return_intermediates:
- intermediates.append(out.clone())
-
- # Inner layers
- for i in range(1, len(self.layers)-1):
- layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
- out = self.layers[i](layer_input)
- out = torch.tanh(out)
- if return_intermediates:
- intermediates.append(out.clone())
-
- # Final layer (grayscale→RGB)
- final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
- out = self.layers[-1](final_input) # [B,1,H,W] grayscale
- out = torch.sigmoid(out) # Map to [0,1] with smooth gradients
- final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB
-
- if return_intermediates:
- return final_out, intermediates
- return final_out
-
-
-def generate_layer_shader(output_path, num_layers, kernel_sizes):
- """Generate cnn_layer.wgsl with proper layer switches"""
-
- with open(output_path, 'w') as f:
- f.write("// CNN layer shader - uses modular convolution snippets\n")
- f.write("// Supports multi-pass rendering with residual connections\n")
- f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
- f.write("@group(0) @binding(0) var smplr: sampler;\n")
- f.write("@group(0) @binding(1) var txt: texture_2d<f32>;\n\n")
- f.write("#include \"common_uniforms\"\n")
- f.write("#include \"cnn_activation\"\n")
-
- # Include necessary conv functions
- conv_sizes = set(kernel_sizes)
- for ks in sorted(conv_sizes):
- f.write(f"#include \"cnn_conv{ks}x{ks}\"\n")
- f.write("#include \"cnn_weights_generated\"\n\n")
-
- f.write("struct CNNLayerParams {\n")
- f.write(" layer_index: i32,\n")
- f.write(" blend_amount: f32,\n")
- f.write(" _pad: vec2<f32>,\n")
- f.write("};\n\n")
- f.write("@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;\n")
- f.write("@group(0) @binding(3) var<uniform> params: CNNLayerParams;\n")
- f.write("@group(0) @binding(4) var original_input: texture_2d<f32>;\n\n")
- f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {\n")
- f.write(" var pos = array<vec2<f32>, 3>(\n")
- f.write(" vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)\n")
- f.write(" );\n")
- f.write(" return vec4<f32>(pos[i], 0.0, 1.0);\n")
- f.write("}\n\n")
- f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n")
- f.write(" // Match PyTorch linspace\n")
- f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n")
- f.write(" let original_raw = textureSample(original_input, smplr, uv);\n")
- f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
- f.write(" let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;\n")
- f.write(" var result = vec4<f32>(0.0);\n\n")
-
- # Generate layer switches
- for layer_idx in range(num_layers):
- is_final = layer_idx == num_layers - 1
- ks = kernel_sizes[layer_idx]
- conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1"
-
- if layer_idx == 0:
- conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src"
- f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n")
- f.write(f" if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution, weights_layer{layer_idx});\n")
- f.write(f" result = cnn_tanh(result);\n")
- f.write(f" }}\n")
- elif not is_final:
- f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n")
- f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n")
- f.write(f" }}\n")
- else:
- f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" let sum = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n")
- f.write(f" let gray_out = 1.0 / (1.0 + exp(-sum)); // Sigmoid activation\n")
- f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
- f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n")
- f.write(f" }}\n")
-
- f.write(" return result; // [-1,1]\n")
- f.write("}\n")
-
-
-def export_weights_to_wgsl(model, output_path, kernel_sizes):
- """Export trained weights to WGSL format (vec4-optimized)"""
-
- with open(output_path, 'w') as f:
- f.write("// Auto-generated CNN weights (vec4-optimized)\n")
- f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
-
- for i, layer in enumerate(model.layers):
- weights = layer.weight.data.cpu().numpy()
- bias = layer.bias.data.cpu().numpy()
- out_ch, in_ch, kh, kw = weights.shape
- num_positions = kh * kw
-
- is_final = (i == len(model.layers) - 1)
-
- if is_final:
- # Final layer: 7→1, structure: array<vec4<f32>, 18> (9 pos × 2 vec4)
- # Input: [rgba, uv_gray_1] → 2 vec4s per position
- f.write(f"const weights_layer{i}: array<vec4<f32>, {num_positions * 2}> = array(\n")
- for pos in range(num_positions):
- row, col = pos // kw, pos % kw
- # First vec4: [w0, w1, w2, w3] (rgba)
- v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)]
- # Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
- v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)]
- v1.append(f"{bias[0] / num_positions:.6f}")
- f.write(f" vec4<f32>({', '.join(v0)}),\n")
- f.write(f" vec4<f32>({', '.join(v1)})")
- f.write(",\n" if pos < num_positions-1 else "\n")
- f.write(");\n\n")
- else:
- # Inner layers: 7→4, structure: array<vec4<f32>, 72> (36 entries × 2 vec4)
- # Each filter: 2 vec4s for [rgba][uv_gray_1] inputs
- num_vec4s = num_positions * 4 * 2
- f.write(f"const weights_layer{i}: array<vec4<f32>, {num_vec4s}> = array(\n")
- for pos in range(num_positions):
- row, col = pos // kw, pos % kw
- for out_c in range(4):
- # First vec4: [w0, w1, w2, w3] (rgba)
- v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)]
- # Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
- v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)]
- v1.append(f"{bias[out_c] / num_positions:.6f}")
- idx = (pos * 4 + out_c) * 2
- f.write(f" vec4<f32>({', '.join(v0)}),\n")
- f.write(f" vec4<f32>({', '.join(v1)})")
- f.write(",\n" if idx < num_vec4s-2 else "\n")
- f.write(");\n\n")
-
-
-def generate_conv_base_function(kernel_size, output_path):
- """Generate cnn_conv{K}x{K}_7to4() function for inner layers (vec4-optimized)"""
-
- k = kernel_size
- num_positions = k * k
- radius = k // 2
-
- with open(output_path, 'a') as f:
- f.write(f"\n// Inner layers: 7→4 channels (vec4-optimized)\n")
- f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
- f.write(f"fn cnn_conv{k}x{k}_7to4(\n")
- f.write(f" tex: texture_2d<f32>,\n")
- f.write(f" samp: sampler,\n")
- f.write(f" uv: vec2<f32>,\n")
- f.write(f" resolution: vec2<f32>,\n")
- f.write(f" gray: f32,\n")
- f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n")
- f.write(f") -> vec4<f32> {{\n")
- f.write(f" let step = 1.0 / resolution;\n")
- f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
- f.write(f" var sum = vec4<f32>(0.0);\n")
- f.write(f" var pos = 0;\n\n")
-
- # Convolution loop
- f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
- f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
- f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
- f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n")
- f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
-
- # Accumulate
- f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
- f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n")
- f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n")
- f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n")
- f.write(f" pos += 8;\n")
- f.write(f" }}\n")
- f.write(f" }}\n\n")
-
- f.write(f" return sum;\n")
- f.write(f"}}\n")
-
-
-def generate_conv_src_function(kernel_size, output_path):
- """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)"""
-
- k = kernel_size
- num_positions = k * k
- radius = k // 2
-
- with open(output_path, 'a') as f:
- f.write(f"\n// Source layer: 7→4 channels (vec4-optimized)\n")
- f.write(f"// Normalizes [0,1] input to [-1,1] internally\n")
- f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n")
- f.write(f" tex: texture_2d<f32>,\n")
- f.write(f" samp: sampler,\n")
- f.write(f" uv: vec2<f32>,\n")
- f.write(f" resolution: vec2<f32>,\n")
- f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n")
- f.write(f") -> vec4<f32> {{\n")
- f.write(f" let step = 1.0 / resolution;\n\n")
-
- # Normalize center pixel for gray channel
- f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n")
- f.write(f" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n")
- f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
- f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
-
- f.write(f" var sum = vec4<f32>(0.0);\n")
- f.write(f" var pos = 0;\n\n")
-
- # Convolution loop
- f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
- f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
- f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
- f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n")
-
- # Accumulate with dot products (unrolled)
- f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
- f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n")
- f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n")
- f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n")
- f.write(f" pos += 8;\n")
- f.write(f" }}\n")
- f.write(f" }}\n\n")
-
- f.write(f" return sum;\n")
- f.write(f"}}\n")
-
-
-def generate_conv_final_function(kernel_size, output_path):
- """Generate cnn_conv{K}x{K}_7to1() function for final layer (vec4-optimized)"""
-
- k = kernel_size
- num_positions = k * k
- radius = k // 2
-
- with open(output_path, 'a') as f:
- f.write(f"\n// Final layer: 7→1 channel (vec4-optimized)\n")
- f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
- f.write(f"// Returns raw sum (activation applied at call site)\n")
- f.write(f"fn cnn_conv{k}x{k}_7to1(\n")
- f.write(f" tex: texture_2d<f32>,\n")
- f.write(f" samp: sampler,\n")
- f.write(f" uv: vec2<f32>,\n")
- f.write(f" resolution: vec2<f32>,\n")
- f.write(f" gray: f32,\n")
- f.write(f" weights: array<vec4<f32>, {num_positions * 2}>\n")
- f.write(f") -> f32 {{\n")
- f.write(f" let step = 1.0 / resolution;\n")
- f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
- f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
- f.write(f" var sum = 0.0;\n")
- f.write(f" var pos = 0;\n\n")
-
- # Convolution loop
- f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
- f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
- f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
- f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n\n")
-
- # Accumulate with dot products
- f.write(f" sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);\n")
- f.write(f" pos += 2;\n")
- f.write(f" }}\n")
- f.write(f" }}\n\n")
-
- f.write(f" return sum;\n")
- f.write(f"}}\n")
-
-
-def train(args):
- """Main training loop"""
-
- # Setup device
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- print(f"Using device: {device}")
-
- # Prepare dataset
- if args.patch_size:
- # Patch-based training (preserves natural scale)
- transform = transforms.Compose([
- transforms.ToTensor(),
- ])
- dataset = PatchDataset(args.input, args.target,
- patch_size=args.patch_size,
- patches_per_image=args.patches_per_image,
- detector=args.detector,
- transform=transform)
- else:
- # Full-image training (resize mode)
- transform = transforms.Compose([
- transforms.Resize((256, 256)),
- transforms.ToTensor(),
- ])
- dataset = ImagePairDataset(args.input, args.target, transform=transform)
-
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
-
- # Parse kernel sizes
- kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
- if len(kernel_sizes) == 1 and args.layers > 1:
- kernel_sizes = kernel_sizes * args.layers
-
- # Create model
- model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device)
-
- # Loss and optimizer
- criterion = nn.MSELoss()
- optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
-
- # Resume from checkpoint
- start_epoch = 0
- if args.resume:
- if os.path.exists(args.resume):
- print(f"Loading checkpoint from {args.resume}...")
- checkpoint = torch.load(args.resume, map_location=device)
- model.load_state_dict(checkpoint['model_state'])
- optimizer.load_state_dict(checkpoint['optimizer_state'])
- start_epoch = checkpoint['epoch'] + 1
- print(f"Resumed from epoch {start_epoch}")
- else:
- print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch")
-
- # Compute valid center region (exclude conv padding borders)
- num_layers = args.layers
- border = num_layers # Each 3x3 layer needs 1px, accumulates across layers
-
- # Early stopping setup
- loss_history = []
- early_stop_triggered = False
-
- # Training loop
- print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...")
- print(f"Computing loss on center region only (excluding {border}px border)")
- if args.early_stop_patience > 0:
- print(f"Early stopping: patience={args.early_stop_patience}, eps={args.early_stop_eps}")
-
- for epoch in range(start_epoch, args.epochs):
- epoch_loss = 0.0
- for batch_idx, (inputs, targets) in enumerate(dataloader):
- inputs, targets = inputs.to(device), targets.to(device)
-
- optimizer.zero_grad()
- outputs = model(inputs)
-
- # Only compute loss on center pixels with valid neighborhoods
- if border > 0 and outputs.shape[2] > 2*border and outputs.shape[3] > 2*border:
- outputs_center = outputs[:, :, border:-border, border:-border]
- targets_center = targets[:, :, border:-border, border:-border]
- loss = criterion(outputs_center, targets_center)
- else:
- loss = criterion(outputs, targets)
-
- loss.backward()
- optimizer.step()
-
- epoch_loss += loss.item()
-
- avg_loss = epoch_loss / len(dataloader)
- if (epoch + 1) % 10 == 0:
- print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
-
- # Early stopping check
- if args.early_stop_patience > 0:
- loss_history.append(avg_loss)
- if len(loss_history) >= args.early_stop_patience:
- oldest_loss = loss_history[-args.early_stop_patience]
- loss_change = abs(avg_loss - oldest_loss)
- if loss_change < args.early_stop_eps:
- print(f"Early stopping triggered at epoch {epoch+1}")
- print(f"Loss change over last {args.early_stop_patience} epochs: {loss_change:.8f} < {args.early_stop_eps}")
- early_stop_triggered = True
- break
-
- # Save checkpoint
- if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0:
- checkpoint_dir = args.checkpoint_dir or 'training/checkpoints'
- os.makedirs(checkpoint_dir, exist_ok=True)
- checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
- torch.save({
- 'epoch': epoch,
- 'model_state': model.state_dict(),
- 'optimizer_state': optimizer.state_dict(),
- 'loss': avg_loss,
- 'kernel_sizes': kernel_sizes,
- 'num_layers': args.layers
- }, checkpoint_path)
- print(f"Saved checkpoint to {checkpoint_path}")
-
- # Export weights and shader
- output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
- print(f"\nExporting weights to {output_path}...")
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- export_weights_to_wgsl(model, output_path, kernel_sizes)
-
- # Generate layer shader
- shader_dir = os.path.dirname(output_path)
- shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
- print(f"Generating layer shader to {shader_path}...")
- generate_layer_shader(shader_path, args.layers, kernel_sizes)
-
- # Generate conv shader files for all kernel sizes
- for ks in set(kernel_sizes):
- conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
-
- # Create file with header if it doesn't exist
- if not os.path.exists(conv_path):
- print(f"Creating {conv_path}...")
- with open(conv_path, 'w') as f:
- f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
- generate_conv_base_function(ks, conv_path)
- generate_conv_src_function(ks, conv_path)
- generate_conv_final_function(ks, conv_path)
- print(f"Generated complete {conv_path}")
- continue
-
- # File exists, check for missing functions
- with open(conv_path, 'r') as f:
- content = f.read()
-
- # Generate base 7to4 if missing
- if f"cnn_conv{ks}x{ks}_7to4" not in content:
- generate_conv_base_function(ks, conv_path)
- print(f"Added base 7to4 to {conv_path}")
- with open(conv_path, 'r') as f:
- content = f.read()
-
- # Generate _src variant if missing
- if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
- generate_conv_src_function(ks, conv_path)
- print(f"Added _src variant to {conv_path}")
- with open(conv_path, 'r') as f:
- content = f.read()
-
- # Generate 7to1 final layer if missing
- if f"cnn_conv{ks}x{ks}_7to1" not in content:
- generate_conv_final_function(ks, conv_path)
- print(f"Added 7to1 variant to {conv_path}")
-
- print("Training complete!")
-
-
-def export_from_checkpoint(checkpoint_path, output_path=None):
- """Export WGSL files from checkpoint without training"""
-
- if not os.path.exists(checkpoint_path):
- print(f"Error: Checkpoint file '{checkpoint_path}' not found")
- sys.exit(1)
-
- print(f"Loading checkpoint from {checkpoint_path}...")
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
-
- kernel_sizes = checkpoint['kernel_sizes']
- num_layers = checkpoint['num_layers']
-
- # Recreate model
- model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes)
- model.load_state_dict(checkpoint['model_state'])
-
- # Export weights
- output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
- print(f"Exporting weights to {output_path}...")
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- export_weights_to_wgsl(model, output_path, kernel_sizes)
-
- # Generate layer shader
- shader_dir = os.path.dirname(output_path)
- shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
- print(f"Generating layer shader to {shader_path}...")
- generate_layer_shader(shader_path, num_layers, kernel_sizes)
-
- # Generate conv shader files for all kernel sizes
- for ks in set(kernel_sizes):
- conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
-
- # Create file with header if it doesn't exist
- if not os.path.exists(conv_path):
- print(f"Creating {conv_path}...")
- with open(conv_path, 'w') as f:
- f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
- generate_conv_base_function(ks, conv_path)
- generate_conv_src_function(ks, conv_path)
- generate_conv_final_function(ks, conv_path)
- print(f"Generated complete {conv_path}")
- continue
-
- # File exists, check for missing functions
- with open(conv_path, 'r') as f:
- content = f.read()
-
- # Generate base 7to4 if missing
- if f"cnn_conv{ks}x{ks}_7to4" not in content:
- generate_conv_base_function(ks, conv_path)
- print(f"Added base 7to4 to {conv_path}")
- with open(conv_path, 'r') as f:
- content = f.read()
-
- # Generate _src variant if missing
- if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
- generate_conv_src_function(ks, conv_path)
- print(f"Added _src variant to {conv_path}")
- with open(conv_path, 'r') as f:
- content = f.read()
-
- # Generate 7to1 final layer if missing
- if f"cnn_conv{ks}x{ks}_7to1" not in content:
- generate_conv_final_function(ks, conv_path)
- print(f"Added 7to1 variant to {conv_path}")
-
- print("Export complete!")
-
-
-def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False):
- """Run sliding-window inference to match WGSL shader behavior
-
- Outputs RGBA PNG (RGB from model + alpha from input).
- """
-
- if not os.path.exists(checkpoint_path):
- print(f"Error: Checkpoint '{checkpoint_path}' not found")
- sys.exit(1)
-
- if not os.path.exists(input_path):
- print(f"Error: Input image '{input_path}' not found")
- sys.exit(1)
-
- print(f"Loading checkpoint from {checkpoint_path}...")
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
-
- # Reconstruct model
- model = SimpleCNN(
- num_layers=checkpoint['num_layers'],
- kernel_sizes=checkpoint['kernel_sizes']
- )
- model.load_state_dict(checkpoint['model_state'])
-
- # Debug: Zero out all weights and biases
- if zero_weights:
- print("DEBUG: Zeroing out all weights and biases")
- for layer in model.layers:
- with torch.no_grad():
- layer.weight.zero_()
- layer.bias.zero_()
-
- model.eval()
-
- # Load image
- print(f"Loading input image: {input_path}")
- img = Image.open(input_path).convert('RGBA')
- img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
- W, H = img.size
-
- # Process full image with sliding window (matches WGSL shader)
- print(f"Processing full image ({W}×{H}) with sliding window...")
- with torch.no_grad():
- if save_intermediates:
- output_tensor, intermediates = model(img_tensor, return_intermediates=True)
- else:
- output_tensor = model(img_tensor) # [1,3,H,W] RGB
-
- # Convert to numpy and append alpha
- output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB
- alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input
- output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA
-
- # Debug: print first 8 pixels as hex
- if debug_hex:
- output_u8 = (output_rgba * 255).astype(np.uint8)
- print("First 8 pixels (RGBA hex):")
- for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])):
- y, x = i // output_u8.shape[1], i % output_u8.shape[1]
- r, g, b, a = output_u8[y, x]
- print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
-
- # Save final output as RGBA
- print(f"Saving output to: {output_path}")
- os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
- output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA')
- output_img.save(output_path)
-
- # Save intermediates if requested
- if save_intermediates:
- os.makedirs(save_intermediates, exist_ok=True)
- print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}")
- for layer_idx, layer_tensor in enumerate(intermediates):
- # Convert [-1,1] to [0,1] for visualization
- layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5
- layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8)
-
- # Debug: print first 8 pixels as hex
- if debug_hex:
- print(f"Layer {layer_idx} first 8 pixels (RGBA hex):")
- for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])):
- y, x = i // layer_u8.shape[1], i % layer_u8.shape[1]
- if layer_u8.shape[2] == 4:
- r, g, b, a = layer_u8[y, x]
- print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
- else:
- r, g, b = layer_u8[y, x]
- print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}")
-
- # Save all 4 channels for intermediate layers
- if layer_data.shape[2] == 4:
- layer_img = Image.fromarray(layer_u8, mode='RGBA')
- else:
- layer_img = Image.fromarray(layer_u8)
- layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png')
- layer_img.save(layer_path)
- print(f" Saved layer {layer_idx} to {layer_path}")
-
- print("Done!")
-
-
-def main():
- parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
- parser.add_argument('--input', help='Input image directory (training) or single image (inference)')
- parser.add_argument('--target', help='Target image directory')
- parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
- parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
- parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
- parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
- parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
- parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)')
- parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
- parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
- parser.add_argument('--resume', help='Resume from checkpoint file')
- parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
- parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)')
- parser.add_argument('--patch-size', type=int, help='Extract patches of this size (e.g., 32) instead of resizing (default: None = resize to 256x256)')
- parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)')
- parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
- help='Salient point detector for patch extraction (default: harris)')
- parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
- parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
- parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)')
- parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)')
- parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)')
-
- args = parser.parse_args()
-
- # Inference mode
- if args.infer:
- checkpoint = args.export_only
- if not checkpoint:
- print("Error: --infer requires --export-only <checkpoint>")
- sys.exit(1)
- output_path = args.output or 'inference_output.png'
- patch_size = args.patch_size or 32
- infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex)
- return
-
- # Export-only mode
- if args.export_only:
- export_from_checkpoint(args.export_only, args.output)
- return
-
- # Validate directories for training
- if not args.input or not args.target:
- print("Error: --input and --target required for training (or use --export-only)")
- sys.exit(1)
-
- if not os.path.isdir(args.input):
- print(f"Error: Input directory '{args.input}' does not exist")
- sys.exit(1)
-
- if not os.path.isdir(args.target):
- print(f"Error: Target directory '{args.target}' does not exist")
- sys.exit(1)
-
- train(args)
-
-
-if __name__ == "__main__":
- main()
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
deleted file mode 100755
index 9e5df2f..0000000
--- a/training/train_cnn_v2.py
+++ /dev/null
@@ -1,472 +0,0 @@
-#!/usr/bin/env python3
-"""CNN v2 Training Script - Uniform 12D→4D Architecture
-
-Architecture:
-- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias
-- Input RGBD (4D): original image mip 0
-- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
-- Per-layer kernel sizes (e.g., 1×1, 3×3, 5×5)
-- Uniform layer structure with bias=False (bias in static features)
-"""
-
-import argparse
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.utils.data import Dataset, DataLoader
-from pathlib import Path
-from PIL import Image
-import time
-import cv2
-
-
-def compute_static_features(rgb, depth=None, mip_level=0):
- """Generate 8D static features (parametric + spatial).
-
- Args:
- rgb: (H, W, 3) RGB image [0, 1]
- depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane)
- mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth)
-
- Returns:
- (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
-
- Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0
-
- TODO: Binary format should support arbitrary layout and ordering for feature vector (7D),
- alongside mip-level indication. Current layout is hardcoded as:
- [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
- Future: Allow experimentation with different feature combinations without shader recompilation.
- Examples: [R, G, B, dx, dy, uv_x, bias] or [mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias]
- """
- h, w = rgb.shape[:2]
-
- # Generate mip level for p0-p3
- if mip_level > 0:
- # Downsample to mip level
- mip_rgb = rgb.copy()
- for _ in range(mip_level):
- mip_rgb = cv2.pyrDown(mip_rgb)
- # Upsample back to original size
- for _ in range(mip_level):
- mip_rgb = cv2.pyrUp(mip_rgb)
- # Crop/pad to exact original size if needed
- if mip_rgb.shape[:2] != (h, w):
- mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
- else:
- mip_rgb = rgb
-
- # Parametric features (p0-p3) from mip level
- p0 = mip_rgb[:, :, 0].astype(np.float32)
- p1 = mip_rgb[:, :, 1].astype(np.float32)
- p2 = mip_rgb[:, :, 2].astype(np.float32)
- p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
-
- # UV coordinates (normalized [0, 1])
- uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
- uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32)
-
- # Multi-frequency position encoding
- sin20_y = np.sin(20.0 * uv_y).astype(np.float32)
-
- # Bias dimension (always 1.0) - replaces Conv2d bias parameter
- bias = np.ones((h, w), dtype=np.float32)
-
- # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin20_y, bias]
- features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias], axis=-1)
- return features
-
-
-class CNNv2(nn.Module):
- """CNN v2 - Uniform 12D→4D Architecture
-
- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
- Per-layer kernel sizes supported (e.g., [1, 3, 5])
- Uses bias=False (bias integrated in static features as 1.0)
-
- TODO: Add quantization-aware training (QAT) for 8-bit weights
- - Use torch.quantization.QuantStub/DeQuantStub
- - Train with fake quantization to adapt to 8-bit precision
- - Target: ~1.3 KB weights (vs 2.6 KB with f16)
- """
-
- def __init__(self, kernel_sizes, num_layers=3):
- super().__init__()
- if isinstance(kernel_sizes, int):
- kernel_sizes = [kernel_sizes] * num_layers
- assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
-
- self.kernel_sizes = kernel_sizes
- self.num_layers = num_layers
- self.layers = nn.ModuleList()
-
- # All layers: 12D input (4 RGBD + 8 static) → 4D output
- for kernel_size in kernel_sizes:
- self.layers.append(
- nn.Conv2d(12, 4, kernel_size=kernel_size,
- padding=kernel_size//2, bias=False)
- )
-
- def forward(self, input_rgbd, static_features):
- """Forward pass with uniform 12D→4D layers.
-
- Args:
- input_rgbd: (B, 4, H, W) input image RGBD (mip 0)
- static_features: (B, 8, H, W) static features
-
- Returns:
- (B, 4, H, W) RGBA output [0, 1]
- """
- # Layer 0: input RGBD (4D) + static (8D) = 12D
- x = torch.cat([input_rgbd, static_features], dim=1)
- x = self.layers[0](x)
- x = torch.sigmoid(x) # Soft [0,1] for layer 0
-
- # Layer 1+: previous (4D) + static (8D) = 12D
- for i in range(1, self.num_layers):
- x_input = torch.cat([x, static_features], dim=1)
- x = self.layers[i](x_input)
- if i < self.num_layers - 1:
- x = F.relu(x)
- else:
- x = torch.sigmoid(x) # Soft [0,1] for final layer
-
- return x
-
-
-class PatchDataset(Dataset):
- """Patch-based dataset extracting salient regions from images."""
-
- def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
- detector='harris', mip_level=0):
- self.input_paths = sorted(Path(input_dir).glob("*.png"))
- self.target_paths = sorted(Path(target_dir).glob("*.png"))
- self.patch_size = patch_size
- self.patches_per_image = patches_per_image
- self.detector = detector
- self.mip_level = mip_level
-
- assert len(self.input_paths) == len(self.target_paths), \
- f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
-
- print(f"Found {len(self.input_paths)} image pairs")
- print(f"Extracting {patches_per_image} patches per image using {detector} detector")
- print(f"Total patches: {len(self.input_paths) * patches_per_image}")
-
- def __len__(self):
- return len(self.input_paths) * self.patches_per_image
-
- def _detect_salient_points(self, img_array):
- """Detect salient points on original image.
-
- TODO: Add random sampling to training vectors
- - In addition to salient points, incorporate randomly-located samples
- - Default: 10% random samples, 90% salient points
- - Prevents overfitting to only high-gradient regions
- - Improves generalization across entire image
- - Configurable via --random-sample-percent parameter
- """
- gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
- h, w = gray.shape
- half_patch = self.patch_size // 2
-
- corners = None
- if self.detector == 'harris':
- corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
- qualityLevel=0.01, minDistance=half_patch)
- elif self.detector == 'fast':
- fast = cv2.FastFeatureDetector_create(threshold=20)
- keypoints = fast.detect(gray, None)
- corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
- corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
- elif self.detector == 'shi-tomasi':
- corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
- qualityLevel=0.01, minDistance=half_patch,
- useHarrisDetector=False)
- elif self.detector == 'gradient':
- grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
- grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
- gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
- threshold = np.percentile(gradient_mag, 95)
- y_coords, x_coords = np.where(gradient_mag > threshold)
-
- if len(x_coords) > self.patches_per_image * 2:
- indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
- x_coords = x_coords[indices]
- y_coords = y_coords[indices]
-
- corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
- corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
-
- # Fallback to random if no corners found
- if corners is None or len(corners) == 0:
- x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
- y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
- corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
- corners = corners.reshape(-1, 1, 2)
-
- # Filter valid corners
- valid_corners = []
- for corner in corners:
- x, y = int(corner[0][0]), int(corner[0][1])
- if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
- valid_corners.append((x, y))
- if len(valid_corners) >= self.patches_per_image:
- break
-
- # Fill with random if not enough
- while len(valid_corners) < self.patches_per_image:
- x = np.random.randint(half_patch, w - half_patch)
- y = np.random.randint(half_patch, h - half_patch)
- valid_corners.append((x, y))
-
- return valid_corners
-
- def __getitem__(self, idx):
- img_idx = idx // self.patches_per_image
- patch_idx = idx % self.patches_per_image
-
- # Load original images (no resize)
- input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0
- target_pil = Image.open(self.target_paths[img_idx])
- target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
-
- # Detect salient points on original image (use RGB only)
- salient_points = self._detect_salient_points(input_img)
- cx, cy = salient_points[patch_idx]
-
- # Extract patch
- half_patch = self.patch_size // 2
- y1, y2 = cy - half_patch, cy + half_patch
- x1, x2 = cx - half_patch, cx + half_patch
-
- input_patch = input_img[y1:y2, x1:x2]
- target_patch = target_img[y1:y2, x1:x2] # RGBA
-
- # Extract depth from target alpha channel (or default to 1.0)
- depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None
-
- # Compute static features for patch
- static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level)
-
- # Input RGBD (mip 0) - add depth channel
- input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
-
- # Convert to tensors (C, H, W)
- input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
- static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
- target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image
-
- return input_rgbd, static_feat, target
-
-
-class ImagePairDataset(Dataset):
- """Dataset of input/target image pairs (full-image mode)."""
-
- def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0):
- self.input_paths = sorted(Path(input_dir).glob("*.png"))
- self.target_paths = sorted(Path(target_dir).glob("*.png"))
- self.target_size = target_size
- self.mip_level = mip_level
- assert len(self.input_paths) == len(self.target_paths), \
- f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
-
- def __len__(self):
- return len(self.input_paths)
-
- def __getitem__(self, idx):
- # Load and resize images to fixed size
- input_pil = Image.open(self.input_paths[idx]).convert('RGB')
- target_pil = Image.open(self.target_paths[idx])
-
- # Resize to target size
- input_pil = input_pil.resize(self.target_size, Image.LANCZOS)
- target_pil = target_pil.resize(self.target_size, Image.LANCZOS)
-
- input_img = np.array(input_pil) / 255.0
- target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
-
- # Extract depth from target alpha channel (or default to 1.0)
- depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None
-
- # Compute static features
- static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level)
-
- # Input RGBD (mip 0) - add depth channel
- h, w = input_img.shape[:2]
- input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1)
-
- # Convert to tensors (C, H, W)
- input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
- static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
- target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image
-
- return input_rgbd, static_feat, target
-
-
-def train(args):
- """Train CNN v2 model."""
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- print(f"Training on {device}")
-
- # Create dataset (patch-based or full-image)
- if args.full_image:
- print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})")
- target_size = (args.image_size, args.image_size)
- dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level)
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
- else:
- print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)")
- dataset = PatchDataset(args.input, args.target,
- patch_size=args.patch_size,
- patches_per_image=args.patches_per_image,
- detector=args.detector,
- mip_level=args.mip_level)
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
-
- # Parse kernel sizes
- kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
- if len(kernel_sizes) == 1:
- kernel_sizes = kernel_sizes * args.num_layers
- else:
- # When multiple kernel sizes provided, derive num_layers from list length
- args.num_layers = len(kernel_sizes)
-
- # Create model
- model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device)
- total_params = sum(p.numel() for p in model.parameters())
- kernel_desc = ','.join(map(str, kernel_sizes))
- print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights")
- print(f"Using mip level {args.mip_level} for p0-p3 features")
-
- # Optimizer and loss
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
- criterion = nn.MSELoss()
-
- # Training loop
- print(f"\nTraining for {args.epochs} epochs...")
- start_time = time.time()
-
- for epoch in range(1, args.epochs + 1):
- model.train()
- epoch_loss = 0.0
-
- for input_rgbd, static_feat, target in dataloader:
- input_rgbd = input_rgbd.to(device)
- static_feat = static_feat.to(device)
- target = target.to(device)
-
- optimizer.zero_grad()
- output = model(input_rgbd, static_feat)
-
- # Compute loss (grayscale or RGBA)
- if args.grayscale_loss:
- # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B
- output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3]
- target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3]
- loss = criterion(output_gray, target_gray)
- else:
- loss = criterion(output, target)
-
- loss.backward()
- optimizer.step()
-
- epoch_loss += loss.item()
-
- avg_loss = epoch_loss / len(dataloader)
-
- # Print loss at every epoch (overwrite line with \r)
- elapsed = time.time() - start_time
- print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True)
-
- # Save checkpoint
- if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
- print() # Newline before checkpoint message
- checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
- checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': avg_loss,
- 'config': {
- 'kernel_sizes': kernel_sizes,
- 'num_layers': args.num_layers,
- 'mip_level': args.mip_level,
- 'grayscale_loss': args.grayscale_loss,
- 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
- }
- }, checkpoint_path)
- print(f" → Saved checkpoint: {checkpoint_path}")
-
- # Always save final checkpoint
- print() # Newline after training
- final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth"
- final_checkpoint.parent.mkdir(parents=True, exist_ok=True)
- torch.save({
- 'epoch': args.epochs,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': avg_loss,
- 'config': {
- 'kernel_sizes': kernel_sizes,
- 'num_layers': args.num_layers,
- 'mip_level': args.mip_level,
- 'grayscale_loss': args.grayscale_loss,
- 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
- }
- }, final_checkpoint)
- print(f" → Saved final checkpoint: {final_checkpoint}")
-
- print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
- return model
-
-
-def main():
- parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
- parser.add_argument('--input', type=str, required=True, help='Input images directory')
- parser.add_argument('--target', type=str, required=True, help='Target images directory')
-
- # Training mode
- parser.add_argument('--full-image', action='store_true',
- help='Use full-image mode (resize all images)')
- parser.add_argument('--image-size', type=int, default=256,
- help='Full-image mode: resize to this size (default: 256)')
-
- # Patch-based mode (default)
- parser.add_argument('--patch-size', type=int, default=32,
- help='Patch mode: patch size (default: 32)')
- parser.add_argument('--patches-per-image', type=int, default=64,
- help='Patch mode: patches per image (default: 64)')
- parser.add_argument('--detector', type=str, default='harris',
- choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
- help='Patch mode: salient point detector (default: harris)')
- # TODO: Add --random-sample-percent parameter (default: 10)
- # Mix salient points with random samples for better generalization
-
- # Model architecture
- parser.add_argument('--kernel-sizes', type=str, default='3',
- help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)')
- parser.add_argument('--num-layers', type=int, default=3,
- help='Number of CNN layers (default: 3)')
- parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3],
- help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)')
-
- # Training parameters
- parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
- parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
- parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
- parser.add_argument('--grayscale-loss', action='store_true',
- help='Compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of RGBA')
- parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
- help='Checkpoint directory')
- parser.add_argument('--checkpoint-every', type=int, default=1000,
- help='Save checkpoint every N epochs (0 = disable)')
-
- args = parser.parse_args()
- train(args)
-
-
-if __name__ == '__main__':
- main()