diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-13 12:32:36 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-13 12:32:36 +0100 |
| commit | 561d1dc446db7d1d3e02b92b43abedf1a5017850 (patch) | |
| tree | ef9302dc1f9b6b9f8a12225580f2a3b07602656b /training/train_cnn_v2.py | |
| parent | c27b34279c0d1c2a8f1dbceb0e154b585b5c6916 (diff) | |
CNN v2: Refactor to uniform 12D→4D architecture
**Architecture changes:**
- Static features (8D): p0-p3 (parametric) + uv_x, uv_y, sin(10×uv_x), bias
- Input RGBD (4D): fed separately to all layers
- All layers: uniform 12D→4D (4 prev/input + 8 static → 4 output)
- Bias integrated in static features (bias=False in PyTorch)
**Weight calculations:**
- 3 layers × (12 × 3×3 × 4) = 1296 weights
- f16: 2.6 KB (vs old variable arch: ~6.4 KB)
**Updated files:**
*Training (Python):*
- train_cnn_v2.py: Uniform model, takes input_rgbd + static_features
- export_cnn_v2_weights.py: Binary export for storage buffers
- export_cnn_v2_shader.py: Per-layer shader export (debugging)
*Shaders (WGSL):*
- cnn_v2_static.wgsl: p0-p3 parametric features (mips/gradients)
- cnn_v2_compute.wgsl: 12D input, 4D output, vec4 packing
*Tools:*
- HTML tool (cnn_v2_test): Updated for 12D→4D, layer visualization
*Docs:*
- CNN_V2.md: Updated architecture, training, validation sections
- HOWTO.md: Reference HTML tool for validation
*Removed:*
- validate_cnn_v2.sh: Obsolete (used CNN v1 tool)
All code consistent with bias=False (bias in static features as 1.0).
handoff(Claude): CNN v2 architecture finalized and documented
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 134 |
1 files changed, 75 insertions, 59 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 758b044..8b3b91c 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 -"""CNN v2 Training Script - Parametric Static Features +"""CNN v2 Training Script - Uniform 12D→4D Architecture -Trains a multi-layer CNN with 7D static feature input: -- RGBD (4D) -- UV coordinates (2D) -- sin(10*uv.x) position encoding (1D) -- Bias dimension (1D, always 1.0) +Architecture: +- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias +- Input RGBD (4D): original image mip 0 +- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels +- Uniform layer structure with bias=False (bias in static features) """ import argparse @@ -21,20 +21,26 @@ import cv2 def compute_static_features(rgb, depth=None): - """Generate 7D static features + bias dimension. + """Generate 8D static features (parametric + spatial). Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional Returns: - (H, W, 8) static features tensor + (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] + + Note: p0-p3 are parametric features (can be mips, gradients, etc.) + For training, we use RGBD as default, but could use mip1/2 """ h, w = rgb.shape[:2] - # RGBD channels - r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] - d = depth if depth is not None else np.zeros((h, w), dtype=np.float32) + # Parametric features (p0-p3) - using RGBD as default + # TODO: Experiment with mip1 grayscale, gradients, etc. + p0 = rgb[:, :, 0].astype(np.float32) + p1 = rgb[:, :, 1].astype(np.float32) + p2 = rgb[:, :, 2].astype(np.float32) + p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) @@ -43,65 +49,64 @@ def compute_static_features(rgb, depth=None): # Multi-frequency position encoding sin10_x = np.sin(10.0 * uv_x).astype(np.float32) - # Bias dimension (always 1.0) + # Bias dimension (always 1.0) - replaces Conv2d bias parameter bias = np.ones((h, w), dtype=np.float32) - # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias] - features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1) + # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] + features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) return features class CNNv2(nn.Module): - """CNN v2 with parametric static features. + """CNN v2 - Uniform 12D→4D Architecture + + All layers: input RGBD (4D) + static (8D) = 12D → 4 channels + Uses bias=False (bias integrated in static features as 1.0) TODO: Add quantization-aware training (QAT) for 8-bit weights - Use torch.quantization.QuantStub/DeQuantStub - Train with fake quantization to adapt to 8-bit precision - - Target: ~1.6 KB weights (vs 3.2 KB with f16) + - Target: ~1.3 KB weights (vs 2.6 KB with f16) """ - def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): + def __init__(self, kernel_size=3, num_layers=3): super().__init__() - self.kernels = kernels - self.channels = channels - - # Input layer: 8D (7 features + bias) → channels[0] - self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0], - padding=kernels[0]//2, bias=False) - - # Inner layers: (8 + C_prev) → C_next - in_ch_1 = 8 + channels[0] - self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1], - padding=kernels[1]//2, bias=False) + self.kernel_size = kernel_size + self.num_layers = num_layers + self.layers = nn.ModuleList() - # Output layer: (8 + C_last) → 4 (RGBA) - in_ch_2 = 8 + channels[1] - self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2], - padding=kernels[2]//2, bias=False) + # All layers: 12D input (4 RGBD + 8 static) → 4D output + for _ in range(num_layers): + self.layers.append( + nn.Conv2d(12, 4, kernel_size=kernel_size, + padding=kernel_size//2, bias=False) + ) - def forward(self, static_features): - """Forward pass with static feature concatenation. + def forward(self, input_rgbd, static_features): + """Forward pass with uniform 12D→4D layers. Args: + input_rgbd: (B, 4, H, W) input image RGBD (mip 0) static_features: (B, 8, H, W) static features Returns: (B, 4, H, W) RGBA output [0, 1] """ - # Layer 0: Use full 8D static features - x0 = self.layer0(static_features) - x0 = F.relu(x0) + # Layer 0: input RGBD (4D) + static (8D) = 12D + x = torch.cat([input_rgbd, static_features], dim=1) + x = self.layers[0](x) + x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0 - # Layer 1: Concatenate static + layer0 output - x1_input = torch.cat([static_features, x0], dim=1) - x1 = self.layer1(x1_input) - x1 = F.relu(x1) + # Layer 1+: previous (4D) + static (8D) = 12D + for i in range(1, self.num_layers): + x_input = torch.cat([x, static_features], dim=1) + x = self.layers[i](x_input) + if i < self.num_layers - 1: + x = F.relu(x) + else: + x = torch.clamp(x, 0, 1) # Final output [0,1] - # Layer 2: Concatenate static + layer1 output - x2_input = torch.cat([static_features, x1], dim=1) - output = self.layer2(x2_input) - - return torch.sigmoid(output) + return x class PatchDataset(Dataset): @@ -214,14 +219,18 @@ class PatchDataset(Dataset): # Compute static features for patch static_feat = compute_static_features(input_patch.astype(np.float32)) + # Input RGBD (mip 0) - add depth channel + input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) + # Convert to tensors (C, H, W) + input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) - return static_feat, target + return input_rgbd, static_feat, target class ImagePairDataset(Dataset): @@ -252,14 +261,19 @@ class ImagePairDataset(Dataset): # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) + # Input RGBD (mip 0) - add depth channel + h, w = input_img.shape[:2] + input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1) + # Convert to tensors (C, H, W) + input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) - return static_feat, target + return input_rgbd, static_feat, target def train(args): @@ -282,9 +296,10 @@ def train(args): dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Create model - model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) + model = CNNv2(kernel_size=args.kernel_size, num_layers=args.num_layers).to(device) total_params = sum(p.numel() for p in model.parameters()) - print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights") + weights_per_layer = 12 * args.kernel_size * args.kernel_size * 4 + print(f"Model: {args.num_layers} layers, {args.kernel_size}×{args.kernel_size} kernels, {total_params} weights ({weights_per_layer}/layer)") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) @@ -298,12 +313,13 @@ def train(args): model.train() epoch_loss = 0.0 - for static_feat, target in dataloader: + for input_rgbd, static_feat, target in dataloader: + input_rgbd = input_rgbd.to(device) static_feat = static_feat.to(device) target = target.to(device) optimizer.zero_grad() - output = model(static_feat) + output = model(input_rgbd, static_feat) loss = criterion(output, target) loss.backward() optimizer.step() @@ -327,9 +343,9 @@ def train(args): 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss, 'config': { - 'kernels': args.kernel_sizes, - 'channels': args.channels, - 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias'] + 'kernel_size': args.kernel_size, + 'num_layers': args.num_layers, + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") @@ -361,10 +377,10 @@ def main(): # Mix salient points with random samples for better generalization # Model architecture - parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], - help='Kernel sizes for 3 layers (default: 1 3 5)') - parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], - help='Output channels for 3 layers (default: 16 8 4)') + parser.add_argument('--kernel-size', type=int, default=3, + help='Kernel size (uniform for all layers, default: 3)') + parser.add_argument('--num-layers', type=int, default=3, + help='Number of CNN layers (default: 3)') # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') |
