summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/CNN_V2.md30
-rw-r--r--doc/HOWTO.md6
-rwxr-xr-xtraining/train_cnn_v2.py49
3 files changed, 65 insertions, 20 deletions
diff --git a/doc/CNN_V2.md b/doc/CNN_V2.md
index e56b022..49086ca 100644
--- a/doc/CNN_V2.md
+++ b/doc/CNN_V2.md
@@ -245,12 +245,28 @@ fn pack_channels(values: vec4<f32>) -> vec4<u32> {
**Static Feature Extraction:**
```python
-def compute_static_features(rgb, depth):
- """Generate parametric features (8D: p0-p3 + spatial)."""
+def compute_static_features(rgb, depth, mip_level=0):
+ """Generate parametric features (8D: p0-p3 + spatial).
+
+ Args:
+ mip_level: 0=original, 1=half res, 2=quarter res, 3=eighth res
+ """
h, w = rgb.shape[:2]
- # Parametric features (example: use input RGBD, but could be mips/gradients)
- p0, p1, p2, p3 = rgb[..., 0], rgb[..., 1], rgb[..., 2], depth
+ # Generate mip level for p0-p3 (downsample then upsample)
+ if mip_level > 0:
+ mip_rgb = rgb.copy()
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrDown(mip_rgb)
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrUp(mip_rgb)
+ if mip_rgb.shape[:2] != (h, w):
+ mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
+ else:
+ mip_rgb = rgb
+
+ # Parametric features from mip level
+ p0, p1, p2, p3 = mip_rgb[..., 0], mip_rgb[..., 1], mip_rgb[..., 2], depth
# UV coordinates (normalized)
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0)
@@ -308,6 +324,7 @@ class CNNv2(nn.Module):
# Hyperparameters
kernel_sizes = [3, 3, 3] # Per-layer kernel sizes (e.g., [1,3,5])
num_layers = 3 # Number of CNN layers
+mip_level = 0 # Mip level for p0-p3: 0=orig, 1=half, 2=quarter, 3=eighth
learning_rate = 1e-3
batch_size = 16
epochs = 5000
@@ -318,8 +335,8 @@ epochs = 5000
# Training loop (standard PyTorch f32)
for epoch in range(epochs):
for rgb_batch, depth_batch, target_batch in dataloader:
- # Compute static features (8D)
- static_feat = compute_static_features(rgb_batch, depth_batch)
+ # Compute static features (8D) with mip level
+ static_feat = compute_static_features(rgb_batch, depth_batch, mip_level)
# Input RGBD (4D)
input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1)
@@ -342,6 +359,7 @@ torch.save({
'config': {
'kernel_sizes': [3, 3, 3], # Per-layer kernel sizes
'num_layers': 3,
+ 'mip_level': 0, # Mip level used for p0-p3
'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias']
},
'epoch': epoch,
diff --git a/doc/HOWTO.md b/doc/HOWTO.md
index 9c67106..9003fe1 100644
--- a/doc/HOWTO.md
+++ b/doc/HOWTO.md
@@ -166,6 +166,12 @@ Config: 100 epochs, 3×3 kernels, 8→4→4 channels, patch-based (harris detect
--input training/input/ --target training/target_2/ \
--kernel-sizes 1,3,5 \
--epochs 5000 --batch-size 16
+
+# Mip-level for p0-p3 features (0=original, 1=half, 2=quarter, 3=eighth)
+./training/train_cnn_v2.py \
+ --input training/input/ --target training/target_2/ \
+ --mip-level 1 \
+ --epochs 100 --batch-size 16
```
**Export Binary Weights:**
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index dc087c6..3d49d13 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -21,26 +21,40 @@ import time
import cv2
-def compute_static_features(rgb, depth=None):
+def compute_static_features(rgb, depth=None, mip_level=0):
"""Generate 8D static features (parametric + spatial).
Args:
rgb: (H, W, 3) RGB image [0, 1]
depth: (H, W) depth map [0, 1], optional
+ mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth)
Returns:
(H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]
- Note: p0-p3 are parametric features (can be mips, gradients, etc.)
- For training, we use RGBD as default, but could use mip1/2
+ Note: p0-p3 are parametric features generated from specified mip level
"""
h, w = rgb.shape[:2]
- # Parametric features (p0-p3) - using RGBD as default
- # TODO: Experiment with mip1 grayscale, gradients, etc.
- p0 = rgb[:, :, 0].astype(np.float32)
- p1 = rgb[:, :, 1].astype(np.float32)
- p2 = rgb[:, :, 2].astype(np.float32)
+ # Generate mip level for p0-p3
+ if mip_level > 0:
+ # Downsample to mip level
+ mip_rgb = rgb.copy()
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrDown(mip_rgb)
+ # Upsample back to original size
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrUp(mip_rgb)
+ # Crop/pad to exact original size if needed
+ if mip_rgb.shape[:2] != (h, w):
+ mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
+ else:
+ mip_rgb = rgb
+
+ # Parametric features (p0-p3) from mip level
+ p0 = mip_rgb[:, :, 0].astype(np.float32)
+ p1 = mip_rgb[:, :, 1].astype(np.float32)
+ p2 = mip_rgb[:, :, 2].astype(np.float32)
p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
# UV coordinates (normalized [0, 1])
@@ -119,12 +133,13 @@ class PatchDataset(Dataset):
"""Patch-based dataset extracting salient regions from images."""
def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
- detector='harris'):
+ detector='harris', mip_level=0):
self.input_paths = sorted(Path(input_dir).glob("*.png"))
self.target_paths = sorted(Path(target_dir).glob("*.png"))
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.detector = detector
+ self.mip_level = mip_level
assert len(self.input_paths) == len(self.target_paths), \
f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
@@ -224,7 +239,7 @@ class PatchDataset(Dataset):
target_patch = target_img[y1:y2, x1:x2] # RGBA
# Compute static features for patch
- static_feat = compute_static_features(input_patch.astype(np.float32))
+ static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
@@ -240,10 +255,11 @@ class PatchDataset(Dataset):
class ImagePairDataset(Dataset):
"""Dataset of input/target image pairs (full-image mode)."""
- def __init__(self, input_dir, target_dir, target_size=(256, 256)):
+ def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0):
self.input_paths = sorted(Path(input_dir).glob("*.png"))
self.target_paths = sorted(Path(target_dir).glob("*.png"))
self.target_size = target_size
+ self.mip_level = mip_level
assert len(self.input_paths) == len(self.target_paths), \
f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
@@ -263,7 +279,7 @@ class ImagePairDataset(Dataset):
target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
# Compute static features
- static_feat = compute_static_features(input_img.astype(np.float32))
+ static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
h, w = input_img.shape[:2]
@@ -286,14 +302,15 @@ def train(args):
if args.full_image:
print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})")
target_size = (args.image_size, args.image_size)
- dataset = ImagePairDataset(args.input, args.target, target_size=target_size)
+ dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
else:
print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)")
dataset = PatchDataset(args.input, args.target,
patch_size=args.patch_size,
patches_per_image=args.patches_per_image,
- detector=args.detector)
+ detector=args.detector,
+ mip_level=args.mip_level)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# Parse kernel sizes
@@ -306,6 +323,7 @@ def train(args):
total_params = sum(p.numel() for p in model.parameters())
kernel_desc = ','.join(map(str, kernel_sizes))
print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights")
+ print(f"Using mip level {args.mip_level} for p0-p3 features")
# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
@@ -351,6 +369,7 @@ def train(args):
'config': {
'kernel_sizes': kernel_sizes,
'num_layers': args.num_layers,
+ 'mip_level': args.mip_level,
'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias']
}
}, checkpoint_path)
@@ -387,6 +406,8 @@ def main():
help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)')
parser.add_argument('--num-layers', type=int, default=3,
help='Number of CNN layers (default: 3)')
+ parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3],
+ help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)')
# Training parameters
parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')