summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py57
1 files changed, 48 insertions, 9 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 1ea42a3..c775325 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -378,7 +378,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)]
# Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)]
- v1.append(f"{bias[0]:.6f}")
+ v1.append(f"{bias[0] / num_positions:.6f}")
f.write(f" vec4<f32>({', '.join(v0)}),\n")
f.write(f" vec4<f32>({', '.join(v1)})")
f.write(",\n" if pos < num_positions-1 else "\n")
@@ -395,7 +395,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)]
# Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)]
- v1.append(f"{bias[out_c]:.6f}")
+ v1.append(f"{bias[out_c] / num_positions:.6f}")
idx = (pos * 4 + out_c) * 2
f.write(f" vec4<f32>({', '.join(v0)}),\n")
f.write(f" vec4<f32>({', '.join(v1)})")
@@ -776,7 +776,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print("Export complete!")
-def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None):
+def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False):
"""Run sliding-window inference to match WGSL shader behavior"""
if not os.path.exists(checkpoint_path):
@@ -796,6 +796,15 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
kernel_sizes=checkpoint['kernel_sizes']
)
model.load_state_dict(checkpoint['model_state'])
+
+ # Debug: Zero out all weights and biases
+ if zero_weights:
+ print("DEBUG: Zeroing out all weights and biases")
+ for layer in model.layers:
+ with torch.no_grad():
+ layer.weight.zero_()
+ layer.bias.zero_()
+
model.eval()
# Load image
@@ -815,10 +824,23 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
# Convert to numpy
output = output_tensor.squeeze(0).permute(1, 2, 0).numpy()
- # Save final output
+ # Append alpha channel from input
+ alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1]
+ output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4]
+
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ output_u8 = (output_rgba * 255).astype(np.uint8)
+ print("First 8 pixels (RGBA hex):")
+ for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])):
+ y, x = i // output_u8.shape[1], i % output_u8.shape[1]
+ r, g, b, a = output_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+
+ # Save final output as RGBA
print(f"Saving output to: {output_path}")
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
- output_img = Image.fromarray((output * 255).astype(np.uint8))
+ output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA')
output_img.save(output_path)
# Save intermediates if requested
@@ -828,10 +850,25 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
for layer_idx, layer_tensor in enumerate(intermediates):
# Convert [-1,1] to [0,1] for visualization
layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5
- # Take first channel for 4-channel intermediate layers
+ layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8)
+
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ print(f"Layer {layer_idx} first 8 pixels (RGBA hex):")
+ for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])):
+ y, x = i // layer_u8.shape[1], i % layer_u8.shape[1]
+ if layer_u8.shape[2] == 4:
+ r, g, b, a = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+ else:
+ r, g, b = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}")
+
+ # Save all 4 channels for intermediate layers
if layer_data.shape[2] == 4:
- layer_data = layer_data[:, :, :3] # Show RGB only
- layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8))
+ layer_img = Image.fromarray(layer_u8, mode='RGBA')
+ else:
+ layer_img = Image.fromarray(layer_u8)
layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png')
layer_img.save(layer_path)
print(f" Saved layer {layer_idx} to {layer_path}")
@@ -861,6 +898,8 @@ def main():
parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)')
+ parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)')
+ parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)')
args = parser.parse_args()
@@ -872,7 +911,7 @@ def main():
sys.exit(1)
output_path = args.output or 'inference_output.png'
patch_size = args.patch_size or 32
- infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates)
+ infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex)
return
# Export-only mode