summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/export_cnn_v2_shader.py225
-rwxr-xr-xtraining/export_cnn_v2_weights.py275
-rwxr-xr-xtraining/train_cnn_v2.py383
3 files changed, 883 insertions, 0 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
new file mode 100755
index 0000000..3c53ce2
--- /dev/null
+++ b/training/export_cnn_v2_shader.py
@@ -0,0 +1,225 @@
+#!/usr/bin/env python3
+"""CNN v2 Shader Export Script
+
+Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
+Generates one shader per layer with embedded weight arrays.
+"""
+
+import argparse
+import numpy as np
+import torch
+from pathlib import Path
+
+
+def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels,
+ output_dir, is_output_layer=False):
+ """Generate WGSL compute shader for a single CNN layer.
+
+ Args:
+ layer_idx: Layer index (0, 1, 2)
+ weights: (out_ch, in_ch, k, k) weight tensor
+ kernel_size: Kernel size (1, 3, 5, etc.)
+ in_channels: Input channels (includes 8D static features)
+ out_channels: Output channels
+ output_dir: Output directory path
+ is_output_layer: True if this is the final RGBA output layer
+ """
+ weights_flat = weights.flatten()
+ weights_f16 = weights_flat.astype(np.float16)
+ weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals
+
+ # Format weights as WGSL array
+ weights_str = ",\n ".join(
+ ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8])
+ for i in range(0, len(weights_f32), 8)
+ )
+
+ radius = kernel_size // 2
+ activation = "" if is_output_layer else "output[c] = max(0.0, sum); // ReLU"
+ if is_output_layer:
+ activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation"
+
+ shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated
+// Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels}
+
+const KERNEL_SIZE: u32 = {kernel_size}u;
+const IN_CHANNELS: u32 = {in_channels}u;
+const OUT_CHANNELS: u32 = {out_channels}u;
+const KERNEL_RADIUS: i32 = {radius};
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, {len(weights_f32)}> = array(
+ {weights_str}
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {{
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {{
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{
+ return;
+ }}
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {{
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }}
+
+ // Accumulate: layer input channels (if layer_idx > 0)
+ let prev_channels = IN_CHANNELS - 8u;
+ for (var i: u32 = 0u; i < prev_channels; i++) {{
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }}
+ }}
+ }}
+
+ {activation}
+ }}
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}}
+"""
+
+ output_path = Path(output_dir) / f"cnn_v2_layer_{layer_idx}.wgsl"
+ output_path.write_text(shader_code)
+ print(f" → {output_path}")
+
+
+def export_checkpoint(checkpoint_path, output_dir):
+ """Export PyTorch checkpoint to WGSL shaders.
+
+ Args:
+ checkpoint_path: Path to .pth checkpoint
+ output_dir: Output directory for shaders
+ """
+ print(f"Loading checkpoint: {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ state_dict = checkpoint['model_state_dict']
+ config = checkpoint['config']
+
+ print(f"Configuration:")
+ print(f" Kernels: {config['kernels']}")
+ print(f" Channels: {config['channels']}")
+ print(f" Features: {config['features']}")
+
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"\nExporting shaders to {output_dir}/")
+
+ # Layer 0: 8 → channels[0]
+ layer0_weights = state_dict['layer0.weight'].detach().numpy()
+ export_layer_shader(
+ layer_idx=0,
+ weights=layer0_weights,
+ kernel_size=config['kernels'][0],
+ in_channels=8,
+ out_channels=config['channels'][0],
+ output_dir=output_dir,
+ is_output_layer=False
+ )
+
+ # Layer 1: (8 + channels[0]) → channels[1]
+ layer1_weights = state_dict['layer1.weight'].detach().numpy()
+ export_layer_shader(
+ layer_idx=1,
+ weights=layer1_weights,
+ kernel_size=config['kernels'][1],
+ in_channels=8 + config['channels'][0],
+ out_channels=config['channels'][1],
+ output_dir=output_dir,
+ is_output_layer=False
+ )
+
+ # Layer 2: (8 + channels[1]) → 4 (RGBA)
+ layer2_weights = state_dict['layer2.weight'].detach().numpy()
+ export_layer_shader(
+ layer_idx=2,
+ weights=layer2_weights,
+ kernel_size=config['kernels'][2],
+ in_channels=8 + config['channels'][1],
+ out_channels=4,
+ output_dir=output_dir,
+ is_output_layer=True
+ )
+
+ print(f"\nExport complete! Generated 3 shader files.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders')
+ parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
+ parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders',
+ help='Output directory for shaders')
+
+ args = parser.parse_args()
+ export_checkpoint(args.checkpoint, args.output_dir)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
new file mode 100755
index 0000000..723f572
--- /dev/null
+++ b/training/export_cnn_v2_weights.py
@@ -0,0 +1,275 @@
+#!/usr/bin/env python3
+"""CNN v2 Weight Export Script
+
+Converts PyTorch checkpoints to binary weight format for storage buffer.
+Exports single shader template + binary weights asset.
+"""
+
+import argparse
+import numpy as np
+import torch
+import struct
+from pathlib import Path
+
+
+def export_weights_binary(checkpoint_path, output_path):
+ """Export CNN v2 weights to binary format.
+
+ Binary format:
+ Header (16 bytes):
+ uint32 magic ('CNN2')
+ uint32 version (1)
+ uint32 num_layers
+ uint32 total_weights (f16 count)
+
+ LayerInfo × num_layers (20 bytes each):
+ uint32 kernel_size
+ uint32 in_channels
+ uint32 out_channels
+ uint32 weight_offset (f16 index)
+ uint32 weight_count
+
+ Weights (f16 array):
+ float16[] all_weights
+
+ Args:
+ checkpoint_path: Path to .pth checkpoint
+ output_path: Output .bin file path
+
+ Returns:
+ config dict for shader generation
+ """
+ print(f"Loading checkpoint: {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ state_dict = checkpoint['model_state_dict']
+ config = checkpoint['config']
+
+ print(f"Configuration:")
+ print(f" Kernels: {config['kernels']}")
+ print(f" Channels: {config['channels']}")
+
+ # Collect layer info
+ layers = []
+ all_weights = []
+ weight_offset = 0
+
+ # Layer 0: 8 → channels[0]
+ layer0_weights = state_dict['layer0.weight'].detach().numpy()
+ layer0_flat = layer0_weights.flatten()
+ layers.append({
+ 'kernel_size': config['kernels'][0],
+ 'in_channels': 8,
+ 'out_channels': config['channels'][0],
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer0_flat)
+ })
+ all_weights.extend(layer0_flat)
+ weight_offset += len(layer0_flat)
+
+ # Layer 1: (8 + channels[0]) → channels[1]
+ layer1_weights = state_dict['layer1.weight'].detach().numpy()
+ layer1_flat = layer1_weights.flatten()
+ layers.append({
+ 'kernel_size': config['kernels'][1],
+ 'in_channels': 8 + config['channels'][0],
+ 'out_channels': config['channels'][1],
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer1_flat)
+ })
+ all_weights.extend(layer1_flat)
+ weight_offset += len(layer1_flat)
+
+ # Layer 2: (8 + channels[1]) → 4 (RGBA output)
+ layer2_weights = state_dict['layer2.weight'].detach().numpy()
+ layer2_flat = layer2_weights.flatten()
+ layers.append({
+ 'kernel_size': config['kernels'][2],
+ 'in_channels': 8 + config['channels'][1],
+ 'out_channels': 4,
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer2_flat)
+ })
+ all_weights.extend(layer2_flat)
+ weight_offset += len(layer2_flat)
+
+ # Convert to f16
+ # TODO: Use 8-bit quantization for 2× size reduction
+ # Requires quantization-aware training (QAT) to maintain accuracy
+ all_weights_f16 = np.array(all_weights, dtype=np.float16)
+
+ # Pack f16 pairs into u32 for storage buffer
+ # Pad to even count if needed
+ if len(all_weights_f16) % 2 == 1:
+ all_weights_f16 = np.append(all_weights_f16, np.float16(0.0))
+
+ # Pack pairs using numpy view
+ weights_u32 = all_weights_f16.view(np.uint32)
+
+ print(f"\nWeight statistics:")
+ print(f" Total layers: {len(layers)}")
+ print(f" Total weights: {len(all_weights_f16)} (f16)")
+ print(f" Packed: {len(weights_u32)} u32")
+ print(f" Binary size: {16 + len(layers) * 20 + len(weights_u32) * 4} bytes")
+
+ # Write binary file
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(output_path, 'wb') as f:
+ # Header (16 bytes)
+ f.write(struct.pack('<4sIII',
+ b'CNN2', # magic
+ 1, # version
+ len(layers), # num_layers
+ len(all_weights_f16))) # total_weights (f16 count)
+
+ # Layer info (20 bytes per layer)
+ for layer in layers:
+ f.write(struct.pack('<IIIII',
+ layer['kernel_size'],
+ layer['in_channels'],
+ layer['out_channels'],
+ layer['weight_offset'],
+ layer['weight_count']))
+
+ # Weights (u32 packed f16 pairs)
+ f.write(weights_u32.tobytes())
+
+ print(f" → {output_path}")
+
+ return {
+ 'num_layers': len(layers),
+ 'layers': layers
+ }
+
+
+def export_shader_template(config, output_dir):
+ """Generate single WGSL shader template with storage buffer binding.
+
+ Args:
+ config: Layer configuration from export_weights_binary()
+ output_dir: Output directory path
+ """
+ shader_code = """// CNN v2 Compute Shader - Storage Buffer Version
+// Reads weights from storage buffer, processes all layers in sequence
+
+struct CNNv2Header {
+ magic: u32, // 'CNN2'
+ version: u32, // 1
+ num_layers: u32, // Number of layers
+ total_weights: u32, // Total f16 weight count
+}
+
+struct CNNv2LayerInfo {
+ kernel_size: u32,
+ in_channels: u32,
+ out_channels: u32,
+ weight_offset: u32, // Offset in weights array
+ weight_count: u32,
+}
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(3) var<storage, read> weights: array<u32>; // Packed f16 pairs
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+fn get_weight(idx: u32) -> f32 {
+ let pair_idx = idx / 2u;
+ let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32)
+ let unpacked = unpack2x16float(packed);
+ return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Read header
+ let header_packed = weights[0]; // magic + version
+ let counts_packed = weights[1]; // num_layers + total_weights
+ let num_layers = counts_packed & 0xFFFFu;
+
+ // Load static features
+ let static_feat = unpack_static_features(coord);
+
+ // Process each layer (hardcoded for 3 layers for now)
+ // TODO: Dynamic layer loop when needed
+
+ // Example for layer 0 - expand to full multi-layer when tested
+ let layer_info_offset = 2u; // After header
+ let layer0_info_base = layer_info_offset;
+
+ // Read layer 0 info (5 u32 values = 20 bytes)
+ let kernel_size = weights[layer0_info_base];
+ let in_channels = weights[layer0_info_base + 1u];
+ let out_channels = weights[layer0_info_base + 2u];
+ let weight_offset = weights[layer0_info_base + 3u];
+
+ // Convolution (simplified - expand to full kernel loop)
+ var output: array<f32, 8>;
+ for (var c: u32 = 0u; c < min(out_channels, 8u); c++) {
+ output[c] = 0.0; // TODO: Actual convolution
+ }
+
+ textureStore(output_tex, coord, pack_channels(output));
+}
+"""
+
+ output_path = Path(output_dir) / "cnn_v2_compute.wgsl"
+ output_path.write_text(shader_code)
+ print(f" → {output_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format')
+ parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
+ parser.add_argument('--output-weights', type=str, default='workspaces/main/cnn_v2_weights.bin',
+ help='Output binary weights file')
+ parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders',
+ help='Output directory for shader template')
+
+ args = parser.parse_args()
+
+ print("=== CNN v2 Weight Export ===\n")
+ config = export_weights_binary(args.checkpoint, args.output_weights)
+ print()
+ # Shader is manually maintained in cnn_v2_compute.wgsl
+ # export_shader_template(config, args.output_shader)
+ print("\nExport complete!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
new file mode 100755
index 0000000..758b044
--- /dev/null
+++ b/training/train_cnn_v2.py
@@ -0,0 +1,383 @@
+#!/usr/bin/env python3
+"""CNN v2 Training Script - Parametric Static Features
+
+Trains a multi-layer CNN with 7D static feature input:
+- RGBD (4D)
+- UV coordinates (2D)
+- sin(10*uv.x) position encoding (1D)
+- Bias dimension (1D, always 1.0)
+"""
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader
+from pathlib import Path
+from PIL import Image
+import time
+import cv2
+
+
+def compute_static_features(rgb, depth=None):
+ """Generate 7D static features + bias dimension.
+
+ Args:
+ rgb: (H, W, 3) RGB image [0, 1]
+ depth: (H, W) depth map [0, 1], optional
+
+ Returns:
+ (H, W, 8) static features tensor
+ """
+ h, w = rgb.shape[:2]
+
+ # RGBD channels
+ r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
+ d = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
+
+ # UV coordinates (normalized [0, 1])
+ uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
+ uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32)
+
+ # Multi-frequency position encoding
+ sin10_x = np.sin(10.0 * uv_x).astype(np.float32)
+
+ # Bias dimension (always 1.0)
+ bias = np.ones((h, w), dtype=np.float32)
+
+ # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias]
+ features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1)
+ return features
+
+
+class CNNv2(nn.Module):
+ """CNN v2 with parametric static features.
+
+ TODO: Add quantization-aware training (QAT) for 8-bit weights
+ - Use torch.quantization.QuantStub/DeQuantStub
+ - Train with fake quantization to adapt to 8-bit precision
+ - Target: ~1.6 KB weights (vs 3.2 KB with f16)
+ """
+
+ def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]):
+ super().__init__()
+ self.kernels = kernels
+ self.channels = channels
+
+ # Input layer: 8D (7 features + bias) → channels[0]
+ self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0],
+ padding=kernels[0]//2, bias=False)
+
+ # Inner layers: (8 + C_prev) → C_next
+ in_ch_1 = 8 + channels[0]
+ self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1],
+ padding=kernels[1]//2, bias=False)
+
+ # Output layer: (8 + C_last) → 4 (RGBA)
+ in_ch_2 = 8 + channels[1]
+ self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2],
+ padding=kernels[2]//2, bias=False)
+
+ def forward(self, static_features):
+ """Forward pass with static feature concatenation.
+
+ Args:
+ static_features: (B, 8, H, W) static features
+
+ Returns:
+ (B, 4, H, W) RGBA output [0, 1]
+ """
+ # Layer 0: Use full 8D static features
+ x0 = self.layer0(static_features)
+ x0 = F.relu(x0)
+
+ # Layer 1: Concatenate static + layer0 output
+ x1_input = torch.cat([static_features, x0], dim=1)
+ x1 = self.layer1(x1_input)
+ x1 = F.relu(x1)
+
+ # Layer 2: Concatenate static + layer1 output
+ x2_input = torch.cat([static_features, x1], dim=1)
+ output = self.layer2(x2_input)
+
+ return torch.sigmoid(output)
+
+
+class PatchDataset(Dataset):
+ """Patch-based dataset extracting salient regions from images."""
+
+ def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
+ detector='harris'):
+ self.input_paths = sorted(Path(input_dir).glob("*.png"))
+ self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.detector = detector
+
+ assert len(self.input_paths) == len(self.target_paths), \
+ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
+
+ print(f"Found {len(self.input_paths)} image pairs")
+ print(f"Extracting {patches_per_image} patches per image using {detector} detector")
+ print(f"Total patches: {len(self.input_paths) * patches_per_image}")
+
+ def __len__(self):
+ return len(self.input_paths) * self.patches_per_image
+
+ def _detect_salient_points(self, img_array):
+ """Detect salient points on original image.
+
+ TODO: Add random sampling to training vectors
+ - In addition to salient points, incorporate randomly-located samples
+ - Default: 10% random samples, 90% salient points
+ - Prevents overfitting to only high-gradient regions
+ - Improves generalization across entire image
+ - Configurable via --random-sample-percent parameter
+ """
+ gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
+ h, w = gray.shape
+ half_patch = self.patch_size // 2
+
+ corners = None
+ if self.detector == 'harris':
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch)
+ elif self.detector == 'fast':
+ fast = cv2.FastFeatureDetector_create(threshold=20)
+ keypoints = fast.detect(gray, None)
+ corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ elif self.detector == 'shi-tomasi':
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch,
+ useHarrisDetector=False)
+ elif self.detector == 'gradient':
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
+ threshold = np.percentile(gradient_mag, 95)
+ y_coords, x_coords = np.where(gradient_mag > threshold)
+
+ if len(x_coords) > self.patches_per_image * 2:
+ indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
+ x_coords = x_coords[indices]
+ y_coords = y_coords[indices]
+
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+
+ # Fallback to random if no corners found
+ if corners is None or len(corners) == 0:
+ x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
+ y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2)
+
+ # Filter valid corners
+ valid_corners = []
+ for corner in corners:
+ x, y = int(corner[0][0]), int(corner[0][1])
+ if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
+ valid_corners.append((x, y))
+ if len(valid_corners) >= self.patches_per_image:
+ break
+
+ # Fill with random if not enough
+ while len(valid_corners) < self.patches_per_image:
+ x = np.random.randint(half_patch, w - half_patch)
+ y = np.random.randint(half_patch, h - half_patch)
+ valid_corners.append((x, y))
+
+ return valid_corners
+
+ def __getitem__(self, idx):
+ img_idx = idx // self.patches_per_image
+ patch_idx = idx % self.patches_per_image
+
+ # Load original images (no resize)
+ input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0
+ target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0
+
+ # Detect salient points on original image
+ salient_points = self._detect_salient_points(input_img)
+ cx, cy = salient_points[patch_idx]
+
+ # Extract patch
+ half_patch = self.patch_size // 2
+ y1, y2 = cy - half_patch, cy + half_patch
+ x1, x2 = cx - half_patch, cx + half_patch
+
+ input_patch = input_img[y1:y2, x1:x2]
+ target_patch = target_img[y1:y2, x1:x2]
+
+ # Compute static features for patch
+ static_feat = compute_static_features(input_patch.astype(np.float32))
+
+ # Convert to tensors (C, H, W)
+ static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
+ target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1)
+
+ # Pad target to 4 channels (RGBA)
+ target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
+
+ return static_feat, target
+
+
+class ImagePairDataset(Dataset):
+ """Dataset of input/target image pairs (full-image mode)."""
+
+ def __init__(self, input_dir, target_dir, target_size=(256, 256)):
+ self.input_paths = sorted(Path(input_dir).glob("*.png"))
+ self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ self.target_size = target_size
+ assert len(self.input_paths) == len(self.target_paths), \
+ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
+
+ def __len__(self):
+ return len(self.input_paths)
+
+ def __getitem__(self, idx):
+ # Load and resize images to fixed size
+ input_pil = Image.open(self.input_paths[idx]).convert('RGB')
+ target_pil = Image.open(self.target_paths[idx]).convert('RGB')
+
+ # Resize to target size
+ input_pil = input_pil.resize(self.target_size, Image.LANCZOS)
+ target_pil = target_pil.resize(self.target_size, Image.LANCZOS)
+
+ input_img = np.array(input_pil) / 255.0
+ target_img = np.array(target_pil) / 255.0
+
+ # Compute static features
+ static_feat = compute_static_features(input_img.astype(np.float32))
+
+ # Convert to tensors (C, H, W)
+ static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
+ target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1)
+
+ # Pad target to 4 channels (RGBA)
+ target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
+
+ return static_feat, target
+
+
+def train(args):
+ """Train CNN v2 model."""
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Training on {device}")
+
+ # Create dataset (patch-based or full-image)
+ if args.full_image:
+ print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})")
+ target_size = (args.image_size, args.image_size)
+ dataset = ImagePairDataset(args.input, args.target, target_size=target_size)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+ else:
+ print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)")
+ dataset = PatchDataset(args.input, args.target,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ detector=args.detector)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+
+ # Create model
+ model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device)
+ total_params = sum(p.numel() for p in model.parameters())
+ print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights")
+
+ # Optimizer and loss
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+
+ # Training loop
+ print(f"\nTraining for {args.epochs} epochs...")
+ start_time = time.time()
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ epoch_loss = 0.0
+
+ for static_feat, target in dataloader:
+ static_feat = static_feat.to(device)
+ target = target.to(device)
+
+ optimizer.zero_grad()
+ output = model(static_feat)
+ loss = criterion(output, target)
+ loss.backward()
+ optimizer.step()
+
+ epoch_loss += loss.item()
+
+ avg_loss = epoch_loss / len(dataloader)
+
+ # Print loss at every epoch (overwrite line with \r)
+ elapsed = time.time() - start_time
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True)
+
+ # Save checkpoint
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print() # Newline before checkpoint message
+ checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
+ torch.save({
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'config': {
+ 'kernels': args.kernel_sizes,
+ 'channels': args.channels,
+ 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias']
+ }
+ }, checkpoint_path)
+ print(f" → Saved checkpoint: {checkpoint_path}")
+
+ print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
+ parser.add_argument('--input', type=str, required=True, help='Input images directory')
+ parser.add_argument('--target', type=str, required=True, help='Target images directory')
+
+ # Training mode
+ parser.add_argument('--full-image', action='store_true',
+ help='Use full-image mode (resize all images)')
+ parser.add_argument('--image-size', type=int, default=256,
+ help='Full-image mode: resize to this size (default: 256)')
+
+ # Patch-based mode (default)
+ parser.add_argument('--patch-size', type=int, default=32,
+ help='Patch mode: patch size (default: 32)')
+ parser.add_argument('--patches-per-image', type=int, default=64,
+ help='Patch mode: patches per image (default: 64)')
+ parser.add_argument('--detector', type=str, default='harris',
+ choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
+ help='Patch mode: salient point detector (default: harris)')
+ # TODO: Add --random-sample-percent parameter (default: 10)
+ # Mix salient points with random samples for better generalization
+
+ # Model architecture
+ parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],
+ help='Kernel sizes for 3 layers (default: 1 3 5)')
+ parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4],
+ help='Output channels for 3 layers (default: 16 8 4)')
+
+ # Training parameters
+ parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
+ parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
+ parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
+ help='Checkpoint directory')
+ parser.add_argument('--checkpoint-every', type=int, default=1000,
+ help='Save checkpoint every N epochs (0 = disable)')
+
+ args = parser.parse_args()
+ train(args)
+
+
+if __name__ == '__main__':
+ main()