summaryrefslogtreecommitdiff
path: root/tools/cnn_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tools/cnn_test.cc')
-rw-r--r--tools/cnn_test.cc324
1 files changed, 185 insertions, 139 deletions
diff --git a/tools/cnn_test.cc b/tools/cnn_test.cc
index beeef8f..a707723 100644
--- a/tools/cnn_test.cc
+++ b/tools/cnn_test.cc
@@ -33,10 +33,12 @@ static uint16_t f32_to_f16(float f) {
uint32_t b;
memcpy(&b, &f, 4);
uint32_t sign = (b >> 16) & 0x8000u;
- int32_t exp = (int32_t)((b >> 23) & 0xFFu) - 127 + 15;
+ int32_t exp = (int32_t)((b >> 23) & 0xFFu) - 127 + 15;
uint32_t mant = b & 0x7FFFFFu;
- if (exp <= 0) return (uint16_t)sign;
- if (exp >= 31) return (uint16_t)(sign | 0x7C00u);
+ if (exp <= 0)
+ return (uint16_t)sign;
+ if (exp >= 31)
+ return (uint16_t)(sign | 0x7C00u);
return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13));
}
@@ -49,8 +51,10 @@ static uint32_t pack2x16f(float a, float b) {
static uint32_t pack4x8u(float a, float b, float c, float d) {
auto u8 = [](float v) -> uint32_t {
int i = (int)(v * 255.0f + 0.5f);
- if (i < 0) i = 0;
- if (i > 255) i = 255;
+ if (i < 0)
+ i = 0;
+ if (i > 255)
+ i = 255;
return (uint32_t)i;
};
return u8(a) | (u8(b) << 8) | (u8(c) << 16) | (u8(d) << 24);
@@ -60,8 +64,8 @@ static uint32_t pack4x8u(float a, float b, float c, float d) {
// Oct-decode [0,1] → unit normal (matches Python cnn_v3_utils.oct_decode)
// ---------------------------------------------------------------------------
-static void oct_decode_01(float nx01, float ny01,
- float* out_x, float* out_y, float* out_z) {
+static void oct_decode_01(float nx01, float ny01, float* out_x, float* out_y,
+ float* out_z) {
float fx = nx01 * 2.0f - 1.0f;
float fy = ny01 * 2.0f - 1.0f;
float fz = 1.0f - fabsf(fx) - fabsf(fy);
@@ -71,8 +75,9 @@ static void oct_decode_01(float nx01, float ny01,
fx = (1.0f - fabsf(fy)) * sx;
fy = (1.0f - fabsf(fx)) * sy;
}
- float len = sqrtf(fx*fx + fy*fy + fz*fz);
- if (len < 1e-8f) len = 1e-8f;
+ float len = sqrtf(fx * fx + fy * fy + fz * fz);
+ if (len < 1e-8f)
+ len = 1e-8f;
*out_x = fx / len;
*out_y = fy / len;
*out_z = fz / len;
@@ -89,8 +94,8 @@ static void oct_decode_01(float nx01, float ny01,
// Output: mip1_out and mip2_out are (H*W*3) float arrays in row-major order.
static void compute_mips(const float* rgb, int w, int h,
- std::vector<float>& mip1_out,
- std::vector<float>& mip2_out) {
+ std::vector<float>& mip1_out,
+ std::vector<float>& mip2_out) {
const int w2 = w / 2, h2 = h / 2;
const int w4 = w / 4, h4 = h / 4;
@@ -99,10 +104,9 @@ static void compute_mips(const float* rgb, int w, int h,
for (int x2 = 0; x2 < w2; ++x2) {
for (int c = 0; c < 3; ++c) {
int y0 = y2 * 2, x0 = x2 * 2;
- float v = rgb[(y0 * w + x0 ) * 3 + c]
- + rgb[(y0 * w + x0+1) * 3 + c]
- + rgb[((y0+1) * w + x0 ) * 3 + c]
- + rgb[((y0+1) * w + x0+1) * 3 + c];
+ float v = rgb[(y0 * w + x0) * 3 + c] + rgb[(y0 * w + x0 + 1) * 3 + c] +
+ rgb[((y0 + 1) * w + x0) * 3 + c] +
+ rgb[((y0 + 1) * w + x0 + 1) * 3 + c];
m1[(y2 * w2 + x2) * 3 + c] = v * 0.25f;
}
}
@@ -113,10 +117,9 @@ static void compute_mips(const float* rgb, int w, int h,
for (int x4 = 0; x4 < w4; ++x4) {
for (int c = 0; c < 3; ++c) {
int y0 = y4 * 2, x0 = x4 * 2;
- float v = m1[(y0 * w2 + x0 ) * 3 + c]
- + m1[(y0 * w2 + x0+1) * 3 + c]
- + m1[((y0+1) * w2 + x0 ) * 3 + c]
- + m1[((y0+1) * w2 + x0+1) * 3 + c];
+ float v = m1[(y0 * w2 + x0) * 3 + c] + m1[(y0 * w2 + x0 + 1) * 3 + c] +
+ m1[((y0 + 1) * w2 + x0) * 3 + c] +
+ m1[((y0 + 1) * w2 + x0 + 1) * 3 + c];
m2[(y4 * w4 + x4) * 3 + c] = v * 0.25f;
}
}
@@ -128,14 +131,14 @@ static void compute_mips(const float* rgb, int w, int h,
for (int y = 0; y < h; ++y) {
for (int x = 0; x < w; ++x) {
int i = (y * w + x) * 3;
- int i1 = ((y/2) * w2 + (x/2)) * 3;
- int i2 = ((y/4) * w4 + (x/4)) * 3;
- mip1_out[i ] = (y/2 < h2 && x/2 < w2) ? m1[i1 ] : 0.0f;
- mip1_out[i+1] = (y/2 < h2 && x/2 < w2) ? m1[i1+1] : 0.0f;
- mip1_out[i+2] = (y/2 < h2 && x/2 < w2) ? m1[i1+2] : 0.0f;
- mip2_out[i ] = (y/4 < h4 && x/4 < w4) ? m2[i2 ] : 0.0f;
- mip2_out[i+1] = (y/4 < h4 && x/4 < w4) ? m2[i2+1] : 0.0f;
- mip2_out[i+2] = (y/4 < h4 && x/4 < w4) ? m2[i2+2] : 0.0f;
+ int i1 = ((y / 2) * w2 + (x / 2)) * 3;
+ int i2 = ((y / 4) * w4 + (x / 4)) * 3;
+ mip1_out[i] = (y / 2 < h2 && x / 2 < w2) ? m1[i1] : 0.0f;
+ mip1_out[i + 1] = (y / 2 < h2 && x / 2 < w2) ? m1[i1 + 1] : 0.0f;
+ mip1_out[i + 2] = (y / 2 < h2 && x / 2 < w2) ? m1[i1 + 2] : 0.0f;
+ mip2_out[i] = (y / 4 < h4 && x / 4 < w4) ? m2[i2] : 0.0f;
+ mip2_out[i + 1] = (y / 4 < h4 && x / 4 < w4) ? m2[i2 + 1] : 0.0f;
+ mip2_out[i + 2] = (y / 4 < h4 && x / 4 < w4) ? m2[i2 + 2] : 0.0f;
}
}
}
@@ -161,17 +164,17 @@ static void compute_mips(const float* rgb, int w, int h,
struct FeatureImages {
int w, h;
- std::vector<float> albedo; // w*h*3 [0,1]
- std::vector<float> normal; // w*h*2 [0,1] oct-encoded
- std::vector<float> depth; // w*h [0,1]
- std::vector<float> matid; // w*h [0,1]
- std::vector<float> shadow; // w*h [0,1]
- std::vector<float> transp; // w*h [0,1]
+ std::vector<float> albedo; // w*h*3 [0,1]
+ std::vector<float> normal; // w*h*2 [0,1] oct-encoded
+ std::vector<float> depth; // w*h [0,1]
+ std::vector<float> matid; // w*h [0,1]
+ std::vector<float> shadow; // w*h [0,1]
+ std::vector<float> transp; // w*h [0,1]
};
static void pack_features(const FeatureImages& img,
- std::vector<uint32_t>& feat0, // w*h*4 u32
- std::vector<uint32_t>& feat1) // w*h*4 u32
+ std::vector<uint32_t>& feat0, // w*h*4 u32
+ std::vector<uint32_t>& feat1) // w*h*4 u32
{
const int W = img.w, H = img.h;
feat0.resize(W * H * 4);
@@ -184,49 +187,49 @@ static void pack_features(const FeatureImages& img,
for (int y = 0; y < H; ++y) {
for (int x = 0; x < W; ++x) {
- const int pi = y * W + x;
- const int i3 = pi * 3;
- const int i4 = pi * 4;
+ const int pi = y * W + x;
+ const int i3 = pi * 3;
+ const int i4 = pi * 4;
- float ar = img.albedo[i3 ];
- float ag = img.albedo[i3+1];
- float ab = img.albedo[i3+2];
+ float ar = img.albedo[i3];
+ float ag = img.albedo[i3 + 1];
+ float ab = img.albedo[i3 + 2];
- float nx = img.normal[pi * 2 ]; // [0,1]
- float ny = img.normal[pi * 2 + 1]; // [0,1]
+ float nx = img.normal[pi * 2]; // [0,1]
+ float ny = img.normal[pi * 2 + 1]; // [0,1]
float d = img.depth[pi];
// Central finite difference depth gradient
- int xm = (x > 0) ? x-1 : 0;
- int xp = (x < W-1) ? x+1 : W-1;
- int ym = (y > 0) ? y-1 : 0;
- int yp = (y < H-1) ? y+1 : H-1;
- float dzdx = (img.depth[y * W + xp] - img.depth[y * W + xm]) * 0.5f;
- float dzdy = (img.depth[yp * W + x ] - img.depth[ym * W + x ]) * 0.5f;
+ int xm = (x > 0) ? x - 1 : 0;
+ int xp = (x < W - 1) ? x + 1 : W - 1;
+ int ym = (y > 0) ? y - 1 : 0;
+ int yp = (y < H - 1) ? y + 1 : H - 1;
+ float dzdx = (img.depth[y * W + xp] - img.depth[y * W + xm]) * 0.5f;
+ float dzdy = (img.depth[yp * W + x] - img.depth[ym * W + x]) * 0.5f;
- float mat = img.matid[pi];
+ float mat = img.matid[pi];
float shad = img.shadow[pi];
- float trp = img.transp[pi];
+ float trp = img.transp[pi];
// Diffuse = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow
float n3x, n3y, n3z;
oct_decode_01(nx, ny, &n3x, &n3y, &n3z);
- float dif = fmaxf(0.0f, n3x*KEY_X + n3y*KEY_Y + n3z*KEY_Z) * shad;
+ float dif = fmaxf(0.0f, n3x * KEY_X + n3y * KEY_Y + n3z * KEY_Z) * shad;
- float m1r = mip1[i3 ], m1g = mip1[i3+1], m1b = mip1[i3+2];
- float m2r = mip2[i3 ], m2g = mip2[i3+1], m2b = mip2[i3+2];
+ float m1r = mip1[i3], m1g = mip1[i3 + 1], m1b = mip1[i3 + 2];
+ float m2r = mip2[i3], m2g = mip2[i3 + 1], m2b = mip2[i3 + 2];
// prev.rgb = 0 (no temporal history)
- feat0[i4 ] = pack2x16f(ar, ag);
- feat0[i4+1] = pack2x16f(ab, nx);
- feat0[i4+2] = pack2x16f(ny, d );
- feat0[i4+3] = pack2x16f(dzdx, dzdy);
+ feat0[i4] = pack2x16f(ar, ag);
+ feat0[i4 + 1] = pack2x16f(ab, nx);
+ feat0[i4 + 2] = pack2x16f(ny, d);
+ feat0[i4 + 3] = pack2x16f(dzdx, dzdy);
- feat1[i4 ] = pack4x8u(mat, 0.0f, 0.0f, 0.0f); // mat_id, prev.rgb=0
- feat1[i4+1] = pack4x8u(m1r, m1g, m1b, m2r);
- feat1[i4+2] = pack4x8u(m2g, m2b, dif, trp);
- feat1[i4+3] = 0u;
+ feat1[i4] = pack4x8u(mat, 0.0f, 0.0f, 0.0f); // mat_id, prev.rgb=0
+ feat1[i4 + 1] = pack4x8u(m1r, m1g, m1b, m2r);
+ feat1[i4 + 2] = pack4x8u(m2g, m2b, dif, trp);
+ feat1[i4 + 3] = 0u;
}
}
}
@@ -237,41 +240,41 @@ static void pack_features(const FeatureImages& img,
static WGPUTexture make_feat_tex(WGPUDevice dev, int W, int H) {
WGPUTextureDescriptor d = {};
- d.format = WGPUTextureFormat_RGBA32Uint;
- d.usage = WGPUTextureUsage_TextureBinding | WGPUTextureUsage_CopyDst;
- d.dimension = WGPUTextureDimension_2D;
- d.size = {(uint32_t)W, (uint32_t)H, 1};
+ d.format = WGPUTextureFormat_RGBA32Uint;
+ d.usage = WGPUTextureUsage_TextureBinding | WGPUTextureUsage_CopyDst;
+ d.dimension = WGPUTextureDimension_2D;
+ d.size = {(uint32_t)W, (uint32_t)H, 1};
d.mipLevelCount = 1;
- d.sampleCount = 1;
+ d.sampleCount = 1;
return wgpuDeviceCreateTexture(dev, &d);
}
static WGPUTexture make_output_tex(WGPUDevice dev, int W, int H) {
WGPUTextureDescriptor d = {};
- d.format = WGPUTextureFormat_RGBA16Float;
- d.usage = WGPUTextureUsage_StorageBinding | WGPUTextureUsage_CopySrc;
- d.dimension = WGPUTextureDimension_2D;
- d.size = {(uint32_t)W, (uint32_t)H, 1};
+ d.format = WGPUTextureFormat_RGBA16Float;
+ d.usage = WGPUTextureUsage_StorageBinding | WGPUTextureUsage_CopySrc;
+ d.dimension = WGPUTextureDimension_2D;
+ d.size = {(uint32_t)W, (uint32_t)H, 1};
d.mipLevelCount = 1;
- d.sampleCount = 1;
+ d.sampleCount = 1;
return wgpuDeviceCreateTexture(dev, &d);
}
static WGPUTextureView make_view(WGPUTexture tex, WGPUTextureFormat fmt) {
WGPUTextureViewDescriptor d = {};
- d.format = fmt;
- d.dimension = WGPUTextureViewDimension_2D;
- d.mipLevelCount = 1;
+ d.format = fmt;
+ d.dimension = WGPUTextureViewDimension_2D;
+ d.mipLevelCount = 1;
d.arrayLayerCount = 1;
return wgpuTextureCreateView(tex, &d);
}
-static void upload_tex(WGPUQueue queue, WGPUTexture tex,
- const uint32_t* data, int W, int H) {
+static void upload_tex(WGPUQueue queue, WGPUTexture tex, const uint32_t* data,
+ int W, int H) {
WGPUTexelCopyTextureInfo dst = {};
dst.texture = tex;
WGPUTexelCopyBufferLayout layout = {};
- layout.bytesPerRow = (uint32_t)(W * 16);
+ layout.bytesPerRow = (uint32_t)(W * 16);
layout.rowsPerImage = (uint32_t)H;
WGPUExtent3D ext = {(uint32_t)W, (uint32_t)H, 1};
wgpuQueueWriteTexture(queue, &dst, data, (size_t)(W * H * 16), &layout, &ext);
@@ -281,37 +284,53 @@ static void upload_tex(WGPUQueue queue, WGPUTexture tex,
// RGBA16Float readback
// ---------------------------------------------------------------------------
-static uint16_t fp16_bits_to_f16(float f) { return f32_to_f16(f); }
+static uint16_t fp16_bits_to_f16(float f) {
+ return f32_to_f16(f);
+}
static float fp16_bits_to_f32(uint16_t h) {
uint32_t sign = (uint32_t)(h & 0x8000u) << 16;
- uint32_t exp = (h & 0x7C00u) >> 10;
+ uint32_t exp = (h & 0x7C00u) >> 10;
uint32_t mant = h & 0x03FFu;
- if (exp == 0 && mant == 0) { float r; memcpy(&r, &sign, 4); return r; }
- if (exp == 31) { uint32_t b = sign | 0x7F800000u | (mant << 13);
- float r; memcpy(&r, &b, 4); return r; }
+ if (exp == 0 && mant == 0) {
+ float r;
+ memcpy(&r, &sign, 4);
+ return r;
+ }
+ if (exp == 31) {
+ uint32_t b = sign | 0x7F800000u | (mant << 13);
+ float r;
+ memcpy(&r, &b, 4);
+ return r;
+ }
uint32_t b = sign | ((exp + 112u) << 23) | (mant << 13);
- float r; memcpy(&r, &b, 4); return r;
+ float r;
+ memcpy(&r, &b, 4);
+ return r;
}
-struct MapState { bool done = false; WGPUMapAsyncStatus status = {}; };
+struct MapState {
+ bool done = false;
+ WGPUMapAsyncStatus status = {};
+};
static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue,
- WGPUTexture tex, int W, int H) {
- const uint32_t bytes_per_px = 8;
- const uint32_t raw_bpr = (uint32_t)(W * bytes_per_px);
- const uint32_t aligned_bpr = ((raw_bpr + 255u) / 256u) * 256u;
- const size_t buf_size = (size_t)aligned_bpr * (size_t)H;
+ WGPUTexture tex, int W, int H) {
+ const uint32_t bytes_per_px = 8;
+ const uint32_t raw_bpr = (uint32_t)(W * bytes_per_px);
+ const uint32_t aligned_bpr = ((raw_bpr + 255u) / 256u) * 256u;
+ const size_t buf_size = (size_t)aligned_bpr * (size_t)H;
WGPUBufferDescriptor bd = {};
bd.usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead;
- bd.size = buf_size;
+ bd.size = buf_size;
WGPUBuffer staging = wgpuDeviceCreateBuffer(device, &bd);
WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(device, nullptr);
- WGPUTexelCopyTextureInfo src = {}; src.texture = tex;
- WGPUTexelCopyBufferInfo dst = {};
- dst.buffer = staging;
- dst.layout.bytesPerRow = aligned_bpr;
+ WGPUTexelCopyTextureInfo src = {};
+ src.texture = tex;
+ WGPUTexelCopyBufferInfo dst = {};
+ dst.buffer = staging;
+ dst.layout.bytesPerRow = aligned_bpr;
dst.layout.rowsPerImage = (uint32_t)H;
WGPUExtent3D ext = {(uint32_t)W, (uint32_t)H, 1};
wgpuCommandEncoderCopyTextureToBuffer(enc, &src, &dst, &ext);
@@ -323,9 +342,11 @@ static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue,
MapState ms = {};
WGPUBufferMapCallbackInfo mi = {};
- mi.mode = WGPUCallbackMode_AllowProcessEvents;
+ mi.mode = WGPUCallbackMode_AllowProcessEvents;
mi.callback = [](WGPUMapAsyncStatus s, WGPUStringView, void* u, void*) {
- auto* st = (MapState*)u; st->status = s; st->done = true;
+ auto* st = (MapState*)u;
+ st->status = s;
+ st->done = true;
};
mi.userdata1 = &ms;
wgpuBufferMapAsync(staging, WGPUMapMode_Read, 0, buf_size, mi);
@@ -334,11 +355,12 @@ static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue,
std::vector<float> pixels(W * H * 4, 0.0f);
if (ms.done && ms.status == WGPUMapAsyncStatus_Success) {
- const uint8_t* mapped = (const uint8_t*)
- wgpuBufferGetConstMappedRange(staging, 0, buf_size);
+ const uint8_t* mapped =
+ (const uint8_t*)wgpuBufferGetConstMappedRange(staging, 0, buf_size);
if (mapped) {
for (int y = 0; y < H; ++y) {
- const uint16_t* row = (const uint16_t*)(mapped + (size_t)y * aligned_bpr);
+ const uint16_t* row =
+ (const uint16_t*)(mapped + (size_t)y * aligned_bpr);
for (int x = 0; x < W; ++x) {
for (int c = 0; c < 4; ++c)
pixels[(y * W + x) * 4 + c] = fp16_bits_to_f32(row[x * 4 + c]);
@@ -355,14 +377,16 @@ static std::vector<float> readback_rgba16f(WGPUDevice device, WGPUQueue queue,
// Image I/O helpers
// ---------------------------------------------------------------------------
-static std::vector<float> load_png_rgb(const char* path, int* out_w, int* out_h) {
+static std::vector<float> load_png_rgb(const char* path, int* out_w,
+ int* out_h) {
int w, h, ch;
uint8_t* data = stbi_load(path, &w, &h, &ch, 3);
if (!data) {
fprintf(stderr, "Error: cannot load '%s'\n", path);
return {};
}
- *out_w = w; *out_h = h;
+ *out_w = w;
+ *out_h = h;
std::vector<float> out(w * h * 3);
for (int i = 0; i < w * h * 3; ++i)
out[i] = data[i] / 255.0f;
@@ -375,14 +399,16 @@ static std::vector<float> load_png_rg(const char* path, int ew, int eh) {
int w, h, ch;
uint8_t* data = stbi_load(path, &w, &h, &ch, 3);
if (!data || w != ew || h != eh) {
- if (data) stbi_image_free(data);
- fprintf(stderr, "Warning: cannot load normal '%s' — using (0.5,0.5)\n", path);
+ if (data)
+ stbi_image_free(data);
+ fprintf(stderr, "Warning: cannot load normal '%s' — using (0.5,0.5)\n",
+ path);
std::vector<float> def(ew * eh * 2, 0.5f);
return def;
}
std::vector<float> out(w * h * 2);
for (int i = 0; i < w * h; ++i) {
- out[i * 2 ] = data[i * 3 ] / 255.0f;
+ out[i * 2] = data[i * 3] / 255.0f;
out[i * 2 + 1] = data[i * 3 + 1] / 255.0f;
}
stbi_image_free(data);
@@ -394,7 +420,8 @@ static std::vector<float> load_png_depth16(const char* path, int ew, int eh) {
int w, h, ch;
uint16_t* data = stbi_load_16(path, &w, &h, &ch, 1);
if (!data || w != ew || h != eh) {
- if (data) stbi_image_free(data);
+ if (data)
+ stbi_image_free(data);
fprintf(stderr, "Warning: cannot load depth '%s' — using 0\n", path);
return std::vector<float>(ew * eh, 0.0f);
}
@@ -407,11 +434,12 @@ static std::vector<float> load_png_depth16(const char* path, int ew, int eh) {
// Load 8-bit greyscale PNG → [0,1]
static std::vector<float> load_png_gray(const char* path, int ew, int eh,
- float default_val = 0.0f) {
+ float default_val = 0.0f) {
int w, h, ch;
uint8_t* data = stbi_load(path, &w, &h, &ch, 1);
if (!data || w != ew || h != eh) {
- if (data) stbi_image_free(data);
+ if (data)
+ stbi_image_free(data);
return std::vector<float>(ew * eh, default_val);
}
std::vector<float> out(w * h);
@@ -468,29 +496,37 @@ static bool load_weights_bin(const char* path, std::vector<uint32_t>& out) {
// ---------------------------------------------------------------------------
struct Args {
- const char* input_path = nullptr;
- const char* output_path = nullptr;
- const char* sample_dir = nullptr;
+ const char* input_path = nullptr;
+ const char* output_path = nullptr;
+ const char* sample_dir = nullptr;
const char* weights_path = nullptr;
- bool debug_hex = false;
+ bool debug_hex = false;
};
static void print_usage(const char* prog) {
fprintf(stderr, "Usage: %s input.png output.png [OPTIONS]\n", prog);
fprintf(stderr, "\nOPTIONS:\n");
- fprintf(stderr, " --sample-dir DIR Full sample dir with albedo/normal/depth/matid/shadow/transp\n");
- fprintf(stderr, " --weights FILE Load weights from cnn_v3_weights.bin\n");
+ fprintf(stderr,
+ " --sample-dir DIR Full sample dir with "
+ "albedo/normal/depth/matid/shadow/transp\n");
+ fprintf(stderr,
+ " --weights FILE Load weights from cnn_v3_weights.bin\n");
fprintf(stderr, " --debug-hex Print first 8 output pixels as hex\n");
fprintf(stderr, " --help Show this help\n");
- fprintf(stderr, "\nSimple mode (single PNG): geometry channels zeroed, normal=(0.5,0.5).\n");
+ fprintf(stderr,
+ "\nSimple mode (single PNG): geometry channels zeroed, "
+ "normal=(0.5,0.5).\n");
fprintf(stderr, "FiLM is always identity (gamma=1, beta=0).\n");
- fprintf(stderr, "\nNote: feature packing uses [0,1] oct-normals (training format) to match\n");
+ fprintf(stderr,
+ "\nNote: feature packing uses [0,1] oct-normals (training format) to "
+ "match\n");
fprintf(stderr, " infer_cnn_v3.py for direct Python/WGSL comparison.\n");
}
static bool parse_args(int argc, char** argv, Args* args) {
- if (argc < 3) return false;
- args->input_path = argv[1];
+ if (argc < 3)
+ return false;
+ args->input_path = argv[1];
args->output_path = argv[2];
for (int i = 3; i < argc; ++i) {
if (strcmp(argv[i], "--sample-dir") == 0 && i + 1 < argc) {
@@ -535,7 +571,8 @@ int main(int argc, char** argv) {
// --- Load input image ---
int W, H;
std::vector<float> albedo = load_png_rgb(args.input_path, &W, &H);
- if (albedo.empty()) return 1;
+ if (albedo.empty())
+ return 1;
// Pad to multiples of 4 (U-Net requires 2 pooling levels)
const int W4 = (W + 3) & ~3;
@@ -548,14 +585,16 @@ int main(int argc, char** argv) {
for (int c = 0; c < 3; ++c)
padded[(y * W4 + x) * 3 + c] = albedo[(y * W + x) * 3 + c];
albedo = std::move(padded);
- W = W4; H = H4;
+ W = W4;
+ H = H4;
}
printf("Input: %s (%dx%d)\n", args.input_path, W, H);
// --- Build FeatureImages ---
FeatureImages img;
- img.w = W; img.h = H;
+ img.w = W;
+ img.h = H;
img.albedo = albedo;
if (args.sample_dir) {
@@ -564,8 +603,8 @@ int main(int argc, char** argv) {
return std::string(args.sample_dir) + "/" + name;
};
img.normal = load_png_rg(path("normal.png").c_str(), W, H);
- img.depth = load_png_depth16(path("depth.png").c_str(), W, H);
- img.matid = load_png_gray(path("matid.png").c_str(), W, H, 0.0f);
+ img.depth = load_png_depth16(path("depth.png").c_str(), W, H);
+ img.matid = load_png_gray(path("matid.png").c_str(), W, H, 0.0f);
img.shadow = load_png_gray(path("shadow.png").c_str(), W, H, 1.0f);
img.transp = load_png_gray(path("transp.png").c_str(), W, H, 0.0f);
} else {
@@ -584,11 +623,13 @@ int main(int argc, char** argv) {
// --- Create GPU textures ---
WGPUTexture feat0_tex = make_feat_tex(ctx.device, W, H);
WGPUTexture feat1_tex = make_feat_tex(ctx.device, W, H);
- WGPUTexture out_tex = make_output_tex(ctx.device, W, H);
+ WGPUTexture out_tex = make_output_tex(ctx.device, W, H);
- WGPUTextureView feat0_view = make_view(feat0_tex, WGPUTextureFormat_RGBA32Uint);
- WGPUTextureView feat1_view = make_view(feat1_tex, WGPUTextureFormat_RGBA32Uint);
- WGPUTextureView out_view = make_view(out_tex, WGPUTextureFormat_RGBA16Float);
+ WGPUTextureView feat0_view =
+ make_view(feat0_tex, WGPUTextureFormat_RGBA32Uint);
+ WGPUTextureView feat1_view =
+ make_view(feat1_tex, WGPUTextureFormat_RGBA32Uint);
+ WGPUTextureView out_view = make_view(out_tex, WGPUTextureFormat_RGBA16Float);
upload_tex(ctx.queue, feat0_tex, feat0.data(), W, H);
upload_tex(ctx.queue, feat1_tex, feat1.data(), W, H);
@@ -605,7 +646,8 @@ int main(int argc, char** argv) {
// --- Load weights ---
if (args.weights_path) {
std::vector<uint32_t> wdata;
- if (!load_weights_bin(args.weights_path, wdata)) return 1;
+ if (!load_weights_bin(args.weights_path, wdata))
+ return 1;
effect.upload_weights(ctx.queue, wdata.data(),
(uint32_t)(wdata.size() * 4));
printf("Weights: %s (%zu bytes)\n", args.weights_path, wdata.size() * 4);
@@ -616,7 +658,7 @@ int main(int argc, char** argv) {
// --- Run 5 compute passes ---
WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr);
UniformsSequenceParams params = {};
- params.resolution = {(float)W, (float)H};
+ params.resolution = {(float)W, (float)H};
params.aspect_ratio = (float)W / (float)H;
effect.render(enc, params, registry);
@@ -627,23 +669,27 @@ int main(int argc, char** argv) {
wgpuDevicePoll(ctx.device, true, nullptr);
// --- Readback ---
- std::vector<float> pixels = readback_rgba16f(ctx.device, ctx.queue, out_tex, W, H);
+ std::vector<float> pixels =
+ readback_rgba16f(ctx.device, ctx.queue, out_tex, W, H);
// --- Save output (crop to original size, already same if no padding) ---
- if (!save_png(args.output_path, pixels, W, H)) return 1;
+ if (!save_png(args.output_path, pixels, W, H))
+ return 1;
printf("Saved: %s\n", args.output_path);
if (args.debug_hex) {
printf("First 8 output pixels (RGBA f32 → hex):\n");
for (int i = 0; i < 8 && i < W * H; ++i) {
- float r = pixels[i*4 ], g = pixels[i*4+1];
- float b = pixels[i*4+2], a = pixels[i*4+3];
- int ri = (int)(r*255+.5f), gi = (int)(g*255+.5f);
- int bi = (int)(b*255+.5f), ai = (int)(a*255+.5f);
- ri = ri<0?0:ri>255?255:ri; gi = gi<0?0:gi>255?255:gi;
- bi = bi<0?0:bi>255?255:bi; ai = ai<0?0:ai>255?255:ai;
- printf(" [%d] 0x%02X%02X%02X%02X (%.4f %.4f %.4f %.4f)\n",
- i, ri, gi, bi, ai, r, g, b, a);
+ float r = pixels[i * 4], g = pixels[i * 4 + 1];
+ float b = pixels[i * 4 + 2], a = pixels[i * 4 + 3];
+ int ri = (int)(r * 255 + .5f), gi = (int)(g * 255 + .5f);
+ int bi = (int)(b * 255 + .5f), ai = (int)(a * 255 + .5f);
+ ri = ri < 0 ? 0 : ri > 255 ? 255 : ri;
+ gi = gi < 0 ? 0 : gi > 255 ? 255 : gi;
+ bi = bi < 0 ? 0 : bi > 255 ? 255 : bi;
+ ai = ai < 0 ? 0 : ai > 255 ? 255 : ai;
+ printf(" [%d] 0x%02X%02X%02X%02X (%.4f %.4f %.4f %.4f)\n", i, ri, gi,
+ bi, ai, r, g, b, a);
}
}