summaryrefslogtreecommitdiff
path: root/tools/seq_compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/seq_compiler.py')
-rwxr-xr-xtools/seq_compiler.py42
1 files changed, 38 insertions, 4 deletions
diff --git a/tools/seq_compiler.py b/tools/seq_compiler.py
index 3acb260..fc07f84 100755
--- a/tools/seq_compiler.py
+++ b/tools/seq_compiler.py
@@ -243,6 +243,25 @@ def validate_dag(seq: SequenceDecl) -> None:
print(f"Error: No path from 'source' to 'sink' in DAG", file=sys.stderr)
sys.exit(1)
+ # 5. Check producer/consumer lifespan constraints
+ # Multi-output effects (can't auto-passthrough) must have lifespan >= all consumers
+ for producer in seq.effects:
+ # Check if producer can auto-passthrough (1:1 topology)
+ if len(producer.inputs) == 1 and len(producer.outputs) == 1:
+ continue # Can auto-passthrough, no constraint
+
+ # Find all consumers of producer's outputs
+ for output_node in producer.outputs:
+ for consumer in seq.effects:
+ if output_node in consumer.inputs:
+ # Verify: producer.[start, end] contains consumer.[start, end]
+ if not (producer.start <= consumer.start and consumer.end <= producer.end):
+ print(f"Error: Producer '{producer.class_name}' [{producer.start}, {producer.end}] "
+ f"has lifespan shorter than consumer '{consumer.class_name}' [{consumer.start}, {consumer.end}] "
+ f"for node '{output_node}'", file=sys.stderr)
+ print(f" Multi-output effects cannot auto-passthrough and must span consumer lifespans", file=sys.stderr)
+ sys.exit(1)
+
def topological_sort(seq: SequenceDecl) -> List[EffectDecl]:
"""Sort effects in execution order using Kahn's algorithm."""
@@ -417,7 +436,8 @@ class {class_name} : public Sequence {{
cpp += f''' effect_dag_.push_back({{
.effect = std::make_shared<{effect.class_name}>(ctx,
std::vector<std::string>{{{inputs_str}}},
- std::vector<std::string>{{{outputs_str}}}),
+ std::vector<std::string>{{{outputs_str}}},
+ {effect.start}f, {effect.end}f),
.input_nodes = {{{inputs_str}}},
.output_nodes = {{{outputs_str}}},
.execution_order = {effect.execution_order}
@@ -434,8 +454,9 @@ class {class_name} : public Sequence {{
def main():
parser = argparse.ArgumentParser(description='Sequence compiler with DAG optimization')
parser.add_argument('input', help='Input .seq file')
- parser.add_argument('--output', '-o', help='Output .cc file', required=True)
+ parser.add_argument('--output', '-o', help='Output .cc file')
parser.add_argument('--flatten', action='store_true', help='Generate flattened code (FINAL_STRIP mode)')
+ parser.add_argument('--validate', action='store_true', help='Validate DAG only (no code generation)')
args = parser.parse_args()
@@ -449,6 +470,20 @@ def main():
# Sort sequences by start time
sequences.sort(key=lambda s: s.start_time)
+ # Validate all sequences (always)
+ for seq in sequences:
+ validate_dag(seq)
+
+ # If --validate flag, exit after validation
+ if args.validate:
+ print(f"Validation passed: {len(sequences)} sequence(s) validated")
+ sys.exit(0)
+
+ # Require --output for code generation
+ if not args.output:
+ print("Error: --output is required for code generation", file=sys.stderr)
+ sys.exit(1)
+
# Calculate demo duration from max effect end time (absolute time)
demo_duration = 0.0
for seq in sequences:
@@ -466,8 +501,7 @@ def main():
'''
for seq_idx, seq in enumerate(sequences):
- # Validate DAG
- validate_dag(seq)
+ # validate_dag() already called above
# Topological sort
sorted_effects = topological_sort(seq)