diff options
Diffstat (limited to 'training/export_cnn_v2_weights.py')
| -rwxr-xr-x | training/export_cnn_v2_weights.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py index 723f572..8a2fcdc 100755 --- a/training/export_cnn_v2_weights.py +++ b/training/export_cnn_v2_weights.py @@ -248,7 +248,7 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { } """ - output_path = Path(output_dir) / "cnn_v2_compute.wgsl" + output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl" output_path.write_text(shader_code) print(f" → {output_path}") @@ -256,7 +256,7 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { def main(): parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format') parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') - parser.add_argument('--output-weights', type=str, default='workspaces/main/cnn_v2_weights.bin', + parser.add_argument('--output-weights', type=str, default='workspaces/main/weights/cnn_v2_weights.bin', help='Output binary weights file') parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders', help='Output directory for shader template') |
