#!/usr/bin/env python3 import sys import re import os # WGSL alignment rules (simplified for common types) WGSL_ALIGNMENT = { "f32": 4, "vec2": 8, "vec3": 16, "vec4": 16, # Add other types as needed (e.g., u32, i32, mat4x4) } def get_wgsl_type_size_and_alignment(type_name): type_name = type_name.strip() if type_name in WGSL_ALIGNMENT: return WGSL_ALIGNMENT[type_name], WGSL_ALIGNMENT[type_name] # Handle arrays, e.g., array if type_name.startswith("array"): match = re.search(r"array<([\w<>, ]+)>", type_name) if match: inner_type = match.group(1).split(",")[0].strip() # For simplicity, assume scalar array doesn't change alignment of base type return get_wgsl_type_size_and_alignment(inner_type) # Handle structs recursively (simplified, assumes no nested structs for now) return 0, 0 # Unknown or complex type def parse_wgsl_struct(wgsl_content): structs = {} # Regex to find struct definitions: struct StructName { ... } struct_matches = re.finditer(r"struct\s+(\w+)\s*\{\s*(.*?)\s*\}", wgsl_content, re.DOTALL) for struct_match in struct_matches: struct_name = struct_match.group(1) members_content = struct_match.group(2) members = [] # Regex to find members: member_name: member_type # Adjusted regex to handle types with brackets and spaces, and comments. # CHANGED: \s to [ \t] to avoid consuming newlines member_matches = re.finditer(r"(\w+)\s*:\s*([\w<>,\[\] \t]+)(?:\s*//.*)?", members_content) for member_match in member_matches: member_name = member_match.group(1) member_type = member_match.group(2).strip() if member_type.endswith(','): member_type = member_type[:-1].strip() members.append((member_name, member_type)) structs[struct_name] = members # print(f"DEBUG: Parsed WGSL struct '{struct_name}' with members: {members}") return structs def find_embedded_wgsl_in_cpp(cpp_content): # Regex to find raw string literals R"(...)" which often contain WGSL wgsl_blocks = [] matches = re.finditer(r'R"\((.*?)\)"', cpp_content, re.DOTALL) for match in matches: wgsl_blocks.append(match.group(1)) return wgsl_blocks def calculate_wgsl_struct_size(struct_name, struct_members): total_size = 0 max_alignment = 0 members_info = [] for member_name, member_type in struct_members: size, alignment = get_wgsl_type_size_and_alignment(member_type) if size == 0: # If type is unknown or complex, we can't reliably calculate # print(f"Warning: Unknown or complex WGSL type '{member_type}' for member '{member_name}'. Cannot reliably calculate size.", file=sys.stderr) return 0, 0 members_info.append((member_name, member_type, size, alignment)) max_alignment = max(max_alignment, alignment) current_offset = 0 for member_name, member_type, size, alignment in members_info: # Align current offset to the alignment of the current member current_offset = (current_offset + alignment - 1) & ~(alignment - 1) current_offset += size # The total size of the struct is the final offset, padded to the max alignment if max_alignment > 0: total_size = (current_offset + max_alignment - 1) & ~(max_alignment - 1) else: total_size = current_offset return total_size, max_alignment def parse_cpp_static_asserts(cpp_content): cpp_structs = {} # Regex to find C++ struct definitions with static_asserts for sizeof # This regex is simplified and might need adjustments for more complex C++ code struct_matches = re.finditer(r"struct\s+(\w+)\s*\{\s*(.*?)\s*\}\s*;.*?static_assert\(sizeof\(\1\)\s*==\s*(\d+)\s*,.*?\);", cpp_content, re.DOTALL | re.MULTILINE) for struct_match in struct_matches: struct_name = struct_match.group(1) members_content = struct_match.group(2) expected_size = int(struct_match.group(3)) members = [] # Regex to find members: type member_name; member_matches = re.finditer(r"(.*?)\s+(\w+)\s*(?:=\s*.*?|\s*\{.*?\})?;", members_content) for member_match in member_matches: member_type = member_match.group(1).strip() member_name = member_match.group(2).strip() members.append((member_name, member_type)) cpp_structs[struct_name] = {"members": members, "expected_size": expected_size} return cpp_structs def validate_uniforms(wgsl_files, cpp_files): all_wgsl_structs = {} # Parse separate WGSL files for file_path in wgsl_files: try: with open(file_path, 'r') as f: wgsl_content = f.read() structs = parse_wgsl_struct(wgsl_content) all_wgsl_structs.update(structs) except Exception as e: print(f"Error parsing WGSL file {file_path}: {e}", file=sys.stderr) continue # Parse C++ files for embedded WGSL and static_asserts for cpp_file_path in cpp_files: try: with open(cpp_file_path, 'r') as f: cpp_content = f.read() # Parse embedded WGSL wgsl_blocks = find_embedded_wgsl_in_cpp(cpp_content) for block in wgsl_blocks: structs = parse_wgsl_struct(block) all_wgsl_structs.update(structs) # Parse C++ structs and static_asserts cpp_structs = parse_cpp_static_asserts(cpp_content) for struct_name, data in cpp_structs.items(): expected_size = data["expected_size"] # Try to find the matching WGSL struct if struct_name in all_wgsl_structs: wgsl_members = all_wgsl_structs[struct_name] calculated_wgsl_size, wgsl_max_alignment = calculate_wgsl_struct_size(struct_name, wgsl_members) if calculated_wgsl_size == 0: # If calculation failed # print(f"Validation Warning for '{struct_name}': Could not calculate WGSL size.") continue if calculated_wgsl_size != expected_size: print(f"Validation Mismatch for '{struct_name}':\n WGSL Calculated Size: {calculated_wgsl_size}\n C++ Expected Size: {expected_size}\n Max WGSL Alignment: {wgsl_max_alignment}", file=sys.stderr) sys.exit(1) else: print(f"Validation OK for '{struct_name}': Size {calculated_wgsl_size} matches C++ expected size.") else: print(f"Validation Warning for '{struct_name}': Matching WGSL struct not found.") except Exception as e: print(f"Error processing C++ file {cpp_file_path}: {e}", file=sys.stderr) continue def main(): if len(sys.argv) < 3: print("Usage: validate_uniforms.py [ ...]", file=sys.stderr) sys.exit(1) wgsl_input = sys.argv[1] cpp_files = sys.argv[2:] wgsl_files = [] if os.path.isfile(wgsl_input): wgsl_files.append(wgsl_input) elif os.path.isdir(wgsl_input): for root, _, files in os.walk(wgsl_input): for file in files: if file.endswith(".wgsl"): wgsl_files.append(os.path.join(root, file)) # We proceed even if wgsl_files is empty, because C++ files might contain embedded WGSL validate_uniforms(wgsl_files, cpp_files) if __name__ == "__main__": main()