diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 16:44:39 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 16:44:39 +0100 |
| commit | 61104d5b9e1774c11f0dba3b6d6018dabc2bce8f (patch) | |
| tree | 882e642721984cc921cbe5678fe7905721a2ad40 /training | |
| parent | 3942653de11542acc4892470243a8a6bf8d5c4f7 (diff) | |
feat: CNN RGBD→grayscale with 7-channel augmented input
Upgrade CNN architecture to process RGBD input, output grayscale, with
7-channel layer inputs (RGBD + UV coords + grayscale).
Architecture changes:
- Inner layers: Conv2d(7→4) output RGBD
- Final layer: Conv2d(7→1) output grayscale
- All inputs normalized to [-1,1] for tanh activation
- Removed CoordConv2d in favor of unified 7-channel input
Training (train_cnn.py):
- SimpleCNN: 7→4 (inner), 7→1 (final) architecture
- Forward: Normalize RGBD/coords/gray to [-1,1]
- Weight export: array<array<f32, 8>, 36> (inner), array<f32, 8>, 9> (final)
- Dataset: Load RGBA (RGBD) input
Shaders (cnn_conv3x3.wgsl):
- Added cnn_conv3x3_7to4: 7-channel input → RGBD output
- Added cnn_conv3x3_7to1: 7-channel input → grayscale output
- Both normalize inputs and use flattened weight arrays
Documentation:
- CNN_EFFECT.md: Updated architecture, training, weight format
- CNN_RGBD_GRAYSCALE_SUMMARY.md: Implementation summary
- HOWTO.md: Added training command example
Next: Train with RGBD input data
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn.py | 210 |
1 files changed, 84 insertions, 126 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 1cd6579..0495c65 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -62,7 +62,8 @@ class ImagePairDataset(Dataset): def __getitem__(self, idx): input_path, target_path = self.image_pairs[idx] - input_img = Image.open(input_path).convert('RGB') + # Load RGBD input (4 channels: RGB + Depth) + input_img = Image.open(input_path).convert('RGBA') target_img = Image.open(target_path).convert('RGB') if self.transform: @@ -72,27 +73,8 @@ class ImagePairDataset(Dataset): return input_img, target_img -class CoordConv2d(nn.Module): - """Conv2d that accepts coordinate input separate from spatial patches""" - - def __init__(self, in_channels, out_channels, kernel_size, padding=0): - super().__init__() - self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False) - self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01) - self.bias = nn.Parameter(torch.zeros(out_channels)) - - def forward(self, x, coords): - # x: [B, C, H, W] image - # coords: [B, 2, H, W] coordinate grid - out = self.conv_rgba(x) - B, C, H, W = out.shape - coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights) - out = out + coord_contrib + self.bias.view(1, -1, 1, 1) - return out - - class SimpleCNN(nn.Module): - """Simple CNN for image-to-image transformation""" + """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() @@ -107,26 +89,48 @@ class SimpleCNN(nn.Module): for i, kernel_size in enumerate(kernel_sizes): padding = kernel_size // 2 - if i == 0: - self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding)) + if i < num_layers - 1: + # Inner layers: 7→4 (RGBD output) + self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True)) else: - self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True)) + # Final layer: 7→1 (grayscale output) + self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) def forward(self, x): + # x: [B,4,H,W] - RGBD input (D = 1/z) B, C, H, W = x.shape + + # Normalize RGBD to [-1,1] + x_norm = (x - 0.5) * 2.0 + + # Compute coordinates [0,1] then normalize to [-1,1] y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) - coords = torch.cat([x_coords, y_coords], dim=1) + y_coords = (y_coords - 0.5) * 2.0 # [-1,1] + x_coords = (x_coords - 0.5) * 2.0 # [-1,1] - out = self.layers[0](x, coords) - out = torch.tanh(out) + # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] + gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] + gray = (gray - 0.5) * 2.0 # [-1,1] - for i in range(1, len(self.layers)): - out = self.layers[i](out) - if i < len(self.layers) - 1: - out = torch.tanh(out) + # Layer 0 + layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] + out = self.layers[0](layer0_input) # [B,4,H,W] + out = torch.tanh(out) # [-1,1] - return out + # Inner layers + for i in range(1, len(self.layers)-1): + layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) + out = self.layers[i](layer_input) + out = torch.tanh(out) + + # Final layer (grayscale output) + final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) + out = self.layers[-1](final_input) # [B,1,H,W] in [-1,1] + + # Denormalize to [0,1] and expand to RGB for visualization + out = (out + 1.0) * 0.5 + return out.expand(-1, 3, -1, -1) def generate_layer_shader(output_path, num_layers, kernel_sizes): @@ -169,25 +173,35 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): # Generate layer switches for layer_idx in range(num_layers): - ks = kernel_sizes[layer_idx] + is_final = layer_idx == num_layers - 1 if layer_idx == 0: - f.write(f" // Layer 0 uses coordinate-aware convolution\n") + f.write(f" // Layer 0: 7→4 (RGBD output)\n") f.write(f" if (params.layer_index == {layer_idx}) {{\n") - f.write(f" result = cnn_conv{ks}x{ks}_with_coord(txt, smplr, uv, uniforms.resolution,\n") - f.write(f" rgba_weights_layer{layer_idx}, coord_weights_layer{layer_idx}, bias_layer{layer_idx});\n") - f.write(f" result = cnn_tanh(result);\n") + f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" original, weights_layer{layer_idx});\n") + f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n") + f.write(f" // Denormalize to [0,1] for texture storage\n") + f.write(f" result = (result + 1.0) * 0.5;\n") + f.write(f" }}\n") + elif not is_final: + f.write(f" else if (params.layer_index == {layer_idx}) {{\n") + f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" original, weights_layer{layer_idx});\n") + f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n") + f.write(f" // Denormalize to [0,1] for texture storage\n") + f.write(f" result = (result + 1.0) * 0.5;\n") f.write(f" }}\n") else: - is_last = layer_idx == num_layers - 1 - f.write(f" {'else ' if layer_idx > 0 else ''}if (params.layer_index == {layer_idx}) {{\n") - f.write(f" result = cnn_conv{ks}x{ks}(txt, smplr, uv, uniforms.resolution,\n") - f.write(f" weights_layer{layer_idx}, bias_layer{layer_idx});\n") - if not is_last: - f.write(f" result = cnn_tanh(result);\n") + f.write(f" else if (params.layer_index == {layer_idx}) {{\n") + f.write(f" let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" original, weights_layer{layer_idx});\n") + f.write(f" // Denormalize from [-1,1] to [0,1]\n") + f.write(f" let gray_01 = (gray_out + 1.0) * 0.5;\n") + f.write(f" result = vec4<f32>(gray_01, gray_01, gray_01, 1.0); // Expand to RGB\n") f.write(f" }}\n") # Add else clause for invalid layer index - if num_layers > 1: + if num_layers > 0: f.write(f" else {{\n") f.write(f" result = input;\n") f.write(f" }}\n") @@ -204,96 +218,40 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): f.write("// Auto-generated CNN weights\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") - layer_idx = 0 for i, layer in enumerate(model.layers): - if isinstance(layer, CoordConv2d): - # Export RGBA weights - weights = layer.conv_rgba.weight.data.cpu().numpy() - kernel_size = kernel_sizes[layer_idx] - out_ch, in_ch, kh, kw = weights.shape - num_positions = kh * kw - - f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n") - for pos in range(num_positions): - row = pos // kw - col = pos % kw - f.write(" mat4x4<f32>(\n") - for out_c in range(4): - vals = [] - for in_c in range(4): - if out_c < out_ch and in_c < in_ch: - vals.append(f"{weights[out_c, in_c, row, col]:.6f}") - else: - vals.append("0.0") - f.write(f" {', '.join(vals)},\n") - f.write(" )") - if pos < num_positions - 1: - f.write(",\n") - else: - f.write("\n") - f.write(");\n\n") + weights = layer.weight.data.cpu().numpy() + bias = layer.bias.data.cpu().numpy() + out_ch, in_ch, kh, kw = weights.shape + num_positions = kh * kw - # Export coordinate weights - coord_w = layer.coord_weights.data.cpu().numpy() - f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n") - for c in range(2): - vals = [] - for out_c in range(4): - if out_c < coord_w.shape[0]: - vals.append(f"{coord_w[out_c, c]:.6f}") - else: - vals.append("0.0") - f.write(f" {', '.join(vals)}") - if c < 1: - f.write(",\n") - else: - f.write("\n") - f.write(");\n\n") + is_final = (i == len(model.layers) - 1) - # Export bias - bias = layer.bias.data.cpu().numpy() - bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)] - f.write(f"const bias_layer{layer_idx} = vec4<f32>(") - f.write(", ".join(bias_vals)) + if is_final: + # Final layer: 7→1, structure: array<array<f32, 8>, 9> + # [w0, w1, w2, w3, w4, w5, w6, bias] + f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_positions}> = array(\n") + for pos in range(num_positions): + row, col = pos // kw, pos % kw + vals = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(7)] + vals.append(f"{bias[0]:.6f}") # Append bias as 8th element + f.write(f" array<f32, 8>({', '.join(vals)})") + f.write(",\n" if pos < num_positions-1 else "\n") f.write(");\n\n") - - layer_idx += 1 - elif isinstance(layer, nn.Conv2d): - # Standard conv layer - weights = layer.weight.data.cpu().numpy() - kernel_size = kernel_sizes[layer_idx] - out_ch, in_ch, kh, kw = weights.shape - num_positions = kh * kw - - f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n") + else: + # Inner layers: 7→4, structure: array<array<f32, 8>, 36> + # Flattened: [pos0_ch0[7w+bias], pos0_ch1[7w+bias], ..., pos8_ch3[7w+bias]] + num_entries = num_positions * 4 + f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_entries}> = array(\n") for pos in range(num_positions): - row = pos // kw - col = pos % kw - f.write(" mat4x4<f32>(\n") + row, col = pos // kw, pos % kw for out_c in range(4): - vals = [] - for in_c in range(4): - if out_c < out_ch and in_c < in_ch: - vals.append(f"{weights[out_c, in_c, row, col]:.6f}") - else: - vals.append("0.0") - f.write(f" {', '.join(vals)},\n") - f.write(" )") - if pos < num_positions - 1: - f.write(",\n") - else: - f.write("\n") + vals = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(7)] + vals.append(f"{bias[out_c]:.6f}") # Append bias + idx = pos * 4 + out_c + f.write(f" array<f32, 8>({', '.join(vals)})") + f.write(",\n" if idx < num_entries-1 else "\n") f.write(");\n\n") - # Export bias - bias = layer.bias.data.cpu().numpy() - bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)] - f.write(f"const bias_layer{layer_idx} = vec4<f32>(") - f.write(", ".join(bias_vals)) - f.write(");\n\n") - - layer_idx += 1 - def train(args): """Main training loop""" |
