summaryrefslogtreecommitdiff
path: root/cnn_v3/tools/tester.js
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/tools/tester.js')
-rw-r--r--cnn_v3/tools/tester.js88
1 files changed, 51 insertions, 37 deletions
diff --git a/cnn_v3/tools/tester.js b/cnn_v3/tools/tester.js
index 69f358b..ebe888a 100644
--- a/cnn_v3/tools/tester.js
+++ b/cnn_v3/tools/tester.js
@@ -105,9 +105,9 @@ class CNNv3Tester {
const u32 = new Uint32Array(buf);
if (u32.length < TOTAL_U32) throw new Error(`Too small: ${u32.length} u32, need ${TOTAL_U32}`);
const layers = [
- {n:'enc0',off:ENC0_OFF,cnt:724},{n:'enc1',off:ENC1_OFF,cnt:296},
- {n:'bn', off:BN_OFF, cnt:584},{n:'dec1',off:DEC1_OFF,cnt:580},
- {n:'dec0',off:DEC0_OFF,cnt:292},
+ {n:'enc0',off:ENC0_OFF,cnt:1448},{n:'enc1',off:ENC1_OFF,cnt:1168},
+ {n:'bn', off:BN_OFF, cnt:2320},{n:'dec1',off:DEC1_OFF,cnt:2312},
+ {n:'dec0',off:DEC0_OFF,cnt:580},
];
let html=`<div style="margin-bottom:7px"><b>Size:</b> ${(buf.byteLength/1024).toFixed(1)} KB &nbsp; <b>Weights:</b> ${TOTAL_F16} f16</div>
<table><thead><tr><th>Layer</th><th>Offset</th><th>Count</th><th>Min</th><th>Max</th></tr></thead><tbody>`;
@@ -126,11 +126,11 @@ class CNNv3Tester {
parseFilm(buf) {
const f32=new Float32Array(buf);
- if (f32.length < 776) throw new Error(`FiLM too small: ${f32.length}`);
+ if (f32.length < 1320) throw new Error(`FiLM too small: ${f32.length}`);
let o=0;
- const l0w=f32.slice(o,o+=80), l0b=f32.slice(o,o+=16);
- const l1w=f32.slice(o,o+=640),l1b=f32.slice(o,o+=40);
- this.log(`FiLM MLP: L0(16×5) L1(40×16), ${f32.length} f32`);
+ const l0w=f32.slice(o,o+=80), l0b=f32.slice(o,o+=16);
+ const l1w=f32.slice(o,o+=1152),l1b=f32.slice(o,o+=72);
+ this.log(`FiLM MLP: L0(16×5) L1(72×16), ${f32.length} f32`);
return {l0w,l0b,l1w,l1b};
}
@@ -138,22 +138,24 @@ class CNNv3Tester {
const {l0w,l0b,l1w,l1b}=this.filmMlp;
const h=new Float32Array(16);
for(let j=0;j<16;j++){let s=l0b[j];for(let i=0;i<5;i++)s+=l0w[j*5+i]*cond[i];h[j]=Math.max(0,s);}
- const o=new Float32Array(40);
- for(let j=0;j<40;j++){let s=l1b[j];for(let i=0;i<16;i++)s+=l1w[j*16+i]*h[i];o[j]=s;}
+ const o=new Float32Array(72);
+ for(let j=0;j<72;j++){let s=l1b[j];for(let i=0;i<16;i++)s+=l1w[j*16+i]*h[i];o[j]=s;}
return o;
}
filmParams() {
- const I4=[1,1,1,1],Z4=[0,0,0,0],I8=[1,1,1,1,1,1,1,1],Z8=[0,0,0,0,0,0,0,0];
- if (!this.filmMlp) return {ge0:I4,be0:Z4,ge1:I8,be1:Z8,gd1:I4,bd1:Z4,gd0:I4,bd0:Z4};
+ const I4=Array(4).fill(1),Z4=Array(4).fill(0);
+ const I8=Array(8).fill(1),Z8=Array(8).fill(0);
+ const I16=Array(16).fill(1),Z16=Array(16).fill(0);
+ if (!this.filmMlp) return {ge0:I8,be0:Z8,ge1:I16,be1:Z16,gd1:I8,bd1:Z8,gd0:I4,bd0:Z4};
const v=document.getElementById.bind(document);
const cond=[v('sBP').value,v('sBN').value,v('sAI').value,v('sP0').value,v('sP1').value].map(Number);
const f=this.filmFwd(cond);
return {
- ge0:[...f.slice(0,4)], be0:[...f.slice(4,8)],
- ge1:[...f.slice(8,16)],be1:[...f.slice(16,24)],
- gd1:[...f.slice(24,28)],bd1:[...f.slice(28,32)],
- gd0:[...f.slice(32,36)],bd0:[...f.slice(36,40)],
+ ge0:[...f.slice(0,8)], be0:[...f.slice(8,16)],
+ ge1:[...f.slice(16,32)],be1:[...f.slice(32,48)],
+ gd1:[...f.slice(48,56)],bd1:[...f.slice(56,64)],
+ gd0:[...f.slice(64,68)],bd0:[...f.slice(68,72)],
};
}
@@ -177,6 +179,14 @@ class CNNv3Tester {
for(let i=0;i<4;i++)v.setFloat32(64+i*4,b[i+4],true);
return buf;
}
+ // Params16 (144 bytes): wo u32 _pad×3 gamma[16] beta[16] vec4f×8
+ u16(wo,g,b){
+ const buf=new ArrayBuffer(144),v=new DataView(buf);
+ v.setUint32(0,wo,true);
+ for(let i=0;i<16;i++)v.setFloat32(16+i*4,g[i],true);
+ for(let i=0;i<16;i++)v.setFloat32(80+i*4,b[i],true);
+ return buf;
+ }
// ParamsBN (16 bytes): wo u32 _pad×3
ubn(wo){const buf=new ArrayBuffer(16);new DataView(buf).setUint32(0,wo,true);return buf;}
@@ -330,8 +340,10 @@ class CNNv3Tester {
const mk=(fmt,tw,th)=>this.device.createTexture({size:[tw,th],format:fmt,
usage:GPUTextureUsage.STORAGE_BINDING|GPUTextureUsage.TEXTURE_BINDING|GPUTextureUsage.COPY_SRC});
const f0=mk('rgba32uint',w,h),f1=mk('rgba32uint',w,h);
- const e0=mk('rgba16float',w,h),e1=mk('rgba32uint',W2,H2);
- const bn=mk('rgba32uint',W4,H4),d1=mk('rgba16float',W2,H2),ot=mk('rgba16float',w,h);
+ const e0=mk('rgba32uint',w,h); // 8ch
+ const e1_lo=mk('rgba32uint',W2,H2),e1_hi=mk('rgba32uint',W2,H2); // 16ch split
+ const bn_lo=mk('rgba32uint',W4,H4),bn_hi=mk('rgba32uint',W4,H4); // 16ch split
+ const d1=mk('rgba32uint',W2,H2),ot=mk('rgba16float',w,h); // d1=8ch
// Weights GPU buffer (cached)
if(!this.weightsGPU){
@@ -346,11 +358,11 @@ class CNNv3Tester {
const b=this.device.createBuffer({size:data.byteLength,usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST});
this.device.queue.writeBuffer(b,0,data); return b;
};
- const uE0=wu(this.u4(ENC0_OFF,fp.ge0,fp.be0));
- const uE1=wu(this.u8(ENC1_OFF,fp.ge1,fp.be1));
+ const uE0=wu(this.u8( ENC0_OFF,fp.ge0,fp.be0));
+ const uE1=wu(this.u16(ENC1_OFF,fp.ge1,fp.be1));
const uBN=wu(this.ubn(BN_OFF));
- const uD1=wu(this.u4(DEC1_OFF,fp.gd1,fp.bd1));
- const uD0=wu(this.u4(DEC0_OFF,fp.gd0,fp.bd0));
+ const uD1=wu(this.u8( DEC1_OFF,fp.gd1,fp.bd1));
+ const uD0=wu(this.u4( DEC0_OFF,fp.gd0,fp.bd0));
const dispData=new ArrayBuffer(16);
const dispView=new DataView(dispData);
@@ -366,9 +378,9 @@ class CNNv3Tester {
cp(this.getPack(), bg(this.getPack(), rv(this.inputTex),this.linearSampler,rv(f0),rv(f1)), ceil8(w),ceil8(h));
cp(this.getEnc0(), bg(this.getEnc0(), rv(f0),rv(f1),{buffer:wg},{buffer:uE0},rv(e0)), ceil8(w),ceil8(h));
- cp(this.getEnc1(), bg(this.getEnc1(), rv(e0),{buffer:wg},{buffer:uE1},rv(e1)), ceil8(W2),ceil8(H2));
- cp(this.getBN(), bg(this.getBN(), rv(e1),{buffer:wg},{buffer:uBN},rv(bn)), ceil8(W4),ceil8(H4));
- cp(this.getDec1(), bg(this.getDec1(), rv(bn),rv(e1),{buffer:wg},{buffer:uD1},rv(d1)), ceil8(W2),ceil8(H2));
+ cp(this.getEnc1(), bg(this.getEnc1(), rv(e0),{buffer:wg},{buffer:uE1},rv(e1_lo),rv(e1_hi)), ceil8(W2),ceil8(H2));
+ cp(this.getBN(), bg(this.getBN(), rv(e1_lo),rv(e1_hi),{buffer:wg},{buffer:uBN},rv(bn_lo),rv(bn_hi)), ceil8(W4),ceil8(H4));
+ cp(this.getDec1(), bg(this.getDec1(), rv(bn_lo),rv(bn_hi),rv(e1_lo),rv(e1_hi),{buffer:wg},{buffer:uD1},rv(d1)), ceil8(W2),ceil8(H2));
cp(this.getDec0(), bg(this.getDec0(), rv(d1),rv(e0),{buffer:wg},{buffer:uD0},rv(ot)), ceil8(w),ceil8(h));
const dbg=bg(this.getDisp(),rv(ot),rv(this.inputTex),{buffer:uDp});
@@ -387,7 +399,7 @@ class CNNv3Tester {
// Store for layer viz & redisplay
this.destroyLayerTex();
- this.layerTextures={feat0:f0,feat1:f1,enc0:e0,enc1:e1,bn,dec1:d1,dec0:ot};
+ this.layerTextures={feat0:f0,feat1:f1,enc0:e0,enc1:e1_lo,bn:bn_lo,dec1:d1,dec0:ot};
this.lastResult={ot,itex:this.inputTex,uDp,dispPL:this.getDisp(),w,h};
this.updateVizPanel();
this.refreshZoom();
@@ -442,10 +454,10 @@ class CNNv3Tester {
updateVizPanel() {
const DEFS=[
{id:'feat0', lbl:'Feat', t:'u32',nch:8, ch:['alb.r','alb.g','alb.b','nrm.x','nrm.y','depth','dgx','dgy']},
- {id:'enc0', lbl:'Enc0', t:'f32',nch:4, ch:['c0','c1','c2','c3']},
+ {id:'enc0', lbl:'Enc0', t:'u32',nch:8, ch:['c0','c1','c2','c3','c4','c5','c6','c7']},
{id:'enc1', lbl:'Enc1', t:'u32',nch:8, ch:['c0','c1','c2','c3','c4','c5','c6','c7']},
{id:'bn', lbl:'BN', t:'u32',nch:8, ch:['c0','c1','c2','c3','c4','c5','c6','c7']},
- {id:'dec1', lbl:'Dec1', t:'f32',nch:4, ch:['c0','c1','c2','c3']},
+ {id:'dec1', lbl:'Dec1', t:'u32',nch:8, ch:['c0','c1','c2','c3','c4','c5','c6','c7']},
{id:'dec0', lbl:'Dec0', t:'f32',nch:4, ch:['R','G','B','A']},
];
this.vizDefs=DEFS;
@@ -753,8 +765,10 @@ class CNNv3Tester {
const mk = (fmt, tw, th) => this.device.createTexture({size:[tw,th], format:fmt,
usage:GPUTextureUsage.STORAGE_BINDING|GPUTextureUsage.TEXTURE_BINDING|GPUTextureUsage.COPY_SRC});
- const e0=mk('rgba16float',w,h), e1=mk('rgba32uint',W2,H2);
- const bn=mk('rgba32uint',W4,H4), d1=mk('rgba16float',W2,H2), ot=mk('rgba16float',w,h);
+ const e0=mk('rgba32uint',w,h); // 8ch
+ const e1_lo=mk('rgba32uint',W2,H2),e1_hi=mk('rgba32uint',W2,H2); // 16ch split
+ const bn_lo=mk('rgba32uint',W4,H4),bn_hi=mk('rgba32uint',W4,H4); // 16ch split
+ const d1=mk('rgba32uint',W2,H2), ot=mk('rgba16float',w,h); // d1=8ch
if (!this.weightsGPU) {
this.weightsGPU = this.device.createBuffer({size:this.weightsBuffer.byteLength,
@@ -767,11 +781,11 @@ class CNNv3Tester {
const b = this.device.createBuffer({size:data.byteLength, usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST});
this.device.queue.writeBuffer(b, 0, data); return b;
};
- const uE0=wu(this.u4(ENC0_OFF,fp.ge0,fp.be0));
- const uE1=wu(this.u8(ENC1_OFF,fp.ge1,fp.be1));
+ const uE0=wu(this.u8( ENC0_OFF,fp.ge0,fp.be0));
+ const uE1=wu(this.u16(ENC1_OFF,fp.ge1,fp.be1));
const uBN=wu(this.ubn(BN_OFF));
- const uD1=wu(this.u4(DEC1_OFF,fp.gd1,fp.bd1));
- const uD0=wu(this.u4(DEC0_OFF,fp.gd0,fp.bd0));
+ const uD1=wu(this.u8( DEC1_OFF,fp.gd1,fp.bd1));
+ const uD0=wu(this.u4( DEC0_OFF,fp.gd0,fp.bd0));
const dispData=new ArrayBuffer(16);
new DataView(dispData).setFloat32(4, this.blend, true);
const uDp=wu(dispData);
@@ -784,9 +798,9 @@ class CNNv3Tester {
const ceil8 = (n) => Math.ceil(n/8);
cp(this.getEnc0(), bg(this.getEnc0(), rv(f0),rv(f1),{buffer:wg},{buffer:uE0},rv(e0)), ceil8(w), ceil8(h));
- cp(this.getEnc1(), bg(this.getEnc1(), rv(e0),{buffer:wg},{buffer:uE1},rv(e1)), ceil8(W2), ceil8(H2));
- cp(this.getBN(), bg(this.getBN(), rv(e1),{buffer:wg},{buffer:uBN},rv(bn)), ceil8(W4), ceil8(H4));
- cp(this.getDec1(), bg(this.getDec1(), rv(bn),rv(e1),{buffer:wg},{buffer:uD1},rv(d1)), ceil8(W2), ceil8(H2));
+ cp(this.getEnc1(), bg(this.getEnc1(), rv(e0),{buffer:wg},{buffer:uE1},rv(e1_lo),rv(e1_hi)), ceil8(W2), ceil8(H2));
+ cp(this.getBN(), bg(this.getBN(), rv(e1_lo),rv(e1_hi),{buffer:wg},{buffer:uBN},rv(bn_lo),rv(bn_hi)), ceil8(W4), ceil8(H4));
+ cp(this.getDec1(), bg(this.getDec1(), rv(bn_lo),rv(bn_hi),rv(e1_lo),rv(e1_hi),{buffer:wg},{buffer:uD1},rv(d1)), ceil8(W2), ceil8(H2));
cp(this.getDec0(), bg(this.getDec0(), rv(d1),rv(e0),{buffer:wg},{buffer:uD0},rv(ot)), ceil8(w), ceil8(h));
const dbg = bg(this.getDisp(), rv(ot), rv(this.inputTex), {buffer:uDp});
@@ -807,7 +821,7 @@ class CNNv3Tester {
}
this.destroyLayerTex();
- this.layerTextures = {feat0:f0, feat1:f1, enc0:e0, enc1:e1, bn, dec1:d1, output:ot};
+ this.layerTextures = {feat0:f0, feat1:f1, enc0:e0, enc1:e1_lo, bn:bn_lo, dec1:d1, output:ot};
this.lastResult = {ot, itex:this.inputTex, uDp, dispPL:this.getDisp(), w, h};
this.updateVizPanel();
this.refreshZoom();