diff options
Diffstat (limited to 'cnn_v3/src')
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.cc | 20 | ||||
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.h | 4 |
2 files changed, 18 insertions, 6 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc index d13799c..92178f7 100644 --- a/cnn_v3/src/cnn_v3_effect.cc +++ b/cnn_v3/src/cnn_v3_effect.cc @@ -187,14 +187,17 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx, // --------------------------------------------------------------------------- void CNNv3Effect::declare_nodes(NodeRegistry& registry) { + const int W = registry.default_width(); + const int H = registry.default_height(); + // enc0_tex: rgba16float full-res - registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, -1, -1); - // enc1_tex: rgba32uint half-res - registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, -1, -1); - // bottleneck_tex: rgba32uint quarter-res — declare at 1/4 resolution - registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, -1, -1); + registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, W, H); + // enc1_tex: rgba32uint half-res — shaders use textureDimensions() for bounds + registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); + // bottleneck_tex: rgba32uint quarter-res + registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); // dec1_tex: rgba16float half-res - registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, -1, -1); + registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, W / 2, H / 2); // output_tex: rgba16float full-res (the declared output_nodes_[0]) } @@ -202,6 +205,11 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) { // set_film_params — simple linear mapping, no MLP yet // --------------------------------------------------------------------------- +void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, + uint32_t size_bytes) { + wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes); +} + void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { // Identity + audio/beat modulation. // Replace with FiLM MLP output once training is done. diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h index c358990..36e2797 100644 --- a/cnn_v3/src/cnn_v3_effect.h +++ b/cnn_v3/src/cnn_v3_effect.h @@ -89,6 +89,10 @@ class CNNv3Effect : public Effect { // Update FiLM conditioning; call before render() each frame. void set_film_params(const CNNv3FiLMParams& fp); + // Upload packed-f16 weights (kWeightsBufBytes bytes of u32 pairs). + // Used for testing and inference from trained .bin files. + void upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes); + private: // Intermediate node names (prefixed from output[0]) std::string node_enc0_; |
