summaryrefslogtreecommitdiff
path: root/cnn_v3/src/cnn_v3_effect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.cc')
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc16
1 files changed, 16 insertions, 0 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index 759e7ef..4aa2c25 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -2,8 +2,16 @@
// See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md ยง7 for shader details.
#include "cnn_v3_effect.h"
+
+#if defined(USE_TEST_ASSETS)
+#include "test_assets.h"
+#else
+#include "generated/assets.h"
+#endif
+
#include "gpu/gpu.h"
#include "gpu/shader_composer.h"
+#include "util/asset_manager.h"
#include "util/fatal_error.h"
#include <cstdint>
#include <cstring>
@@ -180,6 +188,14 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; }
create_pipelines();
+
+ // Load trained weights from asset system (zero-initialized if absent).
+ size_t weights_size = 0;
+ const void* weights_data =
+ GetAsset(AssetId::ASSET_WEIGHTS_CNN_V3, &weights_size);
+ if (weights_data && weights_size == kWeightsBufBytes) {
+ upload_weights(ctx_.queue, weights_data, (uint32_t)weights_size);
+ }
}
// ---------------------------------------------------------------------------