summaryrefslogtreecommitdiff
path: root/cnn_v3/training/cnn_v3_utils.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-23 00:43:20 +0100
committerskal <pascal.massimino@gmail.com>2026-03-23 00:43:20 +0100
commit13cf1438caa56b34529d4031ddf73d38286b70e5 (patch)
treeb8de1b0a597edcbfadcea5fca862be4b3d72a3db /cnn_v3/training/cnn_v3_utils.py
parent1470dd240f48652d1fe97957fe44a49b0e1ee9a6 (diff)
feat(cnn_v3): shadow→dif migration complete (ch18)
Replace raw shadow (ch18) with dif = max(0,dot(normal,KEY_LIGHT))*shadow across all layers. Channel count stays 20, weight shapes unchanged. - gbuf_pack.wgsl: t1.z = pack4x8unorm(mip2.g, mip2.b, dif, transp); t1.w = 0u - gbuf_deferred.wgsl: read dif from unpack4x8unorm(t1.z).z - gbuf_view.wgsl: revert to 4×5 grid, ch18=dif label, ch19=trns label - tools/shaders.js: FULL_PACK_SHADER adds oct_decode + computes dif - cnn_v3_utils.py: assemble_features() computes dif on-the-fly via oct_decode - docs: CNN_V3.md, HOWTO.md, HOW_TO_CNN.md, GBUF_DIF_MIGRATION.md updated handoff(Gemini): shadow→dif migration done, ready for first training pass
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py54
1 files changed, 37 insertions, 17 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index 5a3d56c..bef4091 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -11,7 +11,7 @@ Imported by train_cnn_v3.py and export_cnn_v3_weights.py.
[9-11] prev.rgb f32 (zero during training)
[12-14] mip1.rgb pyrdown(albedo)
[15-17] mip2.rgb pyrdown(mip1)
- [18] shadow f32 [0,1]
+ [18] dif f32 [0,1] max(0,dot(normal,KEY_LIGHT))*shadow
[19] transp f32 [0,1]
Sample directory layout (per sample_xxx/):
@@ -48,10 +48,11 @@ from torch.utils.data import Dataset
N_FEATURES = 20
GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy
-CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp
+CONTEXT_CHANNELS = [8, 18, 19] # mat_id, dif, transp
TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb
-_LUMA = np.array([0.2126, 0.7152, 0.0722], dtype=np.float32) # BT.709
+_LUMA = np.array([0.2126, 0.7152, 0.0722], dtype=np.float32) # BT.709
+_KEY_LIGHT = np.array([0.408, 0.816, 0.408 ], dtype=np.float32) # normalize(1,2,1)
# ---------------------------------------------------------------------------
# Image I/O
@@ -102,6 +103,21 @@ def depth_gradient(depth: np.ndarray) -> np.ndarray:
return np.stack([dzdx, dzdy], axis=-1)
+def oct_decode(enc: np.ndarray) -> np.ndarray:
+ """Decode oct-encoded normals (H,W,2) in [0,1] → (H,W,3) unit normals."""
+ f = enc * 2.0 - 1.0 # [0,1] → [-1,1]
+ z = 1.0 - np.abs(f[..., :1]) - np.abs(f[..., 1:2])
+ n = np.concatenate([f, z], axis=-1)
+ neg = n[..., 2:3] < 0.0
+ n = np.concatenate([
+ np.where(neg, (1.0 - np.abs(f[..., 1:2])) * np.sign(f[..., :1]), n[..., :1]),
+ np.where(neg, (1.0 - np.abs(f[..., :1])) * np.sign(f[..., 1:2]), n[..., 1:2]),
+ n[..., 2:3],
+ ], axis=-1)
+ length = np.linalg.norm(n, axis=-1, keepdims=True)
+ return n / np.maximum(length, 1e-8)
+
+
def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
"""Nearest-neighbour upsample (H,W,C) f32 to (h,w,C) — pure numpy, no precision loss."""
sh, sw = a.shape[:2]
@@ -117,25 +133,29 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray,
prev set to zero (no temporal history during training).
mip1/mip2 computed from albedo. depth_grad computed via finite diff.
+ dif (ch18) = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow.
"""
h, w = albedo.shape[:2]
- mip1 = _upsample_nearest(pyrdown(albedo), h, w)
- mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
- dgrad = depth_gradient(depth)
- prev = np.zeros((h, w, 3), dtype=np.float32)
+ mip1 = _upsample_nearest(pyrdown(albedo), h, w)
+ mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
+ dgrad = depth_gradient(depth)
+ prev = np.zeros((h, w, 3), dtype=np.float32)
+ nor3 = oct_decode(normal)
+ diffuse = np.maximum(0.0, (nor3 * _KEY_LIGHT).sum(-1))
+ dif = diffuse * shadow
return np.concatenate([
- albedo, # [0-2] albedo.rgb
- normal, # [3-4] normal.xy
- depth[..., None], # [5] depth
- dgrad, # [6-7] depth_grad.xy
- matid[..., None], # [8] mat_id
- prev, # [9-11] prev.rgb
- mip1, # [12-14] mip1.rgb
- mip2, # [15-17] mip2.rgb
- shadow[..., None], # [18] shadow
- transp[..., None], # [19] transp
+ albedo, # [0-2] albedo.rgb
+ normal, # [3-4] normal.xy
+ depth[..., None], # [5] depth
+ dgrad, # [6-7] depth_grad.xy
+ matid[..., None], # [8] mat_id
+ prev, # [9-11] prev.rgb
+ mip1, # [12-14] mip1.rgb
+ mip2, # [15-17] mip2.rgb
+ dif[..., None], # [18] dif = diffuse * shadow
+ transp[..., None],# [19] transp
], axis=-1).astype(np.float32)