diff options
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index e48f684..fa0d2e2 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -10,8 +10,7 @@ Architecture (enc_channels=[8,16]): enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split) bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split) dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch) - dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch) - output sigmoid → RGBA + dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM + sigmoid H×W rgba16float (4ch) FiLM MLP: Linear(5→16) → ReLU → Linear(16→72) 72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4) @@ -93,9 +92,9 @@ class CNNv3(nn.Module): torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1) ), gd1, bd1)) - x = F.relu(film_apply(self.dec0( + x = film_apply(self.dec0( torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1) - ), gd0, bd0)) + ), gd0, bd0) return torch.sigmoid(x) |
