summaryrefslogtreecommitdiff
path: root/tools/cnn_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tools/cnn_test.cc')
-rw-r--r--tools/cnn_test.cc30
1 files changed, 25 insertions, 5 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc
index c504c3d..b4a4bdc 100644
--- a/tools/cnn_test.cc
+++ b/tools/cnn_test.cc
@@ -784,6 +784,20 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
wgpuQueueWriteBuffer(queue, static_params_buffer, 0, &static_params,
sizeof(static_params));
+ // Create linear sampler for bilinear interpolation
+ WGPUSamplerDescriptor linear_sampler_desc = {};
+ linear_sampler_desc.addressModeU = WGPUAddressMode_ClampToEdge;
+ linear_sampler_desc.addressModeV = WGPUAddressMode_ClampToEdge;
+ linear_sampler_desc.addressModeW = WGPUAddressMode_ClampToEdge;
+ linear_sampler_desc.magFilter = WGPUFilterMode_Linear;
+ linear_sampler_desc.minFilter = WGPUFilterMode_Linear;
+ linear_sampler_desc.mipmapFilter = WGPUMipmapFilterMode_Linear;
+ linear_sampler_desc.lodMinClamp = 0.0f;
+ linear_sampler_desc.lodMaxClamp = 32.0f;
+ linear_sampler_desc.maxAnisotropy = 1;
+
+ WGPUSampler linear_sampler = wgpuDeviceCreateSampler(device, &linear_sampler_desc);
+
// Create static features compute pipeline
WGPUShaderSourceWGSL static_wgsl = {};
static_wgsl.chain.sType = WGPUSType_ShaderSourceWGSL;
@@ -796,8 +810,8 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
wgpuDeviceCreateShaderModule(device, &static_module_desc);
// Bind group layout: 0=input, 1=input_mip1, 2=input_mip2, 3=depth, 4=output,
- // 5=params
- WGPUBindGroupLayoutEntry static_bgl_entries[6] = {};
+ // 5=params, 6=linear_sampler
+ WGPUBindGroupLayoutEntry static_bgl_entries[7] = {};
static_bgl_entries[0].binding = 0;
static_bgl_entries[0].visibility = WGPUShaderStage_Compute;
static_bgl_entries[0].texture.sampleType = WGPUTextureSampleType_Float;
@@ -832,8 +846,12 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
static_bgl_entries[5].buffer.minBindingSize =
sizeof(CNNv2StaticFeatureParams);
+ static_bgl_entries[6].binding = 6;
+ static_bgl_entries[6].visibility = WGPUShaderStage_Compute;
+ static_bgl_entries[6].sampler.type = WGPUSamplerBindingType_Filtering;
+
WGPUBindGroupLayoutDescriptor static_bgl_desc = {};
- static_bgl_desc.entryCount = 6;
+ static_bgl_desc.entryCount = 7;
static_bgl_desc.entries = static_bgl_entries;
WGPUBindGroupLayout static_bgl =
@@ -858,7 +876,7 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
wgpuPipelineLayoutRelease(static_pl);
// Create static bind group (use input as all mips for simplicity)
- WGPUBindGroupEntry static_bg_entries[6] = {};
+ WGPUBindGroupEntry static_bg_entries[7] = {};
static_bg_entries[0].binding = 0;
static_bg_entries[0].textureView = input_view;
static_bg_entries[1].binding = 1;
@@ -872,10 +890,12 @@ static bool process_cnn_v2(WGPUDevice device, WGPUQueue queue,
static_bg_entries[5].binding = 5;
static_bg_entries[5].buffer = static_params_buffer;
static_bg_entries[5].size = sizeof(CNNv2StaticFeatureParams);
+ static_bg_entries[6].binding = 6;
+ static_bg_entries[6].sampler = linear_sampler;
WGPUBindGroupDescriptor static_bg_desc = {};
static_bg_desc.layout = static_bgl;
- static_bg_desc.entryCount = 6;
+ static_bg_desc.entryCount = 7;
static_bg_desc.entries = static_bg_entries;
WGPUBindGroup static_bg = wgpuDeviceCreateBindGroup(device, &static_bg_desc);