summaryrefslogtreecommitdiff
path: root/tools/cnn_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tools/cnn_test.cc')
-rw-r--r--tools/cnn_test.cc28
1 files changed, 25 insertions, 3 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc
index 62a60f4..39ed436 100644
--- a/tools/cnn_test.cc
+++ b/tools/cnn_test.cc
@@ -41,6 +41,7 @@ struct Args {
const char* output_path = nullptr;
float blend = 1.0f;
bool output_png = true; // Default to PNG
+ const char* save_intermediates = nullptr;
};
// Parse command-line arguments
@@ -70,6 +71,8 @@ static bool parse_args(int argc, char** argv, Args* args) {
argv[i]);
return false;
}
+ } else if (strcmp(argv[i], "--save-intermediates") == 0 && i + 1 < argc) {
+ args->save_intermediates = argv[++i];
} else if (strcmp(argv[i], "--help") == 0) {
return false;
} else {
@@ -85,9 +88,10 @@ static bool parse_args(int argc, char** argv, Args* args) {
static void print_usage(const char* prog) {
fprintf(stderr, "Usage: %s input.png output.png [OPTIONS]\n", prog);
fprintf(stderr, "\nOPTIONS:\n");
- fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n");
- fprintf(stderr, " --format ppm|png Output format (default: png)\n");
- fprintf(stderr, " --help Show this help\n");
+ fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n");
+ fprintf(stderr, " --format ppm|png Output format (default: png)\n");
+ fprintf(stderr, " --save-intermediates DIR Save intermediate layers to directory\n");
+ fprintf(stderr, " --help Show this help\n");
}
// Load PNG and upload to GPU texture
@@ -485,6 +489,24 @@ int main(int argc, char** argv) {
wgpuRenderPassEncoderRelease(pass);
wgpuCommandEncoderRelease(encoder);
wgpuBindGroupRelease(bind_group);
+
+ // Save intermediate layer if requested
+ if (args.save_intermediates) {
+ char layer_path[512];
+ snprintf(layer_path, sizeof(layer_path), "%s/layer_%d.png",
+ args.save_intermediates, layer);
+ printf("Saving intermediate layer %d to '%s'...\n", layer, layer_path);
+
+ // Readback RGBA16Float texture
+ std::vector<uint8_t> pixels = texture_readback_fp16_to_u8(
+ device, queue, intermediate_textures[dst_idx], width, height);
+
+ if (!pixels.empty()) {
+ save_png(layer_path, pixels, width, height);
+ } else {
+ fprintf(stderr, "Warning: failed to read intermediate layer %d\n", layer);
+ }
+ }
}
// Update for next layer: output becomes input