diff options
Diffstat (limited to 'training/export_cnn_v2_weights.py')
| -rwxr-xr-x | training/export_cnn_v2_weights.py | 52 |
1 files changed, 32 insertions, 20 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py index 1086516..f64bd8d 100755 --- a/training/export_cnn_v2_weights.py +++ b/training/export_cnn_v2_weights.py @@ -12,7 +12,7 @@ import struct from pathlib import Path -def export_weights_binary(checkpoint_path, output_path): +def export_weights_binary(checkpoint_path, output_path, quiet=False): """Export CNN v2 weights to binary format. Binary format: @@ -40,7 +40,8 @@ def export_weights_binary(checkpoint_path, output_path): Returns: config dict for shader generation """ - print(f"Loading checkpoint: {checkpoint_path}") + if not quiet: + print(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint['model_state_dict'] @@ -59,11 +60,12 @@ def export_weights_binary(checkpoint_path, output_path): num_layers = config.get('num_layers', len(kernel_sizes)) mip_level = config.get('mip_level', 0) - print(f"Configuration:") - print(f" Kernel sizes: {kernel_sizes}") - print(f" Layers: {num_layers}") - print(f" Mip level: {mip_level} (p0-p3 features)") - print(f" Architecture: uniform 12D→4D (bias=False)") + if not quiet: + print(f"Configuration:") + print(f" Kernel sizes: {kernel_sizes}") + print(f" Layers: {num_layers}") + print(f" Mip level: {mip_level} (p0-p3 features)") + print(f" Architecture: uniform 12D→4D (bias=False)") # Collect layer info - all layers uniform 12D→4D layers = [] @@ -89,7 +91,8 @@ def export_weights_binary(checkpoint_path, output_path): all_weights.extend(layer_flat) weight_offset += len(layer_flat) - print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights") + if not quiet: + print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights") # Convert to f16 # TODO: Use 8-bit quantization for 2× size reduction @@ -104,11 +107,13 @@ def export_weights_binary(checkpoint_path, output_path): # Pack pairs using numpy view weights_u32 = all_weights_f16.view(np.uint32) - print(f"\nWeight statistics:") - print(f" Total layers: {len(layers)}") - print(f" Total weights: {len(all_weights_f16)} (f16)") - print(f" Packed: {len(weights_u32)} u32") - print(f" Binary size: {20 + len(layers) * 20 + len(weights_u32) * 4} bytes") + binary_size = 20 + len(layers) * 20 + len(weights_u32) * 4 + if not quiet: + print(f"\nWeight statistics:") + print(f" Total layers: {len(layers)}") + print(f" Total weights: {len(all_weights_f16)} (f16)") + print(f" Packed: {len(weights_u32)} u32") + print(f" Binary size: {binary_size} bytes") # Write binary file output_path = Path(output_path) @@ -135,7 +140,10 @@ def export_weights_binary(checkpoint_path, output_path): # Weights (u32 packed f16 pairs) f.write(weights_u32.tobytes()) - print(f" → {output_path}") + if quiet: + print(f" Exported {num_layers} layers, {len(all_weights_f16)} weights, {binary_size} bytes → {output_path}") + else: + print(f" → {output_path}") return { 'num_layers': len(layers), @@ -257,15 +265,19 @@ def main(): help='Output binary weights file') parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders', help='Output directory for shader template') + parser.add_argument('--quiet', action='store_true', + help='Suppress detailed output') args = parser.parse_args() - print("=== CNN v2 Weight Export ===\n") - config = export_weights_binary(args.checkpoint, args.output_weights) - print() - # Shader is manually maintained in cnn_v2_compute.wgsl - # export_shader_template(config, args.output_shader) - print("\nExport complete!") + if not args.quiet: + print("=== CNN v2 Weight Export ===\n") + config = export_weights_binary(args.checkpoint, args.output_weights, quiet=args.quiet) + if not args.quiet: + print() + # Shader is manually maintained in cnn_v2_compute.wgsl + # export_shader_template(config, args.output_shader) + print("\nExport complete!") if __name__ == '__main__': |
