summaryrefslogtreecommitdiff
path: root/cnn_v3/training/cnn_v3_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index 8da276e..5b43a4d 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -273,9 +273,11 @@ class CNNv3Dataset(Dataset):
matid = load_gray(sd / 'matid.png')
shadow = load_gray(sd / 'shadow.png')
transp = load_gray(sd / 'transp.png')
- target = np.asarray(
- Image.open(sd / 'target.png').convert('RGBA'),
- dtype=np.float32) / 255.0
+ h, w = albedo.shape[:2]
+ target_img = Image.open(sd / 'target.png').convert('RGBA')
+ if target_img.size != (w, h):
+ target_img = target_img.resize((w, h), Image.LANCZOS)
+ target = np.asarray(target_img, dtype=np.float32) / 255.0
return albedo, normal, depth, matid, shadow, transp, target
def __getitem__(self, idx):