summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-11 16:06:16 +0100
committerskal <pascal.massimino@gmail.com>2026-02-11 16:06:16 +0100
commit09eba6004eb5faa5273e310ca560bfd41e1bc901 (patch)
tree670ee9a4e5bdbe91a84bec459bbec475d33e3414
parent3d2ff01e45bf0229d609ffdf84080f0b722f1f24 (diff)
fix: Register cnn_conv1x1 snippet and add verification
- Add cnn_conv1x1 to shader composer registration - Add VerifyIncludes() to detect missing snippet registrations - STRIP_ALL-protected verification warns about unregistered includes - Fixes cnn_test runtime failure loading cnn_layer.wgsl Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rw-r--r--src/gpu/effects/shader_composer.cc28
-rw-r--r--src/gpu/effects/shader_composer.h3
-rw-r--r--src/gpu/effects/shaders.cc5
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv1x1.wgsl100
4 files changed, 136 insertions, 0 deletions
diff --git a/src/gpu/effects/shader_composer.cc b/src/gpu/effects/shader_composer.cc
index b746f8b..fe3ad74 100644
--- a/src/gpu/effects/shader_composer.cc
+++ b/src/gpu/effects/shader_composer.cc
@@ -86,3 +86,31 @@ ShaderComposer::Compose(const std::vector<std::string>& dependencies,
return ss.str();
}
+
+void ShaderComposer::VerifyIncludes() const {
+#if !defined(STRIP_ALL)
+ std::set<std::string> missing;
+ for (const auto& [name, code] : snippets_) {
+ std::istringstream stream(code);
+ std::string line;
+ while (std::getline(stream, line)) {
+ if (line.compare(0, 9, "#include ") == 0) {
+ size_t start = line.find('"');
+ size_t end = line.find('"', start + 1);
+ if (start != std::string::npos && end != std::string::npos) {
+ std::string included = line.substr(start + 1, end - start - 1);
+ if (snippets_.find(included) == snippets_.end()) {
+ missing.insert(included);
+ }
+ }
+ }
+ }
+ }
+ if (!missing.empty()) {
+ fprintf(stderr, "WARNING: Unregistered shader snippets:\n");
+ for (const auto& name : missing) {
+ fprintf(stderr, " - %s\n", name.c_str());
+ }
+ }
+#endif
+}
diff --git a/src/gpu/effects/shader_composer.h b/src/gpu/effects/shader_composer.h
index 9eb43f4..d0972f2 100644
--- a/src/gpu/effects/shader_composer.h
+++ b/src/gpu/effects/shader_composer.h
@@ -24,6 +24,9 @@ class ShaderComposer {
const std::string& main_code,
const CompositionMap& substitutions = {});
+ // Verify all #include directives reference registered snippets
+ void VerifyIncludes() const;
+
private:
ShaderComposer() = default;
diff --git a/src/gpu/effects/shaders.cc b/src/gpu/effects/shaders.cc
index 5f78298..d79f3d3 100644
--- a/src/gpu/effects/shaders.cc
+++ b/src/gpu/effects/shaders.cc
@@ -53,11 +53,16 @@ void InitShaderComposer() {
register_if_exists("ray_triangle", AssetId::ASSET_SHADER_RAY_TRIANGLE);
register_if_exists("cnn_activation", AssetId::ASSET_SHADER_CNN_ACTIVATION);
+ register_if_exists("cnn_conv1x1", AssetId::ASSET_SHADER_CNN_CONV1X1);
register_if_exists("cnn_conv3x3", AssetId::ASSET_SHADER_CNN_CONV3X3);
register_if_exists("cnn_conv5x5", AssetId::ASSET_SHADER_CNN_CONV5X5);
register_if_exists("cnn_conv7x7", AssetId::ASSET_SHADER_CNN_CONV7X7);
register_if_exists("cnn_weights_generated",
AssetId::ASSET_SHADER_CNN_WEIGHTS);
+
+#if !defined(STRIP_ALL)
+ sc.VerifyIncludes();
+#endif
}
// Helper to get asset string or empty string
diff --git a/workspaces/main/shaders/cnn/cnn_conv1x1.wgsl b/workspaces/main/shaders/cnn/cnn_conv1x1.wgsl
new file mode 100644
index 0000000..d468182
--- /dev/null
+++ b/workspaces/main/shaders/cnn/cnn_conv1x1.wgsl
@@ -0,0 +1,100 @@
+// 1x1 convolution (vec4-optimized)
+
+// Inner layers: 7→4 channels (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+fn cnn_conv1x1_7to4(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 8>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -0; dy <= 0; dy++) {
+ for (var dx = -0; dx <= 0; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
+
+// Source layer: 7→4 channels (vec4-optimized)
+// Normalizes [0,1] input to [-1,1] internally
+fn cnn_conv1x1_7to4_src(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ weights: array<vec4<f32>, 8>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+
+ let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;
+ let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -0; dy <= 0; dy++) {
+ for (var dx = -0; dx <= 0; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
+
+// Final layer: 7→1 channel (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+// Returns raw sum (activation applied at call site)
+fn cnn_conv1x1_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 2>
+) -> f32 {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = 0.0;
+ var pos = 0;
+
+ for (var dy = -0; dy <= 0; dy++) {
+ for (var dx = -0; dx <= 0; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+
+ sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);
+ pos += 2;
+ }
+ }
+
+ return sum;
+}