summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 22:42:45 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 22:42:45 +0100
commitf81a30d15e1e7db0492f45a0b9bec6aaa20ae5c2 (patch)
treedeb202a7d995895ec90e8ddc8c3fbf92082ea434
parent7c1f937222d0e36294ebd25db949c6227aed6985 (diff)
CNN v2: Use alpha channel for p3 depth feature + layer visualization
Training changes (train_cnn_v2.py): - p3 now uses target image alpha channel (depth proxy for 2D images) - Default changed from 0.0 → 1.0 (far plane semantics) - Both PatchDataset and ImagePairDataset updated Test tools (cnn_test.cc): - New load_depth_from_alpha() extracts PNG alpha → p3 texture - Fixed bind group layout: use UnfilterableFloat for R32Float depth - Added --save-intermediates support for CNN v2: * Each layer_N.png shows 4 channels horizontally (1812×345 grayscale) * layers_composite.png stacks all layers vertically (1812×1380) * static_features.png shows 4 feature channels horizontally - Per-channel visualization enables debugging layer-by-layer differences HTML tool (index.html): - Extract alpha channel from input image → depth texture - Matches training data distribution for validation Note: Current weights trained with p3=0 are now mismatched. Both tools use p3=alpha consistently, so outputs remain comparable for debugging. Retrain required for optimal quality. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rw-r--r--tools/cnn_test.cc240
-rw-r--r--tools/cnn_v2_test/index.html31
-rwxr-xr-xtraining/train_cnn_v2.py16
3 files changed, 279 insertions, 8 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc
index 5823110..3fad2ff 100644
--- a/tools/cnn_test.cc
+++ b/tools/cnn_test.cc
@@ -169,6 +169,66 @@ static WGPUTexture load_texture(WGPUDevice device, WGPUQueue queue,
return texture;
}
+// Load PNG alpha channel as depth texture (or 1.0 if no alpha)
+static WGPUTexture load_depth_from_alpha(WGPUDevice device, WGPUQueue queue,
+ const char* path, int width,
+ int height) {
+ int w, h, channels;
+ uint8_t* data = stbi_load(path, &w, &h, &channels, 4);
+ if (!data || w != width || h != height) {
+ fprintf(stderr, "Error: failed to load depth from '%s'\n", path);
+ if (data) stbi_image_free(data);
+ return nullptr;
+ }
+
+ // Extract alpha channel (or use 1.0 if original was RGB)
+ std::vector<float> depth_data(width * height);
+ bool has_alpha = (channels == 4);
+ for (int i = 0; i < width * height; ++i) {
+ // Alpha is in data[i*4+3] (0-255), convert to float [0, 1]
+ // If no alpha channel, default to 1.0 (far plane)
+ depth_data[i] = has_alpha ? (data[i * 4 + 3] / 255.0f) : 1.0f;
+ }
+ stbi_image_free(data);
+
+ // Create R32Float depth texture
+ const WGPUTextureDescriptor depth_desc = {
+ .usage = WGPUTextureUsage_TextureBinding | WGPUTextureUsage_CopyDst,
+ .dimension = WGPUTextureDimension_2D,
+ .size = {static_cast<uint32_t>(width), static_cast<uint32_t>(height), 1},
+ .format = WGPUTextureFormat_R32Float,
+ .mipLevelCount = 1,
+ .sampleCount = 1,
+ };
+ WGPUTexture depth_texture = wgpuDeviceCreateTexture(device, &depth_desc);
+ if (!depth_texture) {
+ fprintf(stderr, "Error: failed to create depth texture\n");
+ return nullptr;
+ }
+
+ // Write depth data
+ const WGPUTexelCopyTextureInfo dst = {
+ .texture = depth_texture,
+ .mipLevel = 0
+ };
+ const WGPUTexelCopyBufferLayout layout = {
+ .bytesPerRow = static_cast<uint32_t>(width * sizeof(float)),
+ .rowsPerImage = static_cast<uint32_t>(height)
+ };
+ const WGPUExtent3D size = {
+ static_cast<uint32_t>(width),
+ static_cast<uint32_t>(height),
+ 1
+ };
+ wgpuQueueWriteTexture(queue, &dst, depth_data.data(),
+ depth_data.size() * sizeof(float), &layout, &size);
+
+ printf("Loaded depth from alpha: %dx%d (%s alpha)\n", width, height,
+ has_alpha ? "has" : "no");
+
+ return depth_texture;
+}
+
// Create CNN render pipeline (5 bindings)
// Takes both intermediate format (RGBA16Float) and final format (BGRA8Unorm)
static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device,
@@ -245,6 +305,57 @@ static bool save_png(const char* path, const std::vector<uint8_t>& pixels,
return true;
}
+// Create horizontal grayscale composite of layer outputs
+// Each layer is already 4x wide (showing 4 channels), stack them vertically
+static bool save_layer_composite(const char* dir, int width, int height, int num_layers) {
+ // Each layer PNG is already 4x wide with 4 channels side-by-side
+ int layer_width = width * 4;
+
+ // Load all layer images (they're already grayscale)
+ std::vector<std::vector<uint8_t>> layers(num_layers);
+ for (int i = 0; i < num_layers; ++i) {
+ char path[512];
+ snprintf(path, sizeof(path), "%s/layer_%d.png", dir, i);
+
+ int w, h, channels;
+ uint8_t* data = stbi_load(path, &w, &h, &channels, 1); // Load as grayscale
+ if (!data || w != layer_width || h != height) {
+ if (data) stbi_image_free(data);
+ fprintf(stderr, "Warning: failed to load layer %d for composite (expected %dx%d, got %dx%d)\n",
+ i, layer_width, height, w, h);
+ return false;
+ }
+
+ layers[i].assign(data, data + (layer_width * height));
+ stbi_image_free(data);
+ }
+
+ // Stack layers vertically
+ int composite_height = height * num_layers;
+ std::vector<uint8_t> composite(layer_width * composite_height);
+
+ for (int layer = 0; layer < num_layers; ++layer) {
+ for (int y = 0; y < height; ++y) {
+ int src_row_offset = y * layer_width;
+ int dst_row_offset = (layer * height + y) * layer_width;
+ memcpy(&composite[dst_row_offset], &layers[layer][src_row_offset], layer_width);
+ }
+ }
+
+ // Save as grayscale PNG (stacked vertically)
+ char composite_path[512];
+ snprintf(composite_path, sizeof(composite_path), "%s/layers_composite.png", dir);
+ if (!stbi_write_png(composite_path, layer_width, composite_height, 1,
+ composite.data(), layer_width)) {
+ fprintf(stderr, "Error: failed to write composite PNG\n");
+ return false;
+ }
+
+ printf("Saved layer composite to '%s' (%dx%d, 4 layers stacked vertically)\n",
+ composite_path, layer_width, composite_height);
+ return true;
+}
+
// Save PPM output (fallback)
static bool save_ppm(const char* path, const std::vector<uint8_t>& pixels,
int width, int height) {
@@ -282,6 +393,7 @@ struct CNNv2LayerParams {
uint32_t weight_offset;
uint32_t is_output_layer;
float blend_amount;
+ uint32_t is_layer_0;
};
struct CNNv2StaticFeatureParams {
@@ -433,6 +545,41 @@ static std::vector<uint8_t> readback_rgba32uint_to_bgra8(
return result;
}
+// Read RGBA32Uint and create 4x wide grayscale composite (each channel side-by-side)
+static std::vector<uint8_t> readback_rgba32uint_to_composite(
+ WGPUDevice device, WGPUQueue queue, WGPUTexture texture,
+ int width, int height) {
+
+ // First get BGRA8 data
+ std::vector<uint8_t> bgra = readback_rgba32uint_to_bgra8(device, queue, texture, width, height);
+ if (bgra.empty()) return {};
+
+ // Create 4x wide grayscale image (one channel per horizontal strip)
+ int composite_width = width * 4;
+ std::vector<uint8_t> composite(composite_width * height);
+
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ int src_idx = (y * width + x) * 4;
+ uint8_t b = bgra[src_idx + 0];
+ uint8_t g = bgra[src_idx + 1];
+ uint8_t r = bgra[src_idx + 2];
+ uint8_t a = bgra[src_idx + 3];
+
+ // Convert each channel to grayscale luminance
+ auto to_gray = [](uint8_t val) -> uint8_t { return val; };
+
+ // Place each channel in its horizontal strip
+ composite[y * composite_width + (0 * width + x)] = to_gray(r); // Channel 0
+ composite[y * composite_width + (1 * width + x)] = to_gray(g); // Channel 1
+ composite[y * composite_width + (2 * width + x)] = to_gray(b); // Channel 2
+ composite[y * composite_width + (3 * width + x)] = to_gray(a); // Channel 3
+ }
+ }
+
+ return composite;
+}
+
// Process image with CNN v2
static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
WGPUInstance instance, WGPUTexture input_texture,
@@ -523,6 +670,18 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
WGPUTextureView static_features_view =
wgpuTextureCreateView(static_features_tex, nullptr);
+ // Load depth from input alpha channel (or 1.0 if no alpha)
+ WGPUTexture depth_texture =
+ load_depth_from_alpha(device, queue, args.input_path, width, height);
+ if (!depth_texture) {
+ wgpuTextureViewRelease(static_features_view);
+ wgpuTextureRelease(static_features_tex);
+ wgpuBufferRelease(weights_buffer);
+ wgpuTextureViewRelease(input_view);
+ return false;
+ }
+ WGPUTextureView depth_view = wgpuTextureCreateView(depth_texture, nullptr);
+
// Create layer textures (ping-pong)
WGPUTexture layer_textures[2] = {
wgpuDeviceCreateTexture(device, &static_desc),
@@ -543,6 +702,8 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
fprintf(stderr, "Error: CNN v2 shaders not available\n");
wgpuTextureViewRelease(static_features_view);
wgpuTextureRelease(static_features_tex);
+ wgpuTextureViewRelease(depth_view);
+ wgpuTextureRelease(depth_texture);
wgpuTextureViewRelease(layer_views[0]);
wgpuTextureViewRelease(layer_views[1]);
wgpuTextureRelease(layer_textures[0]);
@@ -600,7 +761,7 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
static_bgl_entries[3].binding = 3;
static_bgl_entries[3].visibility = WGPUShaderStage_Compute;
- static_bgl_entries[3].texture.sampleType = WGPUTextureSampleType_Float;
+ static_bgl_entries[3].texture.sampleType = WGPUTextureSampleType_UnfilterableFloat;
static_bgl_entries[3].texture.viewDimension = WGPUTextureViewDimension_2D;
static_bgl_entries[4].binding = 4;
@@ -651,7 +812,7 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
static_bg_entries[2].binding = 2;
static_bg_entries[2].textureView = input_view;
static_bg_entries[3].binding = 3;
- static_bg_entries[3].textureView = input_view; // Depth (use input)
+ static_bg_entries[3].textureView = depth_view; // Depth from alpha channel (matches training)
static_bg_entries[4].binding = 4;
static_bg_entries[4].textureView = static_features_view;
static_bg_entries[5].binding = 5;
@@ -769,6 +930,43 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
wgpuComputePassEncoderEnd(static_pass);
wgpuComputePassEncoderRelease(static_pass);
+ // Save static features if requested
+ if (args.save_intermediates) {
+ // Submit and wait for static features to complete
+ WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, nullptr);
+ wgpuQueueSubmit(queue, 1, &cmd);
+ wgpuCommandBufferRelease(cmd);
+ wgpuDevicePoll(device, true, nullptr);
+
+ // Create new encoder for layers
+ encoder = wgpuDeviceCreateCommandEncoder(device, nullptr);
+
+ char layer_path[512];
+ snprintf(layer_path, sizeof(layer_path), "%s/static_features.png",
+ args.save_intermediates);
+ printf("Saving static features to '%s'...\n", layer_path);
+
+ // Read back RGBA32Uint and create 8-channel grayscale composite
+ // Static features has 8 channels (packed as 4×u32), create 8x wide composite
+ std::vector<uint8_t> bgra = readback_rgba32uint_to_bgra8(
+ device, queue, static_features_tex, width, height);
+
+ if (!bgra.empty()) {
+ // Static features: 8 f16 values packed in 4×u32
+ // For now, just show first 4 channels (like layers)
+ // TODO: Show all 8 channels in 8x wide composite
+ std::vector<uint8_t> composite = readback_rgba32uint_to_composite(
+ device, queue, static_features_tex, width, height);
+ if (!composite.empty()) {
+ int composite_width = width * 4;
+ if (!stbi_write_png(layer_path, composite_width, height, 1,
+ composite.data(), composite_width)) {
+ fprintf(stderr, "Error: failed to write static features PNG\n");
+ }
+ }
+ }
+ }
+
// Pass 2-N: CNN layers
for (size_t i = 0; i < layer_info.size(); ++i) {
const CNNv2LayerInfo& info = layer_info[i];
@@ -785,6 +983,7 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
params.weight_offset = info.weight_offset;
params.is_output_layer = (i == layer_info.size() - 1) ? 1 : 0;
params.blend_amount = args.blend;
+ params.is_layer_0 = (i == 0) ? 1 : 0;
wgpuQueueWriteBuffer(queue, layer_params_buffers[i], 0, &params,
sizeof(params));
@@ -831,6 +1030,36 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
wgpuComputePassEncoderEnd(layer_pass);
wgpuComputePassEncoderRelease(layer_pass);
wgpuBindGroupRelease(layer_bg);
+
+ // Save intermediate layer if requested
+ if (args.save_intermediates) {
+ // Submit and wait for layer to complete
+ WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, nullptr);
+ wgpuQueueSubmit(queue, 1, &cmd);
+ wgpuCommandBufferRelease(cmd);
+ wgpuDevicePoll(device, true, nullptr);
+
+ // Create new encoder for next layer
+ encoder = wgpuDeviceCreateCommandEncoder(device, nullptr);
+
+ char layer_path[512];
+ snprintf(layer_path, sizeof(layer_path), "%s/layer_%zu.png",
+ args.save_intermediates, i);
+ printf("Saving intermediate layer %zu to '%s'...\n", i, layer_path);
+
+ // Read back RGBA32Uint and create 4-channel grayscale composite
+ WGPUTexture output_tex = layer_textures[(i + 1) % 2];
+ std::vector<uint8_t> composite = readback_rgba32uint_to_composite(
+ device, queue, output_tex, width, height);
+
+ if (!composite.empty()) {
+ int composite_width = width * 4;
+ if (!stbi_write_png(layer_path, composite_width, height, 1,
+ composite.data(), composite_width)) {
+ fprintf(stderr, "Error: failed to write layer PNG\n");
+ }
+ }
+ }
}
WGPUCommandBuffer commands = wgpuCommandEncoderFinish(encoder, nullptr);
@@ -840,6 +1069,11 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
wgpuDevicePoll(device, true, nullptr);
+ // Create layer composite if intermediates were saved
+ if (args.save_intermediates) {
+ save_layer_composite(args.save_intermediates, width, height, layer_info.size());
+ }
+
// Readback final result (from last layer's output texture)
printf("Reading pixels from GPU...\n");
size_t final_layer_idx = (layer_info.size()) % 2;
@@ -856,6 +1090,8 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
wgpuBufferRelease(static_params_buffer);
wgpuTextureViewRelease(static_features_view);
wgpuTextureRelease(static_features_tex);
+ wgpuTextureViewRelease(depth_view);
+ wgpuTextureRelease(depth_texture);
wgpuTextureViewRelease(layer_views[0]);
wgpuTextureViewRelease(layer_views[1]);
wgpuTextureRelease(layer_textures[0]);
diff --git a/tools/cnn_v2_test/index.html b/tools/cnn_v2_test/index.html
index 9636ecf..ca89fb4 100644
--- a/tools/cnn_v2_test/index.html
+++ b/tools/cnn_v2_test/index.html
@@ -1211,12 +1211,41 @@ class CNNTester {
});
}
+ // Extract depth from alpha channel (or 1.0 if no alpha)
+ const depthTex = this.device.createTexture({
+ size: [width, height, 1],
+ format: 'r32float',
+ usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST
+ });
+
+ // Read image data to extract alpha channel
+ const tempCanvas = document.createElement('canvas');
+ tempCanvas.width = width;
+ tempCanvas.height = height;
+ const tempCtx = tempCanvas.getContext('2d');
+ tempCtx.drawImage(source, 0, 0, width, height);
+ const imageData = tempCtx.getImageData(0, 0, width, height);
+ const pixels = imageData.data;
+
+ // Extract alpha channel (RGBA format: every 4th byte)
+ const depthData = new Float32Array(width * height);
+ for (let i = 0; i < width * height; i++) {
+ depthData[i] = pixels[i * 4 + 3] / 255.0; // Alpha channel [0, 255] → [0, 1]
+ }
+
+ this.device.queue.writeTexture(
+ { texture: depthTex },
+ depthData,
+ { bytesPerRow: width * 4 },
+ [width, height, 1]
+ );
+
const staticBG = this.device.createBindGroup({
layout: staticPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: this.inputTexture.createView() },
{ binding: 1, resource: this.pointSampler },
- { binding: 2, resource: this.inputTexture.createView() }, // Use input as depth (matches C++)
+ { binding: 2, resource: depthTex.createView() }, // Depth from alpha (matches training)
{ binding: 3, resource: staticTex.createView() },
{ binding: 4, resource: { buffer: mipLevelBuffer } }
]
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index abe07bc..70229ce 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -26,13 +26,13 @@ def compute_static_features(rgb, depth=None, mip_level=0):
Args:
rgb: (H, W, 3) RGB image [0, 1]
- depth: (H, W) depth map [0, 1], optional
+ depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane)
mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth)
Returns:
(H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
- Note: p0-p3 are parametric features generated from specified mip level
+ Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0
TODO: Binary format should support arbitrary layout and ordering for feature vector (7D),
alongside mip-level indication. Current layout is hardcoded as:
@@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0):
p0 = mip_rgb[:, :, 0].astype(np.float32)
p1 = mip_rgb[:, :, 1].astype(np.float32)
p2 = mip_rgb[:, :, 2].astype(np.float32)
- p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
+ p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
# UV coordinates (normalized [0, 1])
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
@@ -244,8 +244,11 @@ class PatchDataset(Dataset):
input_patch = input_img[y1:y2, x1:x2]
target_patch = target_img[y1:y2, x1:x2] # RGBA
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None
+
# Compute static features for patch
- static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level)
+ static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
@@ -284,8 +287,11 @@ class ImagePairDataset(Dataset):
input_img = np.array(input_pil) / 255.0
target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None
+
# Compute static features
- static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level)
+ static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
h, w = input_img.shape[:2]