diff options
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 1487c08..a9a311a 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -30,13 +30,13 @@ def compute_static_features(rgb, depth=None, mip_level=0): mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth) Returns: - (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] + (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias] Note: p0-p3 are parametric features generated from specified mip level TODO: Binary format should support arbitrary layout and ordering for feature vector (7D), alongside mip-level indication. Current layout is hardcoded as: - [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] + [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias] Future: Allow experimentation with different feature combinations without shader recompilation. Examples: [R, G, B, dx, dy, uv_x, bias] or [mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias] """ @@ -68,13 +68,13 @@ def compute_static_features(rgb, depth=None, mip_level=0): uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) # Multi-frequency position encoding - sin10_x = np.sin(10.0 * uv_x).astype(np.float32) + sin20_y = np.sin(20.0 * uv_y).astype(np.float32) # Bias dimension (always 1.0) - replaces Conv2d bias parameter bias = np.ones((h, w), dtype=np.float32) - # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] - features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) + # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin20_y, bias] + features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias], axis=-1) return features @@ -376,7 +376,7 @@ def train(args): 'kernel_sizes': kernel_sizes, 'num_layers': args.num_layers, 'mip_level': args.mip_level, - 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias'] } }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") |
