summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_weights.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/export_cnn_v2_weights.py')
-rwxr-xr-xtraining/export_cnn_v2_weights.py85
1 files changed, 34 insertions, 51 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
index 8a2fcdc..07254fc 100755
--- a/training/export_cnn_v2_weights.py
+++ b/training/export_cnn_v2_weights.py
@@ -45,53 +45,38 @@ def export_weights_binary(checkpoint_path, output_path):
state_dict = checkpoint['model_state_dict']
config = checkpoint['config']
+ kernel_size = config.get('kernel_size', 3)
+ num_layers = config.get('num_layers', 3)
+
print(f"Configuration:")
- print(f" Kernels: {config['kernels']}")
- print(f" Channels: {config['channels']}")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Layers: {num_layers}")
+ print(f" Architecture: uniform 12D→4D (bias=False)")
- # Collect layer info
+ # Collect layer info - all layers uniform 12D→4D
layers = []
all_weights = []
weight_offset = 0
- # Layer 0: 8 → channels[0]
- layer0_weights = state_dict['layer0.weight'].detach().numpy()
- layer0_flat = layer0_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][0],
- 'in_channels': 8,
- 'out_channels': config['channels'][0],
- 'weight_offset': weight_offset,
- 'weight_count': len(layer0_flat)
- })
- all_weights.extend(layer0_flat)
- weight_offset += len(layer0_flat)
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
+
+ layer_weights = state_dict[layer_key].detach().numpy()
+ layer_flat = layer_weights.flatten()
- # Layer 1: (8 + channels[0]) → channels[1]
- layer1_weights = state_dict['layer1.weight'].detach().numpy()
- layer1_flat = layer1_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][1],
- 'in_channels': 8 + config['channels'][0],
- 'out_channels': config['channels'][1],
- 'weight_offset': weight_offset,
- 'weight_count': len(layer1_flat)
- })
- all_weights.extend(layer1_flat)
- weight_offset += len(layer1_flat)
+ layers.append({
+ 'kernel_size': kernel_size,
+ 'in_channels': 12, # 4 (input/prev) + 8 (static)
+ 'out_channels': 4, # Uniform output
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer_flat)
+ })
+ all_weights.extend(layer_flat)
+ weight_offset += len(layer_flat)
- # Layer 2: (8 + channels[1]) → 4 (RGBA output)
- layer2_weights = state_dict['layer2.weight'].detach().numpy()
- layer2_flat = layer2_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][2],
- 'in_channels': 8 + config['channels'][1],
- 'out_channels': 4,
- 'weight_offset': weight_offset,
- 'weight_count': len(layer2_flat)
- })
- all_weights.extend(layer2_flat)
- weight_offset += len(layer2_flat)
+ print(f" Layer {i}: 12D→4D, {len(layer_flat)} weights")
# Convert to f16
# TODO: Use 8-bit quantization for 2× size reduction
@@ -183,21 +168,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}
@@ -238,9 +221,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let out_channels = weights[layer0_info_base + 2u];
let weight_offset = weights[layer0_info_base + 3u];
- // Convolution (simplified - expand to full kernel loop)
- var output: array<f32, 8>;
- for (var c: u32 = 0u; c < min(out_channels, 8u); c++) {
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
output[c] = 0.0; // TODO: Actual convolution
}