diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-11 16:41:27 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-11 16:41:27 +0100 |
| commit | f71e4b6c3ae7c2b5a0c71fa6b379c44b5d527874 (patch) | |
| tree | 8cedd4592442ca1019a67c0d95f5313a6c371492 /training/train_cnn.py | |
| parent | 09eba6004eb5faa5273e310ca560bfd41e1bc901 (diff) | |
fix: Compute gray from [0,1] RGB in CNN shader generator
Match training forward pass: compute grayscale from original [0,1] RGB
before normalization, then normalize gray to [-1,1].
Previously computed gray from normalized [-1,1] RGB in generated shader,
creating mismatch with train.py which does:
gray = 0.2126*R + 0.7152*G + 0.0722*B # [0,1]
gray = (gray - 0.5) * 2.0 # [-1,1]
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 101 |
1 files changed, 88 insertions, 13 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index ef7a0ae..5ad922e 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -322,7 +322,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n") f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") - f.write(" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n") + f.write(" let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;\n") f.write(" var result = vec4<f32>(0.0);\n\n") # Generate layer switches @@ -405,6 +405,49 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): f.write(");\n\n") +def generate_conv_base_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to4() function for inner layers (vec4-optimized)""" + + k = kernel_size + num_positions = k * k + radius = k // 2 + + with open(output_path, 'a') as f: + f.write(f"\n// Inner layers: 7→4 channels (vec4-optimized)\n") + f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n") + f.write(f"fn cnn_conv{k}x{k}_7to4(\n") + f.write(f" tex: texture_2d<f32>,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2<f32>,\n") + f.write(f" resolution: vec2<f32>,\n") + f.write(f" gray: f32,\n") + f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n") + f.write(f") -> vec4<f32> {{\n") + f.write(f" let step = 1.0 / resolution;\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") + f.write(f" var sum = vec4<f32>(0.0);\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n") + f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n") + + # Accumulate + f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n") + f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n") + f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n") + f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n") + f.write(f" pos += 8;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return sum;\n") + f.write(f"}}\n") + + def generate_conv_src_function(kernel_size, output_path): """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)""" @@ -624,27 +667,43 @@ def train(args): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, args.layers, kernel_sizes) - # Generate _src and 7to1 variants for kernel sizes + # Generate conv shader files for all kernel sizes for ks in set(kernel_sizes): conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + + # Create file with header if it doesn't exist if not os.path.exists(conv_path): - print(f"Warning: {conv_path} not found, skipping function generation") + print(f"Creating {conv_path}...") + with open(conv_path, 'w') as f: + f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n") + generate_conv_base_function(ks, conv_path) + generate_conv_src_function(ks, conv_path) + generate_conv_final_function(ks, conv_path) + print(f"Generated complete {conv_path}") continue + # File exists, check for missing functions with open(conv_path, 'r') as f: content = f.read() - # Generate _src variant (skip 3x3, already exists) - if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: + # Generate base 7to4 if missing + if f"cnn_conv{ks}x{ks}_7to4" not in content: + generate_conv_base_function(ks, conv_path) + print(f"Added base 7to4 to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate _src variant if missing + if f"cnn_conv{ks}x{ks}_7to4_src" not in content: generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") with open(conv_path, 'r') as f: content = f.read() - # Generate 7to1 final layer with sigmoid (all kernel sizes) + # Generate 7to1 final layer if missing if f"cnn_conv{ks}x{ks}_7to1" not in content: generate_conv_final_function(ks, conv_path) - print(f"Added 7to1 variant with sigmoid to {conv_path}") + print(f"Added 7to1 variant to {conv_path}") print("Training complete!") @@ -678,27 +737,43 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, num_layers, kernel_sizes) - # Generate _src and 7to1 variants for kernel sizes + # Generate conv shader files for all kernel sizes for ks in set(kernel_sizes): conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + + # Create file with header if it doesn't exist if not os.path.exists(conv_path): - print(f"Warning: {conv_path} not found, skipping function generation") + print(f"Creating {conv_path}...") + with open(conv_path, 'w') as f: + f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n") + generate_conv_base_function(ks, conv_path) + generate_conv_src_function(ks, conv_path) + generate_conv_final_function(ks, conv_path) + print(f"Generated complete {conv_path}") continue + # File exists, check for missing functions with open(conv_path, 'r') as f: content = f.read() - # Generate _src variant (skip 3x3, already exists) - if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: + # Generate base 7to4 if missing + if f"cnn_conv{ks}x{ks}_7to4" not in content: + generate_conv_base_function(ks, conv_path) + print(f"Added base 7to4 to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate _src variant if missing + if f"cnn_conv{ks}x{ks}_7to4_src" not in content: generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") with open(conv_path, 'r') as f: content = f.read() - # Generate 7to1 final layer with sigmoid (all kernel sizes) + # Generate 7to1 final layer if missing if f"cnn_conv{ks}x{ks}_7to1" not in content: generate_conv_final_function(ks, conv_path) - print(f"Added 7to1 variant with sigmoid to {conv_path}") + print(f"Added 7to1 variant to {conv_path}") print("Export complete!") |
