diff options
| -rw-r--r-- | cnn_v3/shaders/gbuf_pack.wgsl | 5 | ||||
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 2 | ||||
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 2 |
3 files changed, 5 insertions, 4 deletions
diff --git a/cnn_v3/shaders/gbuf_pack.wgsl b/cnn_v3/shaders/gbuf_pack.wgsl index 777b4e5..5870938 100644 --- a/cnn_v3/shaders/gbuf_pack.wgsl +++ b/cnn_v3/shaders/gbuf_pack.wgsl @@ -60,8 +60,9 @@ fn pack_features(@builtin(global_invocation_id) id: vec3u) { let depth_raw = load_depth(coord); // Finite-difference depth gradient (central difference, clamped coords) - let dzdx = (load_depth(coord + vec2i(1, 0)) - load_depth(coord - vec2i(1, 0))) * 0.5; - let dzdy = (load_depth(coord + vec2i(0, 1)) - load_depth(coord - vec2i(0, 1))) * 0.5; + // tanh(10x) keeps typical gradients (±0.05–0.1) in [-1,1]; matches training normalization. + let dzdx = tanh((load_depth(coord + vec2i(1, 0)) - load_depth(coord - vec2i(1, 0))) * 5.0); + let dzdy = tanh((load_depth(coord + vec2i(0, 1)) - load_depth(coord - vec2i(0, 1))) * 5.0); // Normal: stored as oct-encoded [0,1] in RG; extract just the encoded xy for feat_tex0 let normal_enc = nm.rg; // already in [0,1] — decode to get the xy for CNN input diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index 68c0798..bd0ca2f 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -140,7 +140,7 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray, mip1 = _upsample_nearest(pyrdown(albedo), h, w) mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w) - dgrad = depth_gradient(depth) + dgrad = np.tanh(depth_gradient(depth) * 10.0) if prev is None: prev = np.zeros((h, w, 3), dtype=np.float32) nor3 = oct_decode(normal) diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index fa0d2e2..5065f22 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -160,7 +160,7 @@ def train(args): optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) - criterion = nn.MSELoss() + criterion = nn.L1Loss() ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) start_epoch = 1 |
