diff options
Diffstat (limited to 'tools/cnn_test.cc')
| -rw-r--r-- | tools/cnn_test.cc | 324 |
1 files changed, 185 insertions, 139 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc index beeef8f..a707723 100644 --- a/tools/cnn_test.cc +++ b/tools/cnn_test.cc @@ -33,10 +33,12 @@ static uint16_t f32_to_f16(float f) { uint32_t b; memcpy(&b, &f, 4); uint32_t sign = (b >> 16) & 0x8000u; - int32_t exp = (int32_t)((b >> 23) & 0xFFu) - 127 + 15; + int32_t exp = (int32_t)((b >> 23) & 0xFFu) - 127 + 15; uint32_t mant = b & 0x7FFFFFu; - if (exp <= 0) return (uint16_t)sign; - if (exp >= 31) return (uint16_t)(sign | 0x7C00u); + if (exp <= 0) + return (uint16_t)sign; + if (exp >= 31) + return (uint16_t)(sign | 0x7C00u); return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13)); } @@ -49,8 +51,10 @@ static uint32_t pack2x16f(float a, float b) { static uint32_t pack4x8u(float a, float b, float c, float d) { auto u8 = [](float v) -> uint32_t { int i = (int)(v * 255.0f + 0.5f); - if (i < 0) i = 0; - if (i > 255) i = 255; + if (i < 0) + i = 0; + if (i > 255) + i = 255; return (uint32_t)i; }; return u8(a) | (u8(b) << 8) | (u8(c) << 16) | (u8(d) << 24); @@ -60,8 +64,8 @@ static uint32_t pack4x8u(float a, float b, float c, float d) { // Oct-decode [0,1] → unit normal (matches Python cnn_v3_utils.oct_decode) // --------------------------------------------------------------------------- -static void oct_decode_01(float nx01, float ny01, - float* out_x, float* out_y, float* out_z) { +static void oct_decode_01(float nx01, float ny01, float* out_x, float* out_y, + float* out_z) { float fx = nx01 * 2.0f - 1.0f; float fy = ny01 * 2.0f - 1.0f; float fz = 1.0f - fabsf(fx) - fabsf(fy); @@ -71,8 +75,9 @@ static void oct_decode_01(float nx01, float ny01, fx = (1.0f - fabsf(fy)) * sx; fy = (1.0f - fabsf(fx)) * sy; } - float len = sqrtf(fx*fx + fy*fy + fz*fz); - if (len < 1e-8f) len = 1e-8f; + float len = sqrtf(fx * fx + fy * fy + fz * fz); + if (len < 1e-8f) + len = 1e-8f; *out_x = fx / len; *out_y = fy / len; *out_z = fz / len; @@ -89,8 +94,8 @@ static void oct_decode_01(float nx01, float ny01, // Output: mip1_out and mip2_out are (H*W*3) float arrays in row-major order. static void compute_mips(const float* rgb, int w, int h, - std::vector<float>& mip1_out, - std::vector<float>& mip2_out) { + std::vector<float>& mip1_out, + std::vector<float>& mip2_out) { const int w2 = w / 2, h2 = h / 2; const int w4 = w / 4, h4 = h / 4; @@ -99,10 +104,9 @@ static void compute_mips(const float* rgb, int w, int h, for (int x2 = 0; x2 < w2; ++x2) { for (int c = 0; c < 3; ++c) { int y0 = y2 * 2, x0 = x2 * 2; - float v = rgb[(y0 * w + x0 ) * 3 + c] - + rgb[(y0 * w + x0+1) * 3 + c] - + rgb[((y0+1) * w + x0 ) * 3 + c] - + rgb[((y0+1) * w + x0+1) * 3 + c]; + float v = rgb[(y0 * w + x0) * 3 + c] + rgb[(y0 * w + x0 + 1) * 3 + c] + + rgb[((y0 + 1) * w + x0) * 3 + c] + + rgb[((y0 + 1) * w + x0 + 1) * 3 + c]; m1[(y2 * w2 + x2) * 3 + c] = v * 0.25f; } } @@ -113,10 +117,9 @@ static void compute_mips(const float* rgb, int w, int h, for (int x4 = 0; x4 < w4; ++x4) { for (int c = 0; c < 3; ++c) { int y0 = y4 * 2, x0 = x4 * 2; - float v = m1[(y0 * w2 + x0 ) * 3 + c] - + m1[(y0 * w2 + x0+1) * 3 + c] - + m1[((y0+1) * w2 + x0 ) * 3 + c] - + m1[((y0+1) * w2 + x0+1) * 3 + c]; + float v = m1[(y0 * w2 + x0) * 3 + c] + m1[(y0 * w2 + x0 + 1) * 3 + c] + + m1[((y0 + 1) * w2 + x0) * 3 + c] + + m1[((y0 + 1) * w2 + x0 + 1) * 3 + c]; m2[(y4 * w4 + x4) * 3 + c] = v * 0.25f; } } @@ -128,14 +131,14 @@ static void compute_mips(const float* rgb, int w, int h, for (int y = 0; y < h; ++y) { for (int x = 0; x < w; ++x) { int i = (y * w + x) * 3; - int i1 = ((y/2) * w2 + (x/2)) * 3; - int i2 = ((y/4) * w4 + (x/4)) * 3; - mip1_out[i ] = (y/2 < h2 && x/2 < w2) ? m1[i1 ] : 0.0f; - mip1_out[i+1] = (y/2 < h2 && x/2 < w2) ? m1[i1+1] : 0.0f; - mip1_out[i+2] = (y/2 < h2 && x/2 < w2) ? m1[i1+2] : 0.0f; - mip2_out[i ] = (y/4 < h4 && x/4 < w4) ? m2[i2 ] : 0.0f; - mip2_out[i+1] = (y/4 < h4 && x/4 < w4) ? m2[i2+1] : 0.0f; - mip2_out[i+2] = (y/4 < h4 && x/4 < w4) ? m2[i2+2] : 0.0f; + int i1 = ((y / 2) * w2 + (x / 2)) * 3; + int i2 = ((y / 4) * w4 + (x / 4)) * 3; + mip1_out[i] = (y / 2 < h2 && x / 2 < w2) ? m1[i1] : 0.0f; + mip1_out[i + 1] = (y / 2 < h2 && x / 2 < w2) ? m1[i1 + 1] : 0.0f; + mip1_out[i + 2] = (y / 2 < h2 && x / 2 < w2) ? m1[i1 + 2] : 0.0f; + mip2_out[i] = (y / 4 < h4 && x / 4 < w4) ? m2[i2] : 0.0f; + mip2_out[i + 1] = (y / 4 < h4 && x / 4 < w4) ? m2[i2 + 1] : 0.0f; + mip2_out[i + 2] = (y / 4 < h4 && x / 4 < w4) ? m2[i2 + 2] : 0.0f; } } } @@ -161,17 +164,17 @@ static void compute_mips(const float* rgb, int w, int h, struct FeatureImages { int w, h; - std::vector<float> albedo; // w*h*3 [0,1] - std::vector<float> normal; // w*h*2 [0,1] oct-encoded - std::vector<float> depth; // w*h [0,1] - std::vector<float> matid; // w*h [0,1] - std::vector<float> shadow; // w*h [0,1] - std::vector<float> transp; // w*h [0,1] + std::vector<float> albedo; // w*h*3 [0,1] + std::vector<float> normal; // w*h*2 [0,1] oct-encoded + std::vector<float> depth; // w*h [0,1] + std::vector<float> matid; // w*h [0,1] + std::vector<float> shadow; // w*h [0,1] + std::vector<float> transp; // w*h [0,1] }; static void pack_features(const FeatureImages& img, - std::vector<uint32_t>& feat0, // w*h*4 u32 - std::vector<uint32_t>& feat1) // w*h*4 u32 + std::vector<uint32_t>& feat0, // w*h*4 u32 + std::vector<uint32_t>& feat1) // w*h*4 u32 { const int W = img.w, H = img.h; feat0.resize(W * H * 4); @@ -184,49 +187,49 @@ static void pack_features(const FeatureImages& img, for (int y = 0; y < H; ++y) { for (int x = 0; x < W; ++x) { - const int pi = y * W + x; - const int i3 = pi * 3; - const int i4 = pi * 4; + const int pi = y * W + x; + const int i3 = pi * 3; + const int i4 = pi * 4; - float ar = img.albedo[i3 ]; - float ag = img.albedo[i3+1]; - float ab = img.albedo[i3+2]; + float ar = img.albedo[i3]; + float ag = img.albedo[i3 + 1]; + float ab = img.albedo[i3 + 2]; - float nx = img.normal[pi * 2 ]; // [0,1] - float ny = img.normal[pi * 2 + 1]; // [0,1] + float nx = img.normal[pi * 2]; // [0,1] + float ny = img.normal[pi * 2 + 1]; // [0,1] float d = img.depth[pi]; // Central finite difference depth gradient - int xm = (x > 0) ? x-1 : 0; - int xp = (x < W-1) ? x+1 : W-1; - int ym = (y > 0) ? y-1 : 0; - int yp = (y < H-1) ? y+1 : H-1; - float dzdx = (img.depth[y * W + xp] - img.depth[y * W + xm]) * 0.5f; - float dzdy = (img.depth[yp * W + x ] - img.depth[ym * W + x ]) * 0.5f; + int xm = (x > 0) ? x - 1 : 0; + int xp = (x < W - 1) ? x + 1 : W - 1; + int ym = (y > 0) ? y - 1 : 0; + int yp = (y < H - 1) ? y + 1 : H - 1; + float dzdx = (img.depth[y * W + xp] - img.depth[y * W + xm]) * 0.5f; + float dzdy = (img.depth[yp * W + x] - img.depth[ym * W + x]) * 0.5f; - float mat = img.matid[pi]; + float mat = img.matid[pi]; float shad = img.shadow[pi]; - float trp = img.transp[pi]; + float trp = img.transp[pi]; // Diffuse = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow float n3x, n3y, n3z; oct_decode_01(nx, ny, &n3x, &n3y, &n3z); - float dif = fmaxf(0.0f, n3x*KEY_X + n3y*KEY_Y + n3z*KEY_Z) * shad; + float dif = fmaxf(0.0f, n3x * KEY_X + n3y * KEY_Y + n3z * KEY_Z) * shad; - float m1r = mip1[i3 ], m1g = mip1[i3+1], m1b = mip1[i3+2]; - float m2r = mip2[i3 ], m2g = mip2[i3+1], m2b = mip2[i3+2]; + float m1r = mip1[i3], m1g = mip1[i3 + 1], m1b = mip1[i3 + 2]; + float m2r = mip2[i3], m2g = mip2[i3 + 1], m2b = mip2[i3 + 2]; // prev.rgb = 0 (no temporal history) - feat0[i4 ] = pack2x16f(ar, ag); - feat0[i4+1] = pack2x16f(ab, nx); - feat0[i4+2] = pack2x16f(ny, d ); - feat0[i4+3] = pack2x16f(dzdx, dzdy); + feat0[i4] = pack2x16f(ar, ag); + feat0[i4 + 1] = pack2x16f(ab, nx); + feat0[i4 + 2] = pack2x16f(ny, d); + feat0[i4 + 3] = pack2x16f(dzdx, dzdy); - feat1[i4 ] = pack4x8u(mat, 0.0f, 0.0f, 0.0f); // mat_id, prev.rgb=0 - feat1[i4+1] = pack4x8u(m1r, m1g, m1b, m2r); - feat1[i4+2] = pack4x8u(m2g, m2b, dif, trp); - feat1[i4+3] = 0u; + feat1[i4] = pack4x8u(mat, 0.0f, 0.0f, 0.0f); // mat_id, prev.rgb=0 + feat1[i4 + 1] = pack4x8u(m1r, m1g, m1b, m2r); + feat1[i4 + 2] = pack4x8u(m2g, m2b, dif, trp); + feat1[i4 + 3] = 0u; } } } @@ -237,41 +240,41 @@ static void pack_features(const FeatureImages& img, static WGPUTexture make_feat_tex(WGPUDevice dev, int W, int H) { WGPUTextureDescriptor d = {}; - d.format = WGPUTextureFormat_RGBA32Uint; - d.usage = WGPUTextureUsage_TextureBinding | WGPUTextureUsage_CopyDst; - d.dimension = WGPUTextureDimension_2D; - d.size = {(uint32_t)W, (uint32_t)H, 1}; + d.format = WGPUTextureFormat_RGBA32Uint; + d.usage = WGPUTextureUsage_TextureBinding | WGPUTextureUsage_CopyDst; + d.dimension = WGPUTextureDimension_2D; + d.size = {(uint32_t)W, (uint32_t)H, 1}; d.mipLevelCount = 1; - d.sampleCount = 1; + d.sampleCount = 1; return wgpuDeviceCreateTexture(dev, &d); } static WGPUTexture make_output_tex(WGPUDevice dev, int W, int H) { WGPUTextureDescriptor d = {}; - d.format = WGPUTextureFormat_RGBA16Float; - d.usage = WGPUTextureUsage_StorageBinding | WGPUTextureUsage_CopySrc; - d.dimension = WGPUTextureDimension_2D; - d.size = {(uint32_t)W, (uint32_t)H, 1}; + d.format = WGPUTextureFormat_RGBA16Float; + d.usage = WGPUTextureUsage_StorageBinding | WGPUTextureUsage_CopySrc; + d.dimension = WGPUTextureDimension_2D; + d.size = {(uint32_t)W, (uint32_t)H, 1}; d.mipLevelCount = 1; - d.sampleCount = 1; + d.sampleCount = 1; return wgpuDeviceCreateTexture(dev, &d); } static WGPUTextureView make_view(WGPUTexture tex, WGPUTextureFormat fmt) { WGPUTextureViewDescriptor d = {}; - d.format = fmt; - d.dimension = WGPUTextureViewDimension_2D; - d.mipLevelCount = 1; + d.format = fmt; + d.dimension = WGPUTextureViewDimension_2D; + d.mipLevelCount = 1; d.arrayLayerCount = 1; return wgpuTextureCreateView(tex, &d); } -static void upload_tex(WGPUQueue queue, WGPUTexture tex, - const uint32_t* data, int W, int H) { +static void upload_tex(WGPUQueue queue, WGPUTexture tex, const uint32_t* data, + int W, int H) { WGPUTexelCopyTextureInfo dst = {}; dst.texture = tex; WGPUTexelCopyBufferLayout layout = {}; - layout.bytesPerRow = (uint32_t)(W * 16); + layout.bytesPerRow = (uint32_t)(W * 16); layout.rowsPerImage = (uint32_t)H; WGPUExtent3D ext = {(uint32_t)W, (uint32_t)H, 1}; wgpuQueueWriteTexture(queue, &dst, data, (size_t)(W * H * 16), &layout, &ext); @@ -281,37 +284,53 @@ static void upload_tex(WGPUQueue queue, WGPUTexture tex, // RGBA16Float readback // --------------------------------------------------------------------------- -static uint16_t fp16_bits_to_f16(float f) { return f32_to_f16(f); } +static uint16_t fp16_bits_to_f16(float f) { + return f32_to_f16(f); +} static float fp16_bits_to_f32(uint16_t h) { uint32_t sign = (uint32_t)(h & 0x8000u) << 16; - uint32_t exp = (h & 0x7C00u) >> 10; + uint32_t exp = (h & 0x7C00u) >> 10; uint32_t mant = h & 0x03FFu; - if (exp == 0 && mant == 0) { float r; memcpy(&r, &sign, 4); return r; } - if (exp == 31) { uint32_t b = sign | 0x7F800000u | (mant << 13); - float r; memcpy(&r, &b, 4); return r; } + if (exp == 0 && mant == 0) { + float r; + memcpy(&r, &sign, 4); + return r; + } + if (exp == 31) { + uint32_t b = sign | 0x7F800000u | (mant << 13); + float r; + memcpy(&r, &b, 4); + return r; + } uint32_t b = sign | ((exp + 112u) << 23) | (mant << 13); - float r; memcpy(&r, &b, 4); return r; + float r; + memcpy(&r, &b, 4); + return r; } -struct MapState { bool done = false; WGPUMapAsyncStatus status = {}; }; +struct MapState { + bool done = false; + WGPUMapAsyncStatus status = {}; +}; static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue, - WGPUTexture tex, int W, int H) { - const uint32_t bytes_per_px = 8; - const uint32_t raw_bpr = (uint32_t)(W * bytes_per_px); - const uint32_t aligned_bpr = ((raw_bpr + 255u) / 256u) * 256u; - const size_t buf_size = (size_t)aligned_bpr * (size_t)H; + WGPUTexture tex, int W, int H) { + const uint32_t bytes_per_px = 8; + const uint32_t raw_bpr = (uint32_t)(W * bytes_per_px); + const uint32_t aligned_bpr = ((raw_bpr + 255u) / 256u) * 256u; + const size_t buf_size = (size_t)aligned_bpr * (size_t)H; WGPUBufferDescriptor bd = {}; bd.usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead; - bd.size = buf_size; + bd.size = buf_size; WGPUBuffer staging = wgpuDeviceCreateBuffer(device, &bd); WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(device, nullptr); - WGPUTexelCopyTextureInfo src = {}; src.texture = tex; - WGPUTexelCopyBufferInfo dst = {}; - dst.buffer = staging; - dst.layout.bytesPerRow = aligned_bpr; + WGPUTexelCopyTextureInfo src = {}; + src.texture = tex; + WGPUTexelCopyBufferInfo dst = {}; + dst.buffer = staging; + dst.layout.bytesPerRow = aligned_bpr; dst.layout.rowsPerImage = (uint32_t)H; WGPUExtent3D ext = {(uint32_t)W, (uint32_t)H, 1}; wgpuCommandEncoderCopyTextureToBuffer(enc, &src, &dst, &ext); @@ -323,9 +342,11 @@ static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue, MapState ms = {}; WGPUBufferMapCallbackInfo mi = {}; - mi.mode = WGPUCallbackMode_AllowProcessEvents; + mi.mode = WGPUCallbackMode_AllowProcessEvents; mi.callback = [](WGPUMapAsyncStatus s, WGPUStringView, void* u, void*) { - auto* st = (MapState*)u; st->status = s; st->done = true; + auto* st = (MapState*)u; + st->status = s; + st->done = true; }; mi.userdata1 = &ms; wgpuBufferMapAsync(staging, WGPUMapMode_Read, 0, buf_size, mi); @@ -334,11 +355,12 @@ static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue, std::vector<float> pixels(W * H * 4, 0.0f); if (ms.done && ms.status == WGPUMapAsyncStatus_Success) { - const uint8_t* mapped = (const uint8_t*) - wgpuBufferGetConstMappedRange(staging, 0, buf_size); + const uint8_t* mapped = + (const uint8_t*)wgpuBufferGetConstMappedRange(staging, 0, buf_size); if (mapped) { for (int y = 0; y < H; ++y) { - const uint16_t* row = (const uint16_t*)(mapped + (size_t)y * aligned_bpr); + const uint16_t* row = + (const uint16_t*)(mapped + (size_t)y * aligned_bpr); for (int x = 0; x < W; ++x) { for (int c = 0; c < 4; ++c) pixels[(y * W + x) * 4 + c] = fp16_bits_to_f32(row[x * 4 + c]); @@ -355,14 +377,16 @@ static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue, // Image I/O helpers // --------------------------------------------------------------------------- -static std::vector<float> load_png_rgb(const char* path, int* out_w, int* out_h) { +static std::vector<float> load_png_rgb(const char* path, int* out_w, + int* out_h) { int w, h, ch; uint8_t* data = stbi_load(path, &w, &h, &ch, 3); if (!data) { fprintf(stderr, "Error: cannot load '%s'\n", path); return {}; } - *out_w = w; *out_h = h; + *out_w = w; + *out_h = h; std::vector<float> out(w * h * 3); for (int i = 0; i < w * h * 3; ++i) out[i] = data[i] / 255.0f; @@ -375,14 +399,16 @@ static std::vector<float> load_png_rg(const char* path, int ew, int eh) { int w, h, ch; uint8_t* data = stbi_load(path, &w, &h, &ch, 3); if (!data || w != ew || h != eh) { - if (data) stbi_image_free(data); - fprintf(stderr, "Warning: cannot load normal '%s' — using (0.5,0.5)\n", path); + if (data) + stbi_image_free(data); + fprintf(stderr, "Warning: cannot load normal '%s' — using (0.5,0.5)\n", + path); std::vector<float> def(ew * eh * 2, 0.5f); return def; } std::vector<float> out(w * h * 2); for (int i = 0; i < w * h; ++i) { - out[i * 2 ] = data[i * 3 ] / 255.0f; + out[i * 2] = data[i * 3] / 255.0f; out[i * 2 + 1] = data[i * 3 + 1] / 255.0f; } stbi_image_free(data); @@ -394,7 +420,8 @@ static std::vector<float> load_png_depth16(const char* path, int ew, int eh) { int w, h, ch; uint16_t* data = stbi_load_16(path, &w, &h, &ch, 1); if (!data || w != ew || h != eh) { - if (data) stbi_image_free(data); + if (data) + stbi_image_free(data); fprintf(stderr, "Warning: cannot load depth '%s' — using 0\n", path); return std::vector<float>(ew * eh, 0.0f); } @@ -407,11 +434,12 @@ static std::vector<float> load_png_depth16(const char* path, int ew, int eh) { // Load 8-bit greyscale PNG → [0,1] static std::vector<float> load_png_gray(const char* path, int ew, int eh, - float default_val = 0.0f) { + float default_val = 0.0f) { int w, h, ch; uint8_t* data = stbi_load(path, &w, &h, &ch, 1); if (!data || w != ew || h != eh) { - if (data) stbi_image_free(data); + if (data) + stbi_image_free(data); return std::vector<float>(ew * eh, default_val); } std::vector<float> out(w * h); @@ -468,29 +496,37 @@ static bool load_weights_bin(const char* path, std::vector<uint32_t>& out) { // --------------------------------------------------------------------------- struct Args { - const char* input_path = nullptr; - const char* output_path = nullptr; - const char* sample_dir = nullptr; + const char* input_path = nullptr; + const char* output_path = nullptr; + const char* sample_dir = nullptr; const char* weights_path = nullptr; - bool debug_hex = false; + bool debug_hex = false; }; static void print_usage(const char* prog) { fprintf(stderr, "Usage: %s input.png output.png [OPTIONS]\n", prog); fprintf(stderr, "\nOPTIONS:\n"); - fprintf(stderr, " --sample-dir DIR Full sample dir with albedo/normal/depth/matid/shadow/transp\n"); - fprintf(stderr, " --weights FILE Load weights from cnn_v3_weights.bin\n"); + fprintf(stderr, + " --sample-dir DIR Full sample dir with " + "albedo/normal/depth/matid/shadow/transp\n"); + fprintf(stderr, + " --weights FILE Load weights from cnn_v3_weights.bin\n"); fprintf(stderr, " --debug-hex Print first 8 output pixels as hex\n"); fprintf(stderr, " --help Show this help\n"); - fprintf(stderr, "\nSimple mode (single PNG): geometry channels zeroed, normal=(0.5,0.5).\n"); + fprintf(stderr, + "\nSimple mode (single PNG): geometry channels zeroed, " + "normal=(0.5,0.5).\n"); fprintf(stderr, "FiLM is always identity (gamma=1, beta=0).\n"); - fprintf(stderr, "\nNote: feature packing uses [0,1] oct-normals (training format) to match\n"); + fprintf(stderr, + "\nNote: feature packing uses [0,1] oct-normals (training format) to " + "match\n"); fprintf(stderr, " infer_cnn_v3.py for direct Python/WGSL comparison.\n"); } static bool parse_args(int argc, char** argv, Args* args) { - if (argc < 3) return false; - args->input_path = argv[1]; + if (argc < 3) + return false; + args->input_path = argv[1]; args->output_path = argv[2]; for (int i = 3; i < argc; ++i) { if (strcmp(argv[i], "--sample-dir") == 0 && i + 1 < argc) { @@ -535,7 +571,8 @@ int main(int argc, char** argv) { // --- Load input image --- int W, H; std::vector<float> albedo = load_png_rgb(args.input_path, &W, &H); - if (albedo.empty()) return 1; + if (albedo.empty()) + return 1; // Pad to multiples of 4 (U-Net requires 2 pooling levels) const int W4 = (W + 3) & ~3; @@ -548,14 +585,16 @@ int main(int argc, char** argv) { for (int c = 0; c < 3; ++c) padded[(y * W4 + x) * 3 + c] = albedo[(y * W + x) * 3 + c]; albedo = std::move(padded); - W = W4; H = H4; + W = W4; + H = H4; } printf("Input: %s (%dx%d)\n", args.input_path, W, H); // --- Build FeatureImages --- FeatureImages img; - img.w = W; img.h = H; + img.w = W; + img.h = H; img.albedo = albedo; if (args.sample_dir) { @@ -564,8 +603,8 @@ int main(int argc, char** argv) { return std::string(args.sample_dir) + "/" + name; }; img.normal = load_png_rg(path("normal.png").c_str(), W, H); - img.depth = load_png_depth16(path("depth.png").c_str(), W, H); - img.matid = load_png_gray(path("matid.png").c_str(), W, H, 0.0f); + img.depth = load_png_depth16(path("depth.png").c_str(), W, H); + img.matid = load_png_gray(path("matid.png").c_str(), W, H, 0.0f); img.shadow = load_png_gray(path("shadow.png").c_str(), W, H, 1.0f); img.transp = load_png_gray(path("transp.png").c_str(), W, H, 0.0f); } else { @@ -584,11 +623,13 @@ int main(int argc, char** argv) { // --- Create GPU textures --- WGPUTexture feat0_tex = make_feat_tex(ctx.device, W, H); WGPUTexture feat1_tex = make_feat_tex(ctx.device, W, H); - WGPUTexture out_tex = make_output_tex(ctx.device, W, H); + WGPUTexture out_tex = make_output_tex(ctx.device, W, H); - WGPUTextureView feat0_view = make_view(feat0_tex, WGPUTextureFormat_RGBA32Uint); - WGPUTextureView feat1_view = make_view(feat1_tex, WGPUTextureFormat_RGBA32Uint); - WGPUTextureView out_view = make_view(out_tex, WGPUTextureFormat_RGBA16Float); + WGPUTextureView feat0_view = + make_view(feat0_tex, WGPUTextureFormat_RGBA32Uint); + WGPUTextureView feat1_view = + make_view(feat1_tex, WGPUTextureFormat_RGBA32Uint); + WGPUTextureView out_view = make_view(out_tex, WGPUTextureFormat_RGBA16Float); upload_tex(ctx.queue, feat0_tex, feat0.data(), W, H); upload_tex(ctx.queue, feat1_tex, feat1.data(), W, H); @@ -605,7 +646,8 @@ int main(int argc, char** argv) { // --- Load weights --- if (args.weights_path) { std::vector<uint32_t> wdata; - if (!load_weights_bin(args.weights_path, wdata)) return 1; + if (!load_weights_bin(args.weights_path, wdata)) + return 1; effect.upload_weights(ctx.queue, wdata.data(), (uint32_t)(wdata.size() * 4)); printf("Weights: %s (%zu bytes)\n", args.weights_path, wdata.size() * 4); @@ -616,7 +658,7 @@ int main(int argc, char** argv) { // --- Run 5 compute passes --- WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); UniformsSequenceParams params = {}; - params.resolution = {(float)W, (float)H}; + params.resolution = {(float)W, (float)H}; params.aspect_ratio = (float)W / (float)H; effect.render(enc, params, registry); @@ -627,23 +669,27 @@ int main(int argc, char** argv) { wgpuDevicePoll(ctx.device, true, nullptr); // --- Readback --- - std::vector<float> pixels = readback_rgba16f(ctx.device, ctx.queue, out_tex, W, H); + std::vector<float> pixels = + readback_rgba16f(ctx.device, ctx.queue, out_tex, W, H); // --- Save output (crop to original size, already same if no padding) --- - if (!save_png(args.output_path, pixels, W, H)) return 1; + if (!save_png(args.output_path, pixels, W, H)) + return 1; printf("Saved: %s\n", args.output_path); if (args.debug_hex) { printf("First 8 output pixels (RGBA f32 → hex):\n"); for (int i = 0; i < 8 && i < W * H; ++i) { - float r = pixels[i*4 ], g = pixels[i*4+1]; - float b = pixels[i*4+2], a = pixels[i*4+3]; - int ri = (int)(r*255+.5f), gi = (int)(g*255+.5f); - int bi = (int)(b*255+.5f), ai = (int)(a*255+.5f); - ri = ri<0?0:ri>255?255:ri; gi = gi<0?0:gi>255?255:gi; - bi = bi<0?0:bi>255?255:bi; ai = ai<0?0:ai>255?255:ai; - printf(" [%d] 0x%02X%02X%02X%02X (%.4f %.4f %.4f %.4f)\n", - i, ri, gi, bi, ai, r, g, b, a); + float r = pixels[i * 4], g = pixels[i * 4 + 1]; + float b = pixels[i * 4 + 2], a = pixels[i * 4 + 3]; + int ri = (int)(r * 255 + .5f), gi = (int)(g * 255 + .5f); + int bi = (int)(b * 255 + .5f), ai = (int)(a * 255 + .5f); + ri = ri < 0 ? 0 : ri > 255 ? 255 : ri; + gi = gi < 0 ? 0 : gi > 255 ? 255 : gi; + bi = bi < 0 ? 0 : bi > 255 ? 255 : bi; + ai = ai < 0 ? 0 : ai > 255 ? 255 : ai; + printf(" [%d] 0x%02X%02X%02X%02X (%.4f %.4f %.4f %.4f)\n", i, ri, gi, + bi, ai, r, g, b, a); } } |
