summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 21:03:46 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 21:03:46 +0100
commit9bcdc6fdaa7e1b4a20ec0c86d521af69f4c13c62 (patch)
tree7ebaebe08d95a049ed9ea29b55663bacb44c4807 /tools
parent60fe2ae74267eba1c33b8e00f0f4d6906cc6eea3 (diff)
CNN v2 web tool: Multiple fixes for feature parity with cnn_test
Changes: - Static shader: Point sampler (nearest filter) instead of linear - Mip handling: Use textureSampleLevel with point sampler (fixes coordinate scaling) - Save PNG: GPU readback via staging buffer (WebGPU canvas lacks toBlob support) - Depth binding: Use input texture as depth (matches C++ simplification) - Header offset: Version-aware calculation (v1=4, v2=5 u32) Known issue: Output still differs from cnn_test (color tones). Root cause TBD. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'tools')
-rw-r--r--tools/cnn_v2_test/index.html127
1 files changed, 87 insertions, 40 deletions
diff --git a/tools/cnn_v2_test/index.html b/tools/cnn_v2_test/index.html
index 9c2506b..682cb2a 100644
--- a/tools/cnn_v2_test/index.html
+++ b/tools/cnn_v2_test/index.html
@@ -408,11 +408,10 @@ fn vs_main(@builtin(vertex_index) idx: u32) -> @builtin(position) vec4<f32> {
// Static features: 7D parametric features (RGBD + UV + sin(10*uv_x) + bias)
const STATIC_SHADER = `
@group(0) @binding(0) var input_tex: texture_2d<f32>;
-@group(0) @binding(1) var input_tex_mip1: texture_2d<f32>;
-@group(0) @binding(2) var input_tex_mip2: texture_2d<f32>;
-@group(0) @binding(3) var depth_tex: texture_2d<f32>;
-@group(0) @binding(4) var output_tex: texture_storage_2d<rgba32uint, write>;
-@group(0) @binding(5) var<uniform> mip_level: u32;
+@group(0) @binding(1) var point_sampler: sampler;
+@group(0) @binding(2) var depth_tex: texture_2d<f32>;
+@group(0) @binding(3) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(4) var<uniform> mip_level: u32;
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
@@ -420,16 +419,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let dims = textureDimensions(input_tex);
if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; }
- var rgba: vec4<f32>;
- if (mip_level == 0u) {
- rgba = textureLoad(input_tex, coord, 0);
- } else if (mip_level == 1u) {
- rgba = textureLoad(input_tex_mip1, coord, 0);
- } else if (mip_level == 2u) {
- rgba = textureLoad(input_tex_mip2, coord, 0);
- } else {
- rgba = textureLoad(input_tex_mip2, coord, 0);
- }
+ // Use normalized UV coords with point sampler (no filtering)
+ let uv = (vec2<f32>(coord) + 0.5) / vec2<f32>(dims);
+ let rgba = textureSampleLevel(input_tex, point_sampler, uv, f32(mip_level));
let p0 = rgba.r;
let p1 = rgba.g;
@@ -1145,19 +1137,6 @@ class CNNTester {
// Generate mipmaps
this.generateMipmaps(this.inputTexture, width, height);
- const depthTex = this.device.createTexture({
- size: [width, height],
- format: 'r32float',
- usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST
- });
- const depthData = new Float32Array(width * height).fill(this.depth);
- this.device.queue.writeTexture(
- { texture: depthTex },
- depthData,
- { bytesPerRow: width * 4 },
- [width, height]
- );
-
const staticTex = this.device.createTexture({
size: [width, height],
format: 'rgba32uint',
@@ -1223,15 +1202,22 @@ class CNNTester {
});
this.device.queue.writeBuffer(mipLevelBuffer, 0, new Uint32Array([this.mipLevel]));
+ if (!this.pointSampler) {
+ this.pointSampler = this.device.createSampler({
+ magFilter: 'nearest',
+ minFilter: 'nearest',
+ mipmapFilter: 'nearest'
+ });
+ }
+
const staticBG = this.device.createBindGroup({
layout: staticPipeline.getBindGroupLayout(0),
entries: [
- { binding: 0, resource: this.inputTexture.createView({ baseMipLevel: 0, mipLevelCount: 1 }) },
- { binding: 1, resource: this.inputTexture.createView({ baseMipLevel: 1, mipLevelCount: 1 }) },
- { binding: 2, resource: this.inputTexture.createView({ baseMipLevel: 2, mipLevelCount: 1 }) },
- { binding: 3, resource: depthTex.createView() },
- { binding: 4, resource: staticTex.createView() },
- { binding: 5, resource: { buffer: mipLevelBuffer } }
+ { binding: 0, resource: this.inputTexture.createView() },
+ { binding: 1, resource: this.pointSampler },
+ { binding: 2, resource: this.inputTexture.createView() }, // Use input as depth (matches C++)
+ { binding: 3, resource: staticTex.createView() },
+ { binding: 4, resource: { buffer: mipLevelBuffer } }
]
});
@@ -1735,17 +1721,78 @@ class CNNTester {
return;
}
+ if (!this.resultTexture) {
+ this.log('No result to save', 'error');
+ return;
+ }
+
try {
- const blob = await new Promise((resolve, reject) => {
- this.canvas.toBlob(blob => {
- if (blob) resolve(blob);
- else reject(new Error('Failed to create blob'));
- }, 'image/png');
+ const { width, height } = this.getDimensions();
+
+ // GPU readback from result texture
+ const bytesPerRow = width * 16; // 4×u32 per pixel
+ const paddedBytesPerRow = Math.ceil(bytesPerRow / 256) * 256;
+ const bufferSize = paddedBytesPerRow * height;
+
+ const stagingBuffer = this.device.createBuffer({
+ size: bufferSize,
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
});
+ const encoder = this.device.createCommandEncoder();
+ encoder.copyTextureToBuffer(
+ { texture: this.resultTexture },
+ { buffer: stagingBuffer, bytesPerRow: paddedBytesPerRow, rowsPerImage: height },
+ { width, height, depthOrArrayLayers: 1 }
+ );
+ this.device.queue.submit([encoder.finish()]);
+
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
+ const mapped = new Uint8Array(stagingBuffer.getMappedRange());
+
+ // Unpack f16 to RGBA8
+ const pixels = new Uint8Array(width * height * 4);
+ for (let y = 0; y < height; y++) {
+ const rowOffset = y * paddedBytesPerRow;
+ for (let x = 0; x < width; x++) {
+ const pixelOffset = rowOffset + x * 16;
+ const data = new Uint32Array(mapped.buffer, mapped.byteOffset + pixelOffset, 4);
+
+ // Unpack f16 (first 4 channels only)
+ const unpack = (u32, idx) => {
+ const h = (idx === 0) ? (u32 & 0xFFFF) : ((u32 >> 16) & 0xFFFF);
+ const sign = (h >> 15) & 1;
+ const exp = (h >> 10) & 0x1F;
+ const frac = h & 0x3FF;
+ if (exp === 0) return 0;
+ if (exp === 31) return sign ? 0 : 255;
+ const e = exp - 15;
+ const val = (1 + frac / 1024) * Math.pow(2, e);
+ return Math.max(0, Math.min(255, Math.round(val * 255)));
+ };
+
+ const outIdx = (y * width + x) * 4;
+ pixels[outIdx + 0] = unpack(data[0], 0); // R
+ pixels[outIdx + 1] = unpack(data[0], 1); // G
+ pixels[outIdx + 2] = unpack(data[1], 0); // B
+ pixels[outIdx + 3] = 255; // A
+ }
+ }
+
+ stagingBuffer.unmap();
+ stagingBuffer.destroy();
+
+ // Create blob from pixels
+ const canvas = document.createElement('canvas');
+ canvas.width = width;
+ canvas.height = height;
+ const ctx = canvas.getContext('2d');
+ const imageData = new ImageData(new Uint8ClampedArray(pixels), width, height);
+ ctx.putImageData(imageData, 0, 0);
+
+ const blob = await new Promise(resolve => canvas.toBlob(resolve, 'image/png'));
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
- const { width, height } = this.getDimensions();
const mode = ['cnn', 'original', 'diff'][this.viewMode];
a.href = url;
a.download = `output_${width}x${height}_${mode}.png`;