diff options
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index 8da276e..5b43a4d 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -273,9 +273,11 @@ class CNNv3Dataset(Dataset): matid = load_gray(sd / 'matid.png') shadow = load_gray(sd / 'shadow.png') transp = load_gray(sd / 'transp.png') - target = np.asarray( - Image.open(sd / 'target.png').convert('RGBA'), - dtype=np.float32) / 255.0 + h, w = albedo.shape[:2] + target_img = Image.open(sd / 'target.png').convert('RGBA') + if target_img.size != (w, h): + target_img = target_img.resize((w, h), Image.LANCZOS) + target = np.asarray(target_img, dtype=np.float32) / 255.0 return albedo, normal, depth, matid, shadow, transp, target def __getitem__(self, idx): |
