From f71e4b6c3ae7c2b5a0c71fa6b379c44b5d527874 Mon Sep 17 00:00:00 2001 From: skal Date: Wed, 11 Feb 2026 16:41:27 +0100 Subject: fix: Compute gray from [0,1] RGB in CNN shader generator Match training forward pass: compute grayscale from original [0,1] RGB before normalization, then normalize gray to [-1,1]. Previously computed gray from normalized [-1,1] RGB in generated shader, creating mismatch with train.py which does: gray = 0.2126*R + 0.7152*G + 0.0722*B # [0,1] gray = (gray - 0.5) * 2.0 # [-1,1] Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn.py | 101 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 13 deletions(-) (limited to 'training/train_cnn.py') diff --git a/training/train_cnn.py b/training/train_cnn.py index ef7a0ae..5ad922e 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -322,7 +322,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n") f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") - f.write(" let gray = dot(original.rgb, vec3(0.2126, 0.7152, 0.0722));\n") + f.write(" let gray = (dot(original_raw.rgb, vec3(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;\n") f.write(" var result = vec4(0.0);\n\n") # Generate layer switches @@ -405,6 +405,49 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): f.write(");\n\n") +def generate_conv_base_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to4() function for inner layers (vec4-optimized)""" + + k = kernel_size + num_positions = k * k + radius = k // 2 + + with open(output_path, 'a') as f: + f.write(f"\n// Inner layers: 7→4 channels (vec4-optimized)\n") + f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n") + f.write(f"fn cnn_conv{k}x{k}_7to4(\n") + f.write(f" tex: texture_2d,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2,\n") + f.write(f" resolution: vec2,\n") + f.write(f" gray: f32,\n") + f.write(f" weights: array, {num_positions * 8}>\n") + f.write(f") -> vec4 {{\n") + f.write(f" let step = 1.0 / resolution;\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") + f.write(f" var sum = vec4(0.0);\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n") + f.write(f" let in1 = vec4(uv_norm, gray, 1.0);\n\n") + + # Accumulate + f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n") + f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n") + f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n") + f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n") + f.write(f" pos += 8;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return sum;\n") + f.write(f"}}\n") + + def generate_conv_src_function(kernel_size, output_path): """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)""" @@ -624,27 +667,43 @@ def train(args): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, args.layers, kernel_sizes) - # Generate _src and 7to1 variants for kernel sizes + # Generate conv shader files for all kernel sizes for ks in set(kernel_sizes): conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + + # Create file with header if it doesn't exist if not os.path.exists(conv_path): - print(f"Warning: {conv_path} not found, skipping function generation") + print(f"Creating {conv_path}...") + with open(conv_path, 'w') as f: + f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n") + generate_conv_base_function(ks, conv_path) + generate_conv_src_function(ks, conv_path) + generate_conv_final_function(ks, conv_path) + print(f"Generated complete {conv_path}") continue + # File exists, check for missing functions with open(conv_path, 'r') as f: content = f.read() - # Generate _src variant (skip 3x3, already exists) - if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: + # Generate base 7to4 if missing + if f"cnn_conv{ks}x{ks}_7to4" not in content: + generate_conv_base_function(ks, conv_path) + print(f"Added base 7to4 to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate _src variant if missing + if f"cnn_conv{ks}x{ks}_7to4_src" not in content: generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") with open(conv_path, 'r') as f: content = f.read() - # Generate 7to1 final layer with sigmoid (all kernel sizes) + # Generate 7to1 final layer if missing if f"cnn_conv{ks}x{ks}_7to1" not in content: generate_conv_final_function(ks, conv_path) - print(f"Added 7to1 variant with sigmoid to {conv_path}") + print(f"Added 7to1 variant to {conv_path}") print("Training complete!") @@ -678,27 +737,43 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, num_layers, kernel_sizes) - # Generate _src and 7to1 variants for kernel sizes + # Generate conv shader files for all kernel sizes for ks in set(kernel_sizes): conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + + # Create file with header if it doesn't exist if not os.path.exists(conv_path): - print(f"Warning: {conv_path} not found, skipping function generation") + print(f"Creating {conv_path}...") + with open(conv_path, 'w') as f: + f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n") + generate_conv_base_function(ks, conv_path) + generate_conv_src_function(ks, conv_path) + generate_conv_final_function(ks, conv_path) + print(f"Generated complete {conv_path}") continue + # File exists, check for missing functions with open(conv_path, 'r') as f: content = f.read() - # Generate _src variant (skip 3x3, already exists) - if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: + # Generate base 7to4 if missing + if f"cnn_conv{ks}x{ks}_7to4" not in content: + generate_conv_base_function(ks, conv_path) + print(f"Added base 7to4 to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate _src variant if missing + if f"cnn_conv{ks}x{ks}_7to4_src" not in content: generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") with open(conv_path, 'r') as f: content = f.read() - # Generate 7to1 final layer with sigmoid (all kernel sizes) + # Generate 7to1 final layer if missing if f"cnn_conv{ks}x{ks}_7to1" not in content: generate_conv_final_function(ks, conv_path) - print(f"Added 7to1 variant with sigmoid to {conv_path}") + print(f"Added 7to1 variant to {conv_path}") print("Export complete!") -- cgit v1.2.3