From 5074e4caec017d6607de5806858d0271a554d77c Mon Sep 17 00:00:00 2001 From: skal Date: Fri, 13 Feb 2026 12:47:43 +0100 Subject: CNN v2 training: Use target image alpha channel Changed target loading from RGB to RGBA to preserve transparency. Model learns to predict alpha channel from target image instead of constant 1.0 padding. Before: Target padded with alpha=1.0 After: Target uses actual alpha from image (or 1.0 if no alpha) Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn_v2.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) (limited to 'training/train_cnn_v2.py') diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 3673b97..dc087c6 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -208,9 +208,10 @@ class PatchDataset(Dataset): # Load original images (no resize) input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 - target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0 + target_pil = Image.open(self.target_paths[img_idx]) + target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha - # Detect salient points on original image + # Detect salient points on original image (use RGB only) salient_points = self._detect_salient_points(input_img) cx, cy = salient_points[patch_idx] @@ -220,7 +221,7 @@ class PatchDataset(Dataset): x1, x2 = cx - half_patch, cx + half_patch input_patch = input_img[y1:y2, x1:x2] - target_patch = target_img[y1:y2, x1:x2] + target_patch = target_img[y1:y2, x1:x2] # RGBA # Compute static features for patch static_feat = compute_static_features(input_patch.astype(np.float32)) @@ -231,10 +232,7 @@ class PatchDataset(Dataset): # Convert to tensors (C, H, W) input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) - target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) - - # Pad target to 4 channels (RGBA) - target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) + target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image return input_rgbd, static_feat, target @@ -255,14 +253,14 @@ class ImagePairDataset(Dataset): def __getitem__(self, idx): # Load and resize images to fixed size input_pil = Image.open(self.input_paths[idx]).convert('RGB') - target_pil = Image.open(self.target_paths[idx]).convert('RGB') + target_pil = Image.open(self.target_paths[idx]) # Resize to target size input_pil = input_pil.resize(self.target_size, Image.LANCZOS) target_pil = target_pil.resize(self.target_size, Image.LANCZOS) input_img = np.array(input_pil) / 255.0 - target_img = np.array(target_pil) / 255.0 + target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) @@ -274,10 +272,7 @@ class ImagePairDataset(Dataset): # Convert to tensors (C, H, W) input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) - target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) - - # Pad target to 4 channels (RGBA) - target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) + target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image return input_rgbd, static_feat, target -- cgit v1.2.3