From 161a59fa50bb92e3664c389fa03b95aefe349b3f Mon Sep 17 00:00:00 2001 From: skal Date: Sun, 15 Feb 2026 18:44:17 +0100 Subject: refactor(cnn): isolate CNN v2 to cnn_v2/ subdirectory Move all CNN v2 files to dedicated cnn_v2/ directory to prepare for CNN v3 development. Zero functional changes. Structure: - cnn_v2/src/ - C++ effect implementation - cnn_v2/shaders/ - WGSL shaders (6 files) - cnn_v2/weights/ - Binary weights (3 files) - cnn_v2/training/ - Python training scripts (4 files) - cnn_v2/scripts/ - Shell scripts (train_cnn_v2_full.sh) - cnn_v2/tools/ - Validation tools (HTML) - cnn_v2/docs/ - Documentation (4 markdown files) Changes: - Update CMake source list to cnn_v2/src/cnn_v2_effect.cc - Update assets.txt with relative paths to cnn_v2/ - Update includes to ../../cnn_v2/src/cnn_v2_effect.h - Add PROJECT_ROOT resolution to Python/shell scripts - Update doc references in HOWTO.md, TODO.md - Add cnn_v2/README.md Verification: 34/34 tests passing, demo runs correctly. Co-Authored-By: Claude Sonnet 4.5 --- training/export_cnn_v2_shader.py | 214 ----------------- training/export_cnn_v2_weights.py | 284 ----------------------- training/gen_identity_weights.py | 171 -------------- training/train_cnn_v2.py | 472 -------------------------------------- 4 files changed, 1141 deletions(-) delete mode 100755 training/export_cnn_v2_shader.py delete mode 100755 training/export_cnn_v2_weights.py delete mode 100755 training/gen_identity_weights.py delete mode 100755 training/train_cnn_v2.py (limited to 'training') diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py deleted file mode 100755 index 1c74ad0..0000000 --- a/training/export_cnn_v2_shader.py +++ /dev/null @@ -1,214 +0,0 @@ -#!/usr/bin/env python3 -"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture - -Converts PyTorch checkpoints to WGSL compute shaders with f16 weights. -Generates one shader per layer with embedded weight arrays. - -Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size. - This script is for debugging/testing with per-layer shaders. -""" - -import argparse -import numpy as np -import torch -from pathlib import Path - - -def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False): - """Generate WGSL compute shader for a single CNN layer. - - Args: - layer_idx: Layer index (0, 1, 2, ...) - weights: (4, 12, k, k) weight tensor (uniform 12D→4D) - kernel_size: Kernel size (3, 5, etc.) - output_dir: Output directory path - mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.) - is_output_layer: True if this is the final RGBA output layer - """ - weights_flat = weights.flatten() - weights_f16 = weights_flat.astype(np.float16) - weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals - - # Format weights as WGSL array - weights_str = ",\n ".join( - ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8]) - for i in range(0, len(weights_f32), 8) - ) - - radius = kernel_size // 2 - if is_output_layer: - activation = "output[c] = clamp(sum, 0.0, 1.0); // Output layer" - elif layer_idx == 0: - activation = "output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1]" - else: - activation = "output[c] = max(0.0, sum); // Middle layers: ReLU" - - shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D) -// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D -// Mip level: {mip_level} (p0-p3 features) - -const KERNEL_SIZE: u32 = {kernel_size}u; -const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static) -const OUT_CHANNELS: u32 = 4u; // Uniform output -const KERNEL_RADIUS: i32 = {radius}; - -// Weights quantized to float16 (stored as f32 in WGSL) -const weights: array = array( - {weights_str} -); - -@group(0) @binding(0) var static_features: texture_2d; -@group(0) @binding(1) var layer_input: texture_2d; -@group(0) @binding(2) var output_tex: texture_storage_2d; - -fn unpack_static_features(coord: vec2) -> array {{ - let packed = textureLoad(static_features, coord, 0); - let v0 = unpack2x16float(packed.x); - let v1 = unpack2x16float(packed.y); - let v2 = unpack2x16float(packed.z); - let v3 = unpack2x16float(packed.w); - return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); -}} - -fn unpack_layer_channels(coord: vec2) -> vec4 {{ - let packed = textureLoad(layer_input, coord, 0); - let v0 = unpack2x16float(packed.x); - let v1 = unpack2x16float(packed.y); - return vec4(v0.x, v0.y, v1.x, v1.y); -}} - -fn pack_channels(values: vec4) -> vec4 {{ - return vec4( - pack2x16float(vec2(values.x, values.y)), - pack2x16float(vec2(values.z, values.w)), - 0u, // Unused - 0u // Unused - ); -}} - -@compute @workgroup_size(8, 8) -fn main(@builtin(global_invocation_id) id: vec3) {{ - let coord = vec2(id.xy); - let dims = textureDimensions(static_features); - - if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{ - return; - }} - - // Load static features (always available) - let static_feat = unpack_static_features(coord); - - // Convolution: 12D input (4 prev + 8 static) → 4D output - var output: vec4 = vec4(0.0); - for (var c: u32 = 0u; c < 4u; c++) {{ - var sum: f32 = 0.0; - - for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{ - for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{ - let sample_coord = coord + vec2(kx, ky); - - // Border handling (clamp) - let clamped = vec2( - clamp(sample_coord.x, 0, i32(dims.x) - 1), - clamp(sample_coord.y, 0, i32(dims.y) - 1) - ); - - // Load features at this spatial location - let static_local = unpack_static_features(clamped); - let layer_local = unpack_layer_channels(clamped); // 4D - - // Weight index calculation - let ky_idx = u32(ky + KERNEL_RADIUS); - let kx_idx = u32(kx + KERNEL_RADIUS); - let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; - - // Accumulate: previous/input channels (4D) - for (var i: u32 = 0u; i < 4u; i++) {{ - let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + - i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; - sum += weights[w_idx] * layer_local[i]; - }} - - // Accumulate: static features (8D) - for (var i: u32 = 0u; i < 8u; i++) {{ - let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + - (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; - sum += weights[w_idx] * static_local[i]; - }} - }} - }} - - {activation} - }} - - // Pack and store - textureStore(output_tex, coord, pack_channels(output)); -}} -""" - - output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl" - output_path.write_text(shader_code) - print(f" → {output_path}") - - -def export_checkpoint(checkpoint_path, output_dir): - """Export PyTorch checkpoint to WGSL shaders. - - Args: - checkpoint_path: Path to .pth checkpoint - output_dir: Output directory for shaders - """ - print(f"Loading checkpoint: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location='cpu') - - state_dict = checkpoint['model_state_dict'] - config = checkpoint['config'] - - kernel_size = config.get('kernel_size', 3) - num_layers = config.get('num_layers', 3) - mip_level = config.get('mip_level', 0) - - print(f"Configuration:") - print(f" Kernel size: {kernel_size}×{kernel_size}") - print(f" Layers: {num_layers}") - print(f" Mip level: {mip_level} (p0-p3 features)") - print(f" Architecture: uniform 12D→4D") - - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - print(f"\nExporting shaders to {output_dir}/") - - # All layers uniform: 12D→4D - for i in range(num_layers): - layer_key = f'layers.{i}.weight' - if layer_key not in state_dict: - raise ValueError(f"Missing weights for layer {i}: {layer_key}") - - layer_weights = state_dict[layer_key].detach().numpy() - is_output = (i == num_layers - 1) - - export_layer_shader( - layer_idx=i, - weights=layer_weights, - kernel_size=kernel_size, - output_dir=output_dir, - mip_level=mip_level, - is_output_layer=is_output - ) - - print(f"\nExport complete! Generated {num_layers} shader files.") - - -def main(): - parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders') - parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') - parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders', - help='Output directory for shaders') - - args = parser.parse_args() - export_checkpoint(args.checkpoint, args.output_dir) - - -if __name__ == '__main__': - main() diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py deleted file mode 100755 index f64bd8d..0000000 --- a/training/export_cnn_v2_weights.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python3 -"""CNN v2 Weight Export Script - -Converts PyTorch checkpoints to binary weight format for storage buffer. -Exports single shader template + binary weights asset. -""" - -import argparse -import numpy as np -import torch -import struct -from pathlib import Path - - -def export_weights_binary(checkpoint_path, output_path, quiet=False): - """Export CNN v2 weights to binary format. - - Binary format: - Header (20 bytes): - uint32 magic ('CNN2') - uint32 version (2) - uint32 num_layers - uint32 total_weights (f16 count) - uint32 mip_level (0-3) - - LayerInfo × num_layers (20 bytes each): - uint32 kernel_size - uint32 in_channels - uint32 out_channels - uint32 weight_offset (f16 index) - uint32 weight_count - - Weights (f16 array): - float16[] all_weights - - Args: - checkpoint_path: Path to .pth checkpoint - output_path: Output .bin file path - - Returns: - config dict for shader generation - """ - if not quiet: - print(f"Loading checkpoint: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location='cpu') - - state_dict = checkpoint['model_state_dict'] - config = checkpoint['config'] - - # Support both old (kernel_size) and new (kernel_sizes) format - if 'kernel_sizes' in config: - kernel_sizes = config['kernel_sizes'] - elif 'kernel_size' in config: - kernel_size = config['kernel_size'] - num_layers = config.get('num_layers', 3) - kernel_sizes = [kernel_size] * num_layers - else: - kernel_sizes = [3, 3, 3] # fallback - - num_layers = config.get('num_layers', len(kernel_sizes)) - mip_level = config.get('mip_level', 0) - - if not quiet: - print(f"Configuration:") - print(f" Kernel sizes: {kernel_sizes}") - print(f" Layers: {num_layers}") - print(f" Mip level: {mip_level} (p0-p3 features)") - print(f" Architecture: uniform 12D→4D (bias=False)") - - # Collect layer info - all layers uniform 12D→4D - layers = [] - all_weights = [] - weight_offset = 0 - - for i in range(num_layers): - layer_key = f'layers.{i}.weight' - if layer_key not in state_dict: - raise ValueError(f"Missing weights for layer {i}: {layer_key}") - - layer_weights = state_dict[layer_key].detach().numpy() - layer_flat = layer_weights.flatten() - kernel_size = kernel_sizes[i] - - layers.append({ - 'kernel_size': kernel_size, - 'in_channels': 12, # 4 (input/prev) + 8 (static) - 'out_channels': 4, # Uniform output - 'weight_offset': weight_offset, - 'weight_count': len(layer_flat) - }) - all_weights.extend(layer_flat) - weight_offset += len(layer_flat) - - if not quiet: - print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights") - - # Convert to f16 - # TODO: Use 8-bit quantization for 2× size reduction - # Requires quantization-aware training (QAT) to maintain accuracy - all_weights_f16 = np.array(all_weights, dtype=np.float16) - - # Pack f16 pairs into u32 for storage buffer - # Pad to even count if needed - if len(all_weights_f16) % 2 == 1: - all_weights_f16 = np.append(all_weights_f16, np.float16(0.0)) - - # Pack pairs using numpy view - weights_u32 = all_weights_f16.view(np.uint32) - - binary_size = 20 + len(layers) * 20 + len(weights_u32) * 4 - if not quiet: - print(f"\nWeight statistics:") - print(f" Total layers: {len(layers)}") - print(f" Total weights: {len(all_weights_f16)} (f16)") - print(f" Packed: {len(weights_u32)} u32") - print(f" Binary size: {binary_size} bytes") - - # Write binary file - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - with open(output_path, 'wb') as f: - # Header (20 bytes) - version 2 with mip_level - f.write(struct.pack('<4sIIII', - b'CNN2', # magic - 2, # version (bumped to 2) - len(layers), # num_layers - len(all_weights_f16), # total_weights (f16 count) - mip_level)) # mip_level - - # Layer info (20 bytes per layer) - for layer in layers: - f.write(struct.pack('; -@group(0) @binding(1) var layer_input: texture_2d; -@group(0) @binding(2) var output_tex: texture_storage_2d; -@group(0) @binding(3) var weights: array; // Packed f16 pairs - -fn unpack_static_features(coord: vec2) -> array { - let packed = textureLoad(static_features, coord, 0); - let v0 = unpack2x16float(packed.x); - let v1 = unpack2x16float(packed.y); - let v2 = unpack2x16float(packed.z); - let v3 = unpack2x16float(packed.w); - return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); -} - -fn unpack_layer_channels(coord: vec2) -> vec4 { - let packed = textureLoad(layer_input, coord, 0); - let v0 = unpack2x16float(packed.x); - let v1 = unpack2x16float(packed.y); - return vec4(v0.x, v0.y, v1.x, v1.y); -} - -fn pack_channels(values: vec4) -> vec4 { - return vec4( - pack2x16float(vec2(values.x, values.y)), - pack2x16float(vec2(values.z, values.w)), - 0u, // Unused - 0u // Unused - ); -} - -fn get_weight(idx: u32) -> f32 { - let pair_idx = idx / 2u; - let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32) - let unpacked = unpack2x16float(packed); - return select(unpacked.y, unpacked.x, (idx & 1u) == 0u); -} - -@compute @workgroup_size(8, 8) -fn main(@builtin(global_invocation_id) id: vec3) { - let coord = vec2(id.xy); - let dims = textureDimensions(static_features); - - if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { - return; - } - - // Read header - let header_packed = weights[0]; // magic + version - let counts_packed = weights[1]; // num_layers + total_weights - let num_layers = counts_packed & 0xFFFFu; - - // Load static features - let static_feat = unpack_static_features(coord); - - // Process each layer (hardcoded for 3 layers for now) - // TODO: Dynamic layer loop when needed - - // Example for layer 0 - expand to full multi-layer when tested - let layer_info_offset = 2u; // After header - let layer0_info_base = layer_info_offset; - - // Read layer 0 info (5 u32 values = 20 bytes) - let kernel_size = weights[layer0_info_base]; - let in_channels = weights[layer0_info_base + 1u]; - let out_channels = weights[layer0_info_base + 2u]; - let weight_offset = weights[layer0_info_base + 3u]; - - // Convolution: 12D input (4 prev + 8 static) → 4D output - var output: vec4 = vec4(0.0); - for (var c: u32 = 0u; c < 4u; c++) { - output[c] = 0.0; // TODO: Actual convolution - } - - textureStore(output_tex, coord, pack_channels(output)); -} -""" - - output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl" - output_path.write_text(shader_code) - print(f" → {output_path}") - - -def main(): - parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format') - parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') - parser.add_argument('--output-weights', type=str, default='workspaces/main/weights/cnn_v2_weights.bin', - help='Output binary weights file') - parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders', - help='Output directory for shader template') - parser.add_argument('--quiet', action='store_true', - help='Suppress detailed output') - - args = parser.parse_args() - - if not args.quiet: - print("=== CNN v2 Weight Export ===\n") - config = export_weights_binary(args.checkpoint, args.output_weights, quiet=args.quiet) - if not args.quiet: - print() - # Shader is manually maintained in cnn_v2_compute.wgsl - # export_shader_template(config, args.output_shader) - print("\nExport complete!") - - -if __name__ == '__main__': - main() diff --git a/training/gen_identity_weights.py b/training/gen_identity_weights.py deleted file mode 100755 index 7865d68..0000000 --- a/training/gen_identity_weights.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python3 -"""Generate Identity CNN v2 Weights - -Creates trivial .bin with 1 layer, 1×1 kernel, identity passthrough. -Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features). - -With --mix: Output Ch{i} = 0.5*prev[i] + 0.5*static_p{4+i} - (50-50 blend of prev layer with uv_x, uv_y, sin20_y, bias) - -With --p47: Output Ch{i} = static p{4+i} (uv_x, uv_y, sin20_y, bias) - (p4/uv_x→ch0, p5/uv_y→ch1, p6/sin20_y→ch2, p7/bias→ch3) - -Usage: - ./training/gen_identity_weights.py [output.bin] - ./training/gen_identity_weights.py --mix [output.bin] - ./training/gen_identity_weights.py --p47 [output.bin] -""" - -import argparse -import numpy as np -import struct -from pathlib import Path - - -def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False): - """Generate identity weights: output = input (ignores static features). - - If mix=True, 50-50 blend: 0.5*p0+0.5*p4, 0.5*p1+0.5*p5, etc (avoids overflow). - If p47=True, transfers static p4-p7 (uv_x, uv_y, sin20_y, bias) to output channels. - - Input channel layout: [0-3: prev layer, 4-11: static (p0-p7)] - Static features: p0-p3 (RGB+D), p4 (uv_x), p5 (uv_y), p6 (sin20_y), p7 (bias) - - Binary format: - Header (20 bytes): - uint32 magic ('CNN2') - uint32 version (2) - uint32 num_layers (1) - uint32 total_weights (f16 count) - uint32 mip_level - - LayerInfo (20 bytes): - uint32 kernel_size - uint32 in_channels (12) - uint32 out_channels (4) - uint32 weight_offset (0) - uint32 weight_count - - Weights (u32 packed f16): - Identity matrix for first 4 input channels - Zeros for static features (channels 4-11) OR - Mix matrix (p0+p4, p1+p5, p2+p6, p3+p7) if mix=True - """ - # Identity: 4 output channels, 12 input channels - # Weight shape: [out_ch, in_ch, kernel_h, kernel_w] - in_channels = 12 # 4 input + 8 static - out_channels = 4 - - # Identity matrix: diagonal 1.0 for first 4 channels, 0.0 for rest - weights = np.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=np.float32) - - # Center position for kernel - center = kernel_size // 2 - - if p47: - # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only) - # Input channels: [0-3: prev layer, 4-11: static features (p0-p7)] - # p4-p7 are at input channels 8-11 - for i in range(out_channels): - weights[i, i + 8, center, center] = 1.0 - elif mix: - # Mix mode: 50-50 blend (p0+p4, p1+p5, p2+p6, p3+p7) - # p0-p3 are at channels 0-3 (prev layer), p4-p7 at channels 8-11 (static) - for i in range(out_channels): - weights[i, i, center, center] = 0.5 # 0.5*p{i} (prev layer) - weights[i, i + 8, center, center] = 0.5 # 0.5*p{i+4} (static) - else: - # Identity: output ch i = input ch i - for i in range(out_channels): - weights[i, i, center, center] = 1.0 - - # Flatten - weights_flat = weights.flatten() - weight_count = len(weights_flat) - - mode_name = 'p47' if p47 else ('mix' if mix else 'identity') - print(f"Generating {mode_name} weights:") - print(f" Kernel size: {kernel_size}×{kernel_size}") - print(f" Channels: 12D→4D") - print(f" Weights: {weight_count}") - print(f" Mip level: {mip_level}") - if mix: - print(f" Mode: 0.5*prev[i] + 0.5*static_p{{4+i}} (blend with uv/sin/bias)") - elif p47: - print(f" Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3") - - # Convert to f16 - weights_f16 = np.array(weights_flat, dtype=np.float16) - - # Pad to even count - if len(weights_f16) % 2 == 1: - weights_f16 = np.append(weights_f16, np.float16(0.0)) - - # Pack f16 pairs into u32 - weights_u32 = weights_f16.view(np.uint32) - - print(f" Packed: {len(weights_u32)} u32") - print(f" Binary size: {20 + 20 + len(weights_u32) * 4} bytes") - - # Write binary - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - with open(output_path, 'wb') as f: - # Header (20 bytes) - f.write(struct.pack('<4sIIII', - b'CNN2', # magic - 2, # version - 1, # num_layers - len(weights_f16), # total_weights - mip_level)) # mip_level - - # Layer info (20 bytes) - f.write(struct.pack(' 0: - # Downsample to mip level - mip_rgb = rgb.copy() - for _ in range(mip_level): - mip_rgb = cv2.pyrDown(mip_rgb) - # Upsample back to original size - for _ in range(mip_level): - mip_rgb = cv2.pyrUp(mip_rgb) - # Crop/pad to exact original size if needed - if mip_rgb.shape[:2] != (h, w): - mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) - else: - mip_rgb = rgb - - # Parametric features (p0-p3) from mip level - p0 = mip_rgb[:, :, 0].astype(np.float32) - p1 = mip_rgb[:, :, 1].astype(np.float32) - p2 = mip_rgb[:, :, 2].astype(np.float32) - p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane - - # UV coordinates (normalized [0, 1]) - uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) - uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) - - # Multi-frequency position encoding - sin20_y = np.sin(20.0 * uv_y).astype(np.float32) - - # Bias dimension (always 1.0) - replaces Conv2d bias parameter - bias = np.ones((h, w), dtype=np.float32) - - # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin20_y, bias] - features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias], axis=-1) - return features - - -class CNNv2(nn.Module): - """CNN v2 - Uniform 12D→4D Architecture - - All layers: input RGBD (4D) + static (8D) = 12D → 4 channels - Per-layer kernel sizes supported (e.g., [1, 3, 5]) - Uses bias=False (bias integrated in static features as 1.0) - - TODO: Add quantization-aware training (QAT) for 8-bit weights - - Use torch.quantization.QuantStub/DeQuantStub - - Train with fake quantization to adapt to 8-bit precision - - Target: ~1.3 KB weights (vs 2.6 KB with f16) - """ - - def __init__(self, kernel_sizes, num_layers=3): - super().__init__() - if isinstance(kernel_sizes, int): - kernel_sizes = [kernel_sizes] * num_layers - assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" - - self.kernel_sizes = kernel_sizes - self.num_layers = num_layers - self.layers = nn.ModuleList() - - # All layers: 12D input (4 RGBD + 8 static) → 4D output - for kernel_size in kernel_sizes: - self.layers.append( - nn.Conv2d(12, 4, kernel_size=kernel_size, - padding=kernel_size//2, bias=False) - ) - - def forward(self, input_rgbd, static_features): - """Forward pass with uniform 12D→4D layers. - - Args: - input_rgbd: (B, 4, H, W) input image RGBD (mip 0) - static_features: (B, 8, H, W) static features - - Returns: - (B, 4, H, W) RGBA output [0, 1] - """ - # Layer 0: input RGBD (4D) + static (8D) = 12D - x = torch.cat([input_rgbd, static_features], dim=1) - x = self.layers[0](x) - x = torch.sigmoid(x) # Soft [0,1] for layer 0 - - # Layer 1+: previous (4D) + static (8D) = 12D - for i in range(1, self.num_layers): - x_input = torch.cat([x, static_features], dim=1) - x = self.layers[i](x_input) - if i < self.num_layers - 1: - x = F.relu(x) - else: - x = torch.sigmoid(x) # Soft [0,1] for final layer - - return x - - -class PatchDataset(Dataset): - """Patch-based dataset extracting salient regions from images.""" - - def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, - detector='harris', mip_level=0): - self.input_paths = sorted(Path(input_dir).glob("*.png")) - self.target_paths = sorted(Path(target_dir).glob("*.png")) - self.patch_size = patch_size - self.patches_per_image = patches_per_image - self.detector = detector - self.mip_level = mip_level - - assert len(self.input_paths) == len(self.target_paths), \ - f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" - - print(f"Found {len(self.input_paths)} image pairs") - print(f"Extracting {patches_per_image} patches per image using {detector} detector") - print(f"Total patches: {len(self.input_paths) * patches_per_image}") - - def __len__(self): - return len(self.input_paths) * self.patches_per_image - - def _detect_salient_points(self, img_array): - """Detect salient points on original image. - - TODO: Add random sampling to training vectors - - In addition to salient points, incorporate randomly-located samples - - Default: 10% random samples, 90% salient points - - Prevents overfitting to only high-gradient regions - - Improves generalization across entire image - - Configurable via --random-sample-percent parameter - """ - gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) - h, w = gray.shape - half_patch = self.patch_size // 2 - - corners = None - if self.detector == 'harris': - corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, - qualityLevel=0.01, minDistance=half_patch) - elif self.detector == 'fast': - fast = cv2.FastFeatureDetector_create(threshold=20) - keypoints = fast.detect(gray, None) - corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) - corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None - elif self.detector == 'shi-tomasi': - corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, - qualityLevel=0.01, minDistance=half_patch, - useHarrisDetector=False) - elif self.detector == 'gradient': - grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) - grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) - gradient_mag = np.sqrt(grad_x**2 + grad_y**2) - threshold = np.percentile(gradient_mag, 95) - y_coords, x_coords = np.where(gradient_mag > threshold) - - if len(x_coords) > self.patches_per_image * 2: - indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) - x_coords = x_coords[indices] - y_coords = y_coords[indices] - - corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) - corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None - - # Fallback to random if no corners found - if corners is None or len(corners) == 0: - x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) - y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) - corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) - corners = corners.reshape(-1, 1, 2) - - # Filter valid corners - valid_corners = [] - for corner in corners: - x, y = int(corner[0][0]), int(corner[0][1]) - if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: - valid_corners.append((x, y)) - if len(valid_corners) >= self.patches_per_image: - break - - # Fill with random if not enough - while len(valid_corners) < self.patches_per_image: - x = np.random.randint(half_patch, w - half_patch) - y = np.random.randint(half_patch, h - half_patch) - valid_corners.append((x, y)) - - return valid_corners - - def __getitem__(self, idx): - img_idx = idx // self.patches_per_image - patch_idx = idx % self.patches_per_image - - # Load original images (no resize) - input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 - target_pil = Image.open(self.target_paths[img_idx]) - target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha - - # Detect salient points on original image (use RGB only) - salient_points = self._detect_salient_points(input_img) - cx, cy = salient_points[patch_idx] - - # Extract patch - half_patch = self.patch_size // 2 - y1, y2 = cy - half_patch, cy + half_patch - x1, x2 = cx - half_patch, cx + half_patch - - input_patch = input_img[y1:y2, x1:x2] - target_patch = target_img[y1:y2, x1:x2] # RGBA - - # Extract depth from target alpha channel (or default to 1.0) - depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None - - # Compute static features for patch - static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level) - - # Input RGBD (mip 0) - add depth channel - input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) - - # Convert to tensors (C, H, W) - input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) - static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) - target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image - - return input_rgbd, static_feat, target - - -class ImagePairDataset(Dataset): - """Dataset of input/target image pairs (full-image mode).""" - - def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0): - self.input_paths = sorted(Path(input_dir).glob("*.png")) - self.target_paths = sorted(Path(target_dir).glob("*.png")) - self.target_size = target_size - self.mip_level = mip_level - assert len(self.input_paths) == len(self.target_paths), \ - f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" - - def __len__(self): - return len(self.input_paths) - - def __getitem__(self, idx): - # Load and resize images to fixed size - input_pil = Image.open(self.input_paths[idx]).convert('RGB') - target_pil = Image.open(self.target_paths[idx]) - - # Resize to target size - input_pil = input_pil.resize(self.target_size, Image.LANCZOS) - target_pil = target_pil.resize(self.target_size, Image.LANCZOS) - - input_img = np.array(input_pil) / 255.0 - target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha - - # Extract depth from target alpha channel (or default to 1.0) - depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None - - # Compute static features - static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level) - - # Input RGBD (mip 0) - add depth channel - h, w = input_img.shape[:2] - input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1) - - # Convert to tensors (C, H, W) - input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) - static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) - target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image - - return input_rgbd, static_feat, target - - -def train(args): - """Train CNN v2 model.""" - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Training on {device}") - - # Create dataset (patch-based or full-image) - if args.full_image: - print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") - target_size = (args.image_size, args.image_size) - dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level) - dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) - else: - print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") - dataset = PatchDataset(args.input, args.target, - patch_size=args.patch_size, - patches_per_image=args.patches_per_image, - detector=args.detector, - mip_level=args.mip_level) - dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) - - # Parse kernel sizes - kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] - if len(kernel_sizes) == 1: - kernel_sizes = kernel_sizes * args.num_layers - else: - # When multiple kernel sizes provided, derive num_layers from list length - args.num_layers = len(kernel_sizes) - - # Create model - model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device) - total_params = sum(p.numel() for p in model.parameters()) - kernel_desc = ','.join(map(str, kernel_sizes)) - print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights") - print(f"Using mip level {args.mip_level} for p0-p3 features") - - # Optimizer and loss - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - criterion = nn.MSELoss() - - # Training loop - print(f"\nTraining for {args.epochs} epochs...") - start_time = time.time() - - for epoch in range(1, args.epochs + 1): - model.train() - epoch_loss = 0.0 - - for input_rgbd, static_feat, target in dataloader: - input_rgbd = input_rgbd.to(device) - static_feat = static_feat.to(device) - target = target.to(device) - - optimizer.zero_grad() - output = model(input_rgbd, static_feat) - - # Compute loss (grayscale or RGBA) - if args.grayscale_loss: - # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B - output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3] - target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3] - loss = criterion(output_gray, target_gray) - else: - loss = criterion(output, target) - - loss.backward() - optimizer.step() - - epoch_loss += loss.item() - - avg_loss = epoch_loss / len(dataloader) - - # Print loss at every epoch (overwrite line with \r) - elapsed = time.time() - start_time - print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) - - # Save checkpoint - if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: - print() # Newline before checkpoint message - checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" - checkpoint_path.parent.mkdir(parents=True, exist_ok=True) - torch.save({ - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'loss': avg_loss, - 'config': { - 'kernel_sizes': kernel_sizes, - 'num_layers': args.num_layers, - 'mip_level': args.mip_level, - 'grayscale_loss': args.grayscale_loss, - 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias'] - } - }, checkpoint_path) - print(f" → Saved checkpoint: {checkpoint_path}") - - # Always save final checkpoint - print() # Newline after training - final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth" - final_checkpoint.parent.mkdir(parents=True, exist_ok=True) - torch.save({ - 'epoch': args.epochs, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'loss': avg_loss, - 'config': { - 'kernel_sizes': kernel_sizes, - 'num_layers': args.num_layers, - 'mip_level': args.mip_level, - 'grayscale_loss': args.grayscale_loss, - 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias'] - } - }, final_checkpoint) - print(f" → Saved final checkpoint: {final_checkpoint}") - - print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") - return model - - -def main(): - parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') - parser.add_argument('--input', type=str, required=True, help='Input images directory') - parser.add_argument('--target', type=str, required=True, help='Target images directory') - - # Training mode - parser.add_argument('--full-image', action='store_true', - help='Use full-image mode (resize all images)') - parser.add_argument('--image-size', type=int, default=256, - help='Full-image mode: resize to this size (default: 256)') - - # Patch-based mode (default) - parser.add_argument('--patch-size', type=int, default=32, - help='Patch mode: patch size (default: 32)') - parser.add_argument('--patches-per-image', type=int, default=64, - help='Patch mode: patches per image (default: 64)') - parser.add_argument('--detector', type=str, default='harris', - choices=['harris', 'fast', 'shi-tomasi', 'gradient'], - help='Patch mode: salient point detector (default: harris)') - # TODO: Add --random-sample-percent parameter (default: 10) - # Mix salient points with random samples for better generalization - - # Model architecture - parser.add_argument('--kernel-sizes', type=str, default='3', - help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)') - parser.add_argument('--num-layers', type=int, default=3, - help='Number of CNN layers (default: 3)') - parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3], - help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)') - - # Training parameters - parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') - parser.add_argument('--batch-size', type=int, default=16, help='Batch size') - parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') - parser.add_argument('--grayscale-loss', action='store_true', - help='Compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of RGBA') - parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', - help='Checkpoint directory') - parser.add_argument('--checkpoint-every', type=int, default=1000, - help='Save checkpoint every N epochs (0 = disable)') - - args = parser.parse_args() - train(args) - - -if __name__ == '__main__': - main() -- cgit v1.2.3