summaryrefslogtreecommitdiff
path: root/cnn_v3/training/pack_photo_sample.py
blob: b2943fb32bedc54bff69e6a6e24d7a8d82b0a3ad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Pack a photo into CNN v3 simple training sample files.

Converts a single RGB or RGBA photo into the CNN v3 sample layout.
Geometric channels (normal, depth, matid) are zeroed; the network
degrades gracefully due to channel-dropout training.

Output files:
    albedo.png    — RGB uint8   (photo RGB)
    normal.png    — RG uint8    (zero — no geometry data)
    depth.png     — R uint16    (zero — no depth data)
    matid.png     — R uint8     (zero — no material data)
    shadow.png    — R uint8     (255 = fully lit — assume unoccluded)
    transp.png    — R uint8     (1 - alpha, or 0 if no alpha channel)
    target.png    — RGB/RGBA    (= albedo; no ground-truth styled target)

mip1 and mip2 are computed on-the-fly by the dataloader from albedo.
prev = zero during training (no temporal history).

Usage:
    python3 pack_photo_sample.py --photo photos/img_001.png \\
                                 --output dataset/simple/sample_001/

Dependencies:
    numpy, Pillow
"""

import argparse
import os
import numpy as np
from PIL import Image


# ---- Mip computation ----

def pyrdown(img: np.ndarray) -> np.ndarray:
    """
    2×2 average pooling (half resolution).
    Args:
        img: (H, W, C) float32 in [0, 1].
    Returns:
        (H//2, W//2, C) float32.
    """
    h, w, c = img.shape
    h2, w2 = h // 2, w // 2
    # Crop to even dimensions
    cropped = img[:h2 * 2, :w2 * 2, :]
    # Reshape and average
    return 0.25 * (
        cropped[0::2, 0::2, :] +
        cropped[1::2, 0::2, :] +
        cropped[0::2, 1::2, :] +
        cropped[1::2, 1::2, :]
    )


# ---- Main packing ----

def pack_photo_sample(photo_path: str, output_dir: str) -> None:
    os.makedirs(output_dir, exist_ok=True)

    print(f"[pack_photo_sample] Loading {photo_path} …")
    img = Image.open(photo_path).convert("RGBA")
    width, height = img.size
    print(f"  Dimensions: {width}×{height}")

    img_np = np.asarray(img, dtype=np.float32) / 255.0  # (H, W, 4) in [0, 1]
    rgb  = img_np[..., :3]   # (H, W, 3)
    alpha = img_np[..., 3]   # (H, W)

    # ---- albedo — photo RGB ----
    albedo_u8 = (np.clip(rgb, 0, 1) * 255.0).astype(np.uint8)
    Image.fromarray(albedo_u8, mode="RGB").save(
        os.path.join(output_dir, "albedo.png")
    )

    # ---- normal — zero (no geometry) ----
    normal_zeros = np.zeros((height, width, 3), dtype=np.uint8)
    # Encode "no normal" as (0.5, 0.5) in octahedral space → (128, 128)
    # This maps to oct = (0, 0) → reconstructed normal = (0, 0, 1) (pointing forward)
    normal_zeros[..., 0] = 128
    normal_zeros[..., 1] = 128
    Image.fromarray(normal_zeros, mode="RGB").save(
        os.path.join(output_dir, "normal.png")
    )

    # ---- depth — zero ----
    depth_zero = np.zeros((height, width), dtype=np.uint16)
    Image.fromarray(depth_zero, mode="I;16").save(
        os.path.join(output_dir, "depth.png")
    )

    # ---- matid — zero ----
    matid_zero = np.zeros((height, width), dtype=np.uint8)
    Image.fromarray(matid_zero, mode="L").save(
        os.path.join(output_dir, "matid.png")
    )

    # ---- shadow — 255 (fully lit, assume unoccluded) ----
    shadow_full = np.full((height, width), 255, dtype=np.uint8)
    Image.fromarray(shadow_full, mode="L").save(
        os.path.join(output_dir, "shadow.png")
    )

    # ---- transp — 1 - alpha (0=opaque, 1=transparent) ----
    # If the photo has no meaningful alpha, this is zero everywhere.
    transp = 1.0 - np.clip(alpha, 0.0, 1.0)
    transp_u8 = (transp * 255.0).astype(np.uint8)
    Image.fromarray(transp_u8, mode="L").save(
        os.path.join(output_dir, "transp.png")
    )

    # ---- target — albedo (= photo; no GT styled target) ----
    # Store as RGBA (keep alpha for potential masking by the dataloader).
    target_u8 = (np.clip(img_np, 0, 1) * 255.0).astype(np.uint8)
    Image.fromarray(target_u8, mode="RGBA").save(
        os.path.join(output_dir, "target.png")
    )

    # ---- mip1 / mip2 — informational only, not saved ----
    # The dataloader computes mip1/mip2 on-the-fly from albedo.
    # Verify they look reasonable here for debugging.
    mip1 = pyrdown(rgb)
    mip2 = pyrdown(mip1)
    print(f"  mip1: {mip1.shape[1]}×{mip1.shape[0]}  "
          f"mip2: {mip2.shape[1]}×{mip2.shape[0]}  (computed on-the-fly)")

    print(f"[pack_photo_sample] Wrote sample to {output_dir}")
    print("  Files: albedo.png  normal.png  depth.png  matid.png  "
          "shadow.png  transp.png  target.png")
    print("  Note: normal/depth/matid are zeroed (no geometry data).")
    print("  Note: target = albedo (no ground-truth styled target).")


def main():
    parser = argparse.ArgumentParser(
        description="Pack a photo into CNN v3 simple training sample files."
    )
    parser.add_argument("--photo",  required=True,
                        help="Input photo file (RGB or RGBA PNG/JPG)")
    parser.add_argument("--output", required=True,
                        help="Output directory for sample files")
    args = parser.parse_args()
    pack_photo_sample(args.photo, args.output)


if __name__ == "__main__":
    main()