summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn.py101
1 files changed, 88 insertions, 13 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index ef7a0ae..5ad922e 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -322,7 +322,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n")
f.write(" let original_raw = textureSample(original_input, smplr, uv);\n")
f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
- f.write(" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n")
+ f.write(" let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;\n")
f.write(" var result = vec4<f32>(0.0);\n\n")
# Generate layer switches
@@ -405,6 +405,49 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
f.write(");\n\n")
+def generate_conv_base_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to4() function for inner layers (vec4-optimized)"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Inner layers: 7→4 channels (vec4-optimized)\n")
+ f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to4(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" gray: f32,\n")
+ f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n")
+ f.write(f") -> vec4<f32> {{\n")
+ f.write(f" let step = 1.0 / resolution;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+ f.write(f" var sum = vec4<f32>(0.0);\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+
+ # Accumulate
+ f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
+ f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n")
+ f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n")
+ f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n")
+ f.write(f" pos += 8;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
def generate_conv_src_function(kernel_size, output_path):
"""Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)"""
@@ -624,27 +667,43 @@ def train(args):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, args.layers, kernel_sizes)
- # Generate _src and 7to1 variants for kernel sizes
+ # Generate conv shader files for all kernel sizes
for ks in set(kernel_sizes):
conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+
+ # Create file with header if it doesn't exist
if not os.path.exists(conv_path):
- print(f"Warning: {conv_path} not found, skipping function generation")
+ print(f"Creating {conv_path}...")
+ with open(conv_path, 'w') as f:
+ f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
+ generate_conv_base_function(ks, conv_path)
+ generate_conv_src_function(ks, conv_path)
+ generate_conv_final_function(ks, conv_path)
+ print(f"Generated complete {conv_path}")
continue
+ # File exists, check for missing functions
with open(conv_path, 'r') as f:
content = f.read()
- # Generate _src variant (skip 3x3, already exists)
- if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ # Generate base 7to4 if missing
+ if f"cnn_conv{ks}x{ks}_7to4" not in content:
+ generate_conv_base_function(ks, conv_path)
+ print(f"Added base 7to4 to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate _src variant if missing
+ if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
generate_conv_src_function(ks, conv_path)
print(f"Added _src variant to {conv_path}")
with open(conv_path, 'r') as f:
content = f.read()
- # Generate 7to1 final layer with sigmoid (all kernel sizes)
+ # Generate 7to1 final layer if missing
if f"cnn_conv{ks}x{ks}_7to1" not in content:
generate_conv_final_function(ks, conv_path)
- print(f"Added 7to1 variant with sigmoid to {conv_path}")
+ print(f"Added 7to1 variant to {conv_path}")
print("Training complete!")
@@ -678,27 +737,43 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, num_layers, kernel_sizes)
- # Generate _src and 7to1 variants for kernel sizes
+ # Generate conv shader files for all kernel sizes
for ks in set(kernel_sizes):
conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+
+ # Create file with header if it doesn't exist
if not os.path.exists(conv_path):
- print(f"Warning: {conv_path} not found, skipping function generation")
+ print(f"Creating {conv_path}...")
+ with open(conv_path, 'w') as f:
+ f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
+ generate_conv_base_function(ks, conv_path)
+ generate_conv_src_function(ks, conv_path)
+ generate_conv_final_function(ks, conv_path)
+ print(f"Generated complete {conv_path}")
continue
+ # File exists, check for missing functions
with open(conv_path, 'r') as f:
content = f.read()
- # Generate _src variant (skip 3x3, already exists)
- if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ # Generate base 7to4 if missing
+ if f"cnn_conv{ks}x{ks}_7to4" not in content:
+ generate_conv_base_function(ks, conv_path)
+ print(f"Added base 7to4 to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate _src variant if missing
+ if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
generate_conv_src_function(ks, conv_path)
print(f"Added _src variant to {conv_path}")
with open(conv_path, 'r') as f:
content = f.read()
- # Generate 7to1 final layer with sigmoid (all kernel sizes)
+ # Generate 7to1 final layer if missing
if f"cnn_conv{ks}x{ks}_7to1" not in content:
generate_conv_final_function(ks, conv_path)
- print(f"Added 7to1 variant with sigmoid to {conv_path}")
+ print(f"Added 7to1 variant to {conv_path}")
print("Export complete!")