summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtraining/train_cnn.py97
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv3x3.wgsl2
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv5x5.wgsl2
-rw-r--r--workspaces/main/shaders/cnn/cnn_layer.wgsl2
4 files changed, 81 insertions, 22 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 6bdb15f..7a2e85a 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -199,7 +199,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
f.write(f" gray, weights_layer{layer_idx});\n")
- f.write(f" // gray_out already in [0,1] from clipped training\n")
+ f.write(f" // gray_out in [0,1] (clamped to match PyTorch training)\n")
f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n")
f.write(f" }}\n")
@@ -306,6 +306,53 @@ def generate_conv_src_function(kernel_size, output_path):
f.write(f"}}\n")
+def generate_conv_final_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to1() function for final layer with clamp"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Final layer: 7→1 channel (scalar output)\n")
+ f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
+ f.write(f"// Output clamped to [0,1] to match PyTorch training\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to1(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" gray: f32,\n")
+ f.write(f" weights: array<array<f32, 8>, {num_positions}>\n")
+ f.write(f") -> f32 {{\n")
+ f.write(f" let step = 1.0 / resolution;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+ f.write(f" var sum = 0.0;\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1]\n\n")
+
+ # Accumulate
+ f.write(f" sum += weights[pos][0] * rgbd.r;\n")
+ f.write(f" sum += weights[pos][1] * rgbd.g;\n")
+ f.write(f" sum += weights[pos][2] * rgbd.b;\n")
+ f.write(f" sum += weights[pos][3] * rgbd.a;\n")
+ f.write(f" sum += weights[pos][4] * uv_norm.x;\n")
+ f.write(f" sum += weights[pos][5] * uv_norm.y;\n")
+ f.write(f" sum += weights[pos][6] * gray;\n")
+ f.write(f" sum += weights[pos][7]; // Bias\n\n")
+ f.write(f" pos++;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return clamp(sum, 0.0, 1.0); // Match PyTorch clamp\n")
+ f.write(f"}}\n")
+
+
def train(args):
"""Main training loop"""
@@ -393,23 +440,29 @@ def train(args):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, args.layers, kernel_sizes)
- # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ # Generate _src and 7to1 variants for kernel sizes
for ks in set(kernel_sizes):
- if ks == 3:
- continue
conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
if not os.path.exists(conv_path):
- print(f"Warning: {conv_path} not found, skipping _src generation")
+ print(f"Warning: {conv_path} not found, skipping function generation")
continue
- # Check if _src already exists
with open(conv_path, 'r') as f:
content = f.read()
- if f"cnn_conv{ks}x{ks}_7to4_src" in content:
- continue
- generate_conv_src_function(ks, conv_path)
- print(f"Added _src variant to {conv_path}")
+ # Generate _src variant (skip 3x3, already exists)
+ if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer with clamp (all kernel sizes)
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant with clamp to {conv_path}")
+ elif "clamp(sum, 0.0, 1.0)" not in content:
+ print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed")
print("Training complete!")
@@ -443,23 +496,29 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, num_layers, kernel_sizes)
- # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ # Generate _src and 7to1 variants for kernel sizes
for ks in set(kernel_sizes):
- if ks == 3:
- continue
conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
if not os.path.exists(conv_path):
- print(f"Warning: {conv_path} not found, skipping _src generation")
+ print(f"Warning: {conv_path} not found, skipping function generation")
continue
- # Check if _src already exists
with open(conv_path, 'r') as f:
content = f.read()
- if f"cnn_conv{ks}x{ks}_7to4_src" in content:
- continue
- generate_conv_src_function(ks, conv_path)
- print(f"Added _src variant to {conv_path}")
+ # Generate _src variant (skip 3x3, already exists)
+ if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer with clamp (all kernel sizes)
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant with clamp to {conv_path}")
+ elif "clamp(sum, 0.0, 1.0)" not in content:
+ print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed")
print("Export complete!")
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
index 79b0350..00eae22 100644
--- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
@@ -138,5 +138,5 @@ fn cnn_conv3x3_7to1(
}
}
- return sum; // Output in [-1,1]
+ return clamp(sum, 0.0, 1.0); // Match PyTorch clamp
}
diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
index 5570589..4f0a5f3 100644
--- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
@@ -77,7 +77,7 @@ fn cnn_conv5x5_7to1(
}
}
- return sum; // Output in [-1,1]
+ return clamp(sum, 0.0, 1.0); // Match PyTorch clamp
}
// Source layer: 7→4 channels (RGBD output)
diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl
index e67ad31..8eccb26 100644
--- a/workspaces/main/shaders/cnn/cnn_layer.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl
@@ -49,7 +49,7 @@ struct CNNLayerParams {
else if (params.layer_index == 2) {
let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,
gray, weights_layer2);
- // gray_out already in [0,1] from clipped training
+ // gray_out in [0,1] (clamped to match PyTorch training)
result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);
return mix(original_raw, result, params.blend_amount); // [0,1]
}