diff options
Diffstat (limited to 'workspaces/main')
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl (renamed from workspaces/main/shaders/cnn_v2_compute.wgsl) | 12 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_layer_0.wgsl (renamed from workspaces/main/shaders/cnn_v2_layer_0.wgsl) | 0 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_layer_1.wgsl (renamed from workspaces/main/shaders/cnn_v2_layer_1.wgsl) | 0 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_layer_2.wgsl (renamed from workspaces/main/shaders/cnn_v2_layer_2.wgsl) | 0 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_layer_template.wgsl (renamed from workspaces/main/shaders/cnn_v2_layer_template.wgsl) | 0 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl (renamed from workspaces/main/shaders/cnn_v2_static.wgsl) | 4 |
6 files changed, 14 insertions, 2 deletions
diff --git a/workspaces/main/shaders/cnn_v2_compute.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl index b19a692..1e1704d 100644 --- a/workspaces/main/shaders/cnn_v2_compute.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl @@ -9,6 +9,7 @@ struct LayerParams { out_channels: u32, weight_offset: u32, // Offset in f16 units is_output_layer: u32, // 1 if final layer (sigmoid), 0 otherwise (relu) + blend_amount: f32, // [0,1] blend with original } @group(0) @binding(0) var static_features: texture_2d<u32>; // 8-channel static features @@ -16,6 +17,7 @@ struct LayerParams { @group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // Current layer output @group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights @group(0) @binding(4) var<uniform> params: LayerParams; +@group(0) @binding(5) var original_input: texture_2d<f32>; // Original RGB input for blending fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { let packed = textureLoad(static_features, coord, 0); @@ -133,5 +135,15 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { output[c] = 0.0; } + // Blend with original on final layer + if (is_output) { + let original = textureLoad(original_input, coord, 0).rgb; + let result_rgb = vec3<f32>(output[0], output[1], output[2]); + let blended = mix(original, result_rgb, params.blend_amount); + output[0] = blended.r; + output[1] = blended.g; + output[2] = blended.b; + } + textureStore(output_tex, coord, pack_channels(output)); } diff --git a/workspaces/main/shaders/cnn_v2_layer_0.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_0.wgsl index 8e14957..8e14957 100644 --- a/workspaces/main/shaders/cnn_v2_layer_0.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_0.wgsl diff --git a/workspaces/main/shaders/cnn_v2_layer_1.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_1.wgsl index f490d13..f490d13 100644 --- a/workspaces/main/shaders/cnn_v2_layer_1.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_1.wgsl diff --git a/workspaces/main/shaders/cnn_v2_layer_2.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_2.wgsl index 2f9836a..2f9836a 100644 --- a/workspaces/main/shaders/cnn_v2_layer_2.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_2.wgsl diff --git a/workspaces/main/shaders/cnn_v2_layer_template.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_template.wgsl index 1bf6819..1bf6819 100644 --- a/workspaces/main/shaders/cnn_v2_layer_template.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_layer_template.wgsl diff --git a/workspaces/main/shaders/cnn_v2_static.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl index c3a2de7..dd07f19 100644 --- a/workspaces/main/shaders/cnn_v2_static.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl @@ -25,9 +25,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { // Sample depth let d = textureLoad(depth_tex, coord, 0).r; - // UV coordinates (normalized [0,1]) + // UV coordinates (normalized [0,1], bottom-left origin) let uv_x = f32(coord.x) / f32(dims.x); - let uv_y = f32(coord.y) / f32(dims.y); + let uv_y = 1.0 - (f32(coord.y) / f32(dims.y)); // Multi-frequency position encoding let sin10_x = sin(10.0 * uv_x); |
