summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc16
-rw-r--r--src/tests/gpu/test_cnn_v3_parity.cc6
-rwxr-xr-xtools/seq_compiler.py13
-rw-r--r--workspaces/main/timeline.seq4
-rw-r--r--workspaces/main/weights/cnn_v3_film_mlp.binbin0 -> 3104 bytes
-rw-r--r--workspaces/main/weights/cnn_v3_weights.binbin0 -> 3928 bytes
6 files changed, 36 insertions, 3 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index 759e7ef..4aa2c25 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -2,8 +2,16 @@
// See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md ยง7 for shader details.
#include "cnn_v3_effect.h"
+
+#if defined(USE_TEST_ASSETS)
+#include "test_assets.h"
+#else
+#include "generated/assets.h"
+#endif
+
#include "gpu/gpu.h"
#include "gpu/shader_composer.h"
+#include "util/asset_manager.h"
#include "util/fatal_error.h"
#include <cstdint>
#include <cstring>
@@ -180,6 +188,14 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; }
create_pipelines();
+
+ // Load trained weights from asset system (zero-initialized if absent).
+ size_t weights_size = 0;
+ const void* weights_data =
+ GetAsset(AssetId::ASSET_WEIGHTS_CNN_V3, &weights_size);
+ if (weights_data && weights_size == kWeightsBufBytes) {
+ upload_weights(ctx_.queue, weights_data, (uint32_t)weights_size);
+ }
}
// ---------------------------------------------------------------------------
diff --git a/src/tests/gpu/test_cnn_v3_parity.cc b/src/tests/gpu/test_cnn_v3_parity.cc
index 4b41c94..15fe818 100644
--- a/src/tests/gpu/test_cnn_v3_parity.cc
+++ b/src/tests/gpu/test_cnn_v3_parity.cc
@@ -188,6 +188,12 @@ static std::vector<float> run_cnn_v3(WebGPUTestFixture& fixture,
if (weights_u32) {
effect.upload_weights(ctx.queue, weights_u32, weights_bytes);
+ } else {
+ // Explicitly zero weights to override any asset-loaded defaults.
+ // kWeightsBufBytes = ((1964+1)/2)*4 = 3928
+ const uint32_t zero_size = ((1964u + 1u) / 2u) * 4u;
+ std::vector<uint8_t> zeros(zero_size, 0);
+ effect.upload_weights(ctx.queue, zeros.data(), zero_size);
}
// Run 5 compute passes
diff --git a/tools/seq_compiler.py b/tools/seq_compiler.py
index dfd2ea4..fbd5c0d 100755
--- a/tools/seq_compiler.py
+++ b/tools/seq_compiler.py
@@ -391,14 +391,21 @@ def generate_cpp(seq: SequenceDecl, sorted_effects: List[EffectDecl],
class_name += f'_{seq_index}_Sequence'
# Generate includes
- # Map class names that share a header file
+ # Map class names to header stems (default path: effects/<stem>_effect.h)
+ # Use a full #include string to override the path entirely.
CLASS_TO_HEADER = {
- 'NtscYiq': 'ntsc',
+ 'NtscYiq': 'ntsc',
+ 'GBufferEffect': '#include "../../cnn_v3/src/gbuffer_effect.h"',
+ 'CNNv3Effect': '#include "../../cnn_v3/src/cnn_v3_effect.h"',
}
includes = set()
for effect in seq.effects:
if effect.class_name in CLASS_TO_HEADER:
- header = CLASS_TO_HEADER[effect.class_name]
+ val = CLASS_TO_HEADER[effect.class_name]
+ if val.startswith('#include'):
+ includes.add(val)
+ continue
+ header = val
else:
# Convert ClassName to snake_case header
header = re.sub('([A-Z])', r'_\1', effect.class_name).lower().lstrip('_')
diff --git a/workspaces/main/timeline.seq b/workspaces/main/timeline.seq
index ee7e1e7..1a9cad3 100644
--- a/workspaces/main/timeline.seq
+++ b/workspaces/main/timeline.seq
@@ -42,3 +42,7 @@ SEQUENCE 30.00 3 "complex_chain"
SEQUENCE 40.00 0 "ntsc"
EFFECT + Scene1 source -> temp1 0.00 8.00
EFFECT + Ntsc temp1 -> sink 0.00 8.00
+
+SEQUENCE 48.00 0 "cnn_v3_test"
+ EFFECT + GBufferEffect source -> gbuf_feat0 gbuf_feat1 0.00 8.00
+ EFFECT + CNNv3Effect gbuf_feat0 gbuf_feat1 -> sink 0.00 8.00
diff --git a/workspaces/main/weights/cnn_v3_film_mlp.bin b/workspaces/main/weights/cnn_v3_film_mlp.bin
new file mode 100644
index 0000000..a49dcbe
--- /dev/null
+++ b/workspaces/main/weights/cnn_v3_film_mlp.bin
Binary files differ
diff --git a/workspaces/main/weights/cnn_v3_weights.bin b/workspaces/main/weights/cnn_v3_weights.bin
new file mode 100644
index 0000000..7890fea
--- /dev/null
+++ b/workspaces/main/weights/cnn_v3_weights.bin
Binary files differ