diff options
Diffstat (limited to 'doc/CNN_V2.md')
| -rw-r--r-- | doc/CNN_V2.md | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/doc/CNN_V2.md b/doc/CNN_V2.md index e56b022..49086ca 100644 --- a/doc/CNN_V2.md +++ b/doc/CNN_V2.md @@ -245,12 +245,28 @@ fn pack_channels(values: vec4<f32>) -> vec4<u32> { **Static Feature Extraction:** ```python -def compute_static_features(rgb, depth): - """Generate parametric features (8D: p0-p3 + spatial).""" +def compute_static_features(rgb, depth, mip_level=0): + """Generate parametric features (8D: p0-p3 + spatial). + + Args: + mip_level: 0=original, 1=half res, 2=quarter res, 3=eighth res + """ h, w = rgb.shape[:2] - # Parametric features (example: use input RGBD, but could be mips/gradients) - p0, p1, p2, p3 = rgb[..., 0], rgb[..., 1], rgb[..., 2], depth + # Generate mip level for p0-p3 (downsample then upsample) + if mip_level > 0: + mip_rgb = rgb.copy() + for _ in range(mip_level): + mip_rgb = cv2.pyrDown(mip_rgb) + for _ in range(mip_level): + mip_rgb = cv2.pyrUp(mip_rgb) + if mip_rgb.shape[:2] != (h, w): + mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) + else: + mip_rgb = rgb + + # Parametric features from mip level + p0, p1, p2, p3 = mip_rgb[..., 0], mip_rgb[..., 1], mip_rgb[..., 2], depth # UV coordinates (normalized) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0) @@ -308,6 +324,7 @@ class CNNv2(nn.Module): # Hyperparameters kernel_sizes = [3, 3, 3] # Per-layer kernel sizes (e.g., [1,3,5]) num_layers = 3 # Number of CNN layers +mip_level = 0 # Mip level for p0-p3: 0=orig, 1=half, 2=quarter, 3=eighth learning_rate = 1e-3 batch_size = 16 epochs = 5000 @@ -318,8 +335,8 @@ epochs = 5000 # Training loop (standard PyTorch f32) for epoch in range(epochs): for rgb_batch, depth_batch, target_batch in dataloader: - # Compute static features (8D) - static_feat = compute_static_features(rgb_batch, depth_batch) + # Compute static features (8D) with mip level + static_feat = compute_static_features(rgb_batch, depth_batch, mip_level) # Input RGBD (4D) input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1) @@ -342,6 +359,7 @@ torch.save({ 'config': { 'kernel_sizes': [3, 3, 3], # Per-layer kernel sizes 'num_layers': 3, + 'mip_level': 0, # Mip level used for p0-p3 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] }, 'epoch': epoch, |
