summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
Diffstat (limited to 'tools')
-rw-r--r--tools/cnn_v2_test/index.html96
1 files changed, 70 insertions, 26 deletions
diff --git a/tools/cnn_v2_test/index.html b/tools/cnn_v2_test/index.html
index 6d3f223..9c2506b 100644
--- a/tools/cnn_v2_test/index.html
+++ b/tools/cnn_v2_test/index.html
@@ -323,6 +323,7 @@
<input type="range" id="depth" min="0" max="1" step="0.01" value="1.0">
<span id="depthValue">1.0</span>
</div>
+ <button id="savePngBtn">Save PNG</button>
</div>
</div>
<video id="videoSource" muted loop></video>
@@ -407,10 +408,11 @@ fn vs_main(@builtin(vertex_index) idx: u32) -> @builtin(position) vec4<f32> {
// Static features: 7D parametric features (RGBD + UV + sin(10*uv_x) + bias)
const STATIC_SHADER = `
@group(0) @binding(0) var input_tex: texture_2d<f32>;
-@group(0) @binding(1) var input_sampler: sampler;
-@group(0) @binding(2) var depth_tex: texture_2d<f32>;
-@group(0) @binding(3) var output_tex: texture_storage_2d<rgba32uint, write>;
-@group(0) @binding(4) var<uniform> mip_level: u32;
+@group(0) @binding(1) var input_tex_mip1: texture_2d<f32>;
+@group(0) @binding(2) var input_tex_mip2: texture_2d<f32>;
+@group(0) @binding(3) var depth_tex: texture_2d<f32>;
+@group(0) @binding(4) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(5) var<uniform> mip_level: u32;
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
@@ -418,18 +420,32 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let dims = textureDimensions(input_tex);
if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; }
- let uv = (vec2<f32>(coord) + 0.5) / vec2<f32>(dims);
- let rgba = textureSampleLevel(input_tex, input_sampler, uv, f32(mip_level));
- let d = textureLoad(depth_tex, coord, 0).r;
+ var rgba: vec4<f32>;
+ if (mip_level == 0u) {
+ rgba = textureLoad(input_tex, coord, 0);
+ } else if (mip_level == 1u) {
+ rgba = textureLoad(input_tex_mip1, coord, 0);
+ } else if (mip_level == 2u) {
+ rgba = textureLoad(input_tex_mip2, coord, 0);
+ } else {
+ rgba = textureLoad(input_tex_mip2, coord, 0);
+ }
+
+ let p0 = rgba.r;
+ let p1 = rgba.g;
+ let p2 = rgba.b;
+ let p3 = textureLoad(depth_tex, coord, 0).r;
+
let uv_x = f32(coord.x) / f32(dims.x);
let uv_y = f32(coord.y) / f32(dims.y);
let sin20_y = sin(20.0 * uv_y);
+ let bias = 1.0;
let packed = vec4<u32>(
- pack2x16float(vec2<f32>(rgba.r, rgba.g)),
- pack2x16float(vec2<f32>(rgba.b, d)),
+ pack2x16float(vec2<f32>(p0, p1)),
+ pack2x16float(vec2<f32>(p2, p3)),
pack2x16float(vec2<f32>(uv_x, uv_y)),
- pack2x16float(vec2<f32>(sin20_y, 1.0))
+ pack2x16float(vec2<f32>(sin20_y, bias))
);
textureStore(output_tex, coord, packed);
}`;
@@ -790,7 +806,7 @@ class CNNTester {
}
this.log(` Weight buffer: ${weights.length} u32 (${nonZero} non-zero)`);
- return { layers, weights, mipLevel, fileSize: buffer.byteLength };
+ return { version, layers, weights, mipLevel, fileSize: buffer.byteLength };
}
unpackF16(packed) {
@@ -924,11 +940,12 @@ class CNNTester {
updateWeightsPanel() {
const panel = document.getElementById('weightsInfo');
- const { layers, mipLevel, fileSize } = this.weights;
+ const { version, layers, mipLevel, fileSize } = this.weights;
let html = `
<div style="margin-bottom: 12px;">
<div><strong>File Size:</strong> ${(fileSize / 1024).toFixed(2)} KB</div>
+ <div><strong>Version:</strong> ${version}</div>
<div><strong>CNN Layers:</strong> ${layers.length}</div>
<div><strong>Mip Level:</strong> ${mipLevel} (p0-p3 features)</div>
<div style="font-size: 9px; color: #808080; margin-top: 4px;">Static features (input) + ${layers.length} conv layers</div>
@@ -1206,22 +1223,15 @@ class CNNTester {
});
this.device.queue.writeBuffer(mipLevelBuffer, 0, new Uint32Array([this.mipLevel]));
- if (!this.linearSampler) {
- this.linearSampler = this.device.createSampler({
- magFilter: 'linear',
- minFilter: 'linear',
- mipmapFilter: 'linear'
- });
- }
-
const staticBG = this.device.createBindGroup({
layout: staticPipeline.getBindGroupLayout(0),
entries: [
- { binding: 0, resource: this.inputTexture.createView() },
- { binding: 1, resource: this.linearSampler },
- { binding: 2, resource: depthTex.createView() },
- { binding: 3, resource: staticTex.createView() },
- { binding: 4, resource: { buffer: mipLevelBuffer } }
+ { binding: 0, resource: this.inputTexture.createView({ baseMipLevel: 0, mipLevelCount: 1 }) },
+ { binding: 1, resource: this.inputTexture.createView({ baseMipLevel: 1, mipLevelCount: 1 }) },
+ { binding: 2, resource: this.inputTexture.createView({ baseMipLevel: 2, mipLevelCount: 1 }) },
+ { binding: 3, resource: depthTex.createView() },
+ { binding: 4, resource: staticTex.createView() },
+ { binding: 5, resource: { buffer: mipLevelBuffer } }
]
});
@@ -1247,7 +1257,9 @@ class CNNTester {
const isOutput = i === this.weights.layers.length - 1;
// Calculate absolute weight offset in f16 units (add header offset)
- const headerOffsetU32 = 4 + this.weights.layers.length * 5; // Header + layer info in u32
+ // Version 1: 4 u32 header, Version 2: 5 u32 header
+ const headerSizeU32 = (this.weights.version === 1) ? 4 : 5;
+ const headerOffsetU32 = headerSizeU32 + this.weights.layers.length * 5; // Header + layer info in u32
const absoluteWeightOffset = headerOffsetU32 * 2 + layer.weightOffset; // Convert to f16 units
const paramsData = new Uint32Array(7);
@@ -1716,6 +1728,37 @@ class CNNTester {
this.device.queue.submit([encoder.finish()]);
}
+
+ async savePNG() {
+ if (!this.image && !this.isVideo) {
+ this.log('No image loaded', 'error');
+ return;
+ }
+
+ try {
+ const blob = await new Promise((resolve, reject) => {
+ this.canvas.toBlob(blob => {
+ if (blob) resolve(blob);
+ else reject(new Error('Failed to create blob'));
+ }, 'image/png');
+ });
+
+ const url = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ const { width, height } = this.getDimensions();
+ const mode = ['cnn', 'original', 'diff'][this.viewMode];
+ a.href = url;
+ a.download = `output_${width}x${height}_${mode}.png`;
+ a.click();
+ URL.revokeObjectURL(url);
+
+ this.log(`Saved PNG: ${a.download}`);
+ this.setStatus(`Saved: ${a.download}`);
+ } catch (err) {
+ this.log(`Failed to save PNG: ${err.message}`, 'error');
+ this.setStatus(`Save failed: ${err.message}`, true);
+ }
+ }
}
const tester = new CNNTester();
@@ -1815,6 +1858,7 @@ document.getElementById('mipLevel').addEventListener('change', e => {
document.getElementById('playPauseBtn').addEventListener('click', () => tester.togglePlayPause());
document.getElementById('stepBackBtn').addEventListener('click', () => tester.stepFrame(-1));
document.getElementById('stepForwardBtn').addEventListener('click', () => tester.stepFrame(1));
+document.getElementById('savePngBtn').addEventListener('click', () => tester.savePNG());
document.addEventListener('keydown', e => {
if (e.code === 'Space') {