summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py21
1 files changed, 8 insertions, 13 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 3673b97..dc087c6 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -208,9 +208,10 @@ class PatchDataset(Dataset):
# Load original images (no resize)
input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0
- target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0
+ target_pil = Image.open(self.target_paths[img_idx])
+ target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
- # Detect salient points on original image
+ # Detect salient points on original image (use RGB only)
salient_points = self._detect_salient_points(input_img)
cx, cy = salient_points[patch_idx]
@@ -220,7 +221,7 @@ class PatchDataset(Dataset):
x1, x2 = cx - half_patch, cx + half_patch
input_patch = input_img[y1:y2, x1:x2]
- target_patch = target_img[y1:y2, x1:x2]
+ target_patch = target_img[y1:y2, x1:x2] # RGBA
# Compute static features for patch
static_feat = compute_static_features(input_patch.astype(np.float32))
@@ -231,10 +232,7 @@ class PatchDataset(Dataset):
# Convert to tensors (C, H, W)
input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
- target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1)
-
- # Pad target to 4 channels (RGBA)
- target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
+ target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image
return input_rgbd, static_feat, target
@@ -255,14 +253,14 @@ class ImagePairDataset(Dataset):
def __getitem__(self, idx):
# Load and resize images to fixed size
input_pil = Image.open(self.input_paths[idx]).convert('RGB')
- target_pil = Image.open(self.target_paths[idx]).convert('RGB')
+ target_pil = Image.open(self.target_paths[idx])
# Resize to target size
input_pil = input_pil.resize(self.target_size, Image.LANCZOS)
target_pil = target_pil.resize(self.target_size, Image.LANCZOS)
input_img = np.array(input_pil) / 255.0
- target_img = np.array(target_pil) / 255.0
+ target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
# Compute static features
static_feat = compute_static_features(input_img.astype(np.float32))
@@ -274,10 +272,7 @@ class ImagePairDataset(Dataset):
# Convert to tensors (C, H, W)
input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
- target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1)
-
- # Pad target to 4 channels (RGBA)
- target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
+ target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image
return input_rgbd, static_feat, target