summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/gpu/texture_readback.cc158
-rw-r--r--src/gpu/texture_readback.h10
-rw-r--r--tools/cnn_test.cc28
-rwxr-xr-xtraining/train_cnn.py42
4 files changed, 229 insertions, 9 deletions
diff --git a/src/gpu/texture_readback.cc b/src/gpu/texture_readback.cc
index 0eb63d7..f3e4056 100644
--- a/src/gpu/texture_readback.cc
+++ b/src/gpu/texture_readback.cc
@@ -142,4 +142,162 @@ std::vector<uint8_t> read_texture_pixels(
return pixels;
}
+// Half-float (FP16) to float conversion
+static float fp16_to_float(uint16_t h) {
+ uint32_t sign = (h & 0x8000) << 16;
+ uint32_t exp = (h & 0x7C00) >> 10;
+ uint32_t mant = (h & 0x03FF);
+
+ if (exp == 0) {
+ if (mant == 0) {
+ // Zero
+ uint32_t bits = sign;
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+ }
+ // Denormalized
+ exp = 1;
+ while ((mant & 0x400) == 0) {
+ mant <<= 1;
+ exp--;
+ }
+ mant &= 0x3FF;
+ } else if (exp == 31) {
+ // Inf or NaN
+ uint32_t bits = sign | 0x7F800000 | (mant << 13);
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+ }
+
+ uint32_t bits = sign | ((exp + 112) << 23) | (mant << 13);
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+}
+
+std::vector<uint8_t> texture_readback_fp16_to_u8(
+ WGPUDevice device,
+ WGPUQueue queue,
+ WGPUTexture texture,
+ int width,
+ int height) {
+
+ // Align bytes per row to 256
+ const uint32_t bytes_per_pixel = 8; // RGBA16Float = 4 × 2 bytes
+ const uint32_t unaligned_bytes_per_row = width * bytes_per_pixel;
+ const uint32_t aligned_bytes_per_row =
+ ((unaligned_bytes_per_row + 255) / 256) * 256;
+
+ const size_t buffer_size = aligned_bytes_per_row * height;
+
+ // Create staging buffer
+ const WGPUBufferDescriptor buffer_desc = {
+ .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead,
+ .size = buffer_size,
+ };
+ WGPUBuffer staging = wgpuDeviceCreateBuffer(device, &buffer_desc);
+ if (!staging) {
+ return {};
+ }
+
+ // Copy texture to buffer
+ WGPUCommandEncoder encoder = wgpuDeviceCreateCommandEncoder(device, nullptr);
+ const WGPUTexelCopyTextureInfo src = {
+ .texture = texture,
+ .mipLevel = 0,
+ .origin = {0, 0, 0},
+ };
+ const WGPUTexelCopyBufferInfo dst = {
+ .buffer = staging,
+ .layout =
+ {
+ .bytesPerRow = aligned_bytes_per_row,
+ .rowsPerImage = static_cast<uint32_t>(height),
+ },
+ };
+ const WGPUExtent3D copy_size = {static_cast<uint32_t>(width),
+ static_cast<uint32_t>(height), 1};
+ wgpuCommandEncoderCopyTextureToBuffer(encoder, &src, &dst, &copy_size);
+
+ WGPUCommandBuffer commands = wgpuCommandEncoderFinish(encoder, nullptr);
+ wgpuQueueSubmit(queue, 1, &commands);
+ wgpuCommandBufferRelease(commands);
+ wgpuCommandEncoderRelease(encoder);
+ wgpuDevicePoll(device, true, nullptr);
+
+ // Map buffer
+#if defined(DEMO_CROSS_COMPILE_WIN32)
+ MapState map_state = {};
+ auto map_cb = [](WGPUBufferMapAsyncStatus status, void* userdata) {
+ MapState* state = static_cast<MapState*>(userdata);
+ state->status = status;
+ state->done = true;
+ };
+ wgpuBufferMapAsync(staging, WGPUMapMode_Read, 0, buffer_size, map_cb,
+ &map_state);
+#else
+ MapState map_state = {};
+ auto map_cb = [](WGPUMapAsyncStatus status, WGPUStringView message,
+ void* userdata, void* user2) {
+ (void)message;
+ (void)user2;
+ MapState* state = static_cast<MapState*>(userdata);
+ state->status = status;
+ state->done = true;
+ };
+ WGPUBufferMapCallbackInfo map_info = {};
+ map_info.mode = WGPUCallbackMode_AllowProcessEvents;
+ map_info.callback = map_cb;
+ map_info.userdata1 = &map_state;
+ wgpuBufferMapAsync(staging, WGPUMapMode_Read, 0, buffer_size, map_info);
+#endif
+
+ for (int i = 0; i < 100 && !map_state.done; ++i) {
+ wgpuDevicePoll(device, true, nullptr);
+ }
+
+ if (!map_state.done || map_state.status != WGPUMapAsyncStatus_Success) {
+ wgpuBufferRelease(staging);
+ return {};
+ }
+
+ // Convert FP16 to U8 ([-1,1] → [0,255])
+ const uint16_t* mapped_data = static_cast<const uint16_t*>(
+ wgpuBufferGetConstMappedRange(staging, 0, buffer_size));
+
+ std::vector<uint8_t> pixels(width * height * 4);
+ if (mapped_data) {
+ for (int y = 0; y < height; ++y) {
+ const uint16_t* src_row =
+ reinterpret_cast<const uint16_t*>(
+ reinterpret_cast<const uint8_t*>(mapped_data) +
+ y * aligned_bytes_per_row);
+ for (int x = 0; x < width; ++x) {
+ float r = fp16_to_float(src_row[x * 4 + 0]);
+ float g = fp16_to_float(src_row[x * 4 + 1]);
+ float b = fp16_to_float(src_row[x * 4 + 2]);
+ float a = fp16_to_float(src_row[x * 4 + 3]);
+
+ // Convert [-1,1] → [0,1] → [0,255]
+ r = (r + 1.0f) * 0.5f;
+ g = (g + 1.0f) * 0.5f;
+ b = (b + 1.0f) * 0.5f;
+ a = (a + 1.0f) * 0.5f;
+
+ int idx = (y * width + x) * 4;
+ pixels[idx + 0] = static_cast<uint8_t>(b * 255.0f); // B
+ pixels[idx + 1] = static_cast<uint8_t>(g * 255.0f); // G
+ pixels[idx + 2] = static_cast<uint8_t>(r * 255.0f); // R
+ pixels[idx + 3] = static_cast<uint8_t>(a * 255.0f); // A
+ }
+ }
+ }
+
+ wgpuBufferUnmap(staging);
+ wgpuBufferRelease(staging);
+ return pixels;
+}
+
#endif // !defined(STRIP_ALL)
diff --git a/src/gpu/texture_readback.h b/src/gpu/texture_readback.h
index 1bf770f..8230e13 100644
--- a/src/gpu/texture_readback.h
+++ b/src/gpu/texture_readback.h
@@ -20,4 +20,14 @@ std::vector<uint8_t> read_texture_pixels(
int width,
int height);
+// Read RGBA16Float texture and convert to BGRA8Unorm for saving
+// Converts [-1,1] float range to [0,255] uint8 range
+// Returns: width * height * 4 bytes (BGRA8)
+std::vector<uint8_t> texture_readback_fp16_to_u8(
+ WGPUDevice device,
+ WGPUQueue queue,
+ WGPUTexture texture,
+ int width,
+ int height);
+
#endif // !defined(STRIP_ALL)
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc
index 62a60f4..39ed436 100644
--- a/tools/cnn_test.cc
+++ b/tools/cnn_test.cc
@@ -41,6 +41,7 @@ struct Args {
const char* output_path = nullptr;
float blend = 1.0f;
bool output_png = true; // Default to PNG
+ const char* save_intermediates = nullptr;
};
// Parse command-line arguments
@@ -70,6 +71,8 @@ static bool parse_args(int argc, char** argv, Args* args) {
argv[i]);
return false;
}
+ } else if (strcmp(argv[i], "--save-intermediates") == 0 && i + 1 < argc) {
+ args->save_intermediates = argv[++i];
} else if (strcmp(argv[i], "--help") == 0) {
return false;
} else {
@@ -85,9 +88,10 @@ static bool parse_args(int argc, char** argv, Args* args) {
static void print_usage(const char* prog) {
fprintf(stderr, "Usage: %s input.png output.png [OPTIONS]\n", prog);
fprintf(stderr, "\nOPTIONS:\n");
- fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n");
- fprintf(stderr, " --format ppm|png Output format (default: png)\n");
- fprintf(stderr, " --help Show this help\n");
+ fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n");
+ fprintf(stderr, " --format ppm|png Output format (default: png)\n");
+ fprintf(stderr, " --save-intermediates DIR Save intermediate layers to directory\n");
+ fprintf(stderr, " --help Show this help\n");
}
// Load PNG and upload to GPU texture
@@ -485,6 +489,24 @@ int main(int argc, char** argv) {
wgpuRenderPassEncoderRelease(pass);
wgpuCommandEncoderRelease(encoder);
wgpuBindGroupRelease(bind_group);
+
+ // Save intermediate layer if requested
+ if (args.save_intermediates) {
+ char layer_path[512];
+ snprintf(layer_path, sizeof(layer_path), "%s/layer_%d.png",
+ args.save_intermediates, layer);
+ printf("Saving intermediate layer %d to '%s'...\n", layer, layer_path);
+
+ // Readback RGBA16Float texture
+ std::vector<uint8_t> pixels = texture_readback_fp16_to_u8(
+ device, queue, intermediate_textures[dst_idx], width, height);
+
+ if (!pixels.empty()) {
+ save_png(layer_path, pixels, width, height);
+ } else {
+ fprintf(stderr, "Warning: failed to read intermediate layer %d\n", layer);
+ }
+ }
}
// Update for next layer: output becomes input
diff --git a/training/train_cnn.py b/training/train_cnn.py
index dc14192..ef7a0ae 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -240,10 +240,12 @@ class SimpleCNN(nn.Module):
# Final layer: 7→1 (grayscale output)
self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
- def forward(self, x):
+ def forward(self, x, return_intermediates=False):
# x: [B,4,H,W] - RGBD input (D = 1/z)
B, C, H, W = x.shape
+ intermediates = [] if return_intermediates else None
+
# Normalize RGBD to [-1,1]
x_norm = (x - 0.5) * 2.0
@@ -261,18 +263,26 @@ class SimpleCNN(nn.Module):
layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
out = self.layers[0](layer0_input) # [B,4,H,W]
out = torch.tanh(out) # [-1,1]
+ if return_intermediates:
+ intermediates.append(out.clone())
# Inner layers
for i in range(1, len(self.layers)-1):
layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
out = self.layers[i](layer_input)
out = torch.tanh(out)
+ if return_intermediates:
+ intermediates.append(out.clone())
# Final layer (grayscale output)
final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
out = self.layers[-1](final_input) # [B,1,H,W]
out = torch.sigmoid(out) # Map to [0,1] with smooth gradients
- return out.expand(-1, 3, -1, -1)
+ final_out = out.expand(-1, 3, -1, -1)
+
+ if return_intermediates:
+ return final_out, intermediates
+ return final_out
def generate_layer_shader(output_path, num_layers, kernel_sizes):
@@ -693,7 +703,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print("Export complete!")
-def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32):
+def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None):
"""Run sliding-window inference to match WGSL shader behavior"""
if not os.path.exists(checkpoint_path):
@@ -724,16 +734,35 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
# Process full image with sliding window (matches WGSL shader)
print(f"Processing full image ({W}×{H}) with sliding window...")
with torch.no_grad():
- output_tensor = model(img_tensor) # [1,3,H,W]
+ if save_intermediates:
+ output_tensor, intermediates = model(img_tensor, return_intermediates=True)
+ else:
+ output_tensor = model(img_tensor) # [1,3,H,W]
# Convert to numpy
output = output_tensor.squeeze(0).permute(1, 2, 0).numpy()
- # Save
+ # Save final output
print(f"Saving output to: {output_path}")
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
output_img = Image.fromarray((output * 255).astype(np.uint8))
output_img.save(output_path)
+
+ # Save intermediates if requested
+ if save_intermediates:
+ os.makedirs(save_intermediates, exist_ok=True)
+ print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}")
+ for layer_idx, layer_tensor in enumerate(intermediates):
+ # Convert [-1,1] to [0,1] for visualization
+ layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5
+ # Take first channel for 4-channel intermediate layers
+ if layer_data.shape[2] == 4:
+ layer_data = layer_data[:, :, :3] # Show RGB only
+ layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8))
+ layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png')
+ layer_img.save(layer_path)
+ print(f" Saved layer {layer_idx} to {layer_path}")
+
print("Done!")
@@ -758,6 +787,7 @@ def main():
help='Salient point detector for patch extraction (default: harris)')
parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
+ parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)')
args = parser.parse_args()
@@ -769,7 +799,7 @@ def main():
sys.exit(1)
output_path = args.output or 'inference_output.png'
patch_size = args.patch_size or 32
- infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size)
+ infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates)
return
# Export-only mode