summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py2
-rw-r--r--cnn_v3/training/train_cnn_v3.py2
2 files changed, 2 insertions, 2 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index 68c0798..bd0ca2f 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -140,7 +140,7 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray,
mip1 = _upsample_nearest(pyrdown(albedo), h, w)
mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
- dgrad = depth_gradient(depth)
+ dgrad = np.tanh(depth_gradient(depth) * 10.0)
if prev is None:
prev = np.zeros((h, w, 3), dtype=np.float32)
nor3 = oct_decode(normal)
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index fa0d2e2..5065f22 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -160,7 +160,7 @@ def train(args):
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
- criterion = nn.MSELoss()
+ criterion = nn.L1Loss()
ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
start_epoch = 1