diff options
Diffstat (limited to 'training')
17 files changed, 105 insertions, 17 deletions
diff --git a/training/debug/cur/layer_0.png b/training/debug/cur/layer_0.png Binary files differnew file mode 100644 index 0000000..0cb977b --- /dev/null +++ b/training/debug/cur/layer_0.png diff --git a/training/debug/cur/layer_1.png b/training/debug/cur/layer_1.png Binary files differnew file mode 100644 index 0000000..801aad2 --- /dev/null +++ b/training/debug/cur/layer_1.png diff --git a/training/debug/cur/toto.png b/training/debug/cur/toto.png Binary files differnew file mode 100644 index 0000000..9caff40 --- /dev/null +++ b/training/debug/cur/toto.png diff --git a/training/debug/debug.sh b/training/debug/debug.sh new file mode 100755 index 0000000..083082b --- /dev/null +++ b/training/debug/debug.sh @@ -0,0 +1,45 @@ +#!/bin/sh + +pwd=`pwd` + +img=../input/img_003.png + +# img=/Users/skal/black_512x512_rgba.png +#img=/Users/skal/rgba_0_0_0_0.png +check_pt=../checkpoints/checkpoint_epoch_10000.pth +#check_pt=../chk_5000_3x3x3.pt + +#../train_cnn.py --layers 3 --kernel_sizes 3,3,3 --epochs 10000 --batch_size 8 --input ../input/ --target ../target_2/ --checkpoint-every 1000 +#../train_cnn.py --export-only ${check_pt} +#../train_cnn.py --export-only ${check_pt} --infer ${img} --output test/toto.png + +#../train_cnn.py --layers 2 --kernel_sizes 1,1 --epochs 10 --batch_size 5 --input ../input/ --target ../target_2/ --checkpoint-every 10 +#../train_cnn.py --export-only ${check_pt} +#../train_cnn.py --export-only ${check_pt} --infer ${img} --output test/toto.png + +## XXX uncomment! +../train_cnn.py --export-only ${check_pt} \ + --infer ${img} \ + --output ref/toto.png --save-intermediates ref/ # --debug-hex + +echo "== GENERATE SHADERS ==" +echo +cd ../../ +./training/train_cnn.py --export-only ${pwd}/${check_pt} + +echo "== COMPILE ==" +echo +cmake --build build -j4 --target cnn_test +cd ${pwd} + +echo "== RUN ==" +echo +rm -f cur/toto.png +../../build/cnn_test ${img} cur/toto.png --save-intermediates cur/ --layers 3 # --debug-hex + +open cur/*.png ref/*.png + +echo "open cur/*.png ref/*.png" + +#pngcrush -rem gAMA -rem sRGB cur/toto.png toto.png && mv toto.png cur/toto.png +#pngcrush -rem gAMA -rem sRGB cur/layer_0.png toto.png && mv toto.png cur/layer_0.png diff --git a/training/debug/ref/layer_0.png b/training/debug/ref/layer_0.png Binary files differnew file mode 100644 index 0000000..3e0eebe --- /dev/null +++ b/training/debug/ref/layer_0.png diff --git a/training/debug/ref/layer_1.png b/training/debug/ref/layer_1.png Binary files differnew file mode 100644 index 0000000..d858f80 --- /dev/null +++ b/training/debug/ref/layer_1.png diff --git a/training/debug/ref/toto.png b/training/debug/ref/toto.png Binary files differnew file mode 100644 index 0000000..f869a7c --- /dev/null +++ b/training/debug/ref/toto.png diff --git a/training/debug/training/checkpoints/checkpoint_epoch_10.pth b/training/debug/training/checkpoints/checkpoint_epoch_10.pth Binary files differnew file mode 100644 index 0000000..54ba5c5 --- /dev/null +++ b/training/debug/training/checkpoints/checkpoint_epoch_10.pth diff --git a/training/debug/training/checkpoints/checkpoint_epoch_100.pth b/training/debug/training/checkpoints/checkpoint_epoch_100.pth Binary files differnew file mode 100644 index 0000000..f94e9f8 --- /dev/null +++ b/training/debug/training/checkpoints/checkpoint_epoch_100.pth diff --git a/training/debug/training/checkpoints/checkpoint_epoch_50.pth b/training/debug/training/checkpoints/checkpoint_epoch_50.pth Binary files differnew file mode 100644 index 0000000..a602f4b --- /dev/null +++ b/training/debug/training/checkpoints/checkpoint_epoch_50.pth diff --git a/training/ground_truth.png b/training/ground_truth.png Binary files differdeleted file mode 100644 index 6e1f2aa..0000000 --- a/training/ground_truth.png +++ /dev/null diff --git a/training/layers/chk_10000_5x3x3.pt b/training/layers/chk_10000_5x3x3.pt Binary files differnew file mode 100644 index 0000000..1840b53 --- /dev/null +++ b/training/layers/chk_10000_5x3x3.pt diff --git a/training/layers/chk_5000_3x3x3.pt b/training/layers/chk_5000_3x3x3.pt Binary files differnew file mode 100644 index 0000000..db05d57 --- /dev/null +++ b/training/layers/chk_5000_3x3x3.pt diff --git a/training/pass1_3x5x3.pth b/training/pass1_3x5x3.pth Binary files differdeleted file mode 100644 index a7fa8e3..0000000 --- a/training/pass1_3x5x3.pth +++ /dev/null diff --git a/training/patch_32x32.png b/training/patch_32x32.png Binary files differdeleted file mode 100644 index a665065..0000000 --- a/training/patch_32x32.png +++ /dev/null diff --git a/training/toto.png b/training/toto.png Binary files differnew file mode 100644 index 0000000..2044840 --- /dev/null +++ b/training/toto.png diff --git a/training/train_cnn.py b/training/train_cnn.py index 1ea42a3..4171dcb 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -218,7 +218,10 @@ class PatchDataset(Dataset): class SimpleCNN(nn.Module): - """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" + """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray) + + Internally computes grayscale, expands to 3-channel RGB output. + """ def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() @@ -272,11 +275,11 @@ class SimpleCNN(nn.Module): if return_intermediates: intermediates.append(out.clone()) - # Final layer (grayscale output) + # Final layer (grayscale→RGB) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) - out = self.layers[-1](final_input) # [B,1,H,W] + out = self.layers[-1](final_input) # [B,1,H,W] grayscale out = torch.sigmoid(out) # Map to [0,1] with smooth gradients - final_out = out.expand(-1, 3, -1, -1) + final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB if return_intermediates: return final_out, intermediates @@ -378,7 +381,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)] - v1.append(f"{bias[0]:.6f}") + v1.append(f"{bias[0] / num_positions:.6f}") f.write(f" vec4<f32>({', '.join(v0)}),\n") f.write(f" vec4<f32>({', '.join(v1)})") f.write(",\n" if pos < num_positions-1 else "\n") @@ -395,7 +398,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)] - v1.append(f"{bias[out_c]:.6f}") + v1.append(f"{bias[out_c] / num_positions:.6f}") idx = (pos * 4 + out_c) * 2 f.write(f" vec4<f32>({', '.join(v0)}),\n") f.write(f" vec4<f32>({', '.join(v1)})") @@ -776,8 +779,11 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print("Export complete!") -def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None): - """Run sliding-window inference to match WGSL shader behavior""" +def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False): + """Run sliding-window inference to match WGSL shader behavior + + Outputs RGBA PNG (RGB from model + alpha from input). + """ if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") @@ -796,6 +802,15 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 kernel_sizes=checkpoint['kernel_sizes'] ) model.load_state_dict(checkpoint['model_state']) + + # Debug: Zero out all weights and biases + if zero_weights: + print("DEBUG: Zeroing out all weights and biases") + for layer in model.layers: + with torch.no_grad(): + layer.weight.zero_() + layer.bias.zero_() + model.eval() # Load image @@ -810,15 +825,26 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 if save_intermediates: output_tensor, intermediates = model(img_tensor, return_intermediates=True) else: - output_tensor = model(img_tensor) # [1,3,H,W] + output_tensor = model(img_tensor) # [1,3,H,W] RGB - # Convert to numpy - output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() + # Convert to numpy and append alpha + output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB + alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input + output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA - # Save final output + # Debug: print first 8 pixels as hex + if debug_hex: + output_u8 = (output_rgba * 255).astype(np.uint8) + print("First 8 pixels (RGBA hex):") + for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])): + y, x = i // output_u8.shape[1], i % output_u8.shape[1] + r, g, b, a = output_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + + # Save final output as RGBA print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) - output_img = Image.fromarray((output * 255).astype(np.uint8)) + output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA') output_img.save(output_path) # Save intermediates if requested @@ -828,10 +854,25 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 for layer_idx, layer_tensor in enumerate(intermediates): # Convert [-1,1] to [0,1] for visualization layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5 - # Take first channel for 4-channel intermediate layers + layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8) + + # Debug: print first 8 pixels as hex + if debug_hex: + print(f"Layer {layer_idx} first 8 pixels (RGBA hex):") + for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])): + y, x = i // layer_u8.shape[1], i % layer_u8.shape[1] + if layer_u8.shape[2] == 4: + r, g, b, a = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + else: + r, g, b = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}") + + # Save all 4 channels for intermediate layers if layer_data.shape[2] == 4: - layer_data = layer_data[:, :, :3] # Show RGB only - layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8)) + layer_img = Image.fromarray(layer_u8, mode='RGBA') + else: + layer_img = Image.fromarray(layer_u8) layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png') layer_img.save(layer_path) print(f" Saved layer {layer_idx} to {layer_path}") @@ -861,6 +902,8 @@ def main(): parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)') parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)') parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)') + parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)') + parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)') args = parser.parse_args() @@ -872,7 +915,7 @@ def main(): sys.exit(1) output_path = args.output or 'inference_output.png' patch_size = args.patch_size or 32 - infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates) + infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex) return # Export-only mode |
