summaryrefslogtreecommitdiff
path: root/tools/cnn_test.cc
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-11 17:23:07 +0100
committerskal <pascal.massimino@gmail.com>2026-02-11 17:23:07 +0100
commit46e0935ba3b241dcd0e965e492ef8fa270b537ea (patch)
treea277ce8d7daa0ca6fa7cca7f757366a70767b7cc /tools/cnn_test.cc
parent606a3e8027e901b5a3f9e68444d931982080bdd9 (diff)
update cnn codeHEADmain
Diffstat (limited to 'tools/cnn_test.cc')
-rw-r--r--tools/cnn_test.cc12
1 files changed, 10 insertions, 2 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc
index 39ed436..3c96800 100644
--- a/tools/cnn_test.cc
+++ b/tools/cnn_test.cc
@@ -42,6 +42,7 @@ struct Args {
float blend = 1.0f;
bool output_png = true; // Default to PNG
const char* save_intermediates = nullptr;
+ int num_layers = 3; // Default to 3 layers
};
// Parse command-line arguments
@@ -73,6 +74,12 @@ static bool parse_args(int argc, char** argv, Args* args) {
}
} else if (strcmp(argv[i], "--save-intermediates") == 0 && i + 1 < argc) {
args->save_intermediates = argv[++i];
+ } else if (strcmp(argv[i], "--layers") == 0 && i + 1 < argc) {
+ args->num_layers = atoi(argv[++i]);
+ if (args->num_layers < 1 || args->num_layers > 10) {
+ fprintf(stderr, "Error: layers must be in range [1, 10]\n");
+ return false;
+ }
} else if (strcmp(argv[i], "--help") == 0) {
return false;
} else {
@@ -90,6 +97,7 @@ static void print_usage(const char* prog) {
fprintf(stderr, "\nOPTIONS:\n");
fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n");
fprintf(stderr, " --format ppm|png Output format (default: png)\n");
+ fprintf(stderr, " --layers N Number of CNN layers (1-10, default: 3)\n");
fprintf(stderr, " --save-intermediates DIR Save intermediate layers to directory\n");
fprintf(stderr, " --help Show this help\n");
}
@@ -360,8 +368,8 @@ int main(int argc, char** argv) {
WGPUSampler sampler =
SamplerCache::Get().get_or_create(device, SamplerCache::clamp());
- // Multi-layer processing (fixed 3 layers)
- const int NUM_LAYERS = 3;
+ // Multi-layer processing
+ const int NUM_LAYERS = args.num_layers;
int dst_idx = 0; // Index of texture to render to
// First layer reads from input, subsequent layers read from previous output