summaryrefslogtreecommitdiff
path: root/workspaces/main/shaders
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 12:48:43 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 12:48:43 +0100
commit6944733a6a2f05c18e7e0b73f847a4c9144801fd (patch)
tree10713cd41a0e038a016a2e6b357471690f232834 /workspaces/main/shaders
parentcc9cbeb75353181193e3afb880dc890aa8bf8985 (diff)
feat: Add multi-layer CNN support with framebuffer capture and blend control
Implements automatic layer chaining and generic framebuffer capture API for multi-layer neural network effects with proper original input preservation. Key changes: - Effect::needs_framebuffer_capture() - generic API for pre-render capture - MainSequence: auto-capture to "captured_frame" auxiliary texture - CNNEffect: multi-layer support via layer_index/total_layers params - seq_compiler: expands "layers=N" to N chained effect instances - Shader: @binding(4) original_input available to all layers - Training: generates layer switches and original input binding - Blend: mix(original, result, blend_amount) uses layer 0 input Timeline syntax: CNNEffect layers=3 blend=0.7 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'workspaces/main/shaders')
-rw-r--r--workspaces/main/shaders/cnn/cnn_layer.wgsl25
-rw-r--r--workspaces/main/shaders/cnn/cnn_weights_generated.wgsl194
2 files changed, 196 insertions, 23 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl
index b2bab26..2285ef9 100644
--- a/workspaces/main/shaders/cnn/cnn_layer.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl
@@ -1,5 +1,6 @@
// CNN layer shader - uses modular convolution snippets
// Supports multi-pass rendering with residual connections
+// DO NOT EDIT - Generated by train_cnn.py
@group(0) @binding(0) var smplr: sampler;
@group(0) @binding(1) var txt: texture_2d<f32>;
@@ -11,12 +12,13 @@
struct CNNLayerParams {
layer_index: i32,
- use_residual: i32,
+ blend_amount: f32,
_pad: vec2<f32>,
};
@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;
@group(0) @binding(3) var<uniform> params: CNNLayerParams;
+@group(0) @binding(4) var original_input: texture_2d<f32>;
@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {
var pos = array<vec2<f32>, 3>(
@@ -27,6 +29,8 @@ struct CNNLayerParams {
@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {
let uv = p.xy / uniforms.resolution;
+ let input = textureSample(txt, smplr, uv);
+ let original = textureSample(original_input, smplr, uv);
var result = vec4<f32>(0.0);
// Layer 0 uses coordinate-aware convolution
@@ -35,12 +39,19 @@ struct CNNLayerParams {
rgba_weights_layer0, coord_weights_layer0, bias_layer0);
result = cnn_tanh(result);
}
-
- // Residual connection
- if (params.use_residual != 0) {
- let input = textureSample(txt, smplr, uv);
- result = input + result * 0.3;
+ else if (params.layer_index == 1) {
+ result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution,
+ weights_layer1, bias_layer1);
+ result = cnn_tanh(result);
+ }
+ else if (params.layer_index == 2) {
+ result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution,
+ weights_layer2, bias_layer2);
+ }
+ else {
+ result = input;
}
- return result;
+ // Blend with ORIGINAL input from layer 0
+ return mix(original, result, params.blend_amount);
}
diff --git a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
index e0a7dc4..6052ac5 100644
--- a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
@@ -1,23 +1,185 @@
-// Generated CNN weights and biases
-// DO NOT EDIT MANUALLY - regenerate with scripts/train_cnn.py
+// Auto-generated CNN weights
+// DO NOT EDIT - Generated by train_cnn.py
-// Placeholder identity-like weights for initial testing
-// Layer 0: 3x3 convolution with coordinate awareness
const rgba_weights_layer0: array<mat4x4<f32>, 9> = array(
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
- mat4x4<f32>(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0),
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
- mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
+ mat4x4<f32>(
+ -0.181929, -0.244329, -0.354404, 0.0,
+ -0.291597, -0.195653, 0.081896, 0.0,
+ 0.081595, 0.164081, -0.236318, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 0.731888, 0.717648, 0.524081, 0.0,
+ -0.029760, -0.208000, 0.008438, 0.0,
+ 0.442082, 0.354681, 0.049288, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.623141, -0.695759, -0.087885, 0.0,
+ 0.043135, 0.071979, 0.213065, 0.0,
+ 0.011581, 0.110995, 0.034100, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 0.170016, 0.188298, 0.134083, 0.0,
+ -0.222954, -0.088011, 0.015668, 0.0,
+ 0.921836, 0.437158, 0.061577, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 1.431940, 1.148113, 1.238067, 0.0,
+ -0.212535, 0.366860, 0.320956, 0.0,
+ 0.771192, 0.765570, 0.029189, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 0.171088, 0.000155, 0.212552, 0.0,
+ 0.029536, 0.447892, 0.041381, 0.0,
+ 0.011807, -0.167281, -0.200702, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.668151, -0.813927, -0.132108, 0.0,
+ -0.156250, 0.179112, -0.069585, 0.0,
+ 0.403347, 0.482877, 0.182611, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.609871, -0.768480, -0.590538, 0.0,
+ -0.171854, 0.150167, 0.105694, 0.0,
+ -0.059052, 0.066999, -0.244222, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.112983, -0.066299, 0.117696, 0.0,
+ -0.172541, 0.095008, -0.160754, 0.0,
+ -0.369667, -0.000628, 0.163602, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ )
);
const coord_weights_layer0 = mat2x4<f32>(
- 0.0, 0.0, 0.0, 0.0,
- 0.0, 0.0, 0.0, 0.0
+ 0.059076, -0.026617, -0.005155, 0.0,
+ 0.135407, -0.090329, 0.058216, 0.0
);
-const bias_layer0 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+const bias_layer0 = vec4<f32>(-0.526177, -0.569862, -1.370040, 0.0);
+
+const weights_layer1: array<mat4x4<f32>, 9> = array(
+ mat4x4<f32>(
+ 0.180029, -1.107249, 0.570741, 0.0,
+ -0.098536, 0.079545, -0.083257, 0.0,
+ -0.020066, 0.333084, 0.039506, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 3.068946, -1.783570, -0.550517, 0.0,
+ -0.296369, -0.080958, 0.040260, 0.0,
+ -0.093713, -0.212577, -0.110011, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 2.282564, -0.538192, -0.793214, 0.0,
+ -0.395788, 0.130881, 0.078571, 0.0,
+ -0.041375, 0.061666, 0.045651, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.267284, -1.971639, -0.099616, 0.0,
+ -0.084432, 0.139794, 0.007091, 0.0,
+ -0.103042, -0.104340, 0.067299, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -5.233469, -2.252747, -3.555217, 0.0,
+ 0.647940, -0.178858, 0.351633, 0.0,
+ -0.014237, -0.505881, 0.165940, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.121700, -0.677386, -2.435040, 0.0,
+ 0.084806, -0.028000, 0.380387, 0.0,
+ -0.020906, -0.279161, 0.041915, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 2.982562, -0.298441, -0.147775, 0.0,
+ -0.291832, 0.102875, -0.128590, 0.0,
+ -0.091786, 0.104389, -0.188678, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -4.434978, -0.261830, -2.436411, 0.0,
+ 0.349188, -0.245908, 0.272592, 0.0,
+ 0.010322, -0.148525, -0.031531, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ 0.129886, 1.516168, -0.755576, 0.0,
+ 0.133138, -0.260276, 0.028059, 0.0,
+ 0.001185, 0.141547, -0.003606, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ )
+);
+
+const bias_layer1 = vec4<f32>(1.367986, -1.148709, -0.650040, 0.0);
+
+const weights_layer2: array<mat4x4<f32>, 9> = array(
+ mat4x4<f32>(
+ -0.137003, -0.289376, 0.625000, 0.0,
+ -0.120120, -0.238968, 0.448432, 0.0,
+ -0.142094, -0.253706, 0.458181, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.337017, -0.757585, 0.135953, 0.0,
+ -0.304432, -0.553491, 0.419907, 0.0,
+ -0.313585, -0.467667, 0.615326, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.161089, -0.328735, 0.612679, 0.0,
+ -0.137144, -0.172882, 0.176362, 0.0,
+ -0.153195, -0.061571, 0.173977, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.227814, -0.544193, -0.564658, 0.0,
+ -0.211743, -0.430586, 0.080349, 0.0,
+ -0.214442, -0.417501, 0.880266, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.435370, -0.295169, -0.865976, 0.0,
+ -0.423147, -0.274780, 0.323049, 0.0,
+ -0.411180, -0.062517, 1.099769, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.199573, -0.488030, -0.396440, 0.0,
+ -0.187844, -0.360516, -0.156646, 0.0,
+ -0.188681, -0.292304, -0.134645, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.123218, -0.287990, 0.154656, 0.0,
+ -0.112954, -0.282778, 0.498742, 0.0,
+ -0.139083, -0.319337, 1.112621, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.267477, -0.691374, -0.028960, 0.0,
+ -0.246348, -0.585583, 0.401194, 0.0,
+ -0.253279, -0.562875, 1.105818, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ ),
+ mat4x4<f32>(
+ -0.083133, -0.131627, 0.460039, 0.0,
+ -0.071126, -0.108601, 0.163545, 0.0,
+ -0.092579, -0.110020, 0.131282, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ )
+);
+
+const bias_layer2 = vec4<f32>(-1.805686, -0.798340, 0.462318, 0.0);
+