summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_weights.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/export_cnn_v2_weights.py')
-rwxr-xr-xtraining/export_cnn_v2_weights.py52
1 files changed, 32 insertions, 20 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
index 1086516..f64bd8d 100755
--- a/training/export_cnn_v2_weights.py
+++ b/training/export_cnn_v2_weights.py
@@ -12,7 +12,7 @@ import struct
from pathlib import Path
-def export_weights_binary(checkpoint_path, output_path):
+def export_weights_binary(checkpoint_path, output_path, quiet=False):
"""Export CNN v2 weights to binary format.
Binary format:
@@ -40,7 +40,8 @@ def export_weights_binary(checkpoint_path, output_path):
Returns:
config dict for shader generation
"""
- print(f"Loading checkpoint: {checkpoint_path}")
+ if not quiet:
+ print(f"Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint['model_state_dict']
@@ -59,11 +60,12 @@ def export_weights_binary(checkpoint_path, output_path):
num_layers = config.get('num_layers', len(kernel_sizes))
mip_level = config.get('mip_level', 0)
- print(f"Configuration:")
- print(f" Kernel sizes: {kernel_sizes}")
- print(f" Layers: {num_layers}")
- print(f" Mip level: {mip_level} (p0-p3 features)")
- print(f" Architecture: uniform 12D→4D (bias=False)")
+ if not quiet:
+ print(f"Configuration:")
+ print(f" Kernel sizes: {kernel_sizes}")
+ print(f" Layers: {num_layers}")
+ print(f" Mip level: {mip_level} (p0-p3 features)")
+ print(f" Architecture: uniform 12D→4D (bias=False)")
# Collect layer info - all layers uniform 12D→4D
layers = []
@@ -89,7 +91,8 @@ def export_weights_binary(checkpoint_path, output_path):
all_weights.extend(layer_flat)
weight_offset += len(layer_flat)
- print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights")
+ if not quiet:
+ print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights")
# Convert to f16
# TODO: Use 8-bit quantization for 2× size reduction
@@ -104,11 +107,13 @@ def export_weights_binary(checkpoint_path, output_path):
# Pack pairs using numpy view
weights_u32 = all_weights_f16.view(np.uint32)
- print(f"\nWeight statistics:")
- print(f" Total layers: {len(layers)}")
- print(f" Total weights: {len(all_weights_f16)} (f16)")
- print(f" Packed: {len(weights_u32)} u32")
- print(f" Binary size: {20 + len(layers) * 20 + len(weights_u32) * 4} bytes")
+ binary_size = 20 + len(layers) * 20 + len(weights_u32) * 4
+ if not quiet:
+ print(f"\nWeight statistics:")
+ print(f" Total layers: {len(layers)}")
+ print(f" Total weights: {len(all_weights_f16)} (f16)")
+ print(f" Packed: {len(weights_u32)} u32")
+ print(f" Binary size: {binary_size} bytes")
# Write binary file
output_path = Path(output_path)
@@ -135,7 +140,10 @@ def export_weights_binary(checkpoint_path, output_path):
# Weights (u32 packed f16 pairs)
f.write(weights_u32.tobytes())
- print(f" → {output_path}")
+ if quiet:
+ print(f" Exported {num_layers} layers, {len(all_weights_f16)} weights, {binary_size} bytes → {output_path}")
+ else:
+ print(f" → {output_path}")
return {
'num_layers': len(layers),
@@ -257,15 +265,19 @@ def main():
help='Output binary weights file')
parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders',
help='Output directory for shader template')
+ parser.add_argument('--quiet', action='store_true',
+ help='Suppress detailed output')
args = parser.parse_args()
- print("=== CNN v2 Weight Export ===\n")
- config = export_weights_binary(args.checkpoint, args.output_weights)
- print()
- # Shader is manually maintained in cnn_v2_compute.wgsl
- # export_shader_template(config, args.output_shader)
- print("\nExport complete!")
+ if not args.quiet:
+ print("=== CNN v2 Weight Export ===\n")
+ config = export_weights_binary(args.checkpoint, args.output_weights, quiet=args.quiet)
+ if not args.quiet:
+ print()
+ # Shader is manually maintained in cnn_v2_compute.wgsl
+ # export_shader_template(config, args.output_shader)
+ print("\nExport complete!")
if __name__ == '__main__':