diff options
| -rw-r--r-- | src/gpu/texture_readback.cc | 158 | ||||
| -rw-r--r-- | src/gpu/texture_readback.h | 10 | ||||
| -rw-r--r-- | tools/cnn_test.cc | 28 | ||||
| -rwxr-xr-x | training/train_cnn.py | 42 |
4 files changed, 229 insertions, 9 deletions
diff --git a/src/gpu/texture_readback.cc b/src/gpu/texture_readback.cc index 0eb63d7..f3e4056 100644 --- a/src/gpu/texture_readback.cc +++ b/src/gpu/texture_readback.cc @@ -142,4 +142,162 @@ std::vector<uint8_t> read_texture_pixels( return pixels; } +// Half-float (FP16) to float conversion +static float fp16_to_float(uint16_t h) { + uint32_t sign = (h & 0x8000) << 16; + uint32_t exp = (h & 0x7C00) >> 10; + uint32_t mant = (h & 0x03FF); + + if (exp == 0) { + if (mant == 0) { + // Zero + uint32_t bits = sign; + float result; + memcpy(&result, &bits, sizeof(float)); + return result; + } + // Denormalized + exp = 1; + while ((mant & 0x400) == 0) { + mant <<= 1; + exp--; + } + mant &= 0x3FF; + } else if (exp == 31) { + // Inf or NaN + uint32_t bits = sign | 0x7F800000 | (mant << 13); + float result; + memcpy(&result, &bits, sizeof(float)); + return result; + } + + uint32_t bits = sign | ((exp + 112) << 23) | (mant << 13); + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +} + +std::vector<uint8_t> texture_readback_fp16_to_u8( + WGPUDevice device, + WGPUQueue queue, + WGPUTexture texture, + int width, + int height) { + + // Align bytes per row to 256 + const uint32_t bytes_per_pixel = 8; // RGBA16Float = 4 × 2 bytes + const uint32_t unaligned_bytes_per_row = width * bytes_per_pixel; + const uint32_t aligned_bytes_per_row = + ((unaligned_bytes_per_row + 255) / 256) * 256; + + const size_t buffer_size = aligned_bytes_per_row * height; + + // Create staging buffer + const WGPUBufferDescriptor buffer_desc = { + .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead, + .size = buffer_size, + }; + WGPUBuffer staging = wgpuDeviceCreateBuffer(device, &buffer_desc); + if (!staging) { + return {}; + } + + // Copy texture to buffer + WGPUCommandEncoder encoder = wgpuDeviceCreateCommandEncoder(device, nullptr); + const WGPUTexelCopyTextureInfo src = { + .texture = texture, + .mipLevel = 0, + .origin = {0, 0, 0}, + }; + const WGPUTexelCopyBufferInfo dst = { + .buffer = staging, + .layout = + { + .bytesPerRow = aligned_bytes_per_row, + .rowsPerImage = static_cast<uint32_t>(height), + }, + }; + const WGPUExtent3D copy_size = {static_cast<uint32_t>(width), + static_cast<uint32_t>(height), 1}; + wgpuCommandEncoderCopyTextureToBuffer(encoder, &src, &dst, ©_size); + + WGPUCommandBuffer commands = wgpuCommandEncoderFinish(encoder, nullptr); + wgpuQueueSubmit(queue, 1, &commands); + wgpuCommandBufferRelease(commands); + wgpuCommandEncoderRelease(encoder); + wgpuDevicePoll(device, true, nullptr); + + // Map buffer +#if defined(DEMO_CROSS_COMPILE_WIN32) + MapState map_state = {}; + auto map_cb = [](WGPUBufferMapAsyncStatus status, void* userdata) { + MapState* state = static_cast<MapState*>(userdata); + state->status = status; + state->done = true; + }; + wgpuBufferMapAsync(staging, WGPUMapMode_Read, 0, buffer_size, map_cb, + &map_state); +#else + MapState map_state = {}; + auto map_cb = [](WGPUMapAsyncStatus status, WGPUStringView message, + void* userdata, void* user2) { + (void)message; + (void)user2; + MapState* state = static_cast<MapState*>(userdata); + state->status = status; + state->done = true; + }; + WGPUBufferMapCallbackInfo map_info = {}; + map_info.mode = WGPUCallbackMode_AllowProcessEvents; + map_info.callback = map_cb; + map_info.userdata1 = &map_state; + wgpuBufferMapAsync(staging, WGPUMapMode_Read, 0, buffer_size, map_info); +#endif + + for (int i = 0; i < 100 && !map_state.done; ++i) { + wgpuDevicePoll(device, true, nullptr); + } + + if (!map_state.done || map_state.status != WGPUMapAsyncStatus_Success) { + wgpuBufferRelease(staging); + return {}; + } + + // Convert FP16 to U8 ([-1,1] → [0,255]) + const uint16_t* mapped_data = static_cast<const uint16_t*>( + wgpuBufferGetConstMappedRange(staging, 0, buffer_size)); + + std::vector<uint8_t> pixels(width * height * 4); + if (mapped_data) { + for (int y = 0; y < height; ++y) { + const uint16_t* src_row = + reinterpret_cast<const uint16_t*>( + reinterpret_cast<const uint8_t*>(mapped_data) + + y * aligned_bytes_per_row); + for (int x = 0; x < width; ++x) { + float r = fp16_to_float(src_row[x * 4 + 0]); + float g = fp16_to_float(src_row[x * 4 + 1]); + float b = fp16_to_float(src_row[x * 4 + 2]); + float a = fp16_to_float(src_row[x * 4 + 3]); + + // Convert [-1,1] → [0,1] → [0,255] + r = (r + 1.0f) * 0.5f; + g = (g + 1.0f) * 0.5f; + b = (b + 1.0f) * 0.5f; + a = (a + 1.0f) * 0.5f; + + int idx = (y * width + x) * 4; + pixels[idx + 0] = static_cast<uint8_t>(b * 255.0f); // B + pixels[idx + 1] = static_cast<uint8_t>(g * 255.0f); // G + pixels[idx + 2] = static_cast<uint8_t>(r * 255.0f); // R + pixels[idx + 3] = static_cast<uint8_t>(a * 255.0f); // A + } + } + } + + wgpuBufferUnmap(staging); + wgpuBufferRelease(staging); + return pixels; +} + #endif // !defined(STRIP_ALL) diff --git a/src/gpu/texture_readback.h b/src/gpu/texture_readback.h index 1bf770f..8230e13 100644 --- a/src/gpu/texture_readback.h +++ b/src/gpu/texture_readback.h @@ -20,4 +20,14 @@ std::vector<uint8_t> read_texture_pixels( int width, int height); +// Read RGBA16Float texture and convert to BGRA8Unorm for saving +// Converts [-1,1] float range to [0,255] uint8 range +// Returns: width * height * 4 bytes (BGRA8) +std::vector<uint8_t> texture_readback_fp16_to_u8( + WGPUDevice device, + WGPUQueue queue, + WGPUTexture texture, + int width, + int height); + #endif // !defined(STRIP_ALL) diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc index 62a60f4..39ed436 100644 --- a/tools/cnn_test.cc +++ b/tools/cnn_test.cc @@ -41,6 +41,7 @@ struct Args { const char* output_path = nullptr; float blend = 1.0f; bool output_png = true; // Default to PNG + const char* save_intermediates = nullptr; }; // Parse command-line arguments @@ -70,6 +71,8 @@ static bool parse_args(int argc, char** argv, Args* args) { argv[i]); return false; } + } else if (strcmp(argv[i], "--save-intermediates") == 0 && i + 1 < argc) { + args->save_intermediates = argv[++i]; } else if (strcmp(argv[i], "--help") == 0) { return false; } else { @@ -85,9 +88,10 @@ static bool parse_args(int argc, char** argv, Args* args) { static void print_usage(const char* prog) { fprintf(stderr, "Usage: %s input.png output.png [OPTIONS]\n", prog); fprintf(stderr, "\nOPTIONS:\n"); - fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n"); - fprintf(stderr, " --format ppm|png Output format (default: png)\n"); - fprintf(stderr, " --help Show this help\n"); + fprintf(stderr, " --blend F Final blend amount (0.0-1.0, default: 1.0)\n"); + fprintf(stderr, " --format ppm|png Output format (default: png)\n"); + fprintf(stderr, " --save-intermediates DIR Save intermediate layers to directory\n"); + fprintf(stderr, " --help Show this help\n"); } // Load PNG and upload to GPU texture @@ -485,6 +489,24 @@ int main(int argc, char** argv) { wgpuRenderPassEncoderRelease(pass); wgpuCommandEncoderRelease(encoder); wgpuBindGroupRelease(bind_group); + + // Save intermediate layer if requested + if (args.save_intermediates) { + char layer_path[512]; + snprintf(layer_path, sizeof(layer_path), "%s/layer_%d.png", + args.save_intermediates, layer); + printf("Saving intermediate layer %d to '%s'...\n", layer, layer_path); + + // Readback RGBA16Float texture + std::vector<uint8_t> pixels = texture_readback_fp16_to_u8( + device, queue, intermediate_textures[dst_idx], width, height); + + if (!pixels.empty()) { + save_png(layer_path, pixels, width, height); + } else { + fprintf(stderr, "Warning: failed to read intermediate layer %d\n", layer); + } + } } // Update for next layer: output becomes input diff --git a/training/train_cnn.py b/training/train_cnn.py index dc14192..ef7a0ae 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -240,10 +240,12 @@ class SimpleCNN(nn.Module): # Final layer: 7→1 (grayscale output) self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) - def forward(self, x): + def forward(self, x, return_intermediates=False): # x: [B,4,H,W] - RGBD input (D = 1/z) B, C, H, W = x.shape + intermediates = [] if return_intermediates else None + # Normalize RGBD to [-1,1] x_norm = (x - 0.5) * 2.0 @@ -261,18 +263,26 @@ class SimpleCNN(nn.Module): layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] out = self.layers[0](layer0_input) # [B,4,H,W] out = torch.tanh(out) # [-1,1] + if return_intermediates: + intermediates.append(out.clone()) # Inner layers for i in range(1, len(self.layers)-1): layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[i](layer_input) out = torch.tanh(out) + if return_intermediates: + intermediates.append(out.clone()) # Final layer (grayscale output) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[-1](final_input) # [B,1,H,W] out = torch.sigmoid(out) # Map to [0,1] with smooth gradients - return out.expand(-1, 3, -1, -1) + final_out = out.expand(-1, 3, -1, -1) + + if return_intermediates: + return final_out, intermediates + return final_out def generate_layer_shader(output_path, num_layers, kernel_sizes): @@ -693,7 +703,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print("Export complete!") -def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32): +def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None): """Run sliding-window inference to match WGSL shader behavior""" if not os.path.exists(checkpoint_path): @@ -724,16 +734,35 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 # Process full image with sliding window (matches WGSL shader) print(f"Processing full image ({W}×{H}) with sliding window...") with torch.no_grad(): - output_tensor = model(img_tensor) # [1,3,H,W] + if save_intermediates: + output_tensor, intermediates = model(img_tensor, return_intermediates=True) + else: + output_tensor = model(img_tensor) # [1,3,H,W] # Convert to numpy output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() - # Save + # Save final output print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) output_img = Image.fromarray((output * 255).astype(np.uint8)) output_img.save(output_path) + + # Save intermediates if requested + if save_intermediates: + os.makedirs(save_intermediates, exist_ok=True) + print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}") + for layer_idx, layer_tensor in enumerate(intermediates): + # Convert [-1,1] to [0,1] for visualization + layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5 + # Take first channel for 4-channel intermediate layers + if layer_data.shape[2] == 4: + layer_data = layer_data[:, :, :3] # Show RGB only + layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8)) + layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png') + layer_img.save(layer_path) + print(f" Saved layer {layer_idx} to {layer_path}") + print("Done!") @@ -758,6 +787,7 @@ def main(): help='Salient point detector for patch extraction (default: harris)') parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)') parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)') + parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)') args = parser.parse_args() @@ -769,7 +799,7 @@ def main(): sys.exit(1) output_path = args.output or 'inference_output.png' patch_size = args.patch_size or 32 - infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size) + infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates) return # Export-only mode |
