summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 0495c65..8c7b2b3 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -174,10 +174,13 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
# Generate layer switches
for layer_idx in range(num_layers):
is_final = layer_idx == num_layers - 1
+ ks = kernel_sizes[layer_idx]
+ conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1"
+
if layer_idx == 0:
f.write(f" // Layer 0: 7→4 (RGBD output)\n")
f.write(f" if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
f.write(f" original, weights_layer{layer_idx});\n")
f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n")
f.write(f" // Denormalize to [0,1] for texture storage\n")
@@ -185,7 +188,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write(f" }}\n")
elif not is_final:
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
f.write(f" original, weights_layer{layer_idx});\n")
f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n")
f.write(f" // Denormalize to [0,1] for texture storage\n")
@@ -193,7 +196,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write(f" }}\n")
else:
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
f.write(f" original, weights_layer{layer_idx});\n")
f.write(f" // Denormalize from [-1,1] to [0,1]\n")
f.write(f" let gray_01 = (gray_out + 1.0) * 0.5;\n")