summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/export_cnn_v2_shader.py127
-rwxr-xr-xtraining/export_cnn_v2_weights.py85
-rwxr-xr-xtraining/train_cnn_v2.py134
3 files changed, 163 insertions, 183 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
index add28d2..ad5749c 100755
--- a/training/export_cnn_v2_shader.py
+++ b/training/export_cnn_v2_shader.py
@@ -1,8 +1,11 @@
#!/usr/bin/env python3
-"""CNN v2 Shader Export Script
+"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture
Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
Generates one shader per layer with embedded weight arrays.
+
+Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size.
+ This script is for debugging/testing with per-layer shaders.
"""
import argparse
@@ -11,16 +14,13 @@ import torch
from pathlib import Path
-def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels,
- output_dir, is_output_layer=False):
+def export_layer_shader(layer_idx, weights, kernel_size, output_dir, is_output_layer=False):
"""Generate WGSL compute shader for a single CNN layer.
Args:
- layer_idx: Layer index (0, 1, 2)
- weights: (out_ch, in_ch, k, k) weight tensor
- kernel_size: Kernel size (1, 3, 5, etc.)
- in_channels: Input channels (includes 8D static features)
- out_channels: Output channels
+ layer_idx: Layer index (0, 1, 2, ...)
+ weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
+ kernel_size: Kernel size (3, 5, etc.)
output_dir: Output directory path
is_output_layer: True if this is the final RGBA output layer
"""
@@ -39,12 +39,12 @@ def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channe
if is_output_layer:
activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation"
- shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated
-// Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels}
+ shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
+// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
const KERNEL_SIZE: u32 = {kernel_size}u;
-const IN_CHANNELS: u32 = {in_channels}u;
-const OUT_CHANNELS: u32 = {out_channels}u;
+const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static)
+const OUT_CHANNELS: u32 = 4u; // Uniform output
const KERNEL_RADIUS: i32 = {radius};
// Weights quantized to float16 (stored as f32 in WGSL)
@@ -65,21 +65,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {{
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {{
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {{
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}}
@@ -95,9 +93,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
// Load static features (always available)
let static_feat = unpack_static_features(coord);
- // Convolution
- var output: array<f32, OUT_CHANNELS>;
- for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {{
var sum: f32 = 0.0;
for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
@@ -110,28 +108,27 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
clamp(sample_coord.y, 0, i32(dims.y) - 1)
);
- // Load input features
+ // Load features at this spatial location
let static_local = unpack_static_features(clamped);
- let layer_local = unpack_layer_channels(clamped);
+ let layer_local = unpack_layer_channels(clamped); // 4D
// Weight index calculation
let ky_idx = u32(ky + KERNEL_RADIUS);
let kx_idx = u32(kx + KERNEL_RADIUS);
let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
- // Accumulate: static features (8D)
- for (var i: u32 = 0u; i < 8u; i++) {{
- let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ // Accumulate: previous/input channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * static_local[i];
+ sum += weights[w_idx] * layer_local[i];
}}
- // Accumulate: layer input channels (if layer_idx > 0)
- let prev_channels = IN_CHANNELS - 8u;
- for (var i: u32 = 0u; i < prev_channels; i++) {{
- let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
- (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * layer_local[i];
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
+ (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
}}
}}
}}
@@ -162,53 +159,37 @@ def export_checkpoint(checkpoint_path, output_dir):
state_dict = checkpoint['model_state_dict']
config = checkpoint['config']
+ kernel_size = config.get('kernel_size', 3)
+ num_layers = config.get('num_layers', 3)
+
print(f"Configuration:")
- print(f" Kernels: {config['kernels']}")
- print(f" Channels: {config['channels']}")
- print(f" Features: {config['features']}")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Layers: {num_layers}")
+ print(f" Architecture: uniform 12D→4D")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\nExporting shaders to {output_dir}/")
- # Layer 0: 8 → channels[0]
- layer0_weights = state_dict['layer0.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=0,
- weights=layer0_weights,
- kernel_size=config['kernels'][0],
- in_channels=8,
- out_channels=config['channels'][0],
- output_dir=output_dir,
- is_output_layer=False
- )
+ # All layers uniform: 12D→4D
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
- # Layer 1: (8 + channels[0]) → channels[1]
- layer1_weights = state_dict['layer1.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=1,
- weights=layer1_weights,
- kernel_size=config['kernels'][1],
- in_channels=8 + config['channels'][0],
- out_channels=config['channels'][1],
- output_dir=output_dir,
- is_output_layer=False
- )
+ layer_weights = state_dict[layer_key].detach().numpy()
+ is_output = (i == num_layers - 1)
- # Layer 2: (8 + channels[1]) → 4 (RGBA)
- layer2_weights = state_dict['layer2.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=2,
- weights=layer2_weights,
- kernel_size=config['kernels'][2],
- in_channels=8 + config['channels'][1],
- out_channels=4,
- output_dir=output_dir,
- is_output_layer=True
- )
+ export_layer_shader(
+ layer_idx=i,
+ weights=layer_weights,
+ kernel_size=kernel_size,
+ output_dir=output_dir,
+ is_output_layer=is_output
+ )
- print(f"\nExport complete! Generated 3 shader files.")
+ print(f"\nExport complete! Generated {num_layers} shader files.")
def main():
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
index 8a2fcdc..07254fc 100755
--- a/training/export_cnn_v2_weights.py
+++ b/training/export_cnn_v2_weights.py
@@ -45,53 +45,38 @@ def export_weights_binary(checkpoint_path, output_path):
state_dict = checkpoint['model_state_dict']
config = checkpoint['config']
+ kernel_size = config.get('kernel_size', 3)
+ num_layers = config.get('num_layers', 3)
+
print(f"Configuration:")
- print(f" Kernels: {config['kernels']}")
- print(f" Channels: {config['channels']}")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Layers: {num_layers}")
+ print(f" Architecture: uniform 12D→4D (bias=False)")
- # Collect layer info
+ # Collect layer info - all layers uniform 12D→4D
layers = []
all_weights = []
weight_offset = 0
- # Layer 0: 8 → channels[0]
- layer0_weights = state_dict['layer0.weight'].detach().numpy()
- layer0_flat = layer0_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][0],
- 'in_channels': 8,
- 'out_channels': config['channels'][0],
- 'weight_offset': weight_offset,
- 'weight_count': len(layer0_flat)
- })
- all_weights.extend(layer0_flat)
- weight_offset += len(layer0_flat)
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
+
+ layer_weights = state_dict[layer_key].detach().numpy()
+ layer_flat = layer_weights.flatten()
- # Layer 1: (8 + channels[0]) → channels[1]
- layer1_weights = state_dict['layer1.weight'].detach().numpy()
- layer1_flat = layer1_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][1],
- 'in_channels': 8 + config['channels'][0],
- 'out_channels': config['channels'][1],
- 'weight_offset': weight_offset,
- 'weight_count': len(layer1_flat)
- })
- all_weights.extend(layer1_flat)
- weight_offset += len(layer1_flat)
+ layers.append({
+ 'kernel_size': kernel_size,
+ 'in_channels': 12, # 4 (input/prev) + 8 (static)
+ 'out_channels': 4, # Uniform output
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer_flat)
+ })
+ all_weights.extend(layer_flat)
+ weight_offset += len(layer_flat)
- # Layer 2: (8 + channels[1]) → 4 (RGBA output)
- layer2_weights = state_dict['layer2.weight'].detach().numpy()
- layer2_flat = layer2_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][2],
- 'in_channels': 8 + config['channels'][1],
- 'out_channels': 4,
- 'weight_offset': weight_offset,
- 'weight_count': len(layer2_flat)
- })
- all_weights.extend(layer2_flat)
- weight_offset += len(layer2_flat)
+ print(f" Layer {i}: 12D→4D, {len(layer_flat)} weights")
# Convert to f16
# TODO: Use 8-bit quantization for 2× size reduction
@@ -183,21 +168,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}
@@ -238,9 +221,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let out_channels = weights[layer0_info_base + 2u];
let weight_offset = weights[layer0_info_base + 3u];
- // Convolution (simplified - expand to full kernel loop)
- var output: array<f32, 8>;
- for (var c: u32 = 0u; c < min(out_channels, 8u); c++) {
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
output[c] = 0.0; // TODO: Actual convolution
}
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 758b044..8b3b91c 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -1,11 +1,11 @@
#!/usr/bin/env python3
-"""CNN v2 Training Script - Parametric Static Features
+"""CNN v2 Training Script - Uniform 12D→4D Architecture
-Trains a multi-layer CNN with 7D static feature input:
-- RGBD (4D)
-- UV coordinates (2D)
-- sin(10*uv.x) position encoding (1D)
-- Bias dimension (1D, always 1.0)
+Architecture:
+- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias
+- Input RGBD (4D): original image mip 0
+- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
+- Uniform layer structure with bias=False (bias in static features)
"""
import argparse
@@ -21,20 +21,26 @@ import cv2
def compute_static_features(rgb, depth=None):
- """Generate 7D static features + bias dimension.
+ """Generate 8D static features (parametric + spatial).
Args:
rgb: (H, W, 3) RGB image [0, 1]
depth: (H, W) depth map [0, 1], optional
Returns:
- (H, W, 8) static features tensor
+ (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]
+
+ Note: p0-p3 are parametric features (can be mips, gradients, etc.)
+ For training, we use RGBD as default, but could use mip1/2
"""
h, w = rgb.shape[:2]
- # RGBD channels
- r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
- d = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
+ # Parametric features (p0-p3) - using RGBD as default
+ # TODO: Experiment with mip1 grayscale, gradients, etc.
+ p0 = rgb[:, :, 0].astype(np.float32)
+ p1 = rgb[:, :, 1].astype(np.float32)
+ p2 = rgb[:, :, 2].astype(np.float32)
+ p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
# UV coordinates (normalized [0, 1])
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
@@ -43,65 +49,64 @@ def compute_static_features(rgb, depth=None):
# Multi-frequency position encoding
sin10_x = np.sin(10.0 * uv_x).astype(np.float32)
- # Bias dimension (always 1.0)
+ # Bias dimension (always 1.0) - replaces Conv2d bias parameter
bias = np.ones((h, w), dtype=np.float32)
- # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias]
- features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1)
+ # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias]
+ features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1)
return features
class CNNv2(nn.Module):
- """CNN v2 with parametric static features.
+ """CNN v2 - Uniform 12D→4D Architecture
+
+ All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
+ Uses bias=False (bias integrated in static features as 1.0)
TODO: Add quantization-aware training (QAT) for 8-bit weights
- Use torch.quantization.QuantStub/DeQuantStub
- Train with fake quantization to adapt to 8-bit precision
- - Target: ~1.6 KB weights (vs 3.2 KB with f16)
+ - Target: ~1.3 KB weights (vs 2.6 KB with f16)
"""
- def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]):
+ def __init__(self, kernel_size=3, num_layers=3):
super().__init__()
- self.kernels = kernels
- self.channels = channels
-
- # Input layer: 8D (7 features + bias) → channels[0]
- self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0],
- padding=kernels[0]//2, bias=False)
-
- # Inner layers: (8 + C_prev) → C_next
- in_ch_1 = 8 + channels[0]
- self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1],
- padding=kernels[1]//2, bias=False)
+ self.kernel_size = kernel_size
+ self.num_layers = num_layers
+ self.layers = nn.ModuleList()
- # Output layer: (8 + C_last) → 4 (RGBA)
- in_ch_2 = 8 + channels[1]
- self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2],
- padding=kernels[2]//2, bias=False)
+ # All layers: 12D input (4 RGBD + 8 static) → 4D output
+ for _ in range(num_layers):
+ self.layers.append(
+ nn.Conv2d(12, 4, kernel_size=kernel_size,
+ padding=kernel_size//2, bias=False)
+ )
- def forward(self, static_features):
- """Forward pass with static feature concatenation.
+ def forward(self, input_rgbd, static_features):
+ """Forward pass with uniform 12D→4D layers.
Args:
+ input_rgbd: (B, 4, H, W) input image RGBD (mip 0)
static_features: (B, 8, H, W) static features
Returns:
(B, 4, H, W) RGBA output [0, 1]
"""
- # Layer 0: Use full 8D static features
- x0 = self.layer0(static_features)
- x0 = F.relu(x0)
+ # Layer 0: input RGBD (4D) + static (8D) = 12D
+ x = torch.cat([input_rgbd, static_features], dim=1)
+ x = self.layers[0](x)
+ x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0
- # Layer 1: Concatenate static + layer0 output
- x1_input = torch.cat([static_features, x0], dim=1)
- x1 = self.layer1(x1_input)
- x1 = F.relu(x1)
+ # Layer 1+: previous (4D) + static (8D) = 12D
+ for i in range(1, self.num_layers):
+ x_input = torch.cat([x, static_features], dim=1)
+ x = self.layers[i](x_input)
+ if i < self.num_layers - 1:
+ x = F.relu(x)
+ else:
+ x = torch.clamp(x, 0, 1) # Final output [0,1]
- # Layer 2: Concatenate static + layer1 output
- x2_input = torch.cat([static_features, x1], dim=1)
- output = self.layer2(x2_input)
-
- return torch.sigmoid(output)
+ return x
class PatchDataset(Dataset):
@@ -214,14 +219,18 @@ class PatchDataset(Dataset):
# Compute static features for patch
static_feat = compute_static_features(input_patch.astype(np.float32))
+ # Input RGBD (mip 0) - add depth channel
+ input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
+
# Convert to tensors (C, H, W)
+ input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1)
# Pad target to 4 channels (RGBA)
target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
- return static_feat, target
+ return input_rgbd, static_feat, target
class ImagePairDataset(Dataset):
@@ -252,14 +261,19 @@ class ImagePairDataset(Dataset):
# Compute static features
static_feat = compute_static_features(input_img.astype(np.float32))
+ # Input RGBD (mip 0) - add depth channel
+ h, w = input_img.shape[:2]
+ input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1)
+
# Convert to tensors (C, H, W)
+ input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1)
# Pad target to 4 channels (RGBA)
target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
- return static_feat, target
+ return input_rgbd, static_feat, target
def train(args):
@@ -282,9 +296,10 @@ def train(args):
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# Create model
- model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device)
+ model = CNNv2(kernel_size=args.kernel_size, num_layers=args.num_layers).to(device)
total_params = sum(p.numel() for p in model.parameters())
- print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights")
+ weights_per_layer = 12 * args.kernel_size * args.kernel_size * 4
+ print(f"Model: {args.num_layers} layers, {args.kernel_size}×{args.kernel_size} kernels, {total_params} weights ({weights_per_layer}/layer)")
# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
@@ -298,12 +313,13 @@ def train(args):
model.train()
epoch_loss = 0.0
- for static_feat, target in dataloader:
+ for input_rgbd, static_feat, target in dataloader:
+ input_rgbd = input_rgbd.to(device)
static_feat = static_feat.to(device)
target = target.to(device)
optimizer.zero_grad()
- output = model(static_feat)
+ output = model(input_rgbd, static_feat)
loss = criterion(output, target)
loss.backward()
optimizer.step()
@@ -327,9 +343,9 @@ def train(args):
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
'config': {
- 'kernels': args.kernel_sizes,
- 'channels': args.channels,
- 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias']
+ 'kernel_size': args.kernel_size,
+ 'num_layers': args.num_layers,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias']
}
}, checkpoint_path)
print(f" → Saved checkpoint: {checkpoint_path}")
@@ -361,10 +377,10 @@ def main():
# Mix salient points with random samples for better generalization
# Model architecture
- parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],
- help='Kernel sizes for 3 layers (default: 1 3 5)')
- parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4],
- help='Output channels for 3 layers (default: 16 8 4)')
+ parser.add_argument('--kernel-size', type=int, default=3,
+ help='Kernel size (uniform for all layers, default: 3)')
+ parser.add_argument('--num-layers', type=int, default=3,
+ help='Number of CNN layers (default: 3)')
# Training parameters
parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')