summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_shader.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/export_cnn_v2_shader.py')
-rwxr-xr-xtraining/export_cnn_v2_shader.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
index dc475d8..1c74ad0 100755
--- a/training/export_cnn_v2_shader.py
+++ b/training/export_cnn_v2_shader.py
@@ -14,7 +14,7 @@ import torch
from pathlib import Path
-def export_layer_shader(layer_idx, weights, kernel_size, output_dir, is_output_layer=False):
+def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False):
"""Generate WGSL compute shader for a single CNN layer.
Args:
@@ -22,6 +22,7 @@ def export_layer_shader(layer_idx, weights, kernel_size, output_dir, is_output_l
weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
kernel_size: Kernel size (3, 5, etc.)
output_dir: Output directory path
+ mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.)
is_output_layer: True if this is the final RGBA output layer
"""
weights_flat = weights.flatten()
@@ -44,6 +45,7 @@ def export_layer_shader(layer_idx, weights, kernel_size, output_dir, is_output_l
shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
+// Mip level: {mip_level} (p0-p3 features)
const KERNEL_SIZE: u32 = {kernel_size}u;
const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static)
@@ -164,10 +166,12 @@ def export_checkpoint(checkpoint_path, output_dir):
kernel_size = config.get('kernel_size', 3)
num_layers = config.get('num_layers', 3)
+ mip_level = config.get('mip_level', 0)
print(f"Configuration:")
print(f" Kernel size: {kernel_size}×{kernel_size}")
print(f" Layers: {num_layers}")
+ print(f" Mip level: {mip_level} (p0-p3 features)")
print(f" Architecture: uniform 12D→4D")
output_dir = Path(output_dir)
@@ -189,6 +193,7 @@ def export_checkpoint(checkpoint_path, output_dir):
weights=layer_weights,
kernel_size=kernel_size,
output_dir=output_dir,
+ mip_level=mip_level,
is_output_layer=is_output
)