diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-21 10:50:02 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-21 10:50:02 +0100 |
| commit | 35355b17576e93b035a2a78ecd05771e98f068ee (patch) | |
| tree | a1c1a4563a62ad69c808383fcf0bce1ccf4c5765 /cnn_v3/tools/shaders.js | |
| parent | e343021ac007549c76e58b27a361b11dd3f6a136 (diff) | |
feat(cnn_v3): HTML WebGPU tool (index.html + shaders.js + tester.js)
3-file tool, 939 lines total. Implements full U-Net+FiLM inference in
the browser: Pack→Enc0→Enc1→Bottleneck→Dec1→Dec0 compute passes,
layer visualisation (Feat/Enc0/Enc1/BN/Dec1/Output), FiLM MLP sliders,
drag-drop weights + image/video, Save PNG, diff/blend view modes.
HOW_TO_CNN.md §7 updated to reflect tool is implemented.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v3/tools/shaders.js')
| -rw-r--r-- | cnn_v3/tools/shaders.js | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/cnn_v3/tools/shaders.js b/cnn_v3/tools/shaders.js new file mode 100644 index 0000000..c3e994d --- /dev/null +++ b/cnn_v3/tools/shaders.js @@ -0,0 +1,252 @@ +'use strict'; +// CNN v3 WGSL shaders — matches cnn_v3/shaders/*.wgsl exactly. +// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672, total=1964 + +const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1092, DEC0_OFF=1672; +const TOTAL_F16=1964, TOTAL_U32=982; + +// Inlined helpers — prepended to shaders that need them. +const H = ` +fn get_w(base:u32,idx:u32)->f32{ + let i=base+idx; let v=unpack2x16float(weights[i>>1u]); + return select(v.y,v.x,(i&1u)==0u); +} +fn unpack8(tex:texture_2d<u32>,c:vec2i)->array<f32,8>{ + let t=textureLoad(tex,c,0); + let a=unpack2x16float(t.x);let b=unpack2x16float(t.y); + let d=unpack2x16float(t.z);let e=unpack2x16float(t.w); + return array<f32,8>(a.x,a.y,b.x,b.y,d.x,d.y,e.x,e.y); +}`; + +// Pack simple image (albedo+mips+alpha) into feat_tex0/1 +const PACK_SHADER=` +@group(0) @binding(0) var inp:texture_2d<f32>; +@group(0) @binding(1) var smp:sampler; +@group(0) @binding(2) var f0:texture_storage_2d<rgba32uint,write>; +@group(0) @binding(3) var f1:texture_storage_2d<rgba32uint,write>; +@compute @workgroup_size(8,8) +fn main(@builtin(global_invocation_id) id:vec3u){ + let c=vec2i(id.xy); let d=vec2i(textureDimensions(inp)); + if(c.x>=d.x||c.y>=d.y){return;} + let uv=(vec2f(c)+.5)/vec2f(d); + let px=textureLoad(inp,c,0); + let alb=px.rgb; let tr=1.-px.a; + let m1=textureSampleLevel(inp,smp,uv,1.).rgb; + let m2=textureSampleLevel(inp,smp,uv,2.).rgb; + textureStore(f0,c,vec4u(pack2x16float(alb.rg),pack2x16float(vec2f(alb.b,0.)), + pack2x16float(vec2f(0.,0.)),pack2x16float(vec2f(0.,0.)))); + textureStore(f1,c,vec4u(pack4x8unorm(vec4f(0.,0.,0.,0.)), + pack4x8unorm(vec4f(m1.r,m1.g,m1.b,m2.r)), + pack4x8unorm(vec4f(m2.g,m2.b,1.,tr)),0u)); +}`; + +// Enc0: Conv(20→4, 3×3, zero-pad) + FiLM + ReLU → rgba16float +// Params (48 bytes): weight_offset u32 _pad×3 gamma vec4f beta vec4f +const ENC0_SHADER=H+` +struct P{wo:u32,_a:u32,_b:u32,_c:u32,g:vec4f,b:vec4f} +@group(0) @binding(0) var t0:texture_2d<u32>; +@group(0) @binding(1) var t1:texture_2d<u32>; +@group(0) @binding(2) var<storage,read> weights:array<u32>; +@group(0) @binding(3) var<uniform> p:P; +@group(0) @binding(4) var out:texture_storage_2d<rgba16float,write>; +fn feat(c:vec2i,d:vec2i)->array<f32,20>{ + if(c.x<0||c.y<0||c.x>=d.x||c.y>=d.y){return array<f32,20>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);} + let a=unpack2x16float(textureLoad(t0,c,0).x); let b=unpack2x16float(textureLoad(t0,c,0).y); + let cc=unpack2x16float(textureLoad(t0,c,0).z);let dd=unpack2x16float(textureLoad(t0,c,0).w); + let e=unpack4x8unorm(textureLoad(t1,c,0).x); let f=unpack4x8unorm(textureLoad(t1,c,0).y); + let g=unpack4x8unorm(textureLoad(t1,c,0).z); + return array<f32,20>(a.x,a.y,b.x,b.y,cc.x,cc.y,dd.x,dd.y,e.x,e.y,e.z,e.w,f.x,f.y,f.z,f.w,g.x,g.y,g.z,g.w); +} +@compute @workgroup_size(8,8) +fn main(@builtin(global_invocation_id) id:vec3u){ + let c=vec2i(id.xy); let d=vec2i(textureDimensions(t0)); + if(c.x>=d.x||c.y>=d.y){return;} + const IN:u32=20u; const OUT:u32=4u; + var o:array<f32,4>; + for(var oc:u32=0u;oc<OUT;oc++){ + var s=get_w(p.wo,OUT*IN*9u+oc); + for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){ + let ft=feat(c+vec2i(kx,ky),d); let ki=u32(ky+1)*3u+u32(kx+1); + for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];} + }} + o[oc]=max(0.,p.g[oc]*s+p.b[oc]); + } + textureStore(out,c,vec4f(o[0],o[1],o[2],o[3])); +}`; + +// Enc1: AvgPool(enc0) + Conv(4→8, 3×3) + FiLM + ReLU → rgba32uint half-res +// Params (80 bytes): wo u32 _pad×3 glo ghi blo bhi vec4f×4 +const ENC1_SHADER=H+` +struct P{wo:u32,_a:u32,_b:u32,_c:u32,gl:vec4f,gh:vec4f,bl:vec4f,bh:vec4f} +@group(0) @binding(0) var e0:texture_2d<f32>; +@group(0) @binding(1) var<storage,read> weights:array<u32>; +@group(0) @binding(2) var<uniform> p:P; +@group(0) @binding(3) var out:texture_storage_2d<rgba32uint,write>; +fn fg(o:u32)->f32{if(o<4u){return p.gl[o];}return p.gh[o-4u];} +fn fb(o:u32)->f32{if(o<4u){return p.bl[o];}return p.bh[o-4u];} +fn avg(hc:vec2i,fd:vec2i)->array<f32,4>{ + let hd=fd/2; if(hc.x<0||hc.y<0||hc.x>=hd.x||hc.y>=hd.y){return array<f32,4>(0.,0.,0.,0.);} + var s=vec4f(0.); + for(var y:i32=0;y<2;y++){for(var x:i32=0;x<2;x++){s+=textureLoad(e0,clamp(hc*2+vec2i(x,y),vec2i(0),fd-vec2i(1)),0);}} + let a=s*.25; return array<f32,4>(a.x,a.y,a.z,a.w); +} +@compute @workgroup_size(8,8) +fn main(@builtin(global_invocation_id) id:vec3u){ + let fd=vec2i(textureDimensions(e0)); let hd=fd/2; let c=vec2i(id.xy); + if(c.x>=hd.x||c.y>=hd.y){return;} + const IN:u32=4u; const OUT:u32=8u; + var o:array<f32,8>; + for(var oc:u32=0u;oc<OUT;oc++){ + var s=get_w(p.wo,OUT*IN*9u+oc); + for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){ + let ft=avg(c+vec2i(kx,ky),fd); let ki=u32(ky+1)*3u+u32(kx+1); + for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];} + }} + o[oc]=max(0.,fg(oc)*s+fb(oc)); + } + textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])), + pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7])))); +}`; + +// Bottleneck: AvgPool(enc1) + Conv(8→8, 1×1) + ReLU → rgba32uint quarter-res (no FiLM) +// Params (16 bytes): wo u32 _pad×3 +const BN_SHADER=H+` +struct P{wo:u32,_a:u32,_b:u32,_c:u32} +@group(0) @binding(0) var e1:texture_2d<u32>; +@group(0) @binding(1) var<storage,read> weights:array<u32>; +@group(0) @binding(2) var<uniform> p:P; +@group(0) @binding(3) var out:texture_storage_2d<rgba32uint,write>; +fn avg(qc:vec2i,hd:vec2i)->array<f32,8>{ + let qd=hd/2; if(qc.x<0||qc.y<0||qc.x>=qd.x||qc.y>=qd.y){return array<f32,8>(0.,0.,0.,0.,0.,0.,0.,0.);} + var s:array<f32,8>; + for(var y:i32=0;y<2;y++){for(var x:i32=0;x<2;x++){ + let f=unpack8(e1,clamp(qc*2+vec2i(x,y),vec2i(0),hd-vec2i(1))); + for(var i:u32=0u;i<8u;i++){s[i]+=f[i];} + }} + for(var i:u32=0u;i<8u;i++){s[i]*=.25;} return s; +} +@compute @workgroup_size(8,8) +fn main(@builtin(global_invocation_id) id:vec3u){ + let hd=vec2i(textureDimensions(e1)); let qd=hd/2; let c=vec2i(id.xy); + if(c.x>=qd.x||c.y>=qd.y){return;} + let ft=avg(c,hd); var o:array<f32,8>; + for(var oc:u32=0u;oc<8u;oc++){ + var s=get_w(p.wo,64u+oc); + for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*8u+i)*ft[i];} + o[oc]=max(0.,s); + } + textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])), + pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7])))); +}`; + +// Dec1: NearestUp(bn)+cat(enc1_skip) → Conv(16→4,3×3) + FiLM + ReLU → rgba16float half-res +// Params (48 bytes): same layout as enc0 +const DEC1_SHADER=H+` +struct P{wo:u32,_a:u32,_b:u32,_c:u32,g:vec4f,b:vec4f} +@group(0) @binding(0) var bn:texture_2d<u32>; +@group(0) @binding(1) var e1:texture_2d<u32>; +@group(0) @binding(2) var<storage,read> weights:array<u32>; +@group(0) @binding(3) var<uniform> p:P; +@group(0) @binding(4) var out:texture_storage_2d<rgba16float,write>; +fn cat(hc:vec2i,hd:vec2i)->array<f32,16>{ + if(hc.x<0||hc.y<0||hc.x>=hd.x||hc.y>=hd.y){return array<f32,16>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);} + let qd=hd/2; let b=unpack8(bn,clamp(hc/2,vec2i(0),qd-vec2i(1))); + let s=unpack8(e1,hc); + return array<f32,16>(b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7],s[0],s[1],s[2],s[3],s[4],s[5],s[6],s[7]); +} +@compute @workgroup_size(8,8) +fn main(@builtin(global_invocation_id) id:vec3u){ + let hd=vec2i(textureDimensions(e1)); let c=vec2i(id.xy); + if(c.x>=hd.x||c.y>=hd.y){return;} + const IN:u32=16u; const OUT:u32=4u; + var o:array<f32,4>; + for(var oc:u32=0u;oc<OUT;oc++){ + var s=get_w(p.wo,OUT*IN*9u+oc); + for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){ + let ft=cat(c+vec2i(kx,ky),hd); let ki=u32(ky+1)*3u+u32(kx+1); + for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];} + }} + o[oc]=max(0.,p.g[oc]*s+p.b[oc]); + } + textureStore(out,c,vec4f(o[0],o[1],o[2],o[3])); +}`; + +// Dec0: NearestUp(dec1)+cat(enc0_skip) → Conv(8→4,3×3) + FiLM + ReLU + Sigmoid → rgba16float +// Params (48 bytes): same layout as enc0 +const DEC0_SHADER=H+` +struct P{wo:u32,_a:u32,_b:u32,_c:u32,g:vec4f,b:vec4f} +@group(0) @binding(0) var d1:texture_2d<f32>; +@group(0) @binding(1) var e0:texture_2d<f32>; +@group(0) @binding(2) var<storage,read> weights:array<u32>; +@group(0) @binding(3) var<uniform> p:P; +@group(0) @binding(4) var out:texture_storage_2d<rgba16float,write>; +fn cat(c:vec2i,fd:vec2i)->array<f32,8>{ + if(c.x<0||c.y<0||c.x>=fd.x||c.y>=fd.y){return array<f32,8>(0.,0.,0.,0.,0.,0.,0.,0.);} + let hd=vec2i(textureDimensions(d1)); + let a=textureLoad(d1,clamp(c/2,vec2i(0),hd-vec2i(1)),0); + let b=textureLoad(e0,c,0); + return array<f32,8>(a.x,a.y,a.z,a.w,b.x,b.y,b.z,b.w); +} +@compute @workgroup_size(8,8) +fn main(@builtin(global_invocation_id) id:vec3u){ + let fd=vec2i(textureDimensions(e0)); let c=vec2i(id.xy); + if(c.x>=fd.x||c.y>=fd.y){return;} + const IN:u32=8u; const OUT:u32=4u; + var o:array<f32,4>; + for(var oc:u32=0u;oc<OUT;oc++){ + var s=get_w(p.wo,OUT*IN*9u+oc); + for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){ + let ft=cat(c+vec2i(kx,ky),fd); let ki=u32(ky+1)*3u+u32(kx+1); + for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];} + }} + let v=max(0.,p.g[oc]*s+p.b[oc]); + o[oc]=1./(1.+exp(-v)); + } + textureStore(out,c,vec4f(o[0],o[1],o[2],o[3])); +}`; + +// Display: rgba16float output → canvas (mode 0=cnn,1=orig,2=diff, blend) +const DISP_SHADER=` +@group(0) @binding(0) var otex:texture_2d<f32>; +@group(0) @binding(1) var itex:texture_2d<f32>; +@group(0) @binding(2) var<uniform> pr:vec4f; // x=mode y=blend +@vertex fn vs(@builtin(vertex_index) i:u32)->@builtin(position) vec4f{ + var p=array<vec2f,6>(vec2f(-1.,-1.),vec2f(1.,-1.),vec2f(-1.,1.),vec2f(-1.,1.),vec2f(1.,-1.),vec2f(1.,1.)); + return vec4f(p[i],0.,1.); +} +@fragment fn fs(@builtin(position) pos:vec4f)->@location(0) vec4f{ + let c=vec2i(pos.xy); let m=u32(pr.x); + let orig=textureLoad(itex,c,0).rgb; let cnn=textureLoad(otex,c,0).rgb; + if(m==1u){return vec4f(orig,1.);} + if(m==2u){return vec4f(abs(cnn-orig)*10.,1.);} + return vec4f(mix(orig,cnn,pr.y),1.); +}`; + +// Viz f32: show one channel of rgba16float layer +const VIZ_F32=` +@group(0) @binding(0) var t:texture_2d<f32>; +@group(0) @binding(1) var<uniform> ch:u32; +@vertex fn vs(@builtin(vertex_index) i:u32)->@builtin(position) vec4f{ + var p=array<vec2f,6>(vec2f(-1.,-1.),vec2f(1.,-1.),vec2f(-1.,1.),vec2f(-1.,1.),vec2f(1.,-1.),vec2f(1.,1.)); + return vec4f(p[i],0.,1.); +} +@fragment fn fs(@builtin(position) pos:vec4f)->@location(0) vec4f{ + let v=textureLoad(t,vec2i(pos.xy),0); var a=array<f32,4>(v.x,v.y,v.z,v.w); + let x=clamp(a[min(ch,3u)],0.,1.); return vec4f(x,x,x,1.); +}`; + +// Viz u32: show one f16 channel of rgba32uint layer (8 channels packed) +const VIZ_U32=` +@group(0) @binding(0) var t:texture_2d<u32>; +@group(0) @binding(1) var<uniform> ch:u32; +@vertex fn vs(@builtin(vertex_index) i:u32)->@builtin(position) vec4f{ + var p=array<vec2f,6>(vec2f(-1.,-1.),vec2f(1.,-1.),vec2f(-1.,1.),vec2f(-1.,1.),vec2f(1.,-1.),vec2f(1.,1.)); + return vec4f(p[i],0.,1.); +} +@fragment fn fs(@builtin(position) pos:vec4f)->@location(0) vec4f{ + let t2=textureLoad(t,vec2i(pos.xy),0); + let a=unpack2x16float(t2.x);let b=unpack2x16float(t2.y); + let c=unpack2x16float(t2.z);let d=unpack2x16float(t2.w); + var v=array<f32,8>(a.x,a.y,b.x,b.y,c.x,c.y,d.x,d.y); + let x=clamp(v[min(ch,7u)],0.,1.); return vec4f(x,x,x,1.); +}`; |
