diff options
Diffstat (limited to 'cnn_v3/tools/shaders.js')
| -rw-r--r-- | cnn_v3/tools/shaders.js | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/cnn_v3/tools/shaders.js b/cnn_v3/tools/shaders.js index 6c49864..36f53c8 100644 --- a/cnn_v3/tools/shaders.js +++ b/cnn_v3/tools/shaders.js @@ -1,9 +1,10 @@ 'use strict'; // CNN v3 WGSL shaders — matches cnn_v3/shaders/*.wgsl exactly. -// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672, total=1964 +// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1604, dec0=2184, total=2476 +// BN is now Conv(8→8, 3×3, dilation=2): 8*8*9+8=584 weights (was 72 for 1×1) -const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1092, DEC0_OFF=1672; -const TOTAL_F16=1964, TOTAL_U32=982; +const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1604, DEC0_OFF=2184; +const TOTAL_F16=2476, TOTAL_U32=1238; // Inlined helpers — prepended to shaders that need them. const H = ` @@ -108,7 +109,7 @@ fn main(@builtin(global_invocation_id) id:vec3u){ pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7])))); }`; -// Bottleneck: AvgPool(enc1) + Conv(8→8, 1×1) + ReLU → rgba32uint quarter-res (no FiLM) +// Bottleneck: AvgPool(enc1) + Conv(8→8, 3×3, dilation=2) + ReLU → rgba32uint quarter-res (no FiLM) // Params (16 bytes): wo u32 _pad×3 const BN_SHADER=H+` struct P{wo:u32,_a:u32,_b:u32,_c:u32} @@ -129,10 +130,13 @@ fn avg(qc:vec2i,hd:vec2i)->array<f32,8>{ fn main(@builtin(global_invocation_id) id:vec3u){ let hd=vec2i(textureDimensions(e1)); let qd=hd/2; let c=vec2i(id.xy); if(c.x>=qd.x||c.y>=qd.y){return;} - let ft=avg(c,hd); var o:array<f32,8>; + var o:array<f32,8>; for(var oc:u32=0u;oc<8u;oc++){ - var s=get_w(p.wo,64u+oc); - for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*8u+i)*ft[i];} + var s=get_w(p.wo,576u+oc); + for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){ + let ft=avg(c+vec2i(kx,ky)*2,hd); let ki=u32(ky+1)*3u+u32(kx+1); + for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*72u+i*9u+ki)*ft[i];} + }} o[oc]=max(0.,s); } textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])), |
