diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-11 10:51:06 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-11 10:51:06 +0100 |
| commit | 4da0a3a5369142078fd7c681e3f0f1817bd6e2f3 (patch) | |
| tree | d69429d6800dad0bb819f164122df634543796a5 /tools/cnn_test.cc | |
| parent | 7dd1ac57178055aa8407777d4fb03787e21e6f66 (diff) | |
add --save-intermediates to train.py and cnn_test
Diffstat (limited to 'tools/cnn_test.cc')
| -rw-r--r-- | tools/cnn_test.cc | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc index 62a60f4..39ed436 100644 --- a/tools/cnn_test.cc +++ b/tools/cnn_test.cc @@ -41,6 +41,7 @@ struct Args { const char* output_path = nullptr; float blend = 1.0f; bool output_png = true; // Default to PNG + const char* save_intermediates = nullptr; }; // Parse command-line arguments @@ -70,6 +71,8 @@ static bool parse_args(int argc, char** argv, Args* args) { argv[i]); return false; } + } else if (strcmp(argv[i], "--save-intermediates") == 0 && i + 1 < argc) { + args->save_intermediates = argv[++i]; } else if (strcmp(argv[i], "--help") == 0) { return false; } else { @@ -85,9 +88,10 @@ static bool parse_args(int argc, char** argv, Args* args) { static void print_usage(const char* prog) { fprintf(stderr, "Usage: %s input.png output.png [OPTIONS]\n", prog); fprintf(stderr, "\nOPTIONS:\n"); - fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n"); - fprintf(stderr, " --format ppm|png Output format (default: png)\n"); - fprintf(stderr, " --help Show this help\n"); + fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n"); + fprintf(stderr, " --format ppm|png Output format (default: png)\n"); + fprintf(stderr, " --save-intermediates DIR Save intermediate layers to directory\n"); + fprintf(stderr, " --help Show this help\n"); } // Load PNG and upload to GPU texture @@ -485,6 +489,24 @@ int main(int argc, char** argv) { wgpuRenderPassEncoderRelease(pass); wgpuCommandEncoderRelease(encoder); wgpuBindGroupRelease(bind_group); + + // Save intermediate layer if requested + if (args.save_intermediates) { + char layer_path[512]; + snprintf(layer_path, sizeof(layer_path), "%s/layer_%d.png", + args.save_intermediates, layer); + printf("Saving intermediate layer %d to '%s'...\n", layer, layer_path); + + // Readback RGBA16Float texture + std::vector<uint8_t> pixels = texture_readback_fp16_to_u8( + device, queue, intermediate_textures[dst_idx], width, height); + + if (!pixels.empty()) { + save_png(layer_path, pixels, width, height); + } else { + fprintf(stderr, "Warning: failed to read intermediate layer %d\n", layer); + } + } } // Update for next layer: output becomes input |
