summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 16:44:39 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 16:44:39 +0100
commit61104d5b9e1774c11f0dba3b6d6018dabc2bce8f (patch)
tree882e642721984cc921cbe5678fe7905721a2ad40 /training
parent3942653de11542acc4892470243a8a6bf8d5c4f7 (diff)
feat: CNN RGBD→grayscale with 7-channel augmented input
Upgrade CNN architecture to process RGBD input, output grayscale, with 7-channel layer inputs (RGBD + UV coords + grayscale). Architecture changes: - Inner layers: Conv2d(7→4) output RGBD - Final layer: Conv2d(7→1) output grayscale - All inputs normalized to [-1,1] for tanh activation - Removed CoordConv2d in favor of unified 7-channel input Training (train_cnn.py): - SimpleCNN: 7→4 (inner), 7→1 (final) architecture - Forward: Normalize RGBD/coords/gray to [-1,1] - Weight export: array<array<f32, 8>, 36> (inner), array<f32, 8>, 9> (final) - Dataset: Load RGBA (RGBD) input Shaders (cnn_conv3x3.wgsl): - Added cnn_conv3x3_7to4: 7-channel input → RGBD output - Added cnn_conv3x3_7to1: 7-channel input → grayscale output - Both normalize inputs and use flattened weight arrays Documentation: - CNN_EFFECT.md: Updated architecture, training, weight format - CNN_RGBD_GRAYSCALE_SUMMARY.md: Implementation summary - HOWTO.md: Added training command example Next: Train with RGBD input data Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn.py210
1 files changed, 84 insertions, 126 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 1cd6579..0495c65 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -62,7 +62,8 @@ class ImagePairDataset(Dataset):
def __getitem__(self, idx):
input_path, target_path = self.image_pairs[idx]
- input_img = Image.open(input_path).convert('RGB')
+ # Load RGBD input (4 channels: RGB + Depth)
+ input_img = Image.open(input_path).convert('RGBA')
target_img = Image.open(target_path).convert('RGB')
if self.transform:
@@ -72,27 +73,8 @@ class ImagePairDataset(Dataset):
return input_img, target_img
-class CoordConv2d(nn.Module):
- """Conv2d that accepts coordinate input separate from spatial patches"""
-
- def __init__(self, in_channels, out_channels, kernel_size, padding=0):
- super().__init__()
- self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
- self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01)
- self.bias = nn.Parameter(torch.zeros(out_channels))
-
- def forward(self, x, coords):
- # x: [B, C, H, W] image
- # coords: [B, 2, H, W] coordinate grid
- out = self.conv_rgba(x)
- B, C, H, W = out.shape
- coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights)
- out = out + coord_contrib + self.bias.view(1, -1, 1, 1)
- return out
-
-
class SimpleCNN(nn.Module):
- """Simple CNN for image-to-image transformation"""
+ """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)"""
def __init__(self, num_layers=1, kernel_sizes=None):
super(SimpleCNN, self).__init__()
@@ -107,26 +89,48 @@ class SimpleCNN(nn.Module):
for i, kernel_size in enumerate(kernel_sizes):
padding = kernel_size // 2
- if i == 0:
- self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding))
+ if i < num_layers - 1:
+ # Inner layers: 7→4 (RGBD output)
+ self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True))
else:
- self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True))
+ # Final layer: 7→1 (grayscale output)
+ self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
def forward(self, x):
+ # x: [B,4,H,W] - RGBD input (D = 1/z)
B, C, H, W = x.shape
+
+ # Normalize RGBD to [-1,1]
+ x_norm = (x - 0.5) * 2.0
+
+ # Compute coordinates [0,1] then normalize to [-1,1]
y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
- coords = torch.cat([x_coords, y_coords], dim=1)
+ y_coords = (y_coords - 0.5) * 2.0 # [-1,1]
+ x_coords = (x_coords - 0.5) * 2.0 # [-1,1]
- out = self.layers[0](x, coords)
- out = torch.tanh(out)
+ # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
+ gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]
+ gray = (gray - 0.5) * 2.0 # [-1,1]
- for i in range(1, len(self.layers)):
- out = self.layers[i](out)
- if i < len(self.layers) - 1:
- out = torch.tanh(out)
+ # Layer 0
+ layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
+ out = self.layers[0](layer0_input) # [B,4,H,W]
+ out = torch.tanh(out) # [-1,1]
- return out
+ # Inner layers
+ for i in range(1, len(self.layers)-1):
+ layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[i](layer_input)
+ out = torch.tanh(out)
+
+ # Final layer (grayscale output)
+ final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[-1](final_input) # [B,1,H,W] in [-1,1]
+
+ # Denormalize to [0,1] and expand to RGB for visualization
+ out = (out + 1.0) * 0.5
+ return out.expand(-1, 3, -1, -1)
def generate_layer_shader(output_path, num_layers, kernel_sizes):
@@ -169,25 +173,35 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
# Generate layer switches
for layer_idx in range(num_layers):
- ks = kernel_sizes[layer_idx]
+ is_final = layer_idx == num_layers - 1
if layer_idx == 0:
- f.write(f" // Layer 0 uses coordinate-aware convolution\n")
+ f.write(f" // Layer 0: 7→4 (RGBD output)\n")
f.write(f" if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = cnn_conv{ks}x{ks}_with_coord(txt, smplr, uv, uniforms.resolution,\n")
- f.write(f" rgba_weights_layer{layer_idx}, coord_weights_layer{layer_idx}, bias_layer{layer_idx});\n")
- f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n")
+ f.write(f" // Denormalize to [0,1] for texture storage\n")
+ f.write(f" result = (result + 1.0) * 0.5;\n")
+ f.write(f" }}\n")
+ elif not is_final:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n")
+ f.write(f" // Denormalize to [0,1] for texture storage\n")
+ f.write(f" result = (result + 1.0) * 0.5;\n")
f.write(f" }}\n")
else:
- is_last = layer_idx == num_layers - 1
- f.write(f" {'else ' if layer_idx > 0 else ''}if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = cnn_conv{ks}x{ks}(txt, smplr, uv, uniforms.resolution,\n")
- f.write(f" weights_layer{layer_idx}, bias_layer{layer_idx});\n")
- if not is_last:
- f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" // Denormalize from [-1,1] to [0,1]\n")
+ f.write(f" let gray_01 = (gray_out + 1.0) * 0.5;\n")
+ f.write(f" result = vec4<f32>(gray_01, gray_01, gray_01, 1.0); // Expand to RGB\n")
f.write(f" }}\n")
# Add else clause for invalid layer index
- if num_layers > 1:
+ if num_layers > 0:
f.write(f" else {{\n")
f.write(f" result = input;\n")
f.write(f" }}\n")
@@ -204,96 +218,40 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
f.write("// Auto-generated CNN weights\n")
f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
- layer_idx = 0
for i, layer in enumerate(model.layers):
- if isinstance(layer, CoordConv2d):
- # Export RGBA weights
- weights = layer.conv_rgba.weight.data.cpu().numpy()
- kernel_size = kernel_sizes[layer_idx]
- out_ch, in_ch, kh, kw = weights.shape
- num_positions = kh * kw
-
- f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
- for pos in range(num_positions):
- row = pos // kw
- col = pos % kw
- f.write(" mat4x4<f32>(\n")
- for out_c in range(4):
- vals = []
- for in_c in range(4):
- if out_c < out_ch and in_c < in_ch:
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
- else:
- vals.append("0.0")
- f.write(f" {', '.join(vals)},\n")
- f.write(" )")
- if pos < num_positions - 1:
- f.write(",\n")
- else:
- f.write("\n")
- f.write(");\n\n")
+ weights = layer.weight.data.cpu().numpy()
+ bias = layer.bias.data.cpu().numpy()
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
- # Export coordinate weights
- coord_w = layer.coord_weights.data.cpu().numpy()
- f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n")
- for c in range(2):
- vals = []
- for out_c in range(4):
- if out_c < coord_w.shape[0]:
- vals.append(f"{coord_w[out_c, c]:.6f}")
- else:
- vals.append("0.0")
- f.write(f" {', '.join(vals)}")
- if c < 1:
- f.write(",\n")
- else:
- f.write("\n")
- f.write(");\n\n")
+ is_final = (i == len(model.layers) - 1)
- # Export bias
- bias = layer.bias.data.cpu().numpy()
- bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)]
- f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join(bias_vals))
+ if is_final:
+ # Final layer: 7→1, structure: array<array<f32, 8>, 9>
+ # [w0, w1, w2, w3, w4, w5, w6, bias]
+ f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_positions}> = array(\n")
+ for pos in range(num_positions):
+ row, col = pos // kw, pos % kw
+ vals = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(7)]
+ vals.append(f"{bias[0]:.6f}") # Append bias as 8th element
+ f.write(f" array<f32, 8>({', '.join(vals)})")
+ f.write(",\n" if pos < num_positions-1 else "\n")
f.write(");\n\n")
-
- layer_idx += 1
- elif isinstance(layer, nn.Conv2d):
- # Standard conv layer
- weights = layer.weight.data.cpu().numpy()
- kernel_size = kernel_sizes[layer_idx]
- out_ch, in_ch, kh, kw = weights.shape
- num_positions = kh * kw
-
- f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
+ else:
+ # Inner layers: 7→4, structure: array<array<f32, 8>, 36>
+ # Flattened: [pos0_ch0[7w+bias], pos0_ch1[7w+bias], ..., pos8_ch3[7w+bias]]
+ num_entries = num_positions * 4
+ f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_entries}> = array(\n")
for pos in range(num_positions):
- row = pos // kw
- col = pos % kw
- f.write(" mat4x4<f32>(\n")
+ row, col = pos // kw, pos % kw
for out_c in range(4):
- vals = []
- for in_c in range(4):
- if out_c < out_ch and in_c < in_ch:
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
- else:
- vals.append("0.0")
- f.write(f" {', '.join(vals)},\n")
- f.write(" )")
- if pos < num_positions - 1:
- f.write(",\n")
- else:
- f.write("\n")
+ vals = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(7)]
+ vals.append(f"{bias[out_c]:.6f}") # Append bias
+ idx = pos * 4 + out_c
+ f.write(f" array<f32, 8>({', '.join(vals)})")
+ f.write(",\n" if idx < num_entries-1 else "\n")
f.write(");\n\n")
- # Export bias
- bias = layer.bias.data.cpu().numpy()
- bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)]
- f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join(bias_vals))
- f.write(");\n\n")
-
- layer_idx += 1
-
def train(args):
"""Main training loop"""