summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 10:27:44 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 10:27:44 +0100
commit96a349b9874c6cdaac525ba062a0f4f90c9bc3ed (patch)
treea4eb24fdb417393cbe5a0dc84bf5063cffc94daf
parent75af266889b61b5722d842a1a1eb23f79bc06a85 (diff)
feat: Add coordinate-aware CNN layer 0 for position-dependent stylization
- Implement CoordConv2d custom layer accepting (x,y) patch center - Split layer 0 weights: rgba_weights (9x mat4x4) + coord_weights (mat2x4) - Add *_with_coord() functions to 3x3/5x5/7x7 convolution shaders - Update training script to generate coordinate grid and export split weights - Regenerate placeholder weights with new format Size impact: +32B coord weights + ~100B shader code = +132B total All 36 tests passing (100%) handoff(Claude): CNN coordinate awareness implemented, ready for training Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rw-r--r--training/README.md167
-rw-r--r--training/image_style_processor.py646
-rw-r--r--training/input/img_000.pngbin0 -> 420360 bytes
-rw-r--r--training/input/img_001.pngbin0 -> 232897 bytes
-rw-r--r--training/input/img_002.pngbin0 -> 183917 bytes
-rw-r--r--training/input/img_003.pngbin0 -> 183977 bytes
-rw-r--r--training/input/img_004.pngbin0 -> 358562 bytes
-rw-r--r--training/input/img_005.pngbin0 -> 218300 bytes
-rw-r--r--training/input/img_006.pngbin0 -> 445836 bytes
-rw-r--r--training/input/img_007.pngbin0 -> 349498 bytes
-rw-r--r--training/output/img_000.pngbin0 -> 16332 bytes
-rw-r--r--training/output/img_001.pngbin0 -> 7628 bytes
-rw-r--r--training/output/img_002.pngbin0 -> 7715 bytes
-rw-r--r--training/output/img_003.pngbin0 -> 7206 bytes
-rw-r--r--training/output/img_004.pngbin0 -> 12803 bytes
-rw-r--r--training/output/img_005.pngbin0 -> 5758 bytes
-rw-r--r--training/output/img_006.pngbin0 -> 27958 bytes
-rw-r--r--training/output/img_007.pngbin0 -> 21471 bytes
-rwxr-xr-xtraining/train_cnn.py301
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv3x3.wgsl27
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv5x5.wgsl27
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv7x7.wgsl27
-rw-r--r--workspaces/main/shaders/cnn/cnn_layer.wgsl6
-rw-r--r--workspaces/main/shaders/cnn/cnn_weights_generated.wgsl10
24 files changed, 1206 insertions, 5 deletions
diff --git a/training/README.md b/training/README.md
new file mode 100644
index 0000000..08379ee
--- /dev/null
+++ b/training/README.md
@@ -0,0 +1,167 @@
+# Image Style Processor
+
+A comprehensive Python script that applies artistic hand-drawn and futuristic effects to images.
+
+## Requirements
+
+- Python 3
+- OpenCV (cv2)
+- NumPy
+
+Install dependencies:
+```bash
+pip install opencv-python numpy
+```
+
+## Usage
+
+```bash
+python3 image_style_processor.py <input_directory> <output_directory> <style>
+```
+
+### Arguments
+
+- `input_directory`: Directory containing your input images (PNG, JPG, JPEG)
+- `output_directory`: Directory where processed images will be saved (created if doesn't exist)
+- `style`: The artistic style to apply (see below)
+
+## Available Styles
+
+### Sketch Styles
+
+1. **pencil_sketch** - Dense cross-hatching with progressive layers in shadows
+ - Best for: Detailed technical drawings, architectural scenes
+ - Features: Clean line art, 5 layers of cross-hatching, strong shadow definition
+
+2. **ink_drawing** - Bold black outlines with comic book aesthetic
+ - Best for: Graphic novel style, high contrast scenes
+ - Features: Bold outlines, posterized tones, minimal shading
+
+3. **charcoal_pastel** - Dramatic contrasts with soft, smudged textures
+ - Best for: Portraits, dramatic landscapes
+ - Features: Soft blending, grainy texture, highlighted areas
+
+4. **conte_crayon** - Directional strokes following image contours
+ - Best for: Figure studies, natural forms
+ - Features: Stroke direction follows gradients, cross-hatching in dark areas
+
+5. **gesture_sketch** - Loose, quick observational sketch style
+ - Best for: Quick studies, energetic compositions
+ - Features: Randomized line wobble, sparse suggestion lines
+
+### Futuristic Styles
+
+6. **circuit_board** - Tech blueprint with circuit paths and geometric patterns
+ - Best for: Sci-fi imagery, technological themes
+ - Features: Multi-layer circuit paths, connection nodes, technical grid overlay
+
+7. **glitch_art** - Digital corruption with scan line shifts and pixel sorting
+ - Best for: Cyberpunk aesthetics, digital art
+ - Features: Horizontal scan artifacts, block displacement, pixel sorting, noise strips
+
+8. **wireframe_topo** - Topographic contour lines with holographic grid
+ - Best for: Landscape, abstract patterns, sci-fi hologram effect
+ - Features: 20 contour levels, scan lines, measurement markers, grid overlay
+
+9. **data_mosaic** - Voronoi geometric fragmentation with angular cells
+ - Best for: Abstract art, geometric compositions
+ - Features: 200 Voronoi cells, posterized tones, embedded geometric patterns
+
+10. **holographic_scan** - CRT/hologram display with scanlines and HUD elements
+ - Best for: Retro-futuristic, heads-up display aesthetic
+ - Features: Scanlines, interference patterns, glitch effects, corner brackets, crosshair
+
+## Examples
+
+### Sketch Effects
+
+Process images with pencil sketch:
+```bash
+python3 image_style_processor.py ./photos ./output pencil_sketch
+```
+
+Apply ink drawing style:
+```bash
+python3 image_style_processor.py ./input ./sketches ink_drawing
+```
+
+Create charcoal effect:
+```bash
+python3 image_style_processor.py ./images ./results charcoal_pastel
+```
+
+### Futuristic Effects
+
+Apply circuit board style:
+```bash
+python3 image_style_processor.py ./photos ./output circuit_board
+```
+
+Create glitch art:
+```bash
+python3 image_style_processor.py ./input ./glitched glitch_art
+```
+
+Apply holographic effect:
+```bash
+python3 image_style_processor.py ./images ./holo holographic_scan
+```
+
+## Output
+
+- Processed images are saved to the output directory with **the same filename** as the input
+- Supported input formats: PNG, JPG, JPEG (case-insensitive)
+- Output format: PNG (preserves quality)
+- Original images are never modified
+
+## Style Comparison
+
+### Sketch Styles
+- **pencil_sketch**: Most detailed, traditional drawing look
+- **ink_drawing**: Boldest, most graphic/comic-like
+- **charcoal_pastel**: Softest, most artistic/painterly
+- **conte_crayon**: Most directional, follows contours
+- **gesture_sketch**: Loosest, most expressive
+
+### Futuristic Styles
+- **circuit_board**: Cleanest, most technical/blueprint-like
+- **glitch_art**: Most chaotic, digital corruption aesthetic
+- **wireframe_topo**: Most structured, topographic/hologram feel
+- **data_mosaic**: Most geometric, fragmented cells
+- **holographic_scan**: Most retro-futuristic, HUD/CRT display
+
+## Tips
+
+- Images are automatically converted to grayscale before processing
+- All styles work best with high-resolution images (300+ DPI recommended)
+- Processing time varies by style:
+ - Fast: ink_drawing, glitch_art, holographic_scan
+ - Medium: charcoal_pastel, gesture_sketch, circuit_board, wireframe_topo
+ - Slow: pencil_sketch, conte_crayon, data_mosaic (due to intensive computation)
+- For batch processing large collections, consider processing in smaller batches
+- Randomized styles (glitch_art, gesture_sketch, data_mosaic) will produce slightly different results each run
+
+## Technical Notes
+
+### Randomization
+Some styles use randomization for natural variation:
+- **glitch_art**: Random scan line shifts, block positions
+- **gesture_sketch**: Random line wobble, stroke placement
+- **data_mosaic**: Random Voronoi cell centers
+- **circuit_board**: Random pattern placement in dark regions
+- **holographic_scan**: Random glitch line positions
+
+### Processing Details
+- **pencil_sketch**: Uses 5-level progressive cross-hatching algorithm
+- **conte_crayon**: Follows Sobel gradients for directional strokes
+- **wireframe_topo**: Generates 20 brightness-based contour levels
+- **data_mosaic**: Creates 200 Voronoi cells via nearest-neighbor algorithm
+- **holographic_scan**: Applies scanline patterns and interference waves
+
+## License
+
+Free to use and modify for any purpose.
+
+## Version
+
+Version 1.0 - Complete collection of 10 artistic styles (5 sketch + 5 futuristic)
diff --git a/training/image_style_processor.py b/training/image_style_processor.py
new file mode 100644
index 0000000..fbd247b
--- /dev/null
+++ b/training/image_style_processor.py
@@ -0,0 +1,646 @@
+#!/usr/bin/env python3
+"""
+Hand-Drawn & Futuristic Style Image Processor
+Processes all images from input directory and saves to output directory with same names.
+
+Usage:
+ python3 image_style_processor.py <input_dir> <output_dir> <style>
+
+Sketch Styles:
+ - pencil_sketch: Dense cross-hatching with progressive layers in shadows
+ - ink_drawing: Bold black outlines with comic book aesthetic
+ - charcoal_pastel: Dramatic contrasts with soft, smudged textures
+ - conte_crayon: Directional strokes following image contours
+ - gesture_sketch: Loose, quick observational sketch style
+
+Futuristic Styles:
+ - circuit_board: Tech blueprint with circuit paths and geometric patterns
+ - glitch_art: Digital corruption with scan line shifts and pixel sorting
+ - wireframe_topo: Topographic contour lines with holographic grid
+ - data_mosaic: Voronoi geometric fragmentation with angular cells
+ - holographic_scan: CRT/hologram display with scanlines and HUD elements
+
+Example:
+ python3 image_style_processor.py ./input ./output pencil_sketch
+ python3 image_style_processor.py ./photos ./results glitch_art
+"""
+
+import cv2
+import numpy as np
+import os
+import sys
+import glob
+
+
+def apply_pencil_sketch(gray):
+ """Dense cross-hatching with progressive layers in shadows"""
+ inverted = 255 - gray
+ blurred = cv2.GaussianBlur(inverted, (21, 21), 0)
+ inverted_blurred = 255 - blurred
+ sketch = cv2.divide(gray, inverted_blurred, scale=256.0)
+
+ edges = cv2.Canny(gray, 50, 150)
+ edges = cv2.dilate(edges, np.ones((2, 2), np.uint8), iterations=1)
+ edges_inverted = 255 - edges
+
+ combined = cv2.multiply(sketch.astype(np.float32), edges_inverted.astype(np.float32) / 255.0)
+ combined = combined.astype(np.uint8)
+
+ height, width = combined.shape
+ cross_hatch = np.ones_like(combined) * 255
+ spacing = 4
+
+ for i in range(0, height, spacing):
+ for j in range(0, width, spacing):
+ region = gray[i:min(i+spacing, height), j:min(j+spacing, width)]
+ darkness = 255 - np.mean(region)
+
+ if darkness > 30:
+ cv2.line(cross_hatch, (j, i), (min(j+spacing, width), min(i+spacing, height)), 160, 1)
+ if darkness > 70:
+ cv2.line(cross_hatch, (j, min(i+spacing, height)), (min(j+spacing, width), i), 140, 1)
+ if darkness > 110:
+ cv2.line(cross_hatch, (j+spacing//2, i), (j+spacing//2, min(i+spacing, height)), 120, 1)
+ if darkness > 150:
+ cv2.line(cross_hatch, (j, i+spacing//2), (min(j+spacing, width), i+spacing//2), 100, 1)
+ if darkness > 190:
+ cv2.line(cross_hatch, (j, i+2), (min(j+spacing, width), min(i+spacing+2, height)), 80, 1)
+
+ final = cv2.min(combined, cross_hatch)
+ final = cv2.convertScaleAbs(final, alpha=1.3, beta=5)
+
+ noise = np.random.normal(0, 2, final.shape).astype(np.uint8)
+ final = cv2.add(final, noise)
+ final = np.clip(final, 0, 255)
+
+ return final
+
+
+def apply_ink_drawing(gray):
+ """Bold black outlines with comic book aesthetic"""
+ edges = cv2.Canny(gray, 30, 100)
+ edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
+ edges_inverted = 255 - edges
+
+ posterized = cv2.convertScaleAbs(gray, alpha=0.3, beta=50)
+ posterized = (posterized // 50) * 50
+
+ smooth = cv2.bilateralFilter(posterized, 9, 75, 75)
+
+ final = cv2.min(smooth, edges_inverted)
+ final = cv2.convertScaleAbs(final, alpha=1.5, beta=-30)
+
+ texture = np.random.normal(0, 5, final.shape).astype(np.int16)
+ final = np.clip(final.astype(np.int16) + texture, 0, 255).astype(np.uint8)
+
+ return final
+
+
+def apply_charcoal_pastel(gray):
+ """Dramatic contrasts with soft, smudged textures"""
+ inverted = 255 - gray
+ blurred = cv2.GaussianBlur(inverted, (25, 25), 0)
+ inverted_blurred = 255 - blurred
+ charcoal = cv2.divide(gray, inverted_blurred, scale=256.0)
+
+ charcoal = cv2.convertScaleAbs(charcoal, alpha=1.8, beta=-40)
+
+ kernel_size = 15
+ kernel = np.ones((kernel_size, kernel_size), np.float32) / (kernel_size * kernel_size)
+ smudged = cv2.filter2D(charcoal, -1, kernel)
+
+ blended = cv2.addWeighted(charcoal, 0.6, smudged, 0.4, 0)
+
+ grain = np.random.normal(0, 15, blended.shape).astype(np.int16)
+ textured = np.clip(blended.astype(np.int16) + grain, 0, 255).astype(np.uint8)
+
+ highlights = cv2.threshold(gray, 180, 255, cv2.THRESH_BINARY)[1]
+ highlights_blurred = cv2.GaussianBlur(highlights, (15, 15), 0)
+
+ final = cv2.addWeighted(textured, 0.85, highlights_blurred, 0.15, 0)
+ final = cv2.convertScaleAbs(final, alpha=1.2, beta=5)
+
+ return final
+
+
+def apply_conte_crayon(gray):
+ """Directional strokes following image contours"""
+ inverted = 255 - gray
+ blurred = cv2.GaussianBlur(inverted, (21, 21), 0)
+ inverted_blurred = 255 - blurred
+ sketch = cv2.divide(gray, inverted_blurred, scale=256.0)
+
+ edges = cv2.Canny(gray, 40, 120)
+ edges = cv2.dilate(edges, np.ones((2, 2), np.uint8), iterations=1)
+
+ height, width = gray.shape
+ strokes = np.ones_like(gray) * 255
+
+ sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=5)
+ sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=5)
+
+ stroke_spacing = 6
+ for i in range(0, height - stroke_spacing, stroke_spacing):
+ for j in range(0, width - stroke_spacing, stroke_spacing):
+ gx = sobelx[i:i+stroke_spacing, j:j+stroke_spacing].mean()
+ gy = sobely[i:i+stroke_spacing, j:j+stroke_spacing].mean()
+ darkness = 255 - gray[i:i+stroke_spacing, j:j+stroke_spacing].mean()
+
+ if darkness > 40:
+ angle = np.arctan2(gy, gx)
+ length = min(stroke_spacing * 2, int(darkness / 30))
+ dx = int(length * np.cos(angle))
+ dy = int(length * np.sin(angle))
+ intensity = max(100, 255 - int(darkness * 1.2))
+ cv2.line(strokes, (j, i), (j + dx, i + dy), intensity, 2)
+
+ if darkness > 100:
+ angle2 = angle + np.pi / 3
+ dx2 = int(length * 0.7 * np.cos(angle2))
+ dy2 = int(length * 0.7 * np.sin(angle2))
+ intensity2 = max(80, 255 - int(darkness * 1.4))
+ cv2.line(strokes, (j, i), (j + dx2, i + dy2), intensity2, 2)
+
+ edges_inv = 255 - edges
+ combined = cv2.min(sketch, strokes)
+ combined = cv2.min(combined, edges_inv)
+
+ texture = np.random.normal(0, 10, combined.shape).astype(np.int16)
+ textured = np.clip(combined.astype(np.int16) + texture, 0, 255).astype(np.uint8)
+
+ final = cv2.convertScaleAbs(textured, alpha=1.25, beta=8)
+ final = cv2.GaussianBlur(final, (3, 3), 0)
+
+ return final
+
+
+def apply_gesture_sketch(gray):
+ """Loose, quick observational sketch style"""
+ edges_strong = cv2.Canny(gray, 60, 180)
+ edges_medium = cv2.Canny(gray, 30, 90)
+ edges_strong = cv2.dilate(edges_strong, np.ones((2, 2), np.uint8), iterations=1)
+
+ height, width = gray.shape
+ sketch = np.ones_like(gray) * 255
+
+ y_coords, x_coords = np.where(edges_strong > 0)
+ for i in range(0, len(y_coords), 2):
+ y, x = y_coords[i], x_coords[i]
+ offset_y = int(np.random.normal(0, 1.5))
+ offset_x = int(np.random.normal(0, 1.5))
+ new_y = np.clip(y + offset_y, 0, height - 1)
+ new_x = np.clip(x + offset_x, 0, width - 1)
+ darkness = int(np.random.uniform(50, 120))
+ cv2.circle(sketch, (new_x, new_y), 1, darkness, -1)
+
+ y_coords2, x_coords2 = np.where(edges_medium > 0)
+ for i in range(0, len(y_coords2), 8):
+ y, x = y_coords2[i], x_coords2[i]
+ offset_y = int(np.random.normal(0, 2))
+ offset_x = int(np.random.normal(0, 2))
+ new_y = np.clip(y + offset_y, 0, height - 1)
+ new_x = np.clip(x + offset_x, 0, width - 1)
+ darkness = int(np.random.uniform(180, 220))
+ cv2.circle(sketch, (new_x, new_y), 1, darkness, -1)
+
+ tone_spacing = 12
+ for i in range(0, height, tone_spacing):
+ for j in range(0, width, tone_spacing):
+ region_dark = 255 - gray[i:min(i+tone_spacing, height), j:min(j+tone_spacing, width)].mean()
+
+ if region_dark > 80:
+ num_strokes = int(region_dark / 60)
+ for _ in range(num_strokes):
+ start_x = j + int(np.random.uniform(0, tone_spacing))
+ start_y = i + int(np.random.uniform(0, tone_spacing))
+ angle = np.random.uniform(0, 2 * np.pi)
+ length = int(np.random.uniform(3, 8))
+ end_x = start_x + int(length * np.cos(angle))
+ end_y = start_y + int(length * np.sin(angle))
+ darkness = int(np.random.uniform(140, 200))
+ cv2.line(sketch, (start_x, start_y), (end_x, end_y), darkness, 1)
+
+ texture = np.random.normal(0, 5, sketch.shape).astype(np.int16)
+ final = np.clip(sketch.astype(np.int16) + texture, 0, 255).astype(np.uint8)
+ final = cv2.GaussianBlur(final, (3, 3), 0)
+
+ return final
+
+
+def apply_circuit_board(gray):
+ """Circuit board / tech blueprint effect with clean lines and geometric patterns"""
+ height, width = gray.shape
+ result = np.ones_like(gray) * 255
+
+ edges1 = cv2.Canny(gray, 30, 90)
+ edges2 = cv2.Canny(gray, 60, 180)
+ edges3 = cv2.Canny(gray, 100, 250)
+
+ edges1 = cv2.dilate(edges1, np.ones((2, 2), np.uint8), iterations=1)
+
+ circuit = np.ones_like(gray) * 255
+ circuit[edges3 > 0] = 0
+ circuit[edges2 > 0] = 80
+ circuit[edges1 > 0] = 160
+
+ kernel = np.ones((3, 3), np.uint8)
+ dilated = cv2.dilate(edges2, kernel, iterations=1)
+ intersections = cv2.bitwise_and(dilated, edges3)
+
+ y_coords, x_coords = np.where(intersections > 0)
+ for i in range(0, len(y_coords), 15):
+ y, x = y_coords[i], x_coords[i]
+ cv2.circle(circuit, (x, y), 3, 0, 1)
+ cv2.circle(circuit, (x, y), 2, 255, -1)
+
+ grid_spacing = 20
+ for i in range(0, height, grid_spacing):
+ darkness = 255 - gray[i, :].mean()
+ if darkness > 50:
+ cv2.line(circuit, (0, i), (width, i), 230, 1)
+
+ for j in range(0, width, grid_spacing):
+ darkness = 255 - gray[:, j].mean()
+ if darkness > 50:
+ cv2.line(circuit, (j, 0), (j, height), 230, 1)
+
+ spacing = 25
+ for i in range(spacing, height - spacing, spacing):
+ for j in range(spacing, width - spacing, spacing):
+ region_dark = 255 - gray[i-10:i+10, j-10:j+10].mean()
+
+ if region_dark > 80:
+ pattern = np.random.randint(0, 4)
+ if pattern == 0:
+ cv2.line(circuit, (j-8, i), (j+8, i), 120, 1)
+ elif pattern == 1:
+ cv2.line(circuit, (j, i-8), (j, i+8), 120, 1)
+ elif pattern == 2:
+ cv2.line(circuit, (j-6, i), (j+6, i), 140, 1)
+ cv2.line(circuit, (j, i-6), (j, i+6), 140, 1)
+ elif pattern == 3:
+ cv2.rectangle(circuit, (j-4, i-4), (j+4, i+4), 100, 1)
+
+ return circuit
+
+
+def apply_glitch_art(gray):
+ """Digital glitch / data corruption aesthetic"""
+ height, width = gray.shape
+ result = gray.copy()
+
+ result = cv2.convertScaleAbs(result, alpha=2.0, beta=-100)
+ _, thresh = cv2.threshold(result, 127, 255, cv2.THRESH_BINARY)
+
+ for i in range(0, height, np.random.randint(3, 8)):
+ if np.random.random() > 0.7:
+ shift = np.random.randint(-20, 20)
+ if shift > 0:
+ thresh[i, shift:] = thresh[i, :-shift]
+ elif shift < 0:
+ thresh[i, :shift] = thresh[i, -shift:]
+
+ num_blocks = np.random.randint(5, 15)
+ for _ in range(num_blocks):
+ x = np.random.randint(0, width - 40)
+ y = np.random.randint(0, height - 30)
+ block_width = np.random.randint(20, 60)
+ block_height = np.random.randint(3, 20)
+
+ if y + block_height < height and x + block_width < width:
+ block = thresh[y:y+block_height, x:x+block_width].copy()
+ shift_y = np.random.randint(-10, 10)
+ new_y = np.clip(y + shift_y, 0, height - block_height)
+ thresh[new_y:new_y+block_height, x:x+block_width] = block
+
+ for i in range(0, height, np.random.randint(15, 40)):
+ if np.random.random() > 0.5:
+ row = thresh[i, :].copy()
+ if np.random.random() > 0.5:
+ sorted_row = np.sort(row)
+ else:
+ sorted_row = np.sort(row)[::-1]
+
+ start = np.random.randint(0, width // 3)
+ end = np.random.randint(2 * width // 3, width)
+ thresh[i, start:end] = sorted_row[start:end]
+
+ num_strips = np.random.randint(3, 8)
+ for _ in range(num_strips):
+ y = np.random.randint(0, height)
+ strip_height = np.random.randint(1, 5)
+ if y + strip_height < height:
+ noise = np.random.choice([0, 255], size=width)
+ thresh[y:y+strip_height, :] = noise
+
+ edges = cv2.Canny(gray, 50, 150)
+ edges_inv = 255 - edges
+
+ glitched = cv2.bitwise_and(thresh, edges_inv)
+
+ for i in range(0, height, 3):
+ if np.random.random() > 0.95:
+ glitched[i, :] = glitched[i, :] // 2
+
+ return glitched
+
+
+def apply_wireframe_topo(gray):
+ """Topographic contour lines / wireframe hologram effect"""
+ height, width = gray.shape
+ result = np.ones_like(gray) * 255
+
+ num_levels = 20
+ for level in range(num_levels):
+ threshold = int((level / num_levels) * 255)
+
+ _, thresh = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
+ contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+
+ if level % 3 == 0:
+ intensity = 0
+ thickness = 2
+ else:
+ intensity = 120
+ thickness = 1
+
+ cv2.drawContours(result, contours, -1, intensity, thickness)
+
+ scan_spacing = 4
+ for i in range(0, height, scan_spacing):
+ darkness = 255 - gray[i, :].mean()
+ if darkness > 30:
+ line_intensity = int(200 + (darkness / 255) * 55)
+ cv2.line(result, (0, i), (width, i), line_intensity, 1)
+
+ grid_spacing = 40
+ for i in range(0, height, grid_spacing):
+ cv2.line(result, (0, i), (width, i), 180, 1)
+
+ for j in range(0, width, grid_spacing):
+ cv2.line(result, (j, 0), (j, height), 180, 1)
+
+ edges = cv2.Canny(gray, 50, 150)
+ edges = cv2.dilate(edges, np.ones((2, 2), np.uint8), iterations=1)
+
+ result[edges > 0] = 0
+
+ marker_size = 20
+ cv2.line(result, (0, 0), (marker_size, 0), 0, 2)
+ cv2.line(result, (0, 0), (0, marker_size), 0, 2)
+ cv2.line(result, (width-marker_size, 0), (width, 0), 0, 2)
+ cv2.line(result, (width-1, 0), (width-1, marker_size), 0, 2)
+ cv2.line(result, (0, height-marker_size), (0, height), 0, 2)
+ cv2.line(result, (0, height-1), (marker_size, height-1), 0, 2)
+ cv2.line(result, (width-marker_size, height-1), (width, height-1), 0, 2)
+ cv2.line(result, (width-1, height-marker_size), (width-1, height), 0, 2)
+
+ return result
+
+
+def apply_data_mosaic(gray):
+ """Geometric fragmentation / data mosaic with angular shapes"""
+ height, width = gray.shape
+ result = np.ones_like(gray) * 255
+
+ num_cells = 200
+ points = []
+ values = []
+
+ for _ in range(num_cells):
+ x = np.random.randint(0, width)
+ y = np.random.randint(0, height)
+ points.append((x, y))
+ values.append(gray[y, x])
+
+ for i in range(height):
+ for j in range(width):
+ min_dist = float('inf')
+ nearest_value = 0
+
+ for idx, (px, py) in enumerate(points):
+ dist = (i - py) ** 2 + (j - px) ** 2
+ if dist < min_dist:
+ min_dist = dist
+ nearest_value = values[idx]
+
+ result[i, j] = nearest_value
+
+ result = (result // 40) * 40
+
+ edges = cv2.Canny(result, 30, 90)
+ edges = cv2.dilate(edges, np.ones((2, 2), np.uint8), iterations=1)
+ result[edges > 0] = 0
+
+ pattern_spacing = 60
+ for i in range(pattern_spacing // 2, height, pattern_spacing):
+ for j in range(pattern_spacing // 2, width, pattern_spacing):
+ cell_value = result[i, j]
+
+ if cell_value < 100:
+ pattern = np.random.randint(0, 5)
+ size = 15
+
+ if pattern == 0:
+ pts = np.array([[j, i - size], [j + size, i], [j, i + size], [j - size, i]], np.int32)
+ cv2.polylines(result, [pts], True, 200, 1)
+ elif pattern == 1:
+ pts = np.array([[j, i - size], [j + size, i + size], [j - size, i + size]], np.int32)
+ cv2.polylines(result, [pts], True, 200, 1)
+ elif pattern == 2:
+ cv2.rectangle(result, (j - size, i - size), (j + size, i + size), 200, 1)
+ elif pattern == 3:
+ cv2.line(result, (j - size, i), (j + size, i), 200, 1)
+ cv2.line(result, (j, i - size), (j, i + size), 200, 1)
+ elif pattern == 4:
+ cv2.circle(result, (j, i), size, 200, 1)
+
+ grid_spacing = 30
+ for i in range(0, height, grid_spacing):
+ if np.random.random() > 0.5:
+ cv2.line(result, (0, i), (width, i), 220, 1)
+
+ for j in range(0, width, grid_spacing):
+ if np.random.random() > 0.5:
+ cv2.line(result, (j, 0), (j, height), 220, 1)
+
+ return result
+
+
+def apply_holographic_scan(gray):
+ """Holographic display with scanlines and interference patterns"""
+ height, width = gray.shape
+ result = gray.copy()
+
+ result = cv2.convertScaleAbs(result, alpha=1.8, beta=-50)
+
+ scanline_pattern = np.ones_like(result)
+ for i in range(height):
+ if i % 2 == 0:
+ scanline_pattern[i, :] = 0.9
+ else:
+ scanline_pattern[i, :] = 1.0
+
+ if i % 8 == 0:
+ scanline_pattern[i, :] = 1.2
+
+ result = (result.astype(np.float32) * scanline_pattern).astype(np.uint8)
+ result = np.clip(result, 0, 255)
+
+ interference = np.zeros_like(result, dtype=np.float32)
+ for j in range(width):
+ wave = np.sin(j * 0.1) * 0.15 + 1.0
+ interference[:, j] = wave
+
+ result = (result.astype(np.float32) * interference).astype(np.uint8)
+ result = np.clip(result, 0, 255)
+
+ edges = cv2.Canny(gray, 100, 200)
+ edges = cv2.dilate(edges, np.ones((2, 2), np.uint8), iterations=1)
+
+ result[edges > 0] = 255
+
+ num_glitches = 15
+ for _ in range(num_glitches):
+ y = np.random.randint(0, height)
+ thickness = np.random.randint(1, 4)
+
+ if y + thickness < height:
+ shift = np.random.randint(-30, 30)
+ if shift > 0 and shift < width:
+ result[y:y+thickness, shift:] = result[y:y+thickness, :-shift]
+ result[y:y+thickness, :shift] = 255
+ elif shift < 0 and abs(shift) < width:
+ result[y:y+thickness, :shift] = result[y:y+thickness, -shift:]
+ result[y:y+thickness, shift:] = 255
+
+ center_x, center_y = width // 2, height // 2
+ vignette = np.zeros_like(result, dtype=np.float32)
+
+ for i in range(height):
+ for j in range(width):
+ dist = np.sqrt((i - center_y) ** 2 + (j - center_x) ** 2)
+ max_dist = np.sqrt(center_y ** 2 + center_x ** 2)
+ factor = 1.0 + (1 - dist / max_dist) * 0.3
+ vignette[i, j] = factor
+
+ result = (result.astype(np.float32) * vignette).astype(np.uint8)
+ result = np.clip(result, 0, 255)
+
+ bracket_size = 30
+ bracket_thickness = 2
+
+ cv2.line(result, (10, 10), (10 + bracket_size, 10), 255, bracket_thickness)
+ cv2.line(result, (10, 10), (10, 10 + bracket_size), 255, bracket_thickness)
+ cv2.line(result, (width - 10 - bracket_size, 10), (width - 10, 10), 255, bracket_thickness)
+ cv2.line(result, (width - 10, 10), (width - 10, 10 + bracket_size), 255, bracket_thickness)
+ cv2.line(result, (10, height - 10), (10 + bracket_size, height - 10), 255, bracket_thickness)
+ cv2.line(result, (10, height - 10 - bracket_size), (10, height - 10), 255, bracket_thickness)
+ cv2.line(result, (width - 10 - bracket_size, height - 10), (width - 10, height - 10), 255, bracket_thickness)
+ cv2.line(result, (width - 10, height - 10 - bracket_size), (width - 10, height - 10), 255, bracket_thickness)
+
+ crosshair_size = 15
+ cv2.line(result, (center_x - crosshair_size, center_y), (center_x + crosshair_size, center_y), 255, 1)
+ cv2.line(result, (center_x, center_y - crosshair_size), (center_x, center_y + crosshair_size), 255, 1)
+ cv2.circle(result, (center_x, center_y), crosshair_size, 255, 1)
+
+ return result
+
+
+# Style function mapping
+STYLE_FUNCTIONS = {
+ # Sketch styles
+ 'pencil_sketch': apply_pencil_sketch,
+ 'ink_drawing': apply_ink_drawing,
+ 'charcoal_pastel': apply_charcoal_pastel,
+ 'conte_crayon': apply_conte_crayon,
+ 'gesture_sketch': apply_gesture_sketch,
+ # Futuristic styles
+ 'circuit_board': apply_circuit_board,
+ 'glitch_art': apply_glitch_art,
+ 'wireframe_topo': apply_wireframe_topo,
+ 'data_mosaic': apply_data_mosaic,
+ 'holographic_scan': apply_holographic_scan,
+}
+
+
+def process_images(input_dir, output_dir, style):
+ """Process all images from input_dir and save to output_dir with same names"""
+
+ if style not in STYLE_FUNCTIONS:
+ print(f"Error: Unknown style '{style}'")
+ print(f"\nAvailable styles:")
+ print("\nSketch Styles:")
+ print(" - pencil_sketch")
+ print(" - ink_drawing")
+ print(" - charcoal_pastel")
+ print(" - conte_crayon")
+ print(" - gesture_sketch")
+ print("\nFuturistic Styles:")
+ print(" - circuit_board")
+ print(" - glitch_art")
+ print(" - wireframe_topo")
+ print(" - data_mosaic")
+ print(" - holographic_scan")
+ return
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ image_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ image_files = []
+ for pattern in image_patterns:
+ image_files.extend(glob.glob(os.path.join(input_dir, pattern)))
+
+ if not image_files:
+ print(f"No image files found in {input_dir}")
+ return
+
+ print(f"Found {len(image_files)} images to process")
+ print(f"Style: {style}")
+ print(f"Output directory: {output_dir}")
+ print("-" * 50)
+
+ style_func = STYLE_FUNCTIONS[style]
+
+ for img_path in image_files:
+ try:
+ img = cv2.imread(img_path)
+ if img is None:
+ print(f"Warning: Could not read {img_path}, skipping...")
+ continue
+
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+
+ result = style_func(gray)
+
+ filename = os.path.basename(img_path)
+ output_path = os.path.join(output_dir, filename)
+
+ cv2.imwrite(output_path, result)
+ print(f"✓ Processed: {filename}")
+
+ except Exception as e:
+ print(f"✗ Error processing {img_path}: {str(e)}")
+
+ print("-" * 50)
+ print(f"Processing complete! Results saved to {output_dir}")
+
+
+def main():
+ if len(sys.argv) != 4:
+ print(__doc__)
+ sys.exit(1)
+
+ input_dir = sys.argv[1]
+ output_dir = sys.argv[2]
+ style = sys.argv[3]
+
+ if not os.path.isdir(input_dir):
+ print(f"Error: Input directory '{input_dir}' does not exist")
+ sys.exit(1)
+
+ process_images(input_dir, output_dir, style)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/training/input/img_000.png b/training/input/img_000.png
new file mode 100644
index 0000000..4ed5bfb
--- /dev/null
+++ b/training/input/img_000.png
Binary files differ
diff --git a/training/input/img_001.png b/training/input/img_001.png
new file mode 100644
index 0000000..2c34b47
--- /dev/null
+++ b/training/input/img_001.png
Binary files differ
diff --git a/training/input/img_002.png b/training/input/img_002.png
new file mode 100644
index 0000000..b79c8cb
--- /dev/null
+++ b/training/input/img_002.png
Binary files differ
diff --git a/training/input/img_003.png b/training/input/img_003.png
new file mode 100644
index 0000000..4365365
--- /dev/null
+++ b/training/input/img_003.png
Binary files differ
diff --git a/training/input/img_004.png b/training/input/img_004.png
new file mode 100644
index 0000000..3e71a37
--- /dev/null
+++ b/training/input/img_004.png
Binary files differ
diff --git a/training/input/img_005.png b/training/input/img_005.png
new file mode 100644
index 0000000..624061c
--- /dev/null
+++ b/training/input/img_005.png
Binary files differ
diff --git a/training/input/img_006.png b/training/input/img_006.png
new file mode 100644
index 0000000..24592a3
--- /dev/null
+++ b/training/input/img_006.png
Binary files differ
diff --git a/training/input/img_007.png b/training/input/img_007.png
new file mode 100644
index 0000000..1ed661e
--- /dev/null
+++ b/training/input/img_007.png
Binary files differ
diff --git a/training/output/img_000.png b/training/output/img_000.png
new file mode 100644
index 0000000..6da3aaf
--- /dev/null
+++ b/training/output/img_000.png
Binary files differ
diff --git a/training/output/img_001.png b/training/output/img_001.png
new file mode 100644
index 0000000..3334699
--- /dev/null
+++ b/training/output/img_001.png
Binary files differ
diff --git a/training/output/img_002.png b/training/output/img_002.png
new file mode 100644
index 0000000..a2582fd
--- /dev/null
+++ b/training/output/img_002.png
Binary files differ
diff --git a/training/output/img_003.png b/training/output/img_003.png
new file mode 100644
index 0000000..27b829a
--- /dev/null
+++ b/training/output/img_003.png
Binary files differ
diff --git a/training/output/img_004.png b/training/output/img_004.png
new file mode 100644
index 0000000..b90582c
--- /dev/null
+++ b/training/output/img_004.png
Binary files differ
diff --git a/training/output/img_005.png b/training/output/img_005.png
new file mode 100644
index 0000000..edcac90
--- /dev/null
+++ b/training/output/img_005.png
Binary files differ
diff --git a/training/output/img_006.png b/training/output/img_006.png
new file mode 100644
index 0000000..002230c
--- /dev/null
+++ b/training/output/img_006.png
Binary files differ
diff --git a/training/output/img_007.png b/training/output/img_007.png
new file mode 100644
index 0000000..5e79250
--- /dev/null
+++ b/training/output/img_007.png
Binary files differ
diff --git a/training/train_cnn.py b/training/train_cnn.py
new file mode 100755
index 0000000..4fc3a6c
--- /dev/null
+++ b/training/train_cnn.py
@@ -0,0 +1,301 @@
+#!/usr/bin/env python3
+"""
+CNN Training Script for Image-to-Image Transformation
+
+Trains a convolutional neural network on multiple input/target image pairs.
+
+Usage:
+ python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
+
+Example:
+ python3 train_cnn.py --input ./training/input --target ./training/output --layers 3 --epochs 100
+"""
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+from PIL import Image
+import os
+import sys
+import argparse
+import glob
+
+
+class ImagePairDataset(Dataset):
+ """Dataset for loading matching input/target image pairs"""
+
+ def __init__(self, input_dir, target_dir, transform=None):
+ self.input_dir = input_dir
+ self.target_dir = target_dir
+ self.transform = transform
+
+ # Find all images in input directory
+ input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ self.image_pairs = []
+
+ for pattern in input_patterns:
+ input_files = glob.glob(os.path.join(input_dir, pattern))
+ for input_path in input_files:
+ filename = os.path.basename(input_path)
+ # Try to find matching target with same name but any supported extension
+ target_path = None
+ for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
+ base_name = os.path.splitext(filename)[0]
+ candidate = os.path.join(target_dir, f"{base_name}.{ext}")
+ if os.path.exists(candidate):
+ target_path = candidate
+ break
+
+ if target_path:
+ self.image_pairs.append((input_path, target_path))
+
+ if not self.image_pairs:
+ raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
+
+ print(f"Found {len(self.image_pairs)} matching image pairs")
+
+ def __len__(self):
+ return len(self.image_pairs)
+
+ def __getitem__(self, idx):
+ input_path, target_path = self.image_pairs[idx]
+
+ input_img = Image.open(input_path).convert('RGB')
+ target_img = Image.open(target_path).convert('RGB')
+
+ if self.transform:
+ input_img = self.transform(input_img)
+ target_img = self.transform(target_img)
+
+ return input_img, target_img
+
+
+class CoordConv2d(nn.Module):
+ """Conv2d that accepts coordinate input separate from spatial patches"""
+
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
+ super().__init__()
+ self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
+ self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01)
+ self.bias = nn.Parameter(torch.zeros(out_channels))
+
+ def forward(self, x, coords):
+ # x: [B, C, H, W] image
+ # coords: [B, 2, H, W] coordinate grid
+ out = self.conv_rgba(x)
+ B, C, H, W = out.shape
+ coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights)
+ out = out + coord_contrib + self.bias.view(1, -1, 1, 1)
+ return out
+
+
+class SimpleCNN(nn.Module):
+ """Simple CNN for image-to-image transformation"""
+
+ def __init__(self, num_layers=1, kernel_sizes=None):
+ super(SimpleCNN, self).__init__()
+
+ if kernel_sizes is None:
+ kernel_sizes = [3] * num_layers
+
+ assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
+
+ self.kernel_sizes = kernel_sizes
+ self.layers = nn.ModuleList()
+
+ for i, kernel_size in enumerate(kernel_sizes):
+ padding = kernel_size // 2
+ if i == 0:
+ self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding))
+ else:
+ self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True))
+
+ self.use_residual = True
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
+ x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
+ coords = torch.cat([x_coords, y_coords], dim=1)
+
+ out = self.layers[0](x, coords)
+ out = torch.tanh(out)
+
+ for i in range(1, len(self.layers)):
+ out = self.layers[i](out)
+ if i < len(self.layers) - 1:
+ out = torch.tanh(out)
+
+ if self.use_residual:
+ out = x + out * 0.3
+ return out
+
+
+def export_weights_to_wgsl(model, output_path, kernel_sizes):
+ """Export trained weights to WGSL format"""
+
+ with open(output_path, 'w') as f:
+ f.write("// Auto-generated CNN weights\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+
+ layer_idx = 0
+ for i, layer in enumerate(model.layers):
+ if isinstance(layer, CoordConv2d):
+ # Export RGBA weights
+ weights = layer.conv_rgba.weight.data.cpu().numpy()
+ kernel_size = kernel_sizes[layer_idx]
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
+
+ f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
+ for pos in range(num_positions):
+ row = pos // kw
+ col = pos % kw
+ f.write(" mat4x4<f32>(\n")
+ for out_c in range(min(4, out_ch)):
+ vals = []
+ for in_c in range(min(4, in_ch)):
+ vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ f.write(f" {', '.join(vals)},\n")
+ f.write(" )")
+ if pos < num_positions - 1:
+ f.write(",\n")
+ else:
+ f.write("\n")
+ f.write(");\n\n")
+
+ # Export coordinate weights
+ coord_w = layer.coord_weights.data.cpu().numpy()
+ f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n")
+ for c in range(2):
+ vals = [f"{coord_w[out_c, c]:.6f}" for out_c in range(min(4, coord_w.shape[0]))]
+ f.write(f" {', '.join(vals)}")
+ if c < 1:
+ f.write(",\n")
+ else:
+ f.write("\n")
+ f.write(");\n\n")
+
+ # Export bias
+ bias = layer.bias.data.cpu().numpy()
+ f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
+ f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
+ f.write(");\n\n")
+
+ layer_idx += 1
+ elif isinstance(layer, nn.Conv2d):
+ # Standard conv layer
+ weights = layer.weight.data.cpu().numpy()
+ kernel_size = kernel_sizes[layer_idx]
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
+
+ f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
+ for pos in range(num_positions):
+ row = pos // kw
+ col = pos % kw
+ f.write(" mat4x4<f32>(\n")
+ for out_c in range(min(4, out_ch)):
+ vals = []
+ for in_c in range(min(4, in_ch)):
+ vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
+ f.write(f" {', '.join(vals)},\n")
+ f.write(" )")
+ if pos < num_positions - 1:
+ f.write(",\n")
+ else:
+ f.write("\n")
+ f.write(");\n\n")
+
+ # Export bias
+ bias = layer.bias.data.cpu().numpy()
+ f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
+ f.write(", ".join([f"{b:.6f}" for b in bias[:4]]))
+ f.write(");\n\n")
+
+ layer_idx += 1
+
+
+def train(args):
+ """Main training loop"""
+
+ # Setup device
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # Prepare dataset
+ transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ ])
+
+ dataset = ImagePairDataset(args.input, args.target, transform=transform)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+
+ # Parse kernel sizes
+ kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
+
+ # Create model
+ model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device)
+
+ # Loss and optimizer
+ criterion = nn.MSELoss()
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ # Training loop
+ print(f"\nTraining for {args.epochs} epochs...")
+ for epoch in range(args.epochs):
+ epoch_loss = 0.0
+ for batch_idx, (inputs, targets) in enumerate(dataloader):
+ inputs, targets = inputs.to(device), targets.to(device)
+
+ optimizer.zero_grad()
+ outputs = model(inputs)
+ loss = criterion(outputs, targets)
+ loss.backward()
+ optimizer.step()
+
+ epoch_loss += loss.item()
+
+ avg_loss = epoch_loss / len(dataloader)
+ if (epoch + 1) % 10 == 0:
+ print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
+
+ # Export weights
+ output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"\nExporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ print("Training complete!")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
+ parser.add_argument('--input', required=True, help='Input image directory')
+ parser.add_argument('--target', required=True, help='Target image directory')
+ parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
+ parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
+ parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
+ parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
+ parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)')
+
+ args = parser.parse_args()
+
+ # Validate directories
+ if not os.path.isdir(args.input):
+ print(f"Error: Input directory '{args.input}' does not exist")
+ sys.exit(1)
+
+ if not os.path.isdir(args.target):
+ print(f"Error: Target directory '{args.target}' does not exist")
+ sys.exit(1)
+
+ train(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
index 06ca73a..168c9e2 100644
--- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
@@ -24,3 +24,30 @@ fn cnn_conv3x3(
return sum;
}
+
+fn cnn_conv3x3_with_coord(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ rgba_weights: array<mat4x4<f32>, 9>,
+ coord_weights: mat2x4<f32>,
+ bias: vec4<f32>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ var sum = bias;
+
+ sum += coord_weights * uv;
+
+ var idx = 0;
+ for (var dy = -1; dy <= 1; dy++) {
+ for (var dx = -1; dx <= 1; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgba = textureSample(tex, samp, uv + offset);
+ sum += rgba_weights[idx] * rgba;
+ idx++;
+ }
+ }
+
+ return sum;
+}
diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
index 3d4a03a..bd9abfa 100644
--- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
@@ -24,3 +24,30 @@ fn cnn_conv5x5(
return sum;
}
+
+fn cnn_conv5x5_with_coord(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ rgba_weights: array<mat4x4<f32>, 25>,
+ coord_weights: mat2x4<f32>,
+ bias: vec4<f32>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ var sum = bias;
+
+ sum += coord_weights * uv;
+
+ var idx = 0;
+ for (var dy = -2; dy <= 2; dy++) {
+ for (var dx = -2; dx <= 2; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgba = textureSample(tex, samp, uv + offset);
+ sum += rgba_weights[idx] * rgba;
+ idx++;
+ }
+ }
+
+ return sum;
+}
diff --git a/workspaces/main/shaders/cnn/cnn_conv7x7.wgsl b/workspaces/main/shaders/cnn/cnn_conv7x7.wgsl
index ba28d64..e68d644 100644
--- a/workspaces/main/shaders/cnn/cnn_conv7x7.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv7x7.wgsl
@@ -24,3 +24,30 @@ fn cnn_conv7x7(
return sum;
}
+
+fn cnn_conv7x7_with_coord(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ rgba_weights: array<mat4x4<f32>, 49>,
+ coord_weights: mat2x4<f32>,
+ bias: vec4<f32>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ var sum = bias;
+
+ sum += coord_weights * uv;
+
+ var idx = 0;
+ for (var dy = -3; dy <= 3; dy++) {
+ for (var dx = -3; dx <= 3; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgba = textureSample(tex, samp, uv + offset);
+ sum += rgba_weights[idx] * rgba;
+ idx++;
+ }
+ }
+
+ return sum;
+}
diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl
index e026ce8..b2bab26 100644
--- a/workspaces/main/shaders/cnn/cnn_layer.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl
@@ -29,10 +29,10 @@ struct CNNLayerParams {
let uv = p.xy / uniforms.resolution;
var result = vec4<f32>(0.0);
- // Single layer for now (layer 0)
+ // Layer 0 uses coordinate-aware convolution
if (params.layer_index == 0) {
- result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution,
- weights_layer0, bias_layer0);
+ result = cnn_conv3x3_with_coord(txt, smplr, uv, uniforms.resolution,
+ rgba_weights_layer0, coord_weights_layer0, bias_layer0);
result = cnn_tanh(result);
}
diff --git a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
index 98c17ff..e0a7dc4 100644
--- a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
@@ -2,8 +2,8 @@
// DO NOT EDIT MANUALLY - regenerate with scripts/train_cnn.py
// Placeholder identity-like weights for initial testing
-// Layer 0: 3x3 convolution
-const weights_layer0: array<mat4x4<f32>, 9> = array(
+// Layer 0: 3x3 convolution with coordinate awareness
+const rgba_weights_layer0: array<mat4x4<f32>, 9> = array(
mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
@@ -14,4 +14,10 @@ const weights_layer0: array<mat4x4<f32>, 9> = array(
mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
);
+
+const coord_weights_layer0 = mat2x4<f32>(
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0
+);
+
const bias_layer0 = vec4<f32>(0.0, 0.0, 0.0, 0.0);