diff options
Diffstat (limited to 'cnn_v3')
| -rw-r--r-- | cnn_v3/docs/HOW_TO_CNN.md | 26 | ||||
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 2 |
2 files changed, 14 insertions, 14 deletions
diff --git a/cnn_v3/docs/HOW_TO_CNN.md b/cnn_v3/docs/HOW_TO_CNN.md index 020f79c..458b68f 100644 --- a/cnn_v3/docs/HOW_TO_CNN.md +++ b/cnn_v3/docs/HOW_TO_CNN.md @@ -458,11 +458,14 @@ Converts a trained `.pth` checkpoint to two raw binary files for the C++ runtime ```bash cd cnn_v3/training -python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth -# writes to export/ by default - python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth \ - --output /path/to/assets/ + --output ../../workspaces/main/weights/ +``` + +Output files are registered in `workspaces/main/assets.txt` as: +``` +WEIGHTS_CNN_V3, BINARY, weights/cnn_v3_weights.bin, "CNN v3 conv weights (f16, 3928 bytes)" +WEIGHTS_CNN_V3_FILM_MLP, BINARY, weights/cnn_v3_film_mlp.bin, "CNN v3 FiLM MLP weights (f32, 3104 bytes)" ``` ### Output files @@ -557,20 +560,15 @@ auto cnn = std::make_shared<CNNv3Effect>( ### Uploading weights -Load `cnn_v3_weights.bin` once at startup, before the first `render()`: +Load `cnn_v3_weights.bin` once at startup via the asset system, before the first `render()`: ```cpp -// Read binary file -std::vector<uint8_t> data; -{ - std::ifstream f("cnn_v3_weights.bin", std::ios::binary | std::ios::ate); - data.resize(f.tellg()); - f.seekg(0); - f.read(reinterpret_cast<char*>(data.data()), data.size()); -} +// Load via asset system +const char* data = SafeGetAsset(AssetId::ASSET_WEIGHTS_CNN_V3); +uint32_t size = GetAssetSize(AssetId::ASSET_WEIGHTS_CNN_V3); // Upload to GPU -cnn->upload_weights(ctx.queue, data.data(), (uint32_t)data.size()); +cnn->upload_weights(ctx.queue, reinterpret_cast<const uint8_t*>(data), size); ``` Before `upload_weights()`: all conv weights are zero, so output is `sigmoid(0) = 0.5` gray. diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index d1d2893..99f3a81 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -3,6 +3,8 @@ # requires-python = ">=3.11" # dependencies = [ # "numpy", +# "opencv-python", +# "pillow", # "torch", # ] # /// |
