summaryrefslogtreecommitdiff
path: root/cnn_v3/src/cnn_v3_effect.h
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.h')
-rw-r--r--cnn_v3/src/cnn_v3_effect.h22
1 files changed, 20 insertions, 2 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h
index 070f988..589680c 100644
--- a/cnn_v3/src/cnn_v3_effect.h
+++ b/cnn_v3/src/cnn_v3_effect.h
@@ -7,7 +7,7 @@
// enc1: Conv(8→16, 3×3) + FiLM16 + ReLU H/2×W/2 2× rgba32uint
// bottleneck: Conv(16→16, 3×3, dil=2) + ReLU H/4×W/4 2× rgba32uint
// dec1: Conv(32→8, 3×3) + FiLM8 + ReLU H/2×W/2 rgba32uint
-// dec0: Conv(16→4, 3×3) + FiLM4 + ReLU + sig H×W rgba16float
+// dec0: Conv(16→4, 3×3) + FiLM4 + sig H×W rgba16float
//
// Inputs: feat_tex0, feat_tex1 (rgba32uint, 20-channel G-buffer)
// Output: output_tex (rgba16float, 4-channel RGBA)
@@ -97,6 +97,17 @@ struct CNNv3FiLMParams {
float style_p1 = 0.0f; // user-defined style param
};
+// FiLM MLP weights: Linear(5→16)→ReLU→Linear(16→72).
+// Loaded from cnn_v3_film_mlp.bin (1320 f32 = 5280 bytes).
+// Layout: l0_w(80) | l0_b(16) | l1_w(1152) | l1_b(72), all row-major f32.
+struct CNNv3FilmMlp {
+ float l0_w[16 * 5]; // (16, 5) row-major
+ float l0_b[16];
+ float l1_w[72 * 16]; // (72, 16) row-major
+ float l1_b[72];
+};
+static_assert(sizeof(CNNv3FilmMlp) == 1320 * 4, "CNNv3FilmMlp size mismatch");
+
class CNNv3Effect : public Effect {
public:
CNNv3Effect(const GpuContext& ctx, const std::vector<std::string>& inputs,
@@ -111,9 +122,13 @@ class CNNv3Effect : public Effect {
// Update FiLM conditioning; call before render() each frame.
void set_film_params(const CNNv3FiLMParams& fp);
- // Upload packed-f16 weights (kWeightsBufBytes bytes of u32 pairs).
+ // Upload packed-f16 conv weights (kWeightsBufBytes bytes of u32 pairs).
void upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes);
+ // Load FiLM MLP weights from cnn_v3_film_mlp.bin (1320 f32 = 5280 bytes).
+ // Must be called before set_film_params() for learned conditioning.
+ void load_film_mlp(const void* data, uint32_t size_bytes);
+
private:
// Intermediate node names (prefixed from output[0])
std::string node_enc0_;
@@ -156,4 +171,7 @@ class CNNv3Effect : public Effect {
void create_pipelines();
void update_bind_groups(NodeRegistry& nodes);
+
+ CNNv3FilmMlp mlp_{};
+ bool mlp_loaded_ = false;
};