summaryrefslogtreecommitdiff
path: root/cnn_v3/src
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/src')
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc20
-rw-r--r--cnn_v3/src/cnn_v3_effect.h4
2 files changed, 18 insertions, 6 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index d13799c..92178f7 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -187,14 +187,17 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
// ---------------------------------------------------------------------------
void CNNv3Effect::declare_nodes(NodeRegistry& registry) {
+ const int W = registry.default_width();
+ const int H = registry.default_height();
+
// enc0_tex: rgba16float full-res
- registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, -1, -1);
- // enc1_tex: rgba32uint half-res
- registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, -1, -1);
- // bottleneck_tex: rgba32uint quarter-res — declare at 1/4 resolution
- registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, -1, -1);
+ registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, W, H);
+ // enc1_tex: rgba32uint half-res — shaders use textureDimensions() for bounds
+ registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
+ // bottleneck_tex: rgba32uint quarter-res
+ registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
// dec1_tex: rgba16float half-res
- registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, -1, -1);
+ registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, W / 2, H / 2);
// output_tex: rgba16float full-res (the declared output_nodes_[0])
}
@@ -202,6 +205,11 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) {
// set_film_params — simple linear mapping, no MLP yet
// ---------------------------------------------------------------------------
+void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data,
+ uint32_t size_bytes) {
+ wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes);
+}
+
void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) {
// Identity + audio/beat modulation.
// Replace with FiLM MLP output once training is done.
diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h
index c358990..36e2797 100644
--- a/cnn_v3/src/cnn_v3_effect.h
+++ b/cnn_v3/src/cnn_v3_effect.h
@@ -89,6 +89,10 @@ class CNNv3Effect : public Effect {
// Update FiLM conditioning; call before render() each frame.
void set_film_params(const CNNv3FiLMParams& fp);
+ // Upload packed-f16 weights (kWeightsBufBytes bytes of u32 pairs).
+ // Used for testing and inference from trained .bin files.
+ void upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes);
+
private:
// Intermediate node names (prefixed from output[0])
std::string node_enc0_;