diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-11 17:23:07 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-11 17:23:07 +0100 |
| commit | 46e0935ba3b241dcd0e965e492ef8fa270b537ea (patch) | |
| tree | a277ce8d7daa0ca6fa7cca7f757366a70767b7cc /tools/cnn_test.cc | |
| parent | 606a3e8027e901b5a3f9e68444d931982080bdd9 (diff) | |
Diffstat (limited to 'tools/cnn_test.cc')
| -rw-r--r-- | tools/cnn_test.cc | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc index 39ed436..3c96800 100644 --- a/tools/cnn_test.cc +++ b/tools/cnn_test.cc @@ -42,6 +42,7 @@ struct Args { float blend = 1.0f; bool output_png = true; // Default to PNG const char* save_intermediates = nullptr; + int num_layers = 3; // Default to 3 layers }; // Parse command-line arguments @@ -73,6 +74,12 @@ static bool parse_args(int argc, char** argv, Args* args) { } } else if (strcmp(argv[i], "--save-intermediates") == 0 && i + 1 < argc) { args->save_intermediates = argv[++i]; + } else if (strcmp(argv[i], "--layers") == 0 && i + 1 < argc) { + args->num_layers = atoi(argv[++i]); + if (args->num_layers < 1 || args->num_layers > 10) { + fprintf(stderr, "Error: layers must be in range [1, 10]\n"); + return false; + } } else if (strcmp(argv[i], "--help") == 0) { return false; } else { @@ -90,6 +97,7 @@ static void print_usage(const char* prog) { fprintf(stderr, "\nOPTIONS:\n"); fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n"); fprintf(stderr, " --format ppm|png Output format (default: png)\n"); + fprintf(stderr, " --layers N Number of CNN layers (1-10, default: 3)\n"); fprintf(stderr, " --save-intermediates DIR Save intermediate layers to directory\n"); fprintf(stderr, " --help Show this help\n"); } @@ -360,8 +368,8 @@ int main(int argc, char** argv) { WGPUSampler sampler = SamplerCache::Get().get_or_create(device, SamplerCache::clamp()); - // Multi-layer processing (fixed 3 layers) - const int NUM_LAYERS = 3; + // Multi-layer processing + const int NUM_LAYERS = args.num_layers; int dst_idx = 0; // Index of texture to render to // First layer reads from input, subsequent layers read from previous output |
