summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 16:44:39 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 16:44:39 +0100
commit61104d5b9e1774c11f0dba3b6d6018dabc2bce8f (patch)
tree882e642721984cc921cbe5678fe7905721a2ad40
parent3942653de11542acc4892470243a8a6bf8d5c4f7 (diff)
feat: CNN RGBD→grayscale with 7-channel augmented input
Upgrade CNN architecture to process RGBD input, output grayscale, with 7-channel layer inputs (RGBD + UV coords + grayscale). Architecture changes: - Inner layers: Conv2d(7→4) output RGBD - Final layer: Conv2d(7→1) output grayscale - All inputs normalized to [-1,1] for tanh activation - Removed CoordConv2d in favor of unified 7-channel input Training (train_cnn.py): - SimpleCNN: 7→4 (inner), 7→1 (final) architecture - Forward: Normalize RGBD/coords/gray to [-1,1] - Weight export: array<array<f32, 8>, 36> (inner), array<f32, 8>, 9> (final) - Dataset: Load RGBA (RGBD) input Shaders (cnn_conv3x3.wgsl): - Added cnn_conv3x3_7to4: 7-channel input → RGBD output - Added cnn_conv3x3_7to1: 7-channel input → grayscale output - Both normalize inputs and use flattened weight arrays Documentation: - CNN_EFFECT.md: Updated architecture, training, weight format - CNN_RGBD_GRAYSCALE_SUMMARY.md: Implementation summary - HOWTO.md: Added training command example Next: Train with RGBD input data Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rw-r--r--doc/CNN_EFFECT.md75
-rw-r--r--doc/CNN_RGBD_GRAYSCALE_SUMMARY.md134
-rw-r--r--doc/HOWTO.md8
-rwxr-xr-xtraining/train_cnn.py210
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv3x3.wgsl100
5 files changed, 373 insertions, 154 deletions
diff --git a/doc/CNN_EFFECT.md b/doc/CNN_EFFECT.md
index ae0f38a..b7d157f 100644
--- a/doc/CNN_EFFECT.md
+++ b/doc/CNN_EFFECT.md
@@ -21,27 +21,44 @@ Trainable convolutional neural network layers for artistic stylization (painterl
## Architecture
-### Coordinate-Aware Layer 0
+### RGBD → Grayscale Pipeline
-Layer 0 accepts normalized (x,y) patch center coordinates alongside RGBA samples:
+**Input:** RGBD (RGB + inverse depth D=1/z)
+**Output:** Grayscale (1 channel)
+**Layer Input:** 7 channels = [RGBD, UV coords, grayscale] all normalized to [-1,1]
+
+**Architecture:**
+- **Inner layers (0..N-2):** Conv2d(7→4) - output RGBD
+- **Final layer (N-1):** Conv2d(7→1) - output grayscale
```wgsl
-fn cnn_conv3x3_with_coord(
+// Inner layers: 7→4 (RGBD output)
+fn cnn_conv3x3_7to4(
tex: texture_2d<f32>,
samp: sampler,
- uv: vec2<f32>, # Center position [0,1]
+ uv: vec2<f32>,
resolution: vec2<f32>,
- rgba_weights: array<mat4x4<f32>, 9>, # 9 samples × 4×4 matrix
- coord_weights: mat2x4<f32>, # 2 coords → 4 outputs
- bias: vec4<f32>
+ original: vec4<f32>, # Original RGBD [0,1]
+ weights: array<array<f32, 8>, 36> # 9 pos × 4 out × (7 weights + bias)
) -> vec4<f32>
-```
-**Input structure:** 9 RGBA samples (36 values) + 1 xy coordinate (2 values) = 38 inputs → 4 outputs
+// Final layer: 7→1 (grayscale output)
+fn cnn_conv3x3_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ original: vec4<f32>,
+ weights: array<array<f32, 8>, 9> # 9 pos × (7 weights + bias)
+) -> f32
+```
-**Size impact:** +32B coord weights, kernel-agnostic
+**Input normalization (all to [-1,1]):**
+- RGBD: `(rgbd - 0.5) * 2`
+- UV coords: `(uv - 0.5) * 2`
+- Grayscale: `(0.2126*R + 0.7152*G + 0.0722*B - 0.5) * 2`
-**Use cases:** Position-dependent stylization (vignettes, corner darkening, radial gradients)
+**Activation:** tanh for inner layers, none for final layer
### Multi-Layer Architecture
@@ -80,18 +97,15 @@ workspaces/main/shaders/cnn/
### 1. Prepare Training Data
Collect input/target image pairs:
-- **Input:** Raw 3D render
-- **Target:** Artistic style (hand-painted, filtered, stylized)
+- **Input:** RGBA (RGB + depth as alpha channel, D=1/z)
+- **Target:** Grayscale stylized output
```bash
-training/input/img_000.png # Raw render
-training/output/img_000.png # Stylized target
+training/input/img_000.png # RGBA render (RGB + depth)
+training/output/img_000.png # Grayscale target
```
-Use `image_style_processor.py` to generate targets:
-```bash
-python3 training/image_style_processor.py input/ output/ pencil_sketch
-```
+**Note:** Input images must be RGBA where alpha = inverse depth (1/z)
### 2. Train Network
@@ -245,20 +259,25 @@ Expands to:
**Weight Storage:**
-**Layer 0 (coordinate-aware):**
+**Inner layers (7→4 RGBD output):**
```wgsl
-const rgba_weights_layer0: array<mat4x4<f32>, 9> = array(...);
-const coord_weights_layer0 = mat2x4<f32>(
- 0.1, -0.2, 0.0, 0.0, # x-coord weights
- -0.1, 0.0, 0.2, 0.0 # y-coord weights
+// Structure: array<array<f32, 8>, 36>
+// 9 positions × 4 output channels, each with 7 weights + bias
+const weights_layer0: array<array<f32, 8>, 36> = array(
+ array<f32, 8>(w0_r, w0_g, w0_b, w0_d, w0_u, w0_v, w0_gray, bias0), // pos0_ch0
+ array<f32, 8>(w1_r, w1_g, w1_b, w1_d, w1_u, w1_v, w1_gray, bias1), // pos0_ch1
+ // ... 34 more entries
);
-const bias_layer0 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
```
-**Layers 1+ (standard):**
+**Final layer (7→1 grayscale output):**
```wgsl
-const weights_layer1: array<mat4x4<f32>, 9> = array(...);
-const bias_layer1 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+// Structure: array<array<f32, 8>, 9>
+// 9 positions, each with 7 weights + bias
+const weights_layerN: array<array<f32, 8>, 9> = array(
+ array<f32, 8>(w0_r, w0_g, w0_b, w0_d, w0_u, w0_v, w0_gray, bias0), // pos0
+ // ... 8 more entries
+);
```
---
diff --git a/doc/CNN_RGBD_GRAYSCALE_SUMMARY.md b/doc/CNN_RGBD_GRAYSCALE_SUMMARY.md
new file mode 100644
index 0000000..4c13693
--- /dev/null
+++ b/doc/CNN_RGBD_GRAYSCALE_SUMMARY.md
@@ -0,0 +1,134 @@
+# CNN RGBD→Grayscale Architecture Implementation
+
+## Summary
+
+Implemented CNN architecture upgrade: RGBD input → grayscale output with 7-channel augmented input.
+
+## Changes Made
+
+### Architecture
+
+**Input:** RGBD (4 channels: RGB + inverse depth D=1/z)
+**Output:** Grayscale (1 channel)
+**Layer Input:** 7 channels = [RGBD, UV coords, grayscale] all normalized to [-1,1]
+
+**Layer Configuration:**
+- Inner layers (0..N-2): Conv2d(7→4) - output RGBD with tanh activation
+- Final layer (N-1): Conv2d(7→1) - output grayscale, no activation
+
+### Input Normalization (all to [-1,1])
+
+- **RGBD:** `(rgbd - 0.5) * 2`
+- **UV coords:** `(uv - 0.5) * 2`
+- **Grayscale:** `(0.2126*R + 0.7152*G + 0.0722*B - 0.5) * 2`
+
+**Rationale:** Zero-centered inputs for tanh activation, better gradient flow.
+
+### Modified Files
+
+**Training (`/Users/skal/demo/training/train_cnn.py`):**
+1. Removed `CoordConv2d` class
+2. Updated `SimpleCNN`:
+ - Inner layers: `Conv2d(7, 4)` - RGBD output
+ - Final layer: `Conv2d(7, 1)` - grayscale output
+3. Updated `forward()`:
+ - Normalize RGBD/coords/gray to [-1,1]
+ - Concatenate 7-channel input for each layer
+ - Apply tanh (inner) or none (final)
+ - Denormalize final output
+4. Updated `export_weights_to_wgsl()`:
+ - Inner: `array<array<f32, 8>, 36>` (9 pos × 4 ch × 8 values)
+ - Final: `array<array<f32, 8>, 9>` (9 pos × 8 values)
+5. Updated `generate_layer_shader()`:
+ - Use `cnn_conv3x3_7to4` for inner layers
+ - Use `cnn_conv3x3_7to1` for final layer
+ - Denormalize outputs from [-1,1] to [0,1]
+6. Updated `ImagePairDataset`:
+ - Load RGBA input (was RGB)
+
+**Shaders (`/Users/skal/demo/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl`):**
+1. Added `cnn_conv3x3_7to4()`:
+ - 7-channel input: [RGBD, uv_x, uv_y, gray]
+ - 4-channel output: RGBD
+ - Weights: `array<array<f32, 8>, 36>`
+2. Added `cnn_conv3x3_7to1()`:
+ - 7-channel input: [RGBD, uv_x, uv_y, gray]
+ - 1-channel output: grayscale
+ - Weights: `array<array<f32, 8>, 9>`
+
+**Documentation (`/Users/skal/demo/doc/CNN_EFFECT.md`):**
+1. Updated architecture section with RGBD→grayscale pipeline
+2. Updated training data requirements (RGBA input)
+3. Updated weight storage format
+
+### No C++ Changes
+
+CNNLayerParams and bind groups remain unchanged.
+
+## Data Flow
+
+1. Layer 0 captures original RGBD to `captured_frame`
+2. Each layer:
+ - Samples previous layer output (RGBD in [0,1])
+ - Normalizes RGBD to [-1,1]
+ - Computes UV coords and grayscale, normalizes to [-1,1]
+ - Concatenates 7-channel input
+ - Applies convolution with layer-specific weights
+ - Outputs RGBD (inner) or grayscale (final) in [-1,1]
+ - Applies tanh (inner only)
+ - Denormalizes to [0,1] for texture storage
+ - Blends with original
+
+## Next Steps
+
+1. **Prepare RGBD training data:**
+ - Input: RGBA images (RGB + depth in alpha)
+ - Target: Grayscale stylized output
+
+2. **Train network:**
+ ```bash
+ python3 training/train_cnn.py \
+ --input training/input \
+ --target training/output \
+ --layers 3 \
+ --epochs 1000
+ ```
+
+3. **Verify generated shaders:**
+ - Check `cnn_weights_generated.wgsl` structure
+ - Check `cnn_layer.wgsl` uses new conv functions
+
+4. **Test in demo:**
+ ```bash
+ cmake --build build -j4
+ ./build/demo64k
+ ```
+
+## Design Rationale
+
+**Why [-1,1] normalization?**
+- Centered inputs for tanh (operates best around 0)
+- Better gradient flow
+- Standard ML practice for normalized data
+
+**Why RGBD throughout vs RGB?**
+- Depth information propagates through network
+- Enables depth-aware stylization
+- Consistent 4-channel processing
+
+**Why 7-channel input?**
+- Coordinates: position-dependent effects (vignettes)
+- Grayscale: luminance-aware processing
+- RGBD: full color+depth information
+- Enables richer feature learning
+
+## Testing Checklist
+
+- [ ] Train network with RGBD input data
+- [ ] Verify `cnn_weights_generated.wgsl` structure
+- [ ] Verify `cnn_layer.wgsl` uses `7to4`/`7to1` functions
+- [ ] Build demo without errors
+- [ ] Visual test: inner layers show RGBD evolution
+- [ ] Visual test: final layer produces grayscale
+- [ ] Visual test: blending works correctly
+- [ ] Compare quality with previous RGB→RGB architecture
diff --git a/doc/HOWTO.md b/doc/HOWTO.md
index bdc0214..2c813f7 100644
--- a/doc/HOWTO.md
+++ b/doc/HOWTO.md
@@ -86,6 +86,14 @@ make run_util_tests # Utility tests
---
+## Training
+
+```bash
+./training/train_cnn.py --layers 3 --kernel_sizes 3,5,3 --epochs 10000 --batch_size 8 --input training/input/ --target training/output/ --checkpoint-every 1000
+```
+
+---
+
## Timeline
Edit `workspaces/main/timeline.seq`:
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 1cd6579..0495c65 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -62,7 +62,8 @@ class ImagePairDataset(Dataset):
def __getitem__(self, idx):
input_path, target_path = self.image_pairs[idx]
- input_img = Image.open(input_path).convert('RGB')
+ # Load RGBD input (4 channels: RGB + Depth)
+ input_img = Image.open(input_path).convert('RGBA')
target_img = Image.open(target_path).convert('RGB')
if self.transform:
@@ -72,27 +73,8 @@ class ImagePairDataset(Dataset):
return input_img, target_img
-class CoordConv2d(nn.Module):
- """Conv2d that accepts coordinate input separate from spatial patches"""
-
- def __init__(self, in_channels, out_channels, kernel_size, padding=0):
- super().__init__()
- self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
- self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01)
- self.bias = nn.Parameter(torch.zeros(out_channels))
-
- def forward(self, x, coords):
- # x: [B, C, H, W] image
- # coords: [B, 2, H, W] coordinate grid
- out = self.conv_rgba(x)
- B, C, H, W = out.shape
- coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights)
- out = out + coord_contrib + self.bias.view(1, -1, 1, 1)
- return out
-
-
class SimpleCNN(nn.Module):
- """Simple CNN for image-to-image transformation"""
+ """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)"""
def __init__(self, num_layers=1, kernel_sizes=None):
super(SimpleCNN, self).__init__()
@@ -107,26 +89,48 @@ class SimpleCNN(nn.Module):
for i, kernel_size in enumerate(kernel_sizes):
padding = kernel_size // 2
- if i == 0:
- self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding))
+ if i < num_layers - 1:
+ # Inner layers: 7→4 (RGBD output)
+ self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True))
else:
- self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True))
+ # Final layer: 7→1 (grayscale output)
+ self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
def forward(self, x):
+ # x: [B,4,H,W] - RGBD input (D = 1/z)
B, C, H, W = x.shape
+
+ # Normalize RGBD to [-1,1]
+ x_norm = (x - 0.5) * 2.0
+
+ # Compute coordinates [0,1] then normalize to [-1,1]
y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
- coords = torch.cat([x_coords, y_coords], dim=1)
+ y_coords = (y_coords - 0.5) * 2.0 # [-1,1]
+ x_coords = (x_coords - 0.5) * 2.0 # [-1,1]
- out = self.layers[0](x, coords)
- out = torch.tanh(out)
+ # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
+ gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]
+ gray = (gray - 0.5) * 2.0 # [-1,1]
- for i in range(1, len(self.layers)):
- out = self.layers[i](out)
- if i < len(self.layers) - 1:
- out = torch.tanh(out)
+ # Layer 0
+ layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
+ out = self.layers[0](layer0_input) # [B,4,H,W]
+ out = torch.tanh(out) # [-1,1]
- return out
+ # Inner layers
+ for i in range(1, len(self.layers)-1):
+ layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[i](layer_input)
+ out = torch.tanh(out)
+
+ # Final layer (grayscale output)
+ final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[-1](final_input) # [B,1,H,W] in [-1,1]
+
+ # Denormalize to [0,1] and expand to RGB for visualization
+ out = (out + 1.0) * 0.5
+ return out.expand(-1, 3, -1, -1)
def generate_layer_shader(output_path, num_layers, kernel_sizes):
@@ -169,25 +173,35 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
# Generate layer switches
for layer_idx in range(num_layers):
- ks = kernel_sizes[layer_idx]
+ is_final = layer_idx == num_layers - 1
if layer_idx == 0:
- f.write(f" // Layer 0 uses coordinate-aware convolution\n")
+ f.write(f" // Layer 0: 7→4 (RGBD output)\n")
f.write(f" if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = cnn_conv{ks}x{ks}_with_coord(txt, smplr, uv, uniforms.resolution,\n")
- f.write(f" rgba_weights_layer{layer_idx}, coord_weights_layer{layer_idx}, bias_layer{layer_idx});\n")
- f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n")
+ f.write(f" // Denormalize to [0,1] for texture storage\n")
+ f.write(f" result = (result + 1.0) * 0.5;\n")
+ f.write(f" }}\n")
+ elif not is_final:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n")
+ f.write(f" // Denormalize to [0,1] for texture storage\n")
+ f.write(f" result = (result + 1.0) * 0.5;\n")
f.write(f" }}\n")
else:
- is_last = layer_idx == num_layers - 1
- f.write(f" {'else ' if layer_idx > 0 else ''}if (params.layer_index == {layer_idx}) {{\n")
- f.write(f" result = cnn_conv{ks}x{ks}(txt, smplr, uv, uniforms.resolution,\n")
- f.write(f" weights_layer{layer_idx}, bias_layer{layer_idx});\n")
- if not is_last:
- f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,\n")
+ f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" // Denormalize from [-1,1] to [0,1]\n")
+ f.write(f" let gray_01 = (gray_out + 1.0) * 0.5;\n")
+ f.write(f" result = vec4<f32>(gray_01, gray_01, gray_01, 1.0); // Expand to RGB\n")
f.write(f" }}\n")
# Add else clause for invalid layer index
- if num_layers > 1:
+ if num_layers > 0:
f.write(f" else {{\n")
f.write(f" result = input;\n")
f.write(f" }}\n")
@@ -204,96 +218,40 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes):
f.write("// Auto-generated CNN weights\n")
f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
- layer_idx = 0
for i, layer in enumerate(model.layers):
- if isinstance(layer, CoordConv2d):
- # Export RGBA weights
- weights = layer.conv_rgba.weight.data.cpu().numpy()
- kernel_size = kernel_sizes[layer_idx]
- out_ch, in_ch, kh, kw = weights.shape
- num_positions = kh * kw
-
- f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
- for pos in range(num_positions):
- row = pos // kw
- col = pos % kw
- f.write(" mat4x4<f32>(\n")
- for out_c in range(4):
- vals = []
- for in_c in range(4):
- if out_c < out_ch and in_c < in_ch:
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
- else:
- vals.append("0.0")
- f.write(f" {', '.join(vals)},\n")
- f.write(" )")
- if pos < num_positions - 1:
- f.write(",\n")
- else:
- f.write("\n")
- f.write(");\n\n")
+ weights = layer.weight.data.cpu().numpy()
+ bias = layer.bias.data.cpu().numpy()
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
- # Export coordinate weights
- coord_w = layer.coord_weights.data.cpu().numpy()
- f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n")
- for c in range(2):
- vals = []
- for out_c in range(4):
- if out_c < coord_w.shape[0]:
- vals.append(f"{coord_w[out_c, c]:.6f}")
- else:
- vals.append("0.0")
- f.write(f" {', '.join(vals)}")
- if c < 1:
- f.write(",\n")
- else:
- f.write("\n")
- f.write(");\n\n")
+ is_final = (i == len(model.layers) - 1)
- # Export bias
- bias = layer.bias.data.cpu().numpy()
- bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)]
- f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join(bias_vals))
+ if is_final:
+ # Final layer: 7→1, structure: array<array<f32, 8>, 9>
+ # [w0, w1, w2, w3, w4, w5, w6, bias]
+ f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_positions}> = array(\n")
+ for pos in range(num_positions):
+ row, col = pos // kw, pos % kw
+ vals = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(7)]
+ vals.append(f"{bias[0]:.6f}") # Append bias as 8th element
+ f.write(f" array<f32, 8>({', '.join(vals)})")
+ f.write(",\n" if pos < num_positions-1 else "\n")
f.write(");\n\n")
-
- layer_idx += 1
- elif isinstance(layer, nn.Conv2d):
- # Standard conv layer
- weights = layer.weight.data.cpu().numpy()
- kernel_size = kernel_sizes[layer_idx]
- out_ch, in_ch, kh, kw = weights.shape
- num_positions = kh * kw
-
- f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n")
+ else:
+ # Inner layers: 7→4, structure: array<array<f32, 8>, 36>
+ # Flattened: [pos0_ch0[7w+bias], pos0_ch1[7w+bias], ..., pos8_ch3[7w+bias]]
+ num_entries = num_positions * 4
+ f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_entries}> = array(\n")
for pos in range(num_positions):
- row = pos // kw
- col = pos % kw
- f.write(" mat4x4<f32>(\n")
+ row, col = pos // kw, pos % kw
for out_c in range(4):
- vals = []
- for in_c in range(4):
- if out_c < out_ch and in_c < in_ch:
- vals.append(f"{weights[out_c, in_c, row, col]:.6f}")
- else:
- vals.append("0.0")
- f.write(f" {', '.join(vals)},\n")
- f.write(" )")
- if pos < num_positions - 1:
- f.write(",\n")
- else:
- f.write("\n")
+ vals = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(7)]
+ vals.append(f"{bias[out_c]:.6f}") # Append bias
+ idx = pos * 4 + out_c
+ f.write(f" array<f32, 8>({', '.join(vals)})")
+ f.write(",\n" if idx < num_entries-1 else "\n")
f.write(");\n\n")
- # Export bias
- bias = layer.bias.data.cpu().numpy()
- bias_vals = [f"{bias[i]:.6f}" if i < len(bias) else "0.0" for i in range(4)]
- f.write(f"const bias_layer{layer_idx} = vec4<f32>(")
- f.write(", ".join(bias_vals))
- f.write(");\n\n")
-
- layer_idx += 1
-
def train(args):
"""Main training loop"""
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
index 168c9e2..df58b4d 100644
--- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
@@ -51,3 +51,103 @@ fn cnn_conv3x3_with_coord(
return sum;
}
+
+// Inner layers: 7→4 channels (RGBD output)
+// weights: array<array<f32, 8>, 36> (9 positions × 4 channels, each with 7 weights + bias)
+fn cnn_conv3x3_7to4(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ original: vec4<f32>,
+ weights: array<array<f32, 8>, 36>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+
+ // Compute grayscale from original and normalize to [-1,1]
+ let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;
+ let gray = (gray_01 - 0.5) * 2.0;
+
+ // Normalize UV to [-1,1]
+ let uv_norm = (uv - 0.5) * 2.0;
+
+ var sum = vec4<f32>(0.0);
+
+ var pos = 0;
+ for (var dy = -1; dy <= 1; dy++) {
+ for (var dx = -1; dx <= 1; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd_01 = textureSample(tex, samp, uv + offset);
+
+ // Normalize RGBD to [-1,1]
+ let rgbd = (rgbd_01 - 0.5) * 2.0;
+
+ // 7-channel input: [R,G,B,D, uv.x, uv.y, gray] all in [-1,1]
+ let inputs = array<f32, 7>(
+ rgbd.r, rgbd.g, rgbd.b, rgbd.a,
+ uv_norm.x, uv_norm.y, gray
+ );
+
+ // Accumulate for each output channel (RGBD)
+ for (var out_c = 0; out_c < 4; out_c++) {
+ let idx = pos * 4 + out_c;
+ var channel_sum = weights[idx][7]; // Bias (8th element)
+ for (var in_c = 0; in_c < 7; in_c++) {
+ channel_sum += weights[idx][in_c] * inputs[in_c];
+ }
+ sum[out_c] += channel_sum;
+ }
+
+ pos++;
+ }
+ }
+
+ return sum; // Output in [-1,1] range
+}
+
+// Final layer: 7→1 channel (scalar output)
+// weights: array<array<f32, 8>, 9> (9 positions, each with 7 weights + bias)
+fn cnn_conv3x3_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ original: vec4<f32>,
+ weights: array<array<f32, 8>, 9>
+) -> f32 {
+ let step = 1.0 / resolution;
+
+ // Normalize grayscale to [-1,1]
+ let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;
+ let gray = (gray_01 - 0.5) * 2.0;
+
+ // Normalize UV to [-1,1]
+ let uv_norm = (uv - 0.5) * 2.0;
+
+ var sum = 0.0;
+
+ var pos = 0;
+ for (var dy = -1; dy <= 1; dy++) {
+ for (var dx = -1; dx <= 1; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd_01 = textureSample(tex, samp, uv + offset);
+
+ // Normalize RGBD to [-1,1]
+ let rgbd = (rgbd_01 - 0.5) * 2.0;
+
+ // 7-channel input all in [-1,1]
+ sum += weights[pos][0] * rgbd.r;
+ sum += weights[pos][1] * rgbd.g;
+ sum += weights[pos][2] * rgbd.b;
+ sum += weights[pos][3] * rgbd.a;
+ sum += weights[pos][4] * uv_norm.x;
+ sum += weights[pos][5] * uv_norm.y;
+ sum += weights[pos][6] * gray;
+ sum += weights[pos][7]; // Bias
+
+ pos++;
+ }
+ }
+
+ return sum; // Output in [-1,1], needs denormalization
+}