diff options
Diffstat (limited to 'cnn_v3/training/pack_blender_sample.py')
| -rw-r--r-- | cnn_v3/training/pack_blender_sample.py | 268 |
1 files changed, 268 insertions, 0 deletions
diff --git a/cnn_v3/training/pack_blender_sample.py b/cnn_v3/training/pack_blender_sample.py new file mode 100644 index 0000000..84344c1 --- /dev/null +++ b/cnn_v3/training/pack_blender_sample.py @@ -0,0 +1,268 @@ +""" +Pack a Blender multi-layer EXR into CNN v3 training sample files. + +Reads a multi-layer EXR produced by blender_export.py and writes separate PNG +files per channel into an output directory, ready for the CNN v3 dataloader. + +Output files: + albedo.png — RGB uint8 (DiffCol pass, gamma-corrected) + normal.png — RG uint8 (octahedral-encoded world normal in [0,1]) + depth.png — R uint16 (1/(z+1) normalized to [0,1], 16-bit PNG) + matid.png — R uint8 (IndexOB / 255) + shadow.png — R uint8 (1 - shadow_catcher, so 255 = fully lit) + transp.png — R uint8 (alpha from Combined pass, 0=opaque) + target.png — RGBA uint8 (Combined beauty pass) + +depth_grad, mip1, mip2 are computed on-the-fly by the dataloader (not stored). +prev = zero during training (no temporal history for static frames). + +Usage: + python3 pack_blender_sample.py --exr renders/frame_001.exr \\ + --output dataset/full/sample_001/ + +Dependencies: + numpy, Pillow, OpenEXR (pip install openexr) + — or use imageio[freeimage] as alternative EXR reader. +""" + +import argparse +import os +import sys +import numpy as np +from PIL import Image + + +# ---- EXR loading ---- + +def load_exr_openexr(path: str) -> dict: + """Load a multi-layer EXR using the OpenEXR Python binding.""" + import OpenEXR + import Imath + + exr = OpenEXR.InputFile(path) + header = exr.header() + dw = header["dataWindow"] + width = dw.max.x - dw.min.x + 1 + height = dw.max.y - dw.min.y + 1 + channels = {} + float_type = Imath.PixelType(Imath.PixelType.FLOAT) + for ch_name in header["channels"]: + raw = exr.channel(ch_name, float_type) + arr = np.frombuffer(raw, dtype=np.float32).reshape((height, width)) + channels[ch_name] = arr + return channels, width, height + + +def load_exr_imageio(path: str) -> dict: + """Load a multi-layer EXR using imageio (freeimage backend).""" + import imageio + data = imageio.imread(path, format="exr") + # imageio may return (H, W, C); treat as single layer + h, w = data.shape[:2] + c = data.shape[2] if data.ndim == 3 else 1 + channels = {} + names = ["R", "G", "B", "A"][:c] + for i, n in enumerate(names): + channels[n] = data[:, :, i].astype(np.float32) + return channels, w, h + + +def load_exr(path: str): + """Try OpenEXR first, fall back to imageio.""" + try: + return load_exr_openexr(path) + except ImportError: + pass + try: + return load_exr_imageio(path) + except ImportError: + pass + raise ImportError( + "No EXR reader found. Install OpenEXR or imageio[freeimage]:\n" + " pip install openexr\n" + " pip install imageio[freeimage]" + ) + + +# ---- Octahedral encoding ---- + +def oct_encode(normals: np.ndarray) -> np.ndarray: + """ + Octahedral-encode world-space normals. + + Args: + normals: (H, W, 3) float32, unit vectors. + Returns: + (H, W, 2) float32 in [0, 1] for PNG storage. + """ + nx, ny, nz = normals[..., 0], normals[..., 1], normals[..., 2] + # L1-normalize projection onto the octahedron + l1 = np.abs(nx) + np.abs(ny) + np.abs(nz) + 1e-9 + ox = nx / l1 + oy = ny / l1 + # Fold lower hemisphere + mask = nz < 0.0 + ox_folded = np.where(mask, (1.0 - np.abs(oy)) * np.sign(ox + 1e-9), ox) + oy_folded = np.where(mask, (1.0 - np.abs(ox)) * np.sign(oy + 1e-9), oy) + # Remap [-1, 1] → [0, 1] + encoded = np.stack([ox_folded, oy_folded], axis=-1) * 0.5 + 0.5 + return np.clip(encoded, 0.0, 1.0) + + +# ---- Channel extraction helpers ---- + +def get_pass_rgb(channels: dict, prefix: str) -> np.ndarray: + """Extract an RGB pass (prefix.R, prefix.G, prefix.B).""" + r = channels.get(f"{prefix}.R", channels.get("R", None)) + g = channels.get(f"{prefix}.G", channels.get("G", None)) + b = channels.get(f"{prefix}.B", channels.get("B", None)) + if r is None or g is None or b is None: + raise KeyError(f"Could not find RGB channels for pass '{prefix}'.") + return np.stack([r, g, b], axis=-1) + + +def get_pass_rgba(channels: dict, prefix: str) -> np.ndarray: + """Extract an RGBA pass.""" + rgb = get_pass_rgb(channels, prefix) + a = channels.get(f"{prefix}.A", np.ones_like(rgb[..., 0])) + return np.concatenate([rgb, a[..., np.newaxis]], axis=-1) + + +def get_pass_r(channels: dict, prefix: str, default: float = 0.0) -> np.ndarray: + """Extract a single-channel pass.""" + ch = channels.get(f"{prefix}.R", channels.get(prefix, None)) + if ch is None: + h, w = next(iter(channels.values())).shape[:2] + return np.full((h, w), default, dtype=np.float32) + return ch.astype(np.float32) + + +def get_pass_xyz(channels: dict, prefix: str) -> np.ndarray: + """Extract an XYZ pass (Normal uses .X .Y .Z in Blender).""" + x = channels.get(f"{prefix}.X") + y = channels.get(f"{prefix}.Y") + z = channels.get(f"{prefix}.Z") + if x is None or y is None or z is None: + # Fall back to R/G/B naming + return get_pass_rgb(channels, prefix) + return np.stack([x, y, z], axis=-1) + + +# ---- Main packing ---- + +def pack_blender_sample(exr_path: str, output_dir: str) -> None: + os.makedirs(output_dir, exist_ok=True) + + print(f"[pack_blender_sample] Loading {exr_path} …") + channels, width, height = load_exr(exr_path) + print(f" Dimensions: {width}×{height}") + print(f" Channels: {sorted(channels.keys())}") + + # ---- albedo (DiffCol → RGB uint8, gamma-correct linear→sRGB) ---- + try: + albedo_lin = get_pass_rgb(channels, "DiffCol") + except KeyError: + print(" WARNING: DiffCol pass not found; using zeros.") + albedo_lin = np.zeros((height, width, 3), dtype=np.float32) + # Convert linear → sRGB (approximate gamma 2.2) + albedo_srgb = np.clip(np.power(np.clip(albedo_lin, 0, 1), 1.0 / 2.2), 0, 1) + albedo_u8 = (albedo_srgb * 255.0).astype(np.uint8) + Image.fromarray(albedo_u8, mode="RGB").save( + os.path.join(output_dir, "albedo.png") + ) + + # ---- normal (Normal pass → oct-encoded RG uint8) ---- + try: + # Blender world normals use .X .Y .Z channels + normal_xyz = get_pass_xyz(channels, "Normal") + # Normalize to unit length (may not be exactly unit after compression) + nlen = np.linalg.norm(normal_xyz, axis=-1, keepdims=True) + 1e-9 + normal_unit = normal_xyz / nlen + normal_enc = oct_encode(normal_unit) # (H, W, 2) in [0, 1] + normal_u8 = (normal_enc * 255.0).astype(np.uint8) + # Store in RGB with B=0 (unused) + normal_rgb = np.concatenate( + [normal_u8, np.zeros((height, width, 1), dtype=np.uint8)], axis=-1 + ) + except KeyError: + print(" WARNING: Normal pass not found; using zeros.") + normal_rgb = np.zeros((height, width, 3), dtype=np.uint8) + Image.fromarray(normal_rgb, mode="RGB").save( + os.path.join(output_dir, "normal.png") + ) + + # ---- depth (Z pass → 1/(z+1), stored as 16-bit PNG) ---- + z_raw = get_pass_r(channels, "Z", default=0.0) + # 1/z style: 1/(z + 1) maps z=0→1.0, z=∞→0.0 + depth_norm = 1.0 / (np.clip(z_raw, 0.0, None) + 1.0) + depth_norm = np.clip(depth_norm, 0.0, 1.0) + depth_u16 = (depth_norm * 65535.0).astype(np.uint16) + Image.fromarray(depth_u16, mode="I;16").save( + os.path.join(output_dir, "depth.png") + ) + + # ---- matid (IndexOB → u8) ---- + # Blender object index is an integer; clamp to [0, 255]. + matid_raw = get_pass_r(channels, "IndexOB", default=0.0) + matid_u8 = np.clip(matid_raw, 0, 255).astype(np.uint8) + Image.fromarray(matid_u8, mode="L").save( + os.path.join(output_dir, "matid.png") + ) + + # ---- shadow (Shadow pass → invert: 1=fully lit, stored u8) ---- + # Blender Shadow pass: 1=lit, 0=shadowed. We keep that convention + # (shadow=1 means fully lit), so just convert directly. + shadow_raw = get_pass_r(channels, "Shadow", default=1.0) + shadow_u8 = (np.clip(shadow_raw, 0.0, 1.0) * 255.0).astype(np.uint8) + Image.fromarray(shadow_u8, mode="L").save( + os.path.join(output_dir, "shadow.png") + ) + + # ---- transp (Alpha from Combined pass → u8, 0=opaque) ---- + # Blender alpha: 1=opaque, 0=transparent. + # CNN convention: transp=0 means opaque, transp=1 means transparent. + # So transp = 1 - alpha. + try: + combined_rgba = get_pass_rgba(channels, "Combined") + alpha = combined_rgba[..., 3] + except KeyError: + alpha = np.ones((height, width), dtype=np.float32) + transp = 1.0 - np.clip(alpha, 0.0, 1.0) + transp_u8 = (transp * 255.0).astype(np.uint8) + Image.fromarray(transp_u8, mode="L").save( + os.path.join(output_dir, "transp.png") + ) + + # ---- target (Combined beauty pass → RGBA uint8, gamma-correct) ---- + try: + combined_rgba = get_pass_rgba(channels, "Combined") + # Convert linear → sRGB for display (RGB channels only) + c_rgb = np.power(np.clip(combined_rgba[..., :3], 0, 1), 1.0 / 2.2) + c_alpha = combined_rgba[..., 3:4] + target_lin = np.concatenate([c_rgb, c_alpha], axis=-1) + target_u8 = (np.clip(target_lin, 0, 1) * 255.0).astype(np.uint8) + except KeyError: + print(" WARNING: Combined pass not found; target will be zeros.") + target_u8 = np.zeros((height, width, 4), dtype=np.uint8) + Image.fromarray(target_u8, mode="RGBA").save( + os.path.join(output_dir, "target.png") + ) + + print(f"[pack_blender_sample] Wrote sample to {output_dir}") + print(" Files: albedo.png normal.png depth.png matid.png " + "shadow.png transp.png target.png") + print(" Note: depth_grad, mip1, mip2 are computed on-the-fly by the dataloader.") + + +def main(): + parser = argparse.ArgumentParser( + description="Pack a Blender multi-layer EXR into CNN v3 training sample files." + ) + parser.add_argument("--exr", required=True, help="Input multi-layer EXR file") + parser.add_argument("--output", required=True, help="Output directory for sample files") + args = parser.parse_args() + pack_blender_sample(args.exr, args.output) + + +if __name__ == "__main__": + main() |
