diff options
Diffstat (limited to 'cnn_v3/tools')
| -rw-r--r-- | cnn_v3/tools/index.html | 17 | ||||
| -rw-r--r-- | cnn_v3/tools/shaders.js | 48 | ||||
| -rw-r--r-- | cnn_v3/tools/tester.js | 277 |
3 files changed, 341 insertions, 1 deletions
diff --git a/cnn_v3/tools/index.html b/cnn_v3/tools/index.html index 8494fef..1398ca5 100644 --- a/cnn_v3/tools/index.html +++ b/cnn_v3/tools/index.html @@ -64,6 +64,7 @@ video{display:none} <div class="left"> <input type="file" id="wFile" accept=".bin" style="display:none"> <input type="file" id="fFile" accept=".bin" style="display:none"> + <input type="file" id="sFile" webkitdirectory style="display:none" onchange="tester.loadSampleDir(this.files)"> <div class="dz" id="wDrop" onclick="document.getElementById('wFile').click()">Drop cnn_v3_weights.bin</div> <div class="dz" id="fDrop" onclick="document.getElementById('fFile').click()">Drop cnn_v3_film_mlp.bin (optional)</div> @@ -79,6 +80,10 @@ video{display:none} <div id="fullHelp" style="display:none;margin-top:6px;font-size:9px;color:#555;line-height:1.6"> Drop PNGs: *albedo*/color · *normal* · *depth* · *matid*/index · *shadow* · *transp*/alpha </div> + <div style="margin-top:8px;border-top:1px solid #333;padding-top:8px"> + <button onclick="document.getElementById('sFile').click()" style="width:100%">↑ Load sample directory</button> + <div id="sampleSt" style="font-size:9px;color:#555;margin-top:3px"></div> + </div> </div> </div> @@ -121,7 +126,17 @@ video{display:none} <div class="sep"></div> <button onclick="tester.savePNG()">Save PNG</button> </div> - <canvas id="canvas"></canvas> + <div style="display:flex;gap:12px;align-items:flex-start"> + <div style="display:flex;flex-direction:column;align-items:center;gap:3px"> + <canvas id="canvas"></canvas> + <span id="cnnLabel" style="font-size:9px;color:#555"></span> + </div> + <div id="targetPane" style="display:none;flex-direction:column;align-items:center;gap:3px"> + <canvas id="targetCanvas" style="max-width:100%;max-height:100%;image-rendering:pixelated;box-shadow:0 4px 12px rgba(0,0,0,.5)"></canvas> + <span style="font-size:9px;color:#555">target.png</span> + <span id="psnrSt" style="font-size:9px;color:#4a9eff"></span> + </div> + </div> </div> <div class="right"> diff --git a/cnn_v3/tools/shaders.js b/cnn_v3/tools/shaders.js index c3e994d..d5b1fb4 100644 --- a/cnn_v3/tools/shaders.js +++ b/cnn_v3/tools/shaders.js @@ -250,3 +250,51 @@ const VIZ_U32=` var v=array<f32,8>(a.x,a.y,b.x,b.y,c.x,c.y,d.x,d.y); let x=clamp(v[min(ch,7u)],0.,1.); return vec4f(x,x,x,1.); }`; + +// Full G-buffer pack: assembles feat_tex0/feat_tex1 from individual G-buffer images. +// Bindings: albedo(0) normal(1) depth(2) matid(3) shadow(4) transp(5) sampler(6) f0(7) f1(8) +// All source textures are rgba8unorm (browser-loaded images, R channel for depth/matid/shadow/transp). +// Matches gbuf_pack.wgsl packing exactly so the CNN sees the same layout. +const FULL_PACK_SHADER=` +@group(0) @binding(0) var albedo: texture_2d<f32>; +@group(0) @binding(1) var normal: texture_2d<f32>; +@group(0) @binding(2) var depth: texture_2d<f32>; +@group(0) @binding(3) var matid: texture_2d<f32>; +@group(0) @binding(4) var shadow: texture_2d<f32>; +@group(0) @binding(5) var transp: texture_2d<f32>; +@group(0) @binding(6) var smp: sampler; +@group(0) @binding(7) var f0: texture_storage_2d<rgba32uint,write>; +@group(0) @binding(8) var f1: texture_storage_2d<rgba32uint,write>; +fn ld(c:vec2i,d:vec2i)->f32{return textureLoad(depth,clamp(c,vec2i(0),d-vec2i(1)),0).r;} +fn b2(tl:vec2i,d:vec2i)->vec3f{ + var s=vec3f(0.); + for(var y:i32=0;y<2;y++){for(var x:i32=0;x<2;x++){s+=textureLoad(albedo,clamp(tl+vec2i(x,y),vec2i(0),d-vec2i(1)),0).rgb;}} + return s*.25;} +fn b4(tl:vec2i,d:vec2i)->vec3f{ + var s=vec3f(0.); + for(var y:i32=0;y<4;y++){for(var x:i32=0;x<4;x++){s+=textureLoad(albedo,clamp(tl+vec2i(x,y),vec2i(0),d-vec2i(1)),0).rgb;}} + return s*(1./16.);} +@compute @workgroup_size(8,8) +fn main(@builtin(global_invocation_id) id:vec3u){ + let c=vec2i(id.xy); let d=vec2i(textureDimensions(albedo)); + if(c.x>=d.x||c.y>=d.y){return;} + let alb=textureLoad(albedo,c,0).rgb; + let nrm=textureLoad(normal,c,0).rg; + let oct=nrm*2.-vec2f(1.); // [0,1] -> [-1,1] + let dv=ld(c,d); + let dzdx=(ld(c+vec2i(1,0),d)-ld(c-vec2i(1,0),d))*.5; + let dzdy=(ld(c+vec2i(0,1),d)-ld(c-vec2i(0,1),d))*.5; + textureStore(f0,c,vec4u( + pack2x16float(alb.rg), + pack2x16float(vec2f(alb.b,oct.x)), + pack2x16float(vec2f(oct.y,dv)), + pack2x16float(vec2f(dzdx,dzdy)))); + let mid=textureLoad(matid,c,0).r; + let shd=textureLoad(shadow,c,0).r; + let trp=textureLoad(transp,c,0).r; + let m1=b2(c-vec2i(0),d); let m2=b4(c-vec2i(1),d); + textureStore(f1,c,vec4u( + pack4x8unorm(vec4f(mid,0.,0.,0.)), + pack4x8unorm(vec4f(m1.r,m1.g,m1.b,m2.r)), + pack4x8unorm(vec4f(m2.g,m2.b,shd,trp)), + 0u));}`; diff --git a/cnn_v3/tools/tester.js b/cnn_v3/tools/tester.js index f056444..c1faec9 100644 --- a/cnn_v3/tools/tester.js +++ b/cnn_v3/tools/tester.js @@ -13,6 +13,7 @@ class CNNv3Tester { this.image = null; this.isVideo = false; this.viewMode= 0; // 0=cnn 1=orig 2=diff + this.targetBitmap = null; // set when a sample dir with target.png is loaded this.blend = 1.0; this.layerTextures = {}; this.lastResult = null; @@ -525,6 +526,282 @@ class CNNv3Tester { return(s?-1:1)*Math.pow(2,e-15)*(1+m/1024);}; return [f(lo),f(hi)]; } + + // ── Full G-buffer pack pipeline ─────────────────────────────────────────── + + getFullPack() { + return this.pl('fullpack', () => this.computePL(FULL_PACK_SHADER, 'main')); + } + + // Create a 1×1 rgba8unorm fallback texture with given RGBA bytes [0-255]. + makeFallbackTex(r, g, b, a) { + const tex = this.device.createTexture({size:[1,1], format:'rgba8unorm', + usage: GPUTextureUsage.TEXTURE_BINDING|GPUTextureUsage.COPY_DST}); + this.device.queue.writeTexture({texture:tex}, new Uint8Array([r,g,b,a]), + {bytesPerRow:4,rowsPerImage:1}, [1,1]); + return tex; + } + + // Load an image File as a GPU rgba8unorm texture. Returns {tex, w, h}. + async loadGpuTex(file) { + const bmp = await createImageBitmap(file); + const w = bmp.width, h = bmp.height; + const tex = this.device.createTexture({size:[w,h], format:'rgba8unorm', + usage: GPUTextureUsage.TEXTURE_BINDING|GPUTextureUsage.COPY_DST|GPUTextureUsage.RENDER_ATTACHMENT}); + this.device.queue.copyExternalImageToTexture({source:bmp}, {texture:tex}, [w,h]); + bmp.close(); + return {tex, w, h}; + } + + // ── Load sample directory ───────────────────────────────────────────────── + + async loadSampleDir(files) { + if (!files || files.length === 0) return; + if (!this.weightsU32) { this.setStatus('Load weights first', true); return; } + + this.setMode('full'); + const st = document.getElementById('sampleSt'); + st.textContent = 'Loading…'; + + // Match files by name pattern + const match = (pat) => { + for (const f of files) { + const n = f.name.toLowerCase(); + if (pat.some(p => n.includes(p))) return f; + } + return null; + }; + + const fAlbedo = match(['albedo', 'color']); + const fNormal = match(['normal', 'nrm']); + const fDepth = match(['depth']); + const fMatid = match(['matid', 'index', 'mat_id']); + const fShadow = match(['shadow']); + const fTransp = match(['transp', 'alpha']); + const fTarget = match(['target', 'output', 'ground_truth']); + + if (!fAlbedo) { + st.textContent = '✗ No albedo.png found'; + this.setStatus('No albedo.png in sample dir', true); + return; + } + + try { + const t0 = performance.now(); + + // Load primary albedo to get dimensions + const {tex: albTex, w, h} = await this.loadGpuTex(fAlbedo); + this.canvas.width = w; this.canvas.height = h; + this.context.configure({device:this.device, format:this.format}); + + // Load optional channels — fall back to neutral 1×1 textures + const nrmTex = fNormal ? (await this.loadGpuTex(fNormal)).tex + : this.makeFallbackTex(128, 128, 0, 255); // oct-encoded (0,0) normal + const dptTex = fDepth ? (await this.loadGpuTex(fDepth)).tex + : this.makeFallbackTex(0, 0, 0, 255); + const midTex = fMatid ? (await this.loadGpuTex(fMatid)).tex + : this.makeFallbackTex(0, 0, 0, 255); + const shdTex = fShadow ? (await this.loadGpuTex(fShadow)).tex + : this.makeFallbackTex(255, 255, 255, 255); // fully lit + const trpTex = fTransp ? (await this.loadGpuTex(fTransp)).tex + : this.makeFallbackTex(0, 0, 0, 255); // fully opaque + + // Load target if present + if (this.targetBitmap) { this.targetBitmap.close(); this.targetBitmap = null; } + if (fTarget) { + this.targetBitmap = await createImageBitmap(fTarget); + this.showTarget(); + } else { + document.getElementById('targetPane').style.display = 'none'; + } + + // Pack G-buffer into feat0/feat1 + const mk = (fmt, tw, th) => this.device.createTexture({size:[tw,th], format:fmt, + usage:GPUTextureUsage.STORAGE_BINDING|GPUTextureUsage.TEXTURE_BINDING|GPUTextureUsage.COPY_SRC}); + const f0 = mk('rgba32uint', w, h); + const f1 = mk('rgba32uint', w, h); + + const ceil8 = (n) => Math.ceil(n/8); + const pl = this.getFullPack(); + const bg = this.device.createBindGroup({layout: pl.getBindGroupLayout(0), + entries: [ + {binding:0, resource: albTex.createView()}, + {binding:1, resource: nrmTex.createView()}, + {binding:2, resource: dptTex.createView()}, + {binding:3, resource: midTex.createView()}, + {binding:4, resource: shdTex.createView()}, + {binding:5, resource: trpTex.createView()}, + {binding:6, resource: this.linearSampler}, + {binding:7, resource: f0.createView()}, + {binding:8, resource: f1.createView()}, + ]}); + + const enc = this.device.createCommandEncoder(); + const cp = enc.beginComputePass(); + cp.setPipeline(pl); cp.setBindGroup(0, bg); + cp.dispatchWorkgroups(ceil8(w), ceil8(h)); + cp.end(); + this.device.queue.submit([enc.finish()]); + await this.device.queue.onSubmittedWorkDone(); + + // Cleanup source textures + [albTex, nrmTex, dptTex, midTex, shdTex, trpTex].forEach(t => t.destroy()); + + const found = [fAlbedo, fNormal, fDepth, fMatid, fShadow, fTransp] + .filter(Boolean).map(f => f.name).join(', '); + st.textContent = `✓ ${found}`; + this.log(`Sample packed: ${w}×${h}, ${((performance.now()-t0)).toFixed(0)}ms`); + + // Run inference from packed feat textures + await this.runFromFeat(f0, f1, w, h); + f0.destroy(); f1.destroy(); + + } catch(e) { + st.textContent = `✗ ${e.message}`; + this.setStatus(`Sample error: ${e.message}`, true); + this.log(`Sample error: ${e.message}`, 'err'); + } + } + + // Show target.png in the #targetPane alongside main canvas. + showTarget() { + if (!this.targetBitmap) return; + const tc = document.getElementById('targetCanvas'); + tc.width = this.targetBitmap.width; + tc.height = this.targetBitmap.height; + const ctx2d = tc.getContext('2d'); + ctx2d.drawImage(this.targetBitmap, 0, 0); + document.getElementById('targetPane').style.display = 'flex'; + } + + // Run CNN inference starting from pre-packed feat_tex0 / feat_tex1. + // Used by loadSampleDir() to skip the photo-pack step. + async runFromFeat(f0, f1, w, h) { + if (!this.weightsU32 || !this.device) return; + const t0 = performance.now(); + const W2=w>>1, H2=h>>1, W4=W2>>1, H4=H2>>1; + + this.context.configure({device:this.device, format:this.format}); + + // Create a neutral "original" texture so the display shader can still + // render Orig/Diff modes (just black for sample mode). + if (this.inputTex) this.inputTex.destroy(); + this.inputTex = this.device.createTexture({size:[w,h], format:'rgba8unorm', + usage:GPUTextureUsage.TEXTURE_BINDING|GPUTextureUsage.COPY_DST|GPUTextureUsage.RENDER_ATTACHMENT}); + // Leave it cleared to black — Diff mode against target would need more work + + const mk = (fmt, tw, th) => this.device.createTexture({size:[tw,th], format:fmt, + usage:GPUTextureUsage.STORAGE_BINDING|GPUTextureUsage.TEXTURE_BINDING|GPUTextureUsage.COPY_SRC}); + const e0=mk('rgba16float',w,h), e1=mk('rgba32uint',W2,H2); + const bn=mk('rgba32uint',W4,H4), d1=mk('rgba16float',W2,H2), ot=mk('rgba16float',w,h); + + if (!this.weightsGPU) { + this.weightsGPU = this.device.createBuffer({size:this.weightsBuffer.byteLength, + usage:GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST}); + this.device.queue.writeBuffer(this.weightsGPU, 0, this.weightsBuffer); + } + const wg = this.weightsGPU; + const fp = this.filmParams(); + const wu = (data) => { + const b = this.device.createBuffer({size:data.byteLength, usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST}); + this.device.queue.writeBuffer(b, 0, data); return b; + }; + const uE0=wu(this.u4(ENC0_OFF,fp.ge0,fp.be0)); + const uE1=wu(this.u8(ENC1_OFF,fp.ge1,fp.be1)); + const uBN=wu(this.ubn(BN_OFF)); + const uD1=wu(this.u4(DEC1_OFF,fp.gd1,fp.bd1)); + const uD0=wu(this.u4(DEC0_OFF,fp.gd0,fp.bd0)); + const dispData=new ArrayBuffer(16); + new DataView(dispData).setFloat32(4, this.blend, true); + const uDp=wu(dispData); + + const enc = this.device.createCommandEncoder(); + const bg = (pl,...entries) => this.device.createBindGroup({layout:pl.getBindGroupLayout(0), + entries:entries.map((r,i)=>({binding:i,resource:r}))}); + const rv = (t) => t.createView(); + const cp = (pl,bgr,wx,wy) => {const p=enc.beginComputePass();p.setPipeline(pl);p.setBindGroup(0,bgr);p.dispatchWorkgroups(wx,wy);p.end();}; + const ceil8 = (n) => Math.ceil(n/8); + + cp(this.getEnc0(), bg(this.getEnc0(), rv(f0),rv(f1),{buffer:wg},{buffer:uE0},rv(e0)), ceil8(w), ceil8(h)); + cp(this.getEnc1(), bg(this.getEnc1(), rv(e0),{buffer:wg},{buffer:uE1},rv(e1)), ceil8(W2), ceil8(H2)); + cp(this.getBN(), bg(this.getBN(), rv(e1),{buffer:wg},{buffer:uBN},rv(bn)), ceil8(W4), ceil8(H4)); + cp(this.getDec1(), bg(this.getDec1(), rv(bn),rv(e1),{buffer:wg},{buffer:uD1},rv(d1)), ceil8(W2), ceil8(H2)); + cp(this.getDec0(), bg(this.getDec0(), rv(d1),rv(e0),{buffer:wg},{buffer:uD0},rv(ot)), ceil8(w), ceil8(h)); + + const dbg = bg(this.getDisp(), rv(ot), rv(this.inputTex), {buffer:uDp}); + const rp = enc.beginRenderPass({colorAttachments:[{ + view:this.context.getCurrentTexture().createView(), loadOp:'clear', storeOp:'store'}]}); + rp.setPipeline(this.getDisp()); rp.setBindGroup(0, dbg); rp.draw(6); rp.end(); + + this.device.queue.submit([enc.finish()]); + await this.device.queue.onSubmittedWorkDone(); + + [uE0,uE1,uBN,uD1,uD0].forEach(b => b.destroy()); + + // Compute PSNR against target if available + let psnrStr = ''; + if (this.targetBitmap) { + this.showTarget(); + try { psnrStr = await this.computePSNR(ot, w, h); } catch(_) {} + } + + this.destroyLayerTex(); + this.layerTextures = {feat0:f0, feat1:f1, enc0:e0, enc1:e1, bn, dec1:d1, output:ot}; + this.lastResult = {ot, itex:this.inputTex, uDp, dispPL:this.getDisp(), w, h}; + this.updateVizPanel(); + + const ms = (performance.now()-t0).toFixed(1); + document.getElementById('cnnLabel').textContent = `CNN output (${ms}ms)`; + if (psnrStr) document.getElementById('psnrSt').textContent = psnrStr; + this.setStatus(`Sample: ${ms}ms · ${w}×${h}`); + this.log(`runFromFeat: ${ms}ms`); + } + + // Compute PSNR between CNN rgba16float output texture and target.png bitmap. + async computePSNR(outTex, w, h) { + const bpr = Math.ceil(w * 8 / 256) * 256; + const stg = this.device.createBuffer({size:bpr*h, + usage:GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ}); + const enc = this.device.createCommandEncoder(); + enc.copyTextureToBuffer({texture:outTex}, {buffer:stg, bytesPerRow:bpr, rowsPerImage:h}, [w,h]); + this.device.queue.submit([enc.finish()]); + await stg.mapAsync(GPUMapMode.READ); + const raw = new DataView(stg.getMappedRange()); + + // Decode output pixels from f16 + const f16 = (bits) => { + const s=(bits>>15)&1, e=(bits>>10)&0x1F, m=bits&0x3FF; + if(e===0) return 0; if(e===31) return s?0:1; + return Math.max(0,Math.min(1,(s?-1:1)*Math.pow(2,e-15)*(1+m/1024))); + }; + const cnnPx = new Float32Array(w*h*3); + for (let y=0;y<h;y++) for (let x=0;x<w;x++) { + const src=y*bpr+x*8, pi=(y*w+x)*3; + cnnPx[pi] = f16(raw.getUint16(src, true)); + cnnPx[pi+1]= f16(raw.getUint16(src+2, true)); + cnnPx[pi+2]= f16(raw.getUint16(src+4, true)); + } + stg.unmap(); stg.destroy(); + + // Read target pixels via offscreen canvas + const oc = document.createElement('canvas'); + oc.width = w; oc.height = h; + const ctx2d = oc.getContext('2d'); + ctx2d.drawImage(this.targetBitmap, 0, 0, w, h); + const tgtData = ctx2d.getImageData(0, 0, w, h).data; + + let mse = 0; + const n = w * h * 3; + for (let i=0; i<w*h; i++) { + const dr = cnnPx[i*3] - tgtData[i*4] /255; + const dg = cnnPx[i*3+1] - tgtData[i*4+1]/255; + const db = cnnPx[i*3+2] - tgtData[i*4+2]/255; + mse += dr*dr + dg*dg + db*db; + } + mse /= n; + const psnr = mse > 0 ? (10 * Math.log10(1 / mse)).toFixed(2) : '∞'; + return `MSE=${mse.toFixed(5)} PSNR=${psnr}dB`; + } } // ── UI helpers ─────────────────────────────────────────────────────────────── |
