diff options
Diffstat (limited to 'cnn_v2/training')
| -rwxr-xr-x | cnn_v2/training/export_cnn_v2_shader.py | 218 | ||||
| -rwxr-xr-x | cnn_v2/training/export_cnn_v2_weights.py | 288 | ||||
| -rwxr-xr-x | cnn_v2/training/gen_identity_weights.py | 175 | ||||
| -rwxr-xr-x | cnn_v2/training/train_cnn_v2.py | 472 |
4 files changed, 1153 insertions, 0 deletions
diff --git a/cnn_v2/training/export_cnn_v2_shader.py b/cnn_v2/training/export_cnn_v2_shader.py new file mode 100755 index 0000000..8692a62 --- /dev/null +++ b/cnn_v2/training/export_cnn_v2_shader.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture + +Converts PyTorch checkpoints to WGSL compute shaders with f16 weights. +Generates one shader per layer with embedded weight arrays. + +Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size. + This script is for debugging/testing with per-layer shaders. +""" + +import argparse +import numpy as np +import torch +from pathlib import Path + +# Path resolution for running from any directory +SCRIPT_DIR = Path(__file__).parent +PROJECT_ROOT = SCRIPT_DIR.parent.parent + + +def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False): + """Generate WGSL compute shader for a single CNN layer. + + Args: + layer_idx: Layer index (0, 1, 2, ...) + weights: (4, 12, k, k) weight tensor (uniform 12D→4D) + kernel_size: Kernel size (3, 5, etc.) + output_dir: Output directory path + mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.) + is_output_layer: True if this is the final RGBA output layer + """ + weights_flat = weights.flatten() + weights_f16 = weights_flat.astype(np.float16) + weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals + + # Format weights as WGSL array + weights_str = ",\n ".join( + ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8]) + for i in range(0, len(weights_f32), 8) + ) + + radius = kernel_size // 2 + if is_output_layer: + activation = "output[c] = clamp(sum, 0.0, 1.0); // Output layer" + elif layer_idx == 0: + activation = "output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1]" + else: + activation = "output[c] = max(0.0, sum); // Middle layers: ReLU" + + shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D) +// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D +// Mip level: {mip_level} (p0-p3 features) + +const KERNEL_SIZE: u32 = {kernel_size}u; +const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static) +const OUT_CHANNELS: u32 = 4u; // Uniform output +const KERNEL_RADIUS: i32 = {radius}; + +// Weights quantized to float16 (stored as f32 in WGSL) +const weights: array<f32, {len(weights_f32)}> = array( + {weights_str} +); + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{ + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +}} + +fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{ + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); +}} + +fn pack_channels(values: vec4<f32>) -> vec4<u32> {{ + return vec4<u32>( + pack2x16float(vec2<f32>(values.x, values.y)), + pack2x16float(vec2<f32>(values.z, values.w)), + 0u, // Unused + 0u // Unused + ); +}} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) {{ + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{ + return; + }} + + // Load static features (always available) + let static_feat = unpack_static_features(coord); + + // Convolution: 12D input (4 prev + 8 static) → 4D output + var output: vec4<f32> = vec4<f32>(0.0); + for (var c: u32 = 0u; c < 4u; c++) {{ + var sum: f32 = 0.0; + + for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{ + for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{ + let sample_coord = coord + vec2<i32>(kx, ky); + + // Border handling (clamp) + let clamped = vec2<i32>( + clamp(sample_coord.x, 0, i32(dims.x) - 1), + clamp(sample_coord.y, 0, i32(dims.y) - 1) + ); + + // Load features at this spatial location + let static_local = unpack_static_features(clamped); + let layer_local = unpack_layer_channels(clamped); // 4D + + // Weight index calculation + let ky_idx = u32(ky + KERNEL_RADIUS); + let kx_idx = u32(kx + KERNEL_RADIUS); + let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; + + // Accumulate: previous/input channels (4D) + for (var i: u32 = 0u; i < 4u; i++) {{ + let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * layer_local[i]; + }} + + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) {{ + let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + + (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * static_local[i]; + }} + }} + }} + + {activation} + }} + + // Pack and store + textureStore(output_tex, coord, pack_channels(output)); +}} +""" + + output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl" + output_path.write_text(shader_code) + print(f" → {output_path}") + + +def export_checkpoint(checkpoint_path, output_dir): + """Export PyTorch checkpoint to WGSL shaders. + + Args: + checkpoint_path: Path to .pth checkpoint + output_dir: Output directory for shaders + """ + print(f"Loading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + state_dict = checkpoint['model_state_dict'] + config = checkpoint['config'] + + kernel_size = config.get('kernel_size', 3) + num_layers = config.get('num_layers', 3) + mip_level = config.get('mip_level', 0) + + print(f"Configuration:") + print(f" Kernel size: {kernel_size}×{kernel_size}") + print(f" Layers: {num_layers}") + print(f" Mip level: {mip_level} (p0-p3 features)") + print(f" Architecture: uniform 12D→4D") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\nExporting shaders to {output_dir}/") + + # All layers uniform: 12D→4D + for i in range(num_layers): + layer_key = f'layers.{i}.weight' + if layer_key not in state_dict: + raise ValueError(f"Missing weights for layer {i}: {layer_key}") + + layer_weights = state_dict[layer_key].detach().numpy() + is_output = (i == num_layers - 1) + + export_layer_shader( + layer_idx=i, + weights=layer_weights, + kernel_size=kernel_size, + output_dir=output_dir, + mip_level=mip_level, + is_output_layer=is_output + ) + + print(f"\nExport complete! Generated {num_layers} shader files.") + + +def main(): + parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders') + parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') + parser.add_argument('--output-dir', type=str, default=str(PROJECT_ROOT / 'workspaces/main/shaders'), + help='Output directory for shaders') + + args = parser.parse_args() + export_checkpoint(args.checkpoint, args.output_dir) + + +if __name__ == '__main__': + main() diff --git a/cnn_v2/training/export_cnn_v2_weights.py b/cnn_v2/training/export_cnn_v2_weights.py new file mode 100755 index 0000000..d66b980 --- /dev/null +++ b/cnn_v2/training/export_cnn_v2_weights.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""CNN v2 Weight Export Script + +Converts PyTorch checkpoints to binary weight format for storage buffer. +Exports single shader template + binary weights asset. +""" + +import argparse +import numpy as np +import torch +import struct +from pathlib import Path + +# Path resolution for running from any directory +SCRIPT_DIR = Path(__file__).parent +PROJECT_ROOT = SCRIPT_DIR.parent.parent + + +def export_weights_binary(checkpoint_path, output_path, quiet=False): + """Export CNN v2 weights to binary format. + + Binary format: + Header (20 bytes): + uint32 magic ('CNN2') + uint32 version (2) + uint32 num_layers + uint32 total_weights (f16 count) + uint32 mip_level (0-3) + + LayerInfo × num_layers (20 bytes each): + uint32 kernel_size + uint32 in_channels + uint32 out_channels + uint32 weight_offset (f16 index) + uint32 weight_count + + Weights (f16 array): + float16[] all_weights + + Args: + checkpoint_path: Path to .pth checkpoint + output_path: Output .bin file path + + Returns: + config dict for shader generation + """ + if not quiet: + print(f"Loading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + state_dict = checkpoint['model_state_dict'] + config = checkpoint['config'] + + # Support both old (kernel_size) and new (kernel_sizes) format + if 'kernel_sizes' in config: + kernel_sizes = config['kernel_sizes'] + elif 'kernel_size' in config: + kernel_size = config['kernel_size'] + num_layers = config.get('num_layers', 3) + kernel_sizes = [kernel_size] * num_layers + else: + kernel_sizes = [3, 3, 3] # fallback + + num_layers = config.get('num_layers', len(kernel_sizes)) + mip_level = config.get('mip_level', 0) + + if not quiet: + print(f"Configuration:") + print(f" Kernel sizes: {kernel_sizes}") + print(f" Layers: {num_layers}") + print(f" Mip level: {mip_level} (p0-p3 features)") + print(f" Architecture: uniform 12D→4D (bias=False)") + + # Collect layer info - all layers uniform 12D→4D + layers = [] + all_weights = [] + weight_offset = 0 + + for i in range(num_layers): + layer_key = f'layers.{i}.weight' + if layer_key not in state_dict: + raise ValueError(f"Missing weights for layer {i}: {layer_key}") + + layer_weights = state_dict[layer_key].detach().numpy() + layer_flat = layer_weights.flatten() + kernel_size = kernel_sizes[i] + + layers.append({ + 'kernel_size': kernel_size, + 'in_channels': 12, # 4 (input/prev) + 8 (static) + 'out_channels': 4, # Uniform output + 'weight_offset': weight_offset, + 'weight_count': len(layer_flat) + }) + all_weights.extend(layer_flat) + weight_offset += len(layer_flat) + + if not quiet: + print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights") + + # Convert to f16 + # TODO: Use 8-bit quantization for 2× size reduction + # Requires quantization-aware training (QAT) to maintain accuracy + all_weights_f16 = np.array(all_weights, dtype=np.float16) + + # Pack f16 pairs into u32 for storage buffer + # Pad to even count if needed + if len(all_weights_f16) % 2 == 1: + all_weights_f16 = np.append(all_weights_f16, np.float16(0.0)) + + # Pack pairs using numpy view + weights_u32 = all_weights_f16.view(np.uint32) + + binary_size = 20 + len(layers) * 20 + len(weights_u32) * 4 + if not quiet: + print(f"\nWeight statistics:") + print(f" Total layers: {len(layers)}") + print(f" Total weights: {len(all_weights_f16)} (f16)") + print(f" Packed: {len(weights_u32)} u32") + print(f" Binary size: {binary_size} bytes") + + # Write binary file + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'wb') as f: + # Header (20 bytes) - version 2 with mip_level + f.write(struct.pack('<4sIIII', + b'CNN2', # magic + 2, # version (bumped to 2) + len(layers), # num_layers + len(all_weights_f16), # total_weights (f16 count) + mip_level)) # mip_level + + # Layer info (20 bytes per layer) + for layer in layers: + f.write(struct.pack('<IIIII', + layer['kernel_size'], + layer['in_channels'], + layer['out_channels'], + layer['weight_offset'], + layer['weight_count'])) + + # Weights (u32 packed f16 pairs) + f.write(weights_u32.tobytes()) + + if quiet: + print(f" Exported {num_layers} layers, {len(all_weights_f16)} weights, {binary_size} bytes → {output_path}") + else: + print(f" → {output_path}") + + return { + 'num_layers': len(layers), + 'layers': layers + } + + +def export_shader_template(config, output_dir): + """Generate single WGSL shader template with storage buffer binding. + + Args: + config: Layer configuration from export_weights_binary() + output_dir: Output directory path + """ + shader_code = """// CNN v2 Compute Shader - Storage Buffer Version +// Reads weights from storage buffer, processes all layers in sequence + +struct CNNv2Header { + magic: u32, // 'CNN2' + version: u32, // 1 + num_layers: u32, // Number of layers + total_weights: u32, // Total f16 weight count +} + +struct CNNv2LayerInfo { + kernel_size: u32, + in_channels: u32, + out_channels: u32, + weight_offset: u32, // Offset in weights array + weight_count: u32, +} + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; +@group(0) @binding(3) var<storage, read> weights: array<u32>; // Packed f16 pairs + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> { + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); +} + +fn pack_channels(values: vec4<f32>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2<f32>(values.x, values.y)), + pack2x16float(vec2<f32>(values.z, values.w)), + 0u, // Unused + 0u // Unused + ); +} + +fn get_weight(idx: u32) -> f32 { + let pair_idx = idx / 2u; + let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32) + let unpacked = unpack2x16float(packed); + return select(unpacked.y, unpacked.x, (idx & 1u) == 0u); +} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) { + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { + return; + } + + // Read header + let header_packed = weights[0]; // magic + version + let counts_packed = weights[1]; // num_layers + total_weights + let num_layers = counts_packed & 0xFFFFu; + + // Load static features + let static_feat = unpack_static_features(coord); + + // Process each layer (hardcoded for 3 layers for now) + // TODO: Dynamic layer loop when needed + + // Example for layer 0 - expand to full multi-layer when tested + let layer_info_offset = 2u; // After header + let layer0_info_base = layer_info_offset; + + // Read layer 0 info (5 u32 values = 20 bytes) + let kernel_size = weights[layer0_info_base]; + let in_channels = weights[layer0_info_base + 1u]; + let out_channels = weights[layer0_info_base + 2u]; + let weight_offset = weights[layer0_info_base + 3u]; + + // Convolution: 12D input (4 prev + 8 static) → 4D output + var output: vec4<f32> = vec4<f32>(0.0); + for (var c: u32 = 0u; c < 4u; c++) { + output[c] = 0.0; // TODO: Actual convolution + } + + textureStore(output_tex, coord, pack_channels(output)); +} +""" + + output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl" + output_path.write_text(shader_code) + print(f" → {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format') + parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') + parser.add_argument('--output-weights', type=str, default=str(PROJECT_ROOT / 'workspaces/main/weights/cnn_v2_weights.bin'), + help='Output binary weights file') + parser.add_argument('--output-shader', type=str, default=str(PROJECT_ROOT / 'workspaces/main/shaders'), + help='Output directory for shader template') + parser.add_argument('--quiet', action='store_true', + help='Suppress detailed output') + + args = parser.parse_args() + + if not args.quiet: + print("=== CNN v2 Weight Export ===\n") + config = export_weights_binary(args.checkpoint, args.output_weights, quiet=args.quiet) + if not args.quiet: + print() + # Shader is manually maintained in cnn_v2_compute.wgsl + # export_shader_template(config, args.output_shader) + print("\nExport complete!") + + +if __name__ == '__main__': + main() diff --git a/cnn_v2/training/gen_identity_weights.py b/cnn_v2/training/gen_identity_weights.py new file mode 100755 index 0000000..08eecc6 --- /dev/null +++ b/cnn_v2/training/gen_identity_weights.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +"""Generate Identity CNN v2 Weights + +Creates trivial .bin with 1 layer, 1×1 kernel, identity passthrough. +Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features). + +With --mix: Output Ch{i} = 0.5*prev[i] + 0.5*static_p{4+i} + (50-50 blend of prev layer with uv_x, uv_y, sin20_y, bias) + +With --p47: Output Ch{i} = static p{4+i} (uv_x, uv_y, sin20_y, bias) + (p4/uv_x→ch0, p5/uv_y→ch1, p6/sin20_y→ch2, p7/bias→ch3) + +Usage: + ./training/gen_identity_weights.py [output.bin] + ./training/gen_identity_weights.py --mix [output.bin] + ./training/gen_identity_weights.py --p47 [output.bin] +""" + +import argparse +import numpy as np +import struct +from pathlib import Path + +# Path resolution for running from any directory +SCRIPT_DIR = Path(__file__).parent +PROJECT_ROOT = SCRIPT_DIR.parent.parent + + +def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False): + """Generate identity weights: output = input (ignores static features). + + If mix=True, 50-50 blend: 0.5*p0+0.5*p4, 0.5*p1+0.5*p5, etc (avoids overflow). + If p47=True, transfers static p4-p7 (uv_x, uv_y, sin20_y, bias) to output channels. + + Input channel layout: [0-3: prev layer, 4-11: static (p0-p7)] + Static features: p0-p3 (RGB+D), p4 (uv_x), p5 (uv_y), p6 (sin20_y), p7 (bias) + + Binary format: + Header (20 bytes): + uint32 magic ('CNN2') + uint32 version (2) + uint32 num_layers (1) + uint32 total_weights (f16 count) + uint32 mip_level + + LayerInfo (20 bytes): + uint32 kernel_size + uint32 in_channels (12) + uint32 out_channels (4) + uint32 weight_offset (0) + uint32 weight_count + + Weights (u32 packed f16): + Identity matrix for first 4 input channels + Zeros for static features (channels 4-11) OR + Mix matrix (p0+p4, p1+p5, p2+p6, p3+p7) if mix=True + """ + # Identity: 4 output channels, 12 input channels + # Weight shape: [out_ch, in_ch, kernel_h, kernel_w] + in_channels = 12 # 4 input + 8 static + out_channels = 4 + + # Identity matrix: diagonal 1.0 for first 4 channels, 0.0 for rest + weights = np.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=np.float32) + + # Center position for kernel + center = kernel_size // 2 + + if p47: + # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only) + # Input channels: [0-3: prev layer, 4-11: static features (p0-p7)] + # p4-p7 are at input channels 8-11 + for i in range(out_channels): + weights[i, i + 8, center, center] = 1.0 + elif mix: + # Mix mode: 50-50 blend (p0+p4, p1+p5, p2+p6, p3+p7) + # p0-p3 are at channels 0-3 (prev layer), p4-p7 at channels 8-11 (static) + for i in range(out_channels): + weights[i, i, center, center] = 0.5 # 0.5*p{i} (prev layer) + weights[i, i + 8, center, center] = 0.5 # 0.5*p{i+4} (static) + else: + # Identity: output ch i = input ch i + for i in range(out_channels): + weights[i, i, center, center] = 1.0 + + # Flatten + weights_flat = weights.flatten() + weight_count = len(weights_flat) + + mode_name = 'p47' if p47 else ('mix' if mix else 'identity') + print(f"Generating {mode_name} weights:") + print(f" Kernel size: {kernel_size}×{kernel_size}") + print(f" Channels: 12D→4D") + print(f" Weights: {weight_count}") + print(f" Mip level: {mip_level}") + if mix: + print(f" Mode: 0.5*prev[i] + 0.5*static_p{{4+i}} (blend with uv/sin/bias)") + elif p47: + print(f" Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3") + + # Convert to f16 + weights_f16 = np.array(weights_flat, dtype=np.float16) + + # Pad to even count + if len(weights_f16) % 2 == 1: + weights_f16 = np.append(weights_f16, np.float16(0.0)) + + # Pack f16 pairs into u32 + weights_u32 = weights_f16.view(np.uint32) + + print(f" Packed: {len(weights_u32)} u32") + print(f" Binary size: {20 + 20 + len(weights_u32) * 4} bytes") + + # Write binary + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'wb') as f: + # Header (20 bytes) + f.write(struct.pack('<4sIIII', + b'CNN2', # magic + 2, # version + 1, # num_layers + len(weights_f16), # total_weights + mip_level)) # mip_level + + # Layer info (20 bytes) + f.write(struct.pack('<IIIII', + kernel_size, # kernel_size + in_channels, # in_channels + out_channels, # out_channels + 0, # weight_offset + weight_count)) # weight_count + + # Weights (u32 packed f16) + f.write(weights_u32.tobytes()) + + print(f" → {output_path}") + + # Verify + print("\nVerification:") + with open(output_path, 'rb') as f: + data = f.read() + magic, version, num_layers, total_weights, mip = struct.unpack('<4sIIII', data[:20]) + print(f" Magic: {magic}") + print(f" Version: {version}") + print(f" Layers: {num_layers}") + print(f" Total weights: {total_weights}") + print(f" Mip level: {mip}") + print(f" File size: {len(data)} bytes") + + +def main(): + parser = argparse.ArgumentParser(description='Generate identity CNN v2 weights') + parser.add_argument('output', type=str, nargs='?', + default=str(PROJECT_ROOT / 'workspaces/main/weights/cnn_v2_identity.bin'), + help='Output .bin file path') + parser.add_argument('--kernel-size', type=int, default=1, + help='Kernel size (default: 1×1)') + parser.add_argument('--mip-level', type=int, default=0, + help='Mip level for p0-p3 features (default: 0)') + parser.add_argument('--mix', action='store_true', + help='Mix mode: 50-50 blend of p0-p3 and p4-p7') + parser.add_argument('--p47', action='store_true', + help='Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3') + + args = parser.parse_args() + + print("=== Identity Weight Generator ===\n") + generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix, args.p47) + print("\nDone!") + + +if __name__ == '__main__': + main() diff --git a/cnn_v2/training/train_cnn_v2.py b/cnn_v2/training/train_cnn_v2.py new file mode 100755 index 0000000..9e5df2f --- /dev/null +++ b/cnn_v2/training/train_cnn_v2.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python3 +"""CNN v2 Training Script - Uniform 12D→4D Architecture + +Architecture: +- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias +- Input RGBD (4D): original image mip 0 +- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels +- Per-layer kernel sizes (e.g., 1×1, 3×3, 5×5) +- Uniform layer structure with bias=False (bias in static features) +""" + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +from PIL import Image +import time +import cv2 + + +def compute_static_features(rgb, depth=None, mip_level=0): + """Generate 8D static features (parametric + spatial). + + Args: + rgb: (H, W, 3) RGB image [0, 1] + depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane) + mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth) + + Returns: + (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias] + + Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0 + + TODO: Binary format should support arbitrary layout and ordering for feature vector (7D), + alongside mip-level indication. Current layout is hardcoded as: + [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias] + Future: Allow experimentation with different feature combinations without shader recompilation. + Examples: [R, G, B, dx, dy, uv_x, bias] or [mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias] + """ + h, w = rgb.shape[:2] + + # Generate mip level for p0-p3 + if mip_level > 0: + # Downsample to mip level + mip_rgb = rgb.copy() + for _ in range(mip_level): + mip_rgb = cv2.pyrDown(mip_rgb) + # Upsample back to original size + for _ in range(mip_level): + mip_rgb = cv2.pyrUp(mip_rgb) + # Crop/pad to exact original size if needed + if mip_rgb.shape[:2] != (h, w): + mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) + else: + mip_rgb = rgb + + # Parametric features (p0-p3) from mip level + p0 = mip_rgb[:, :, 0].astype(np.float32) + p1 = mip_rgb[:, :, 1].astype(np.float32) + p2 = mip_rgb[:, :, 2].astype(np.float32) + p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane + + # UV coordinates (normalized [0, 1]) + uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) + uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) + + # Multi-frequency position encoding + sin20_y = np.sin(20.0 * uv_y).astype(np.float32) + + # Bias dimension (always 1.0) - replaces Conv2d bias parameter + bias = np.ones((h, w), dtype=np.float32) + + # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin20_y, bias] + features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias], axis=-1) + return features + + +class CNNv2(nn.Module): + """CNN v2 - Uniform 12D→4D Architecture + + All layers: input RGBD (4D) + static (8D) = 12D → 4 channels + Per-layer kernel sizes supported (e.g., [1, 3, 5]) + Uses bias=False (bias integrated in static features as 1.0) + + TODO: Add quantization-aware training (QAT) for 8-bit weights + - Use torch.quantization.QuantStub/DeQuantStub + - Train with fake quantization to adapt to 8-bit precision + - Target: ~1.3 KB weights (vs 2.6 KB with f16) + """ + + def __init__(self, kernel_sizes, num_layers=3): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * num_layers + assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" + + self.kernel_sizes = kernel_sizes + self.num_layers = num_layers + self.layers = nn.ModuleList() + + # All layers: 12D input (4 RGBD + 8 static) → 4D output + for kernel_size in kernel_sizes: + self.layers.append( + nn.Conv2d(12, 4, kernel_size=kernel_size, + padding=kernel_size//2, bias=False) + ) + + def forward(self, input_rgbd, static_features): + """Forward pass with uniform 12D→4D layers. + + Args: + input_rgbd: (B, 4, H, W) input image RGBD (mip 0) + static_features: (B, 8, H, W) static features + + Returns: + (B, 4, H, W) RGBA output [0, 1] + """ + # Layer 0: input RGBD (4D) + static (8D) = 12D + x = torch.cat([input_rgbd, static_features], dim=1) + x = self.layers[0](x) + x = torch.sigmoid(x) # Soft [0,1] for layer 0 + + # Layer 1+: previous (4D) + static (8D) = 12D + for i in range(1, self.num_layers): + x_input = torch.cat([x, static_features], dim=1) + x = self.layers[i](x_input) + if i < self.num_layers - 1: + x = F.relu(x) + else: + x = torch.sigmoid(x) # Soft [0,1] for final layer + + return x + + +class PatchDataset(Dataset): + """Patch-based dataset extracting salient regions from images.""" + + def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, + detector='harris', mip_level=0): + self.input_paths = sorted(Path(input_dir).glob("*.png")) + self.target_paths = sorted(Path(target_dir).glob("*.png")) + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.detector = detector + self.mip_level = mip_level + + assert len(self.input_paths) == len(self.target_paths), \ + f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" + + print(f"Found {len(self.input_paths)} image pairs") + print(f"Extracting {patches_per_image} patches per image using {detector} detector") + print(f"Total patches: {len(self.input_paths) * patches_per_image}") + + def __len__(self): + return len(self.input_paths) * self.patches_per_image + + def _detect_salient_points(self, img_array): + """Detect salient points on original image. + + TODO: Add random sampling to training vectors + - In addition to salient points, incorporate randomly-located samples + - Default: 10% random samples, 90% salient points + - Prevents overfitting to only high-gradient regions + - Improves generalization across entire image + - Configurable via --random-sample-percent parameter + """ + gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + h, w = gray.shape + half_patch = self.patch_size // 2 + + corners = None + if self.detector == 'harris': + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch) + elif self.detector == 'fast': + fast = cv2.FastFeatureDetector_create(threshold=20) + keypoints = fast.detect(gray, None) + corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + elif self.detector == 'shi-tomasi': + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch, + useHarrisDetector=False) + elif self.detector == 'gradient': + grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + gradient_mag = np.sqrt(grad_x**2 + grad_y**2) + threshold = np.percentile(gradient_mag, 95) + y_coords, x_coords = np.where(gradient_mag > threshold) + + if len(x_coords) > self.patches_per_image * 2: + indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) + x_coords = x_coords[indices] + y_coords = y_coords[indices] + + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + + # Fallback to random if no corners found + if corners is None or len(corners) == 0: + x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) + y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) + + # Filter valid corners + valid_corners = [] + for corner in corners: + x, y = int(corner[0][0]), int(corner[0][1]) + if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: + valid_corners.append((x, y)) + if len(valid_corners) >= self.patches_per_image: + break + + # Fill with random if not enough + while len(valid_corners) < self.patches_per_image: + x = np.random.randint(half_patch, w - half_patch) + y = np.random.randint(half_patch, h - half_patch) + valid_corners.append((x, y)) + + return valid_corners + + def __getitem__(self, idx): + img_idx = idx // self.patches_per_image + patch_idx = idx % self.patches_per_image + + # Load original images (no resize) + input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 + target_pil = Image.open(self.target_paths[img_idx]) + target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha + + # Detect salient points on original image (use RGB only) + salient_points = self._detect_salient_points(input_img) + cx, cy = salient_points[patch_idx] + + # Extract patch + half_patch = self.patch_size // 2 + y1, y2 = cy - half_patch, cy + half_patch + x1, x2 = cx - half_patch, cx + half_patch + + input_patch = input_img[y1:y2, x1:x2] + target_patch = target_img[y1:y2, x1:x2] # RGBA + + # Extract depth from target alpha channel (or default to 1.0) + depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None + + # Compute static features for patch + static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level) + + # Input RGBD (mip 0) - add depth channel + input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) + + # Convert to tensors (C, H, W) + input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) + static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) + target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image + + return input_rgbd, static_feat, target + + +class ImagePairDataset(Dataset): + """Dataset of input/target image pairs (full-image mode).""" + + def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0): + self.input_paths = sorted(Path(input_dir).glob("*.png")) + self.target_paths = sorted(Path(target_dir).glob("*.png")) + self.target_size = target_size + self.mip_level = mip_level + assert len(self.input_paths) == len(self.target_paths), \ + f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" + + def __len__(self): + return len(self.input_paths) + + def __getitem__(self, idx): + # Load and resize images to fixed size + input_pil = Image.open(self.input_paths[idx]).convert('RGB') + target_pil = Image.open(self.target_paths[idx]) + + # Resize to target size + input_pil = input_pil.resize(self.target_size, Image.LANCZOS) + target_pil = target_pil.resize(self.target_size, Image.LANCZOS) + + input_img = np.array(input_pil) / 255.0 + target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha + + # Extract depth from target alpha channel (or default to 1.0) + depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None + + # Compute static features + static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level) + + # Input RGBD (mip 0) - add depth channel + h, w = input_img.shape[:2] + input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1) + + # Convert to tensors (C, H, W) + input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) + static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) + target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image + + return input_rgbd, static_feat, target + + +def train(args): + """Train CNN v2 model.""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Training on {device}") + + # Create dataset (patch-based or full-image) + if args.full_image: + print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") + target_size = (args.image_size, args.image_size) + dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + else: + print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") + dataset = PatchDataset(args.input, args.target, + patch_size=args.patch_size, + patches_per_image=args.patches_per_image, + detector=args.detector, + mip_level=args.mip_level) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + # Parse kernel sizes + kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] + if len(kernel_sizes) == 1: + kernel_sizes = kernel_sizes * args.num_layers + else: + # When multiple kernel sizes provided, derive num_layers from list length + args.num_layers = len(kernel_sizes) + + # Create model + model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device) + total_params = sum(p.numel() for p in model.parameters()) + kernel_desc = ','.join(map(str, kernel_sizes)) + print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights") + print(f"Using mip level {args.mip_level} for p0-p3 features") + + # Optimizer and loss + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + criterion = nn.MSELoss() + + # Training loop + print(f"\nTraining for {args.epochs} epochs...") + start_time = time.time() + + for epoch in range(1, args.epochs + 1): + model.train() + epoch_loss = 0.0 + + for input_rgbd, static_feat, target in dataloader: + input_rgbd = input_rgbd.to(device) + static_feat = static_feat.to(device) + target = target.to(device) + + optimizer.zero_grad() + output = model(input_rgbd, static_feat) + + # Compute loss (grayscale or RGBA) + if args.grayscale_loss: + # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B + output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3] + target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3] + loss = criterion(output_gray, target_gray) + else: + loss = criterion(output, target) + + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(dataloader) + + # Print loss at every epoch (overwrite line with \r) + elapsed = time.time() - start_time + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) + + # Save checkpoint + if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() # Newline before checkpoint message + checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': avg_loss, + 'config': { + 'kernel_sizes': kernel_sizes, + 'num_layers': args.num_layers, + 'mip_level': args.mip_level, + 'grayscale_loss': args.grayscale_loss, + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias'] + } + }, checkpoint_path) + print(f" → Saved checkpoint: {checkpoint_path}") + + # Always save final checkpoint + print() # Newline after training + final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth" + final_checkpoint.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + 'epoch': args.epochs, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': avg_loss, + 'config': { + 'kernel_sizes': kernel_sizes, + 'num_layers': args.num_layers, + 'mip_level': args.mip_level, + 'grayscale_loss': args.grayscale_loss, + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias'] + } + }, final_checkpoint) + print(f" → Saved final checkpoint: {final_checkpoint}") + + print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") + return model + + +def main(): + parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') + parser.add_argument('--input', type=str, required=True, help='Input images directory') + parser.add_argument('--target', type=str, required=True, help='Target images directory') + + # Training mode + parser.add_argument('--full-image', action='store_true', + help='Use full-image mode (resize all images)') + parser.add_argument('--image-size', type=int, default=256, + help='Full-image mode: resize to this size (default: 256)') + + # Patch-based mode (default) + parser.add_argument('--patch-size', type=int, default=32, + help='Patch mode: patch size (default: 32)') + parser.add_argument('--patches-per-image', type=int, default=64, + help='Patch mode: patches per image (default: 64)') + parser.add_argument('--detector', type=str, default='harris', + choices=['harris', 'fast', 'shi-tomasi', 'gradient'], + help='Patch mode: salient point detector (default: harris)') + # TODO: Add --random-sample-percent parameter (default: 10) + # Mix salient points with random samples for better generalization + + # Model architecture + parser.add_argument('--kernel-sizes', type=str, default='3', + help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)') + parser.add_argument('--num-layers', type=int, default=3, + help='Number of CNN layers (default: 3)') + parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3], + help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)') + + # Training parameters + parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') + parser.add_argument('--batch-size', type=int, default=16, help='Batch size') + parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') + parser.add_argument('--grayscale-loss', action='store_true', + help='Compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of RGBA') + parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', + help='Checkpoint directory') + parser.add_argument('--checkpoint-every', type=int, default=1000, + help='Save checkpoint every N epochs (0 = disable)') + + args = parser.parse_args() + train(args) + + +if __name__ == '__main__': + main() |
