summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py181
1 files changed, 164 insertions, 17 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 2250e9c..16f8e7a 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -5,10 +5,15 @@ CNN Training Script for Image-to-Image Transformation
Trains a convolutional neural network on multiple input/target image pairs.
Usage:
+ # Training
python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
+ # Inference (generate ground truth)
+ python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png
+
Example:
python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
+ python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth
"""
import torch
@@ -126,10 +131,8 @@ class SimpleCNN(nn.Module):
# Final layer (grayscale output)
final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
- out = self.layers[-1](final_input) # [B,1,H,W] in [-1,1]
-
- # Denormalize to [0,1] and expand to RGB for visualization
- out = (out + 1.0) * 0.5
+ out = self.layers[-1](final_input) # [B,1,H,W]
+ out = torch.clamp(out, 0.0, 1.0) # Clip to [0,1]
return out.expand(-1, 3, -1, -1)
@@ -167,8 +170,6 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write("}\n\n")
f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n")
f.write(" let uv = p.xy / uniforms.resolution;\n")
- f.write(" let input_raw = textureSample(txt, smplr, uv);\n")
- f.write(" let input = (input_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
f.write(" let original_raw = textureSample(original_input, smplr, uv);\n")
f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
f.write(" var result = vec4<f32>(0.0);\n\n")
@@ -180,11 +181,12 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1"
if layer_idx == 0:
- f.write(f" // Layer 0: 7→4 (RGBD output)\n")
+ conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src"
+ f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n")
f.write(f" if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
- f.write(f" original, weights_layer{layer_idx});\n")
- f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n")
+ f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result);\n")
f.write(f" }}\n")
elif not is_final:
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
@@ -196,18 +198,21 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
f.write(f" original, weights_layer{layer_idx});\n")
- f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0); // Keep in [-1,1]\n")
+ f.write(f" // gray_out already in [0,1] from clipped training\n")
+ f.write(f" let original_denorm = (original + 1.0) * 0.5;\n")
+ f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
+ f.write(f" let blended = mix(original_denorm, result, params.blend_amount);\n")
+ f.write(f" return blended; // [0,1]\n")
f.write(f" }}\n")
# Add else clause for invalid layer index
if num_layers > 0:
f.write(f" else {{\n")
- f.write(f" result = input;\n")
+ f.write(f" return textureSample(txt, smplr, uv);\n")
f.write(f" }}\n")
- f.write("\n // Blend with ORIGINAL input from layer 0 and denormalize for display\n")
- f.write(" let blended = mix(original, result, params.blend_amount);\n")
- f.write(" return (blended + 1.0) * 0.5; // Denormalize to [0,1] for display\n")
+ f.write("\n // Non-final layers: denormalize for display\n")
+ f.write(" return (result + 1.0) * 0.5; // [-1,1] → [0,1]\n")
f.write("}\n")
@@ -253,6 +258,62 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
f.write(");\n\n")
+def generate_conv_src_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Source layer: 7→4 channels (RGBD output)\n")
+ f.write(f"// Normalizes [0,1] input to [-1,1] internally\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" weights: array<array<f32, 8>, {num_positions * 4}>\n")
+ f.write(f") -> vec4<f32> {{\n")
+ f.write(f" let step = 1.0 / resolution;\n\n")
+
+ # Normalize center pixel for gray channel
+ f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n")
+ f.write(f" let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+
+ f.write(f" var sum = vec4<f32>(0.0);\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n")
+
+ # 7-channel input
+ f.write(f" let inputs = array<f32, 7>(\n")
+ f.write(f" rgbd.r, rgbd.g, rgbd.b, rgbd.a,\n")
+ f.write(f" uv_norm.x, uv_norm.y, gray\n")
+ f.write(f" );\n\n")
+
+ # Accumulate
+ f.write(f" for (var out_c = 0; out_c < 4; out_c++) {{\n")
+ f.write(f" let idx = pos * 4 + out_c;\n")
+ f.write(f" var channel_sum = weights[idx][7];\n")
+ f.write(f" for (var in_c = 0; in_c < 7; in_c++) {{\n")
+ f.write(f" channel_sum += weights[idx][in_c] * inputs[in_c];\n")
+ f.write(f" }}\n")
+ f.write(f" sum[out_c] += channel_sum;\n")
+ f.write(f" }}\n")
+ f.write(f" pos++;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
def train(args):
"""Main training loop"""
@@ -340,6 +401,24 @@ def train(args):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, args.layers, kernel_sizes)
+ # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ for ks in set(kernel_sizes):
+ if ks == 3:
+ continue
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+ if not os.path.exists(conv_path):
+ print(f"Warning: {conv_path} not found, skipping _src generation")
+ continue
+
+ # Check if _src already exists
+ with open(conv_path, 'r') as f:
+ content = f.read()
+ if f"cnn_conv{ks}x{ks}_7to4_src" in content:
+ continue
+
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+
print("Training complete!")
@@ -372,26 +451,94 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, num_layers, kernel_sizes)
+ # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ for ks in set(kernel_sizes):
+ if ks == 3:
+ continue
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+ if not os.path.exists(conv_path):
+ print(f"Warning: {conv_path} not found, skipping _src generation")
+ continue
+
+ # Check if _src already exists
+ with open(conv_path, 'r') as f:
+ content = f.read()
+ if f"cnn_conv{ks}x{ks}_7to4_src" in content:
+ continue
+
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+
print("Export complete!")
+def infer_from_checkpoint(checkpoint_path, input_path, output_path):
+ """Run inference on single image to generate ground truth"""
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ if not os.path.exists(input_path):
+ print(f"Error: Input image '{input_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ # Reconstruct model
+ model = SimpleCNN(
+ num_layers=checkpoint['num_layers'],
+ kernel_sizes=checkpoint['kernel_sizes']
+ )
+ model.load_state_dict(checkpoint['model_state'])
+ model.eval()
+
+ # Load image [0,1]
+ print(f"Loading input image: {input_path}")
+ img = Image.open(input_path).convert('RGBA')
+ img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
+
+ # Inference
+ print("Running inference...")
+ with torch.no_grad():
+ out = model(img_tensor) # [1,3,H,W] in [0,1]
+
+ # Save
+ print(f"Saving output to: {output_path}")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ transforms.ToPILImage()(out.squeeze(0)).save(output_path)
+ print("Done!")
+
+
def main():
parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
- parser.add_argument('--input', help='Input image directory')
+ parser.add_argument('--input', help='Input image directory (training) or single image (inference)')
parser.add_argument('--target', help='Target image directory')
parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
- parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)')
+ parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)')
parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
parser.add_argument('--resume', help='Resume from checkpoint file')
parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
+ parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)')
args = parser.parse_args()
+ # Inference mode
+ if args.infer:
+ checkpoint = args.export_only
+ if not checkpoint:
+ print("Error: --infer requires --export-only <checkpoint>")
+ sys.exit(1)
+ output_path = args.output or 'inference_output.png'
+ infer_from_checkpoint(checkpoint, args.infer, output_path)
+ return
+
# Export-only mode
if args.export_only:
export_from_checkpoint(args.export_only, args.output)