diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 21:01:47 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 21:01:47 +0100 |
| commit | 2fbfc406abe5a42f45face9b07a91ec64c0d4f78 (patch) | |
| tree | 2a65ffc385ad6edbdb24cf3c945bb701f601e1f3 /training/train_cnn.py | |
| parent | bbef66d114ddd8091f79c8b27e6877c80236031b (diff) | |
update train_cnn.py and shader
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 13 |
1 files changed, 2 insertions, 11 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 16f8e7a..902daa8 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -199,20 +199,11 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" // gray_out already in [0,1] from clipped training\n") - f.write(f" let original_denorm = (original + 1.0) * 0.5;\n") f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n") - f.write(f" let blended = mix(original_denorm, result, params.blend_amount);\n") - f.write(f" return blended; // [0,1]\n") + f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n") f.write(f" }}\n") - # Add else clause for invalid layer index - if num_layers > 0: - f.write(f" else {{\n") - f.write(f" return textureSample(txt, smplr, uv);\n") - f.write(f" }}\n") - - f.write("\n // Non-final layers: denormalize for display\n") - f.write(" return (result + 1.0) * 0.5; // [-1,1] → [0,1]\n") + f.write(" return result; // [-1,1]\n") f.write("}\n") |
