diff options
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 0495c65..8c7b2b3 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -174,10 +174,13 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): # Generate layer switches for layer_idx in range(num_layers): is_final = layer_idx == num_layers - 1 + ks = kernel_sizes[layer_idx] + conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1" + if layer_idx == 0: f.write(f" // Layer 0: 7→4 (RGBD output)\n") f.write(f" if (params.layer_index == {layer_idx}) {{\n") - f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n") f.write(f" // Denormalize to [0,1] for texture storage\n") @@ -185,7 +188,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(f" }}\n") elif not is_final: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") - f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n") f.write(f" // Denormalize to [0,1] for texture storage\n") @@ -193,7 +196,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(f" }}\n") else: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") - f.write(f" let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" // Denormalize from [-1,1] to [0,1]\n") f.write(f" let gray_01 = (gray_out + 1.0) * 0.5;\n") |
