summaryrefslogtreecommitdiff
path: root/cnn_v3/tools/shaders.js
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 10:50:02 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 10:50:02 +0100
commit35355b17576e93b035a2a78ecd05771e98f068ee (patch)
treea1c1a4563a62ad69c808383fcf0bce1ccf4c5765 /cnn_v3/tools/shaders.js
parente343021ac007549c76e58b27a361b11dd3f6a136 (diff)
feat(cnn_v3): HTML WebGPU tool (index.html + shaders.js + tester.js)
3-file tool, 939 lines total. Implements full U-Net+FiLM inference in the browser: Pack→Enc0→Enc1→Bottleneck→Dec1→Dec0 compute passes, layer visualisation (Feat/Enc0/Enc1/BN/Dec1/Output), FiLM MLP sliders, drag-drop weights + image/video, Save PNG, diff/blend view modes. HOW_TO_CNN.md §7 updated to reflect tool is implemented. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v3/tools/shaders.js')
-rw-r--r--cnn_v3/tools/shaders.js252
1 files changed, 252 insertions, 0 deletions
diff --git a/cnn_v3/tools/shaders.js b/cnn_v3/tools/shaders.js
new file mode 100644
index 0000000..c3e994d
--- /dev/null
+++ b/cnn_v3/tools/shaders.js
@@ -0,0 +1,252 @@
+'use strict';
+// CNN v3 WGSL shaders — matches cnn_v3/shaders/*.wgsl exactly.
+// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672, total=1964
+
+const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1092, DEC0_OFF=1672;
+const TOTAL_F16=1964, TOTAL_U32=982;
+
+// Inlined helpers — prepended to shaders that need them.
+const H = `
+fn get_w(base:u32,idx:u32)->f32{
+ let i=base+idx; let v=unpack2x16float(weights[i>>1u]);
+ return select(v.y,v.x,(i&1u)==0u);
+}
+fn unpack8(tex:texture_2d<u32>,c:vec2i)->array<f32,8>{
+ let t=textureLoad(tex,c,0);
+ let a=unpack2x16float(t.x);let b=unpack2x16float(t.y);
+ let d=unpack2x16float(t.z);let e=unpack2x16float(t.w);
+ return array<f32,8>(a.x,a.y,b.x,b.y,d.x,d.y,e.x,e.y);
+}`;
+
+// Pack simple image (albedo+mips+alpha) into feat_tex0/1
+const PACK_SHADER=`
+@group(0) @binding(0) var inp:texture_2d<f32>;
+@group(0) @binding(1) var smp:sampler;
+@group(0) @binding(2) var f0:texture_storage_2d<rgba32uint,write>;
+@group(0) @binding(3) var f1:texture_storage_2d<rgba32uint,write>;
+@compute @workgroup_size(8,8)
+fn main(@builtin(global_invocation_id) id:vec3u){
+ let c=vec2i(id.xy); let d=vec2i(textureDimensions(inp));
+ if(c.x>=d.x||c.y>=d.y){return;}
+ let uv=(vec2f(c)+.5)/vec2f(d);
+ let px=textureLoad(inp,c,0);
+ let alb=px.rgb; let tr=1.-px.a;
+ let m1=textureSampleLevel(inp,smp,uv,1.).rgb;
+ let m2=textureSampleLevel(inp,smp,uv,2.).rgb;
+ textureStore(f0,c,vec4u(pack2x16float(alb.rg),pack2x16float(vec2f(alb.b,0.)),
+ pack2x16float(vec2f(0.,0.)),pack2x16float(vec2f(0.,0.))));
+ textureStore(f1,c,vec4u(pack4x8unorm(vec4f(0.,0.,0.,0.)),
+ pack4x8unorm(vec4f(m1.r,m1.g,m1.b,m2.r)),
+ pack4x8unorm(vec4f(m2.g,m2.b,1.,tr)),0u));
+}`;
+
+// Enc0: Conv(20→4, 3×3, zero-pad) + FiLM + ReLU → rgba16float
+// Params (48 bytes): weight_offset u32 _pad×3 gamma vec4f beta vec4f
+const ENC0_SHADER=H+`
+struct P{wo:u32,_a:u32,_b:u32,_c:u32,g:vec4f,b:vec4f}
+@group(0) @binding(0) var t0:texture_2d<u32>;
+@group(0) @binding(1) var t1:texture_2d<u32>;
+@group(0) @binding(2) var<storage,read> weights:array<u32>;
+@group(0) @binding(3) var<uniform> p:P;
+@group(0) @binding(4) var out:texture_storage_2d<rgba16float,write>;
+fn feat(c:vec2i,d:vec2i)->array<f32,20>{
+ if(c.x<0||c.y<0||c.x>=d.x||c.y>=d.y){return array<f32,20>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);}
+ let a=unpack2x16float(textureLoad(t0,c,0).x); let b=unpack2x16float(textureLoad(t0,c,0).y);
+ let cc=unpack2x16float(textureLoad(t0,c,0).z);let dd=unpack2x16float(textureLoad(t0,c,0).w);
+ let e=unpack4x8unorm(textureLoad(t1,c,0).x); let f=unpack4x8unorm(textureLoad(t1,c,0).y);
+ let g=unpack4x8unorm(textureLoad(t1,c,0).z);
+ return array<f32,20>(a.x,a.y,b.x,b.y,cc.x,cc.y,dd.x,dd.y,e.x,e.y,e.z,e.w,f.x,f.y,f.z,f.w,g.x,g.y,g.z,g.w);
+}
+@compute @workgroup_size(8,8)
+fn main(@builtin(global_invocation_id) id:vec3u){
+ let c=vec2i(id.xy); let d=vec2i(textureDimensions(t0));
+ if(c.x>=d.x||c.y>=d.y){return;}
+ const IN:u32=20u; const OUT:u32=4u;
+ var o:array<f32,4>;
+ for(var oc:u32=0u;oc<OUT;oc++){
+ var s=get_w(p.wo,OUT*IN*9u+oc);
+ for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){
+ let ft=feat(c+vec2i(kx,ky),d); let ki=u32(ky+1)*3u+u32(kx+1);
+ for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];}
+ }}
+ o[oc]=max(0.,p.g[oc]*s+p.b[oc]);
+ }
+ textureStore(out,c,vec4f(o[0],o[1],o[2],o[3]));
+}`;
+
+// Enc1: AvgPool(enc0) + Conv(4→8, 3×3) + FiLM + ReLU → rgba32uint half-res
+// Params (80 bytes): wo u32 _pad×3 glo ghi blo bhi vec4f×4
+const ENC1_SHADER=H+`
+struct P{wo:u32,_a:u32,_b:u32,_c:u32,gl:vec4f,gh:vec4f,bl:vec4f,bh:vec4f}
+@group(0) @binding(0) var e0:texture_2d<f32>;
+@group(0) @binding(1) var<storage,read> weights:array<u32>;
+@group(0) @binding(2) var<uniform> p:P;
+@group(0) @binding(3) var out:texture_storage_2d<rgba32uint,write>;
+fn fg(o:u32)->f32{if(o<4u){return p.gl[o];}return p.gh[o-4u];}
+fn fb(o:u32)->f32{if(o<4u){return p.bl[o];}return p.bh[o-4u];}
+fn avg(hc:vec2i,fd:vec2i)->array<f32,4>{
+ let hd=fd/2; if(hc.x<0||hc.y<0||hc.x>=hd.x||hc.y>=hd.y){return array<f32,4>(0.,0.,0.,0.);}
+ var s=vec4f(0.);
+ for(var y:i32=0;y<2;y++){for(var x:i32=0;x<2;x++){s+=textureLoad(e0,clamp(hc*2+vec2i(x,y),vec2i(0),fd-vec2i(1)),0);}}
+ let a=s*.25; return array<f32,4>(a.x,a.y,a.z,a.w);
+}
+@compute @workgroup_size(8,8)
+fn main(@builtin(global_invocation_id) id:vec3u){
+ let fd=vec2i(textureDimensions(e0)); let hd=fd/2; let c=vec2i(id.xy);
+ if(c.x>=hd.x||c.y>=hd.y){return;}
+ const IN:u32=4u; const OUT:u32=8u;
+ var o:array<f32,8>;
+ for(var oc:u32=0u;oc<OUT;oc++){
+ var s=get_w(p.wo,OUT*IN*9u+oc);
+ for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){
+ let ft=avg(c+vec2i(kx,ky),fd); let ki=u32(ky+1)*3u+u32(kx+1);
+ for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];}
+ }}
+ o[oc]=max(0.,fg(oc)*s+fb(oc));
+ }
+ textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])),
+ pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7]))));
+}`;
+
+// Bottleneck: AvgPool(enc1) + Conv(8→8, 1×1) + ReLU → rgba32uint quarter-res (no FiLM)
+// Params (16 bytes): wo u32 _pad×3
+const BN_SHADER=H+`
+struct P{wo:u32,_a:u32,_b:u32,_c:u32}
+@group(0) @binding(0) var e1:texture_2d<u32>;
+@group(0) @binding(1) var<storage,read> weights:array<u32>;
+@group(0) @binding(2) var<uniform> p:P;
+@group(0) @binding(3) var out:texture_storage_2d<rgba32uint,write>;
+fn avg(qc:vec2i,hd:vec2i)->array<f32,8>{
+ let qd=hd/2; if(qc.x<0||qc.y<0||qc.x>=qd.x||qc.y>=qd.y){return array<f32,8>(0.,0.,0.,0.,0.,0.,0.,0.);}
+ var s:array<f32,8>;
+ for(var y:i32=0;y<2;y++){for(var x:i32=0;x<2;x++){
+ let f=unpack8(e1,clamp(qc*2+vec2i(x,y),vec2i(0),hd-vec2i(1)));
+ for(var i:u32=0u;i<8u;i++){s[i]+=f[i];}
+ }}
+ for(var i:u32=0u;i<8u;i++){s[i]*=.25;} return s;
+}
+@compute @workgroup_size(8,8)
+fn main(@builtin(global_invocation_id) id:vec3u){
+ let hd=vec2i(textureDimensions(e1)); let qd=hd/2; let c=vec2i(id.xy);
+ if(c.x>=qd.x||c.y>=qd.y){return;}
+ let ft=avg(c,hd); var o:array<f32,8>;
+ for(var oc:u32=0u;oc<8u;oc++){
+ var s=get_w(p.wo,64u+oc);
+ for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*8u+i)*ft[i];}
+ o[oc]=max(0.,s);
+ }
+ textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])),
+ pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7]))));
+}`;
+
+// Dec1: NearestUp(bn)+cat(enc1_skip) → Conv(16→4,3×3) + FiLM + ReLU → rgba16float half-res
+// Params (48 bytes): same layout as enc0
+const DEC1_SHADER=H+`
+struct P{wo:u32,_a:u32,_b:u32,_c:u32,g:vec4f,b:vec4f}
+@group(0) @binding(0) var bn:texture_2d<u32>;
+@group(0) @binding(1) var e1:texture_2d<u32>;
+@group(0) @binding(2) var<storage,read> weights:array<u32>;
+@group(0) @binding(3) var<uniform> p:P;
+@group(0) @binding(4) var out:texture_storage_2d<rgba16float,write>;
+fn cat(hc:vec2i,hd:vec2i)->array<f32,16>{
+ if(hc.x<0||hc.y<0||hc.x>=hd.x||hc.y>=hd.y){return array<f32,16>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);}
+ let qd=hd/2; let b=unpack8(bn,clamp(hc/2,vec2i(0),qd-vec2i(1)));
+ let s=unpack8(e1,hc);
+ return array<f32,16>(b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7],s[0],s[1],s[2],s[3],s[4],s[5],s[6],s[7]);
+}
+@compute @workgroup_size(8,8)
+fn main(@builtin(global_invocation_id) id:vec3u){
+ let hd=vec2i(textureDimensions(e1)); let c=vec2i(id.xy);
+ if(c.x>=hd.x||c.y>=hd.y){return;}
+ const IN:u32=16u; const OUT:u32=4u;
+ var o:array<f32,4>;
+ for(var oc:u32=0u;oc<OUT;oc++){
+ var s=get_w(p.wo,OUT*IN*9u+oc);
+ for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){
+ let ft=cat(c+vec2i(kx,ky),hd); let ki=u32(ky+1)*3u+u32(kx+1);
+ for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];}
+ }}
+ o[oc]=max(0.,p.g[oc]*s+p.b[oc]);
+ }
+ textureStore(out,c,vec4f(o[0],o[1],o[2],o[3]));
+}`;
+
+// Dec0: NearestUp(dec1)+cat(enc0_skip) → Conv(8→4,3×3) + FiLM + ReLU + Sigmoid → rgba16float
+// Params (48 bytes): same layout as enc0
+const DEC0_SHADER=H+`
+struct P{wo:u32,_a:u32,_b:u32,_c:u32,g:vec4f,b:vec4f}
+@group(0) @binding(0) var d1:texture_2d<f32>;
+@group(0) @binding(1) var e0:texture_2d<f32>;
+@group(0) @binding(2) var<storage,read> weights:array<u32>;
+@group(0) @binding(3) var<uniform> p:P;
+@group(0) @binding(4) var out:texture_storage_2d<rgba16float,write>;
+fn cat(c:vec2i,fd:vec2i)->array<f32,8>{
+ if(c.x<0||c.y<0||c.x>=fd.x||c.y>=fd.y){return array<f32,8>(0.,0.,0.,0.,0.,0.,0.,0.);}
+ let hd=vec2i(textureDimensions(d1));
+ let a=textureLoad(d1,clamp(c/2,vec2i(0),hd-vec2i(1)),0);
+ let b=textureLoad(e0,c,0);
+ return array<f32,8>(a.x,a.y,a.z,a.w,b.x,b.y,b.z,b.w);
+}
+@compute @workgroup_size(8,8)
+fn main(@builtin(global_invocation_id) id:vec3u){
+ let fd=vec2i(textureDimensions(e0)); let c=vec2i(id.xy);
+ if(c.x>=fd.x||c.y>=fd.y){return;}
+ const IN:u32=8u; const OUT:u32=4u;
+ var o:array<f32,4>;
+ for(var oc:u32=0u;oc<OUT;oc++){
+ var s=get_w(p.wo,OUT*IN*9u+oc);
+ for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){
+ let ft=cat(c+vec2i(kx,ky),fd); let ki=u32(ky+1)*3u+u32(kx+1);
+ for(var i:u32=0u;i<IN;i++){s+=get_w(p.wo,oc*IN*9u+i*9u+ki)*ft[i];}
+ }}
+ let v=max(0.,p.g[oc]*s+p.b[oc]);
+ o[oc]=1./(1.+exp(-v));
+ }
+ textureStore(out,c,vec4f(o[0],o[1],o[2],o[3]));
+}`;
+
+// Display: rgba16float output → canvas (mode 0=cnn,1=orig,2=diff, blend)
+const DISP_SHADER=`
+@group(0) @binding(0) var otex:texture_2d<f32>;
+@group(0) @binding(1) var itex:texture_2d<f32>;
+@group(0) @binding(2) var<uniform> pr:vec4f; // x=mode y=blend
+@vertex fn vs(@builtin(vertex_index) i:u32)->@builtin(position) vec4f{
+ var p=array<vec2f,6>(vec2f(-1.,-1.),vec2f(1.,-1.),vec2f(-1.,1.),vec2f(-1.,1.),vec2f(1.,-1.),vec2f(1.,1.));
+ return vec4f(p[i],0.,1.);
+}
+@fragment fn fs(@builtin(position) pos:vec4f)->@location(0) vec4f{
+ let c=vec2i(pos.xy); let m=u32(pr.x);
+ let orig=textureLoad(itex,c,0).rgb; let cnn=textureLoad(otex,c,0).rgb;
+ if(m==1u){return vec4f(orig,1.);}
+ if(m==2u){return vec4f(abs(cnn-orig)*10.,1.);}
+ return vec4f(mix(orig,cnn,pr.y),1.);
+}`;
+
+// Viz f32: show one channel of rgba16float layer
+const VIZ_F32=`
+@group(0) @binding(0) var t:texture_2d<f32>;
+@group(0) @binding(1) var<uniform> ch:u32;
+@vertex fn vs(@builtin(vertex_index) i:u32)->@builtin(position) vec4f{
+ var p=array<vec2f,6>(vec2f(-1.,-1.),vec2f(1.,-1.),vec2f(-1.,1.),vec2f(-1.,1.),vec2f(1.,-1.),vec2f(1.,1.));
+ return vec4f(p[i],0.,1.);
+}
+@fragment fn fs(@builtin(position) pos:vec4f)->@location(0) vec4f{
+ let v=textureLoad(t,vec2i(pos.xy),0); var a=array<f32,4>(v.x,v.y,v.z,v.w);
+ let x=clamp(a[min(ch,3u)],0.,1.); return vec4f(x,x,x,1.);
+}`;
+
+// Viz u32: show one f16 channel of rgba32uint layer (8 channels packed)
+const VIZ_U32=`
+@group(0) @binding(0) var t:texture_2d<u32>;
+@group(0) @binding(1) var<uniform> ch:u32;
+@vertex fn vs(@builtin(vertex_index) i:u32)->@builtin(position) vec4f{
+ var p=array<vec2f,6>(vec2f(-1.,-1.),vec2f(1.,-1.),vec2f(-1.,1.),vec2f(-1.,1.),vec2f(1.,-1.),vec2f(1.,1.));
+ return vec4f(p[i],0.,1.);
+}
+@fragment fn fs(@builtin(position) pos:vec4f)->@location(0) vec4f{
+ let t2=textureLoad(t,vec2i(pos.xy),0);
+ let a=unpack2x16float(t2.x);let b=unpack2x16float(t2.y);
+ let c=unpack2x16float(t2.z);let d=unpack2x16float(t2.w);
+ var v=array<f32,8>(a.x,a.y,b.x,b.y,c.x,c.y,d.x,d.y);
+ let x=clamp(v[min(ch,7u)],0.,1.); return vec4f(x,x,x,1.);
+}`;