summaryrefslogtreecommitdiff
path: root/cnn_v3/tools/shaders.js
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/tools/shaders.js')
-rw-r--r--cnn_v3/tools/shaders.js18
1 files changed, 11 insertions, 7 deletions
diff --git a/cnn_v3/tools/shaders.js b/cnn_v3/tools/shaders.js
index 6c49864..36f53c8 100644
--- a/cnn_v3/tools/shaders.js
+++ b/cnn_v3/tools/shaders.js
@@ -1,9 +1,10 @@
'use strict';
// CNN v3 WGSL shaders — matches cnn_v3/shaders/*.wgsl exactly.
-// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672, total=1964
+// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1604, dec0=2184, total=2476
+// BN is now Conv(8→8, 3×3, dilation=2): 8*8*9+8=584 weights (was 72 for 1×1)
-const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1092, DEC0_OFF=1672;
-const TOTAL_F16=1964, TOTAL_U32=982;
+const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1604, DEC0_OFF=2184;
+const TOTAL_F16=2476, TOTAL_U32=1238;
// Inlined helpers — prepended to shaders that need them.
const H = `
@@ -108,7 +109,7 @@ fn main(@builtin(global_invocation_id) id:vec3u){
pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7]))));
}`;
-// Bottleneck: AvgPool(enc1) + Conv(8→8, 1×1) + ReLU → rgba32uint quarter-res (no FiLM)
+// Bottleneck: AvgPool(enc1) + Conv(8→8, 3×3, dilation=2) + ReLU → rgba32uint quarter-res (no FiLM)
// Params (16 bytes): wo u32 _pad×3
const BN_SHADER=H+`
struct P{wo:u32,_a:u32,_b:u32,_c:u32}
@@ -129,10 +130,13 @@ fn avg(qc:vec2i,hd:vec2i)->array<f32,8>{
fn main(@builtin(global_invocation_id) id:vec3u){
let hd=vec2i(textureDimensions(e1)); let qd=hd/2; let c=vec2i(id.xy);
if(c.x>=qd.x||c.y>=qd.y){return;}
- let ft=avg(c,hd); var o:array<f32,8>;
+ var o:array<f32,8>;
for(var oc:u32=0u;oc<8u;oc++){
- var s=get_w(p.wo,64u+oc);
- for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*8u+i)*ft[i];}
+ var s=get_w(p.wo,576u+oc);
+ for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){
+ let ft=avg(c+vec2i(kx,ky)*2,hd); let ki=u32(ky+1)*3u+u32(kx+1);
+ for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*72u+i*9u+ki)*ft[i];}
+ }}
o[oc]=max(0.,s);
}
textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])),