summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_shader.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 12:32:36 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 12:32:36 +0100
commit561d1dc446db7d1d3e02b92b43abedf1a5017850 (patch)
treeef9302dc1f9b6b9f8a12225580f2a3b07602656b /training/export_cnn_v2_shader.py
parentc27b34279c0d1c2a8f1dbceb0e154b585b5c6916 (diff)
CNN v2: Refactor to uniform 12D→4D architecture
**Architecture changes:** - Static features (8D): p0-p3 (parametric) + uv_x, uv_y, sin(10×uv_x), bias - Input RGBD (4D): fed separately to all layers - All layers: uniform 12D→4D (4 prev/input + 8 static → 4 output) - Bias integrated in static features (bias=False in PyTorch) **Weight calculations:** - 3 layers × (12 × 3×3 × 4) = 1296 weights - f16: 2.6 KB (vs old variable arch: ~6.4 KB) **Updated files:** *Training (Python):* - train_cnn_v2.py: Uniform model, takes input_rgbd + static_features - export_cnn_v2_weights.py: Binary export for storage buffers - export_cnn_v2_shader.py: Per-layer shader export (debugging) *Shaders (WGSL):* - cnn_v2_static.wgsl: p0-p3 parametric features (mips/gradients) - cnn_v2_compute.wgsl: 12D input, 4D output, vec4 packing *Tools:* - HTML tool (cnn_v2_test): Updated for 12D→4D, layer visualization *Docs:* - CNN_V2.md: Updated architecture, training, validation sections - HOWTO.md: Reference HTML tool for validation *Removed:* - validate_cnn_v2.sh: Obsolete (used CNN v1 tool) All code consistent with bias=False (bias in static features as 1.0). handoff(Claude): CNN v2 architecture finalized and documented
Diffstat (limited to 'training/export_cnn_v2_shader.py')
-rwxr-xr-xtraining/export_cnn_v2_shader.py127
1 files changed, 54 insertions, 73 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
index add28d2..ad5749c 100755
--- a/training/export_cnn_v2_shader.py
+++ b/training/export_cnn_v2_shader.py
@@ -1,8 +1,11 @@
#!/usr/bin/env python3
-"""CNN v2 Shader Export Script
+"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture
Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
Generates one shader per layer with embedded weight arrays.
+
+Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size.
+ This script is for debugging/testing with per-layer shaders.
"""
import argparse
@@ -11,16 +14,13 @@ import torch
from pathlib import Path
-def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels,
- output_dir, is_output_layer=False):
+def export_layer_shader(layer_idx, weights, kernel_size, output_dir, is_output_layer=False):
"""Generate WGSL compute shader for a single CNN layer.
Args:
- layer_idx: Layer index (0, 1, 2)
- weights: (out_ch, in_ch, k, k) weight tensor
- kernel_size: Kernel size (1, 3, 5, etc.)
- in_channels: Input channels (includes 8D static features)
- out_channels: Output channels
+ layer_idx: Layer index (0, 1, 2, ...)
+ weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
+ kernel_size: Kernel size (3, 5, etc.)
output_dir: Output directory path
is_output_layer: True if this is the final RGBA output layer
"""
@@ -39,12 +39,12 @@ def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channe
if is_output_layer:
activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation"
- shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated
-// Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels}
+ shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
+// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
const KERNEL_SIZE: u32 = {kernel_size}u;
-const IN_CHANNELS: u32 = {in_channels}u;
-const OUT_CHANNELS: u32 = {out_channels}u;
+const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static)
+const OUT_CHANNELS: u32 = 4u; // Uniform output
const KERNEL_RADIUS: i32 = {radius};
// Weights quantized to float16 (stored as f32 in WGSL)
@@ -65,21 +65,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {{
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {{
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {{
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}}
@@ -95,9 +93,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
// Load static features (always available)
let static_feat = unpack_static_features(coord);
- // Convolution
- var output: array<f32, OUT_CHANNELS>;
- for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {{
var sum: f32 = 0.0;
for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
@@ -110,28 +108,27 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
clamp(sample_coord.y, 0, i32(dims.y) - 1)
);
- // Load input features
+ // Load features at this spatial location
let static_local = unpack_static_features(clamped);
- let layer_local = unpack_layer_channels(clamped);
+ let layer_local = unpack_layer_channels(clamped); // 4D
// Weight index calculation
let ky_idx = u32(ky + KERNEL_RADIUS);
let kx_idx = u32(kx + KERNEL_RADIUS);
let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
- // Accumulate: static features (8D)
- for (var i: u32 = 0u; i < 8u; i++) {{
- let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ // Accumulate: previous/input channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * static_local[i];
+ sum += weights[w_idx] * layer_local[i];
}}
- // Accumulate: layer input channels (if layer_idx > 0)
- let prev_channels = IN_CHANNELS - 8u;
- for (var i: u32 = 0u; i < prev_channels; i++) {{
- let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
- (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * layer_local[i];
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
+ (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
}}
}}
}}
@@ -162,53 +159,37 @@ def export_checkpoint(checkpoint_path, output_dir):
state_dict = checkpoint['model_state_dict']
config = checkpoint['config']
+ kernel_size = config.get('kernel_size', 3)
+ num_layers = config.get('num_layers', 3)
+
print(f"Configuration:")
- print(f" Kernels: {config['kernels']}")
- print(f" Channels: {config['channels']}")
- print(f" Features: {config['features']}")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Layers: {num_layers}")
+ print(f" Architecture: uniform 12D→4D")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\nExporting shaders to {output_dir}/")
- # Layer 0: 8 → channels[0]
- layer0_weights = state_dict['layer0.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=0,
- weights=layer0_weights,
- kernel_size=config['kernels'][0],
- in_channels=8,
- out_channels=config['channels'][0],
- output_dir=output_dir,
- is_output_layer=False
- )
+ # All layers uniform: 12D→4D
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
- # Layer 1: (8 + channels[0]) → channels[1]
- layer1_weights = state_dict['layer1.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=1,
- weights=layer1_weights,
- kernel_size=config['kernels'][1],
- in_channels=8 + config['channels'][0],
- out_channels=config['channels'][1],
- output_dir=output_dir,
- is_output_layer=False
- )
+ layer_weights = state_dict[layer_key].detach().numpy()
+ is_output = (i == num_layers - 1)
- # Layer 2: (8 + channels[1]) → 4 (RGBA)
- layer2_weights = state_dict['layer2.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=2,
- weights=layer2_weights,
- kernel_size=config['kernels'][2],
- in_channels=8 + config['channels'][1],
- out_channels=4,
- output_dir=output_dir,
- is_output_layer=True
- )
+ export_layer_shader(
+ layer_idx=i,
+ weights=layer_weights,
+ kernel_size=kernel_size,
+ output_dir=output_dir,
+ is_output_layer=is_output
+ )
- print(f"\nExport complete! Generated 3 shader files.")
+ print(f"\nExport complete! Generated {num_layers} shader files.")
def main():