summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/debug/cur/layer_0.pngbin0 -> 406194 bytes
-rw-r--r--training/debug/cur/layer_1.pngbin0 -> 238358 bytes
-rw-r--r--training/debug/cur/toto.pngbin0 -> 90164 bytes
-rwxr-xr-xtraining/debug/debug.sh45
-rw-r--r--training/debug/ref/layer_0.pngbin0 -> 356038 bytes
-rw-r--r--training/debug/ref/layer_1.pngbin0 -> 222247 bytes
-rw-r--r--training/debug/ref/toto.pngbin0 -> 107009 bytes
-rw-r--r--training/debug/training/checkpoints/checkpoint_epoch_10.pthbin0 -> 6395 bytes
-rw-r--r--training/debug/training/checkpoints/checkpoint_epoch_100.pthbin0 -> 6417 bytes
-rw-r--r--training/debug/training/checkpoints/checkpoint_epoch_50.pthbin0 -> 6395 bytes
-rw-r--r--training/ground_truth.pngbin127405 -> 0 bytes
-rw-r--r--training/layers/chk_10000_5x3x3.ptbin0 -> 20381 bytes
-rw-r--r--training/layers/chk_5000_3x3x3.ptbin0 -> 14911 bytes
-rw-r--r--training/pass1_3x5x3.pthbin20287 -> 0 bytes
-rw-r--r--training/patch_32x32.pngbin5259 -> 0 bytes
-rw-r--r--training/toto.pngbin0 -> 103619 bytes
-rwxr-xr-xtraining/train_cnn.py77
17 files changed, 105 insertions, 17 deletions
diff --git a/training/debug/cur/layer_0.png b/training/debug/cur/layer_0.png
new file mode 100644
index 0000000..0cb977b
--- /dev/null
+++ b/training/debug/cur/layer_0.png
Binary files differ
diff --git a/training/debug/cur/layer_1.png b/training/debug/cur/layer_1.png
new file mode 100644
index 0000000..801aad2
--- /dev/null
+++ b/training/debug/cur/layer_1.png
Binary files differ
diff --git a/training/debug/cur/toto.png b/training/debug/cur/toto.png
new file mode 100644
index 0000000..9caff40
--- /dev/null
+++ b/training/debug/cur/toto.png
Binary files differ
diff --git a/training/debug/debug.sh b/training/debug/debug.sh
new file mode 100755
index 0000000..083082b
--- /dev/null
+++ b/training/debug/debug.sh
@@ -0,0 +1,45 @@
+#!/bin/sh
+
+pwd=`pwd`
+
+img=../input/img_003.png
+
+# img=/Users/skal/black_512x512_rgba.png
+#img=/Users/skal/rgba_0_0_0_0.png
+check_pt=../checkpoints/checkpoint_epoch_10000.pth
+#check_pt=../chk_5000_3x3x3.pt
+
+#../train_cnn.py --layers 3 --kernel_sizes 3,3,3 --epochs 10000 --batch_size 8 --input ../input/ --target ../target_2/ --checkpoint-every 1000
+#../train_cnn.py --export-only ${check_pt}
+#../train_cnn.py --export-only ${check_pt} --infer ${img} --output test/toto.png
+
+#../train_cnn.py --layers 2 --kernel_sizes 1,1 --epochs 10 --batch_size 5 --input ../input/ --target ../target_2/ --checkpoint-every 10
+#../train_cnn.py --export-only ${check_pt}
+#../train_cnn.py --export-only ${check_pt} --infer ${img} --output test/toto.png
+
+## XXX uncomment!
+../train_cnn.py --export-only ${check_pt} \
+ --infer ${img} \
+ --output ref/toto.png --save-intermediates ref/ # --debug-hex
+
+echo "== GENERATE SHADERS =="
+echo
+cd ../../
+./training/train_cnn.py --export-only ${pwd}/${check_pt}
+
+echo "== COMPILE =="
+echo
+cmake --build build -j4 --target cnn_test
+cd ${pwd}
+
+echo "== RUN =="
+echo
+rm -f cur/toto.png
+../../build/cnn_test ${img} cur/toto.png --save-intermediates cur/ --layers 3 # --debug-hex
+
+open cur/*.png ref/*.png
+
+echo "open cur/*.png ref/*.png"
+
+#pngcrush -rem gAMA -rem sRGB cur/toto.png toto.png && mv toto.png cur/toto.png
+#pngcrush -rem gAMA -rem sRGB cur/layer_0.png toto.png && mv toto.png cur/layer_0.png
diff --git a/training/debug/ref/layer_0.png b/training/debug/ref/layer_0.png
new file mode 100644
index 0000000..3e0eebe
--- /dev/null
+++ b/training/debug/ref/layer_0.png
Binary files differ
diff --git a/training/debug/ref/layer_1.png b/training/debug/ref/layer_1.png
new file mode 100644
index 0000000..d858f80
--- /dev/null
+++ b/training/debug/ref/layer_1.png
Binary files differ
diff --git a/training/debug/ref/toto.png b/training/debug/ref/toto.png
new file mode 100644
index 0000000..f869a7c
--- /dev/null
+++ b/training/debug/ref/toto.png
Binary files differ
diff --git a/training/debug/training/checkpoints/checkpoint_epoch_10.pth b/training/debug/training/checkpoints/checkpoint_epoch_10.pth
new file mode 100644
index 0000000..54ba5c5
--- /dev/null
+++ b/training/debug/training/checkpoints/checkpoint_epoch_10.pth
Binary files differ
diff --git a/training/debug/training/checkpoints/checkpoint_epoch_100.pth b/training/debug/training/checkpoints/checkpoint_epoch_100.pth
new file mode 100644
index 0000000..f94e9f8
--- /dev/null
+++ b/training/debug/training/checkpoints/checkpoint_epoch_100.pth
Binary files differ
diff --git a/training/debug/training/checkpoints/checkpoint_epoch_50.pth b/training/debug/training/checkpoints/checkpoint_epoch_50.pth
new file mode 100644
index 0000000..a602f4b
--- /dev/null
+++ b/training/debug/training/checkpoints/checkpoint_epoch_50.pth
Binary files differ
diff --git a/training/ground_truth.png b/training/ground_truth.png
deleted file mode 100644
index 6e1f2aa..0000000
--- a/training/ground_truth.png
+++ /dev/null
Binary files differ
diff --git a/training/layers/chk_10000_5x3x3.pt b/training/layers/chk_10000_5x3x3.pt
new file mode 100644
index 0000000..1840b53
--- /dev/null
+++ b/training/layers/chk_10000_5x3x3.pt
Binary files differ
diff --git a/training/layers/chk_5000_3x3x3.pt b/training/layers/chk_5000_3x3x3.pt
new file mode 100644
index 0000000..db05d57
--- /dev/null
+++ b/training/layers/chk_5000_3x3x3.pt
Binary files differ
diff --git a/training/pass1_3x5x3.pth b/training/pass1_3x5x3.pth
deleted file mode 100644
index a7fa8e3..0000000
--- a/training/pass1_3x5x3.pth
+++ /dev/null
Binary files differ
diff --git a/training/patch_32x32.png b/training/patch_32x32.png
deleted file mode 100644
index a665065..0000000
--- a/training/patch_32x32.png
+++ /dev/null
Binary files differ
diff --git a/training/toto.png b/training/toto.png
new file mode 100644
index 0000000..2044840
--- /dev/null
+++ b/training/toto.png
Binary files differ
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 1ea42a3..4171dcb 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -218,7 +218,10 @@ class PatchDataset(Dataset):
class SimpleCNN(nn.Module):
- """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)"""
+ """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray)
+
+ Internally computes grayscale, expands to 3-channel RGB output.
+ """
def __init__(self, num_layers=1, kernel_sizes=None):
super(SimpleCNN, self).__init__()
@@ -272,11 +275,11 @@ class SimpleCNN(nn.Module):
if return_intermediates:
intermediates.append(out.clone())
- # Final layer (grayscale output)
+ # Final layer (grayscale→RGB)
final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
- out = self.layers[-1](final_input) # [B,1,H,W]
+ out = self.layers[-1](final_input) # [B,1,H,W] grayscale
out = torch.sigmoid(out) # Map to [0,1] with smooth gradients
- final_out = out.expand(-1, 3, -1, -1)
+ final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB
if return_intermediates:
return final_out, intermediates
@@ -378,7 +381,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)]
# Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)]
- v1.append(f"{bias[0]:.6f}")
+ v1.append(f"{bias[0] / num_positions:.6f}")
f.write(f" vec4<f32>({', '.join(v0)}),\n")
f.write(f" vec4<f32>({', '.join(v1)})")
f.write(",\n" if pos < num_positions-1 else "\n")
@@ -395,7 +398,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)]
# Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)]
- v1.append(f"{bias[out_c]:.6f}")
+ v1.append(f"{bias[out_c] / num_positions:.6f}")
idx = (pos * 4 + out_c) * 2
f.write(f" vec4<f32>({', '.join(v0)}),\n")
f.write(f" vec4<f32>({', '.join(v1)})")
@@ -776,8 +779,11 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print("Export complete!")
-def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None):
- """Run sliding-window inference to match WGSL shader behavior"""
+def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False):
+ """Run sliding-window inference to match WGSL shader behavior
+
+ Outputs RGBA PNG (RGB from model + alpha from input).
+ """
if not os.path.exists(checkpoint_path):
print(f"Error: Checkpoint '{checkpoint_path}' not found")
@@ -796,6 +802,15 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
kernel_sizes=checkpoint['kernel_sizes']
)
model.load_state_dict(checkpoint['model_state'])
+
+ # Debug: Zero out all weights and biases
+ if zero_weights:
+ print("DEBUG: Zeroing out all weights and biases")
+ for layer in model.layers:
+ with torch.no_grad():
+ layer.weight.zero_()
+ layer.bias.zero_()
+
model.eval()
# Load image
@@ -810,15 +825,26 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
if save_intermediates:
output_tensor, intermediates = model(img_tensor, return_intermediates=True)
else:
- output_tensor = model(img_tensor) # [1,3,H,W]
+ output_tensor = model(img_tensor) # [1,3,H,W] RGB
- # Convert to numpy
- output = output_tensor.squeeze(0).permute(1, 2, 0).numpy()
+ # Convert to numpy and append alpha
+ output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB
+ alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input
+ output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA
- # Save final output
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ output_u8 = (output_rgba * 255).astype(np.uint8)
+ print("First 8 pixels (RGBA hex):")
+ for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])):
+ y, x = i // output_u8.shape[1], i % output_u8.shape[1]
+ r, g, b, a = output_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+
+ # Save final output as RGBA
print(f"Saving output to: {output_path}")
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
- output_img = Image.fromarray((output * 255).astype(np.uint8))
+ output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA')
output_img.save(output_path)
# Save intermediates if requested
@@ -828,10 +854,25 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
for layer_idx, layer_tensor in enumerate(intermediates):
# Convert [-1,1] to [0,1] for visualization
layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5
- # Take first channel for 4-channel intermediate layers
+ layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8)
+
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ print(f"Layer {layer_idx} first 8 pixels (RGBA hex):")
+ for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])):
+ y, x = i // layer_u8.shape[1], i % layer_u8.shape[1]
+ if layer_u8.shape[2] == 4:
+ r, g, b, a = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+ else:
+ r, g, b = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}")
+
+ # Save all 4 channels for intermediate layers
if layer_data.shape[2] == 4:
- layer_data = layer_data[:, :, :3] # Show RGB only
- layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8))
+ layer_img = Image.fromarray(layer_u8, mode='RGBA')
+ else:
+ layer_img = Image.fromarray(layer_u8)
layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png')
layer_img.save(layer_path)
print(f" Saved layer {layer_idx} to {layer_path}")
@@ -861,6 +902,8 @@ def main():
parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)')
+ parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)')
+ parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)')
args = parser.parse_args()
@@ -872,7 +915,7 @@ def main():
sys.exit(1)
output_path = args.output or 'inference_output.png'
patch_size = args.patch_size or 32
- infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates)
+ infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex)
return
# Export-only mode