diff options
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/validate_uniforms.py | 178 |
1 files changed, 178 insertions, 0 deletions
diff --git a/tools/validate_uniforms.py b/tools/validate_uniforms.py new file mode 100644 index 0000000..40d1b0f --- /dev/null +++ b/tools/validate_uniforms.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +import sys +import re +import os + +# WGSL alignment rules (simplified for common types) +WGSL_ALIGNMENT = { + "f32": 4, + "vec2<f32>": 8, + "vec3<f32>": 16, + "vec4<f32>": 16, + # Add other types as needed (e.g., u32, i32, mat4x4<f32>) +} + +def get_wgsl_type_size_and_alignment(type_name): + type_name = type_name.strip() + if type_name in WGSL_ALIGNMENT: + return WGSL_ALIGNMENT[type_name], WGSL_ALIGNMENT[type_name] + # Handle arrays, e.g., array<f32, 5> + if type_name.startswith("array"): + match = re.search(r"array<([\w<>, ]+)>", type_name) + if match: + inner_type = match.group(1).split(",")[0].strip() + # For simplicity, assume scalar array doesn't change alignment of base type + return get_wgsl_type_size_and_alignment(inner_type) + # Handle structs recursively (simplified, assumes no nested structs for now) + return 0, 0 # Unknown or complex type + +def parse_wgsl_struct(wgsl_content): + structs = {} + # Regex to find struct definitions: struct StructName { ... } + struct_matches = re.finditer(r"struct\s+(\w+)\s*\{\s*(.*?)\s*\}", wgsl_content, re.DOTALL) + for struct_match in struct_matches: + struct_name = struct_match.group(1) + members_content = struct_match.group(2) + members = [] + # Regex to find members: member_name: member_type + # Adjusted regex to handle types with brackets and spaces, and comments. + # CHANGED: \s to [ \t] to avoid consuming newlines + member_matches = re.finditer(r"(\w+)\s*:\s*([\w<>,\[\] \t]+)(?:\s*//.*)?", members_content) + for member_match in member_matches: + member_name = member_match.group(1) + member_type = member_match.group(2).strip() + if member_type.endswith(','): + member_type = member_type[:-1].strip() + members.append((member_name, member_type)) + structs[struct_name] = members + # print(f"DEBUG: Parsed WGSL struct '{struct_name}' with members: {members}") + return structs + +def find_embedded_wgsl_in_cpp(cpp_content): + # Regex to find raw string literals R"(...)" which often contain WGSL + wgsl_blocks = [] + matches = re.finditer(r'R"\((.*?)\)"', cpp_content, re.DOTALL) + for match in matches: + wgsl_blocks.append(match.group(1)) + return wgsl_blocks + +def calculate_wgsl_struct_size(struct_name, struct_members): + total_size = 0 + max_alignment = 0 + members_info = [] + + for member_name, member_type in struct_members: + size, alignment = get_wgsl_type_size_and_alignment(member_type) + if size == 0: # If type is unknown or complex, we can't reliably calculate + # print(f"Warning: Unknown or complex WGSL type '{member_type}' for member '{member_name}'. Cannot reliably calculate size.", file=sys.stderr) + return 0, 0 + members_info.append((member_name, member_type, size, alignment)) + max_alignment = max(max_alignment, alignment) + + current_offset = 0 + for member_name, member_type, size, alignment in members_info: + # Align current offset to the alignment of the current member + current_offset = (current_offset + alignment - 1) & ~(alignment - 1) + current_offset += size + + # The total size of the struct is the final offset, padded to the max alignment + if max_alignment > 0: + total_size = (current_offset + max_alignment - 1) & ~(max_alignment - 1) + else: + total_size = current_offset + + return total_size, max_alignment + +def parse_cpp_static_asserts(cpp_content): + cpp_structs = {} + # Regex to find C++ struct definitions with static_asserts for sizeof + # This regex is simplified and might need adjustments for more complex C++ code + struct_matches = re.finditer(r"struct\s+(\w+)\s*\{\s*(.*?)\s*\}\s*;.*?static_assert\(sizeof\(\1\)\s*==\s*(\d+)\s*,.*?\);", cpp_content, re.DOTALL | re.MULTILINE) + for struct_match in struct_matches: + struct_name = struct_match.group(1) + members_content = struct_match.group(2) + expected_size = int(struct_match.group(3)) + members = [] + # Regex to find members: type member_name; + member_matches = re.finditer(r"(.*?)\s+(\w+)\s*(?:=\s*.*?|\s*\{.*?\})?;", members_content) + for member_match in member_matches: + member_type = member_match.group(1).strip() + member_name = member_match.group(2).strip() + members.append((member_name, member_type)) + cpp_structs[struct_name] = {"members": members, "expected_size": expected_size} + return cpp_structs + +def validate_uniforms(wgsl_files, cpp_files): + all_wgsl_structs = {} + + # Parse separate WGSL files + for file_path in wgsl_files: + try: + with open(file_path, 'r') as f: + wgsl_content = f.read() + structs = parse_wgsl_struct(wgsl_content) + all_wgsl_structs.update(structs) + except Exception as e: + print(f"Error parsing WGSL file {file_path}: {e}", file=sys.stderr) + continue + + # Parse C++ files for embedded WGSL and static_asserts + for cpp_file_path in cpp_files: + try: + with open(cpp_file_path, 'r') as f: + cpp_content = f.read() + + # Parse embedded WGSL + wgsl_blocks = find_embedded_wgsl_in_cpp(cpp_content) + for block in wgsl_blocks: + structs = parse_wgsl_struct(block) + all_wgsl_structs.update(structs) + + # Parse C++ structs and static_asserts + cpp_structs = parse_cpp_static_asserts(cpp_content) + for struct_name, data in cpp_structs.items(): + expected_size = data["expected_size"] + # Try to find the matching WGSL struct + if struct_name in all_wgsl_structs: + wgsl_members = all_wgsl_structs[struct_name] + calculated_wgsl_size, wgsl_max_alignment = calculate_wgsl_struct_size(struct_name, wgsl_members) + + if calculated_wgsl_size == 0: # If calculation failed + # print(f"Validation Warning for '{struct_name}': Could not calculate WGSL size.") + continue + + if calculated_wgsl_size != expected_size: + print(f"Validation Mismatch for '{struct_name}':\n WGSL Calculated Size: {calculated_wgsl_size}\n C++ Expected Size: {expected_size}\n Max WGSL Alignment: {wgsl_max_alignment}", file=sys.stderr) + sys.exit(1) + else: + print(f"Validation OK for '{struct_name}': Size {calculated_wgsl_size} matches C++ expected size.") + else: + print(f"Validation Warning for '{struct_name}': Matching WGSL struct not found.") + except Exception as e: + print(f"Error processing C++ file {cpp_file_path}: {e}", file=sys.stderr) + continue + +def main(): + if len(sys.argv) < 3: + print("Usage: validate_uniforms.py <wgsl_dir_or_file> <cpp_file1> [<cpp_file2> ...]", file=sys.stderr) + sys.exit(1) + + wgsl_input = sys.argv[1] + cpp_files = sys.argv[2:] + + wgsl_files = [] + if os.path.isfile(wgsl_input): + wgsl_files.append(wgsl_input) + elif os.path.isdir(wgsl_input): + for root, _, files in os.walk(wgsl_input): + for file in files: + if file.endswith(".wgsl"): + wgsl_files.append(os.path.join(root, file)) + + # We proceed even if wgsl_files is empty, because C++ files might contain embedded WGSL + + validate_uniforms(wgsl_files, cpp_files) + +if __name__ == "__main__": + main()
\ No newline at end of file |
