From 3f7bf360d41612e772ff1acbe6b6cbc88edfe454 Mon Sep 17 00:00:00 2001 From: Associate 1 Date: Thu, 19 Feb 2026 12:01:12 -0700 Subject: [PATCH] Fix infinite loops in parser and support nested IF constructs The parser would enter infinite loops (consuming memory until OOM) when encountering: (1) nested/replicated IF within an IF block, and (2) unknown type declarations inside variant receive CASE blocks. Both caused by parsing loops that failed to advance the token position on unrecognized constructs. Fix nested IF by detecting IF tokens as choices in parseIfStatement and recursively parsing them, storing in a new IfChoice.NestedIf field. Codegen flattens non-replicated nested IFs and emits replicated ones as loops with a _ifmatched flag. Add progress guards to all parser loops (IF, CASE, ALT, variant receive) as a safety net. Co-Authored-By: Claude Opus 4.6 --- ast/ast.go | 1 + codegen/codegen.go | 198 +++++++++++++++++++++++++++++------- codegen/e2e_control_test.go | 93 +++++++++++++++++ parser/parser.go | 82 +++++++++++++-- 4 files changed, 329 insertions(+), 45 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index 370a898..3363ffd 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -205,6 +205,7 @@ type IfStatement struct { type IfChoice struct { Condition Expression Body []Statement + NestedIf *IfStatement // non-nil when this choice is a nested/replicated IF } func (i *IfStatement) statementNode() {} diff --git a/codegen/codegen.go b/codegen/codegen.go index b56ba9b..0e8c8a3 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -221,6 +221,11 @@ func (g *Generator) containsPar(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsPar(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsPar(inner) { return true @@ -289,6 +294,11 @@ func (g *Generator) containsPrint(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsPrint(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsPrint(inner) { return true @@ -360,6 +370,11 @@ func (g *Generator) containsTimer(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsTimer(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsTimer(inner) { return true @@ -428,6 +443,11 @@ func (g *Generator) containsStop(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsStop(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsStop(inner) { return true @@ -504,6 +524,11 @@ func (g *Generator) containsMostExpr(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsMostExpr(choice.NestedIf) { + return true + } + } if g.exprNeedsMath(choice.Condition) { return true } @@ -943,6 +968,9 @@ func (g *Generator) collectChanProtocols(stmt ast.Statement) { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + g.collectChanProtocols(choice.NestedIf) + } for _, inner := range choice.Body { g.collectChanProtocols(inner) } @@ -999,6 +1027,9 @@ func (g *Generator) collectRecordVars(stmt ast.Statement) { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + g.collectRecordVars(choice.NestedIf) + } for _, inner := range choice.Body { g.collectRecordVars(inner) } @@ -1551,35 +1582,104 @@ func (g *Generator) generateWhileLoop(loop *ast.WhileLoop) { func (g *Generator) generateIfStatement(stmt *ast.IfStatement) { if stmt.Replicator != nil { // Replicated IF: IF i = start FOR count → for loop with break on first match - v := stmt.Replicator.Variable - if stmt.Replicator.Step != nil { - counter := "_repl_" + v - g.builder.WriteString(strings.Repeat("\t", g.indent)) - g.write(fmt.Sprintf("for %s := 0; %s < ", counter, counter)) - g.generateExpression(stmt.Replicator.Count) - g.write(fmt.Sprintf("; %s++ {\n", counter)) - g.indent++ - g.builder.WriteString(strings.Repeat("\t", g.indent)) - g.write(fmt.Sprintf("%s := ", v)) - g.generateExpression(stmt.Replicator.Start) - g.write(fmt.Sprintf(" + %s * ", counter)) - g.generateExpression(stmt.Replicator.Step) - g.write("\n") + g.generateReplicatedIfLoop(stmt, false) + } else { + // Flatten non-replicated nested IFs into the parent choice list + choices := g.flattenIfChoices(stmt.Choices) + g.generateIfChoiceChain(choices, true) + } +} + +// flattenIfChoices inlines choices from non-replicated nested IFs into a flat list. +// Replicated nested IFs are preserved as-is (they need special loop codegen). +func (g *Generator) flattenIfChoices(choices []ast.IfChoice) []ast.IfChoice { + var flat []ast.IfChoice + for _, c := range choices { + if c.NestedIf != nil && c.NestedIf.Replicator == nil { + // Non-replicated nested IF: inline its choices recursively + flat = append(flat, g.flattenIfChoices(c.NestedIf.Choices)...) } else { - g.builder.WriteString(strings.Repeat("\t", g.indent)) - g.write(fmt.Sprintf("for %s := ", v)) - g.generateExpression(stmt.Replicator.Start) - g.write(fmt.Sprintf("; %s < ", v)) - g.generateExpression(stmt.Replicator.Start) - g.write(" + ") - g.generateExpression(stmt.Replicator.Count) - g.write(fmt.Sprintf("; %s++ {\n", v)) - g.indent++ + flat = append(flat, c) } + } + return flat +} - for i, choice := range stmt.Choices { +// generateReplicatedIfLoop emits a for loop that breaks on first matching choice. +// When withinFlag is true, it sets _ifmatched = true before breaking. +func (g *Generator) generateReplicatedIfLoop(stmt *ast.IfStatement, withinFlag bool) { + repl := stmt.Replicator + v := repl.Variable + if repl.Step != nil { + counter := "_repl_" + v + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write(fmt.Sprintf("for %s := 0; %s < ", counter, counter)) + g.generateExpression(repl.Count) + g.write(fmt.Sprintf("; %s++ {\n", counter)) + g.indent++ + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write(fmt.Sprintf("%s := ", v)) + g.generateExpression(repl.Start) + g.write(fmt.Sprintf(" + %s * ", counter)) + g.generateExpression(repl.Step) + g.write("\n") + } else { + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write(fmt.Sprintf("for %s := ", v)) + g.generateExpression(repl.Start) + g.write(fmt.Sprintf("; %s < ", v)) + g.generateExpression(repl.Start) + g.write(" + ") + g.generateExpression(repl.Count) + g.write(fmt.Sprintf("; %s++ {\n", v)) + g.indent++ + } + + for i, choice := range stmt.Choices { + g.builder.WriteString(strings.Repeat("\t", g.indent)) + if i == 0 { + g.write("if ") + } else { + g.write("} else if ") + } + g.generateExpression(choice.Condition) + g.write(" {\n") + g.indent++ + + for _, s := range choice.Body { + g.generateStatement(s) + } + if withinFlag { + g.writeLine("_ifmatched = true") + } + g.writeLine("break") + + g.indent-- + } + g.writeLine("}") + + g.indent-- + g.writeLine("}") +} + +// generateIfChoiceChain emits a chain of if/else-if for the given choices. +// When a replicated nested IF is encountered, it splits the chain and uses +// a _ifmatched flag to determine whether remaining choices should be tried. +func (g *Generator) generateIfChoiceChain(choices []ast.IfChoice, isFirst bool) { + // Find first replicated nested IF + replIdx := -1 + for i, c := range choices { + if c.NestedIf != nil && c.NestedIf.Replicator != nil { + replIdx = i + break + } + } + + if replIdx == -1 { + // No replicated nested IFs — simple if/else-if chain + for i, choice := range choices { g.builder.WriteString(strings.Repeat("\t", g.indent)) - if i == 0 { + if i == 0 && isFirst { g.write("if ") } else { g.write("} else if ") @@ -1591,18 +1691,25 @@ func (g *Generator) generateIfStatement(stmt *ast.IfStatement) { for _, s := range choice.Body { g.generateStatement(s) } - g.writeLine("break") g.indent-- } - g.writeLine("}") + if len(choices) > 0 { + g.writeLine("}") + } + return + } - g.indent-- - g.writeLine("}") - } else { - for i, choice := range stmt.Choices { + // Split at the replicated nested IF + before := choices[:replIdx] + replChoice := choices[replIdx] + after := choices[replIdx+1:] + + // Emit choices before the replicated IF as a normal if-else chain + if len(before) > 0 { + for i, choice := range before { g.builder.WriteString(strings.Repeat("\t", g.indent)) - if i == 0 { + if i == 0 && isFirst { g.write("if ") } else { g.write("} else if ") @@ -1610,13 +1717,36 @@ func (g *Generator) generateIfStatement(stmt *ast.IfStatement) { g.generateExpression(choice.Condition) g.write(" {\n") g.indent++ - for _, s := range choice.Body { g.generateStatement(s) } - g.indent-- } + // Open else block for the replicated IF + remaining choices + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write("} else {\n") + g.indent++ + } + + // Emit the replicated nested IF with a flag + needFlag := len(after) > 0 + if needFlag { + g.writeLine("_ifmatched := false") + } + g.generateReplicatedIfLoop(replChoice.NestedIf, needFlag) + + // Emit remaining choices inside if !_ifmatched (recursive for multiple) + if len(after) > 0 { + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write("if !_ifmatched {\n") + g.indent++ + g.generateIfChoiceChain(after, true) // recursive for remaining + g.indent-- + g.writeLine("}") + } + + if len(before) > 0 { + g.indent-- g.writeLine("}") } } diff --git a/codegen/e2e_control_test.go b/codegen/e2e_control_test.go index ccedf8a..e07f032 100644 --- a/codegen/e2e_control_test.go +++ b/codegen/e2e_control_test.go @@ -259,6 +259,99 @@ func TestE2E_MultiStatementWhileBody(t *testing.T) { } } +func TestE2E_NestedReplicatedIfWithDefault(t *testing.T) { + // Replicated IF as a choice within outer IF, with TRUE default + occam := `SEQ + [5]INT arr: + INT result: + SEQ i = 0 FOR 5 + arr[i] := i * 10 + IF + IF i = 0 FOR 5 + arr[i] > 25 + result := arr[i] + TRUE + result := -1 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "30\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + +func TestE2E_NestedReplicatedIfNoMatch(t *testing.T) { + // Replicated IF where no choice matches, falls through to TRUE + occam := `SEQ + [3]INT arr: + INT result: + SEQ i = 0 FOR 3 + arr[i] := i + IF + IF i = 0 FOR 3 + arr[i] > 100 + result := arr[i] + TRUE + result := -1 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "-1\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + +func TestE2E_NestedReplicatedIfWithPrecedingChoice(t *testing.T) { + // Normal choice before replicated IF, then default + occam := `SEQ + [3]INT arr: + INT result: + SEQ i = 0 FOR 3 + arr[i] := i + INT x: + x := 99 + IF + x > 100 + result := x + IF i = 0 FOR 3 + arr[i] = 2 + result := arr[i] + TRUE + result := -1 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "2\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + +func TestE2E_NestedNonReplicatedIf(t *testing.T) { + // Non-replicated nested IF (choices inlined into parent) + occam := `SEQ + INT x: + INT result: + x := 5 + IF + IF + x > 10 + result := 1 + x > 3 + result := 2 + TRUE + result := 0 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "2\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + func TestE2E_ChannelDirAtCallSite(t *testing.T) { occam := `PROC worker(CHAN OF INT in?, CHAN OF INT out!) INT x: diff --git a/parser/parser.go b/parser/parser.go index 3340c70..accf746 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1238,12 +1238,17 @@ func (p *Parser) parseVariantReceive(channel string, token lexer.Token) *ast.Var break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + // Parse a variant case: tag [; var]* \n INDENT body vc := ast.VariantCase{} if !p.curTokenIs(lexer.IDENT) { p.addError(fmt.Sprintf("expected variant tag name, got %s", p.curToken.Type)) - return stmt + p.nextToken() // skip unrecognized token to avoid infinite loop + continue } vc.Tag = p.curToken.Literal @@ -1273,6 +1278,14 @@ func (p *Parser) parseVariantReceive(channel string, token lexer.Token) *ast.Var } stmt.Cases = append(stmt.Cases, vc) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + p.nextToken() // force progress + if p.curToken == prevToken { + break + } + } } return stmt @@ -1323,11 +1336,16 @@ func (p *Parser) parseVariantReceiveWithIndex(channel string, channelIndex ast.E break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + vc := ast.VariantCase{} if !p.curTokenIs(lexer.IDENT) { p.addError(fmt.Sprintf("expected variant tag name, got %s", p.curToken.Type)) - return stmt + p.nextToken() // skip unrecognized token to avoid infinite loop + continue } vc.Tag = p.curToken.Literal @@ -1354,6 +1372,14 @@ func (p *Parser) parseVariantReceiveWithIndex(channel string, channelIndex ast.E } stmt.Cases = append(stmt.Cases, vc) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + p.nextToken() // force progress + if p.curToken == prevToken { + break + } + } } return stmt @@ -1517,11 +1543,20 @@ func (p *Parser) parseAltCases() []ast.AltCase { break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + // Parse an ALT case: [guard &] channel ? var altCase := p.parseAltCase() if altCase != nil { cases = append(cases, *altCase) } + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + break + } } return cases @@ -2269,21 +2304,37 @@ func (p *Parser) parseIfStatement() *ast.IfStatement { break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + choice := ast.IfChoice{} - choice.Condition = p.parseExpression(LOWEST) - // Skip newlines and expect INDENT for body - for p.peekTokenIs(lexer.NEWLINE) { - p.nextToken() - } + // Nested IF (plain or replicated) used as a choice within this IF + if p.curTokenIs(lexer.IF) { + nestedIf := p.parseIfStatement() + choice.NestedIf = nestedIf + } else { + choice.Condition = p.parseExpression(LOWEST) - if p.peekTokenIs(lexer.INDENT) { - p.nextToken() // consume INDENT - p.nextToken() // move to body - choice.Body = p.parseBodyStatements() + // Skip newlines and expect INDENT for body + for p.peekTokenIs(lexer.NEWLINE) { + p.nextToken() + } + + if p.peekTokenIs(lexer.INDENT) { + p.nextToken() // consume INDENT + p.nextToken() // move to body + choice.Body = p.parseBodyStatements() + } } stmt.Choices = append(stmt.Choices, choice) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + break + } } return stmt @@ -2338,6 +2389,10 @@ func (p *Parser) parseCaseStatement() *ast.CaseStatement { break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + choice := ast.CaseChoice{} if p.curTokenIs(lexer.ELSE) { @@ -2359,6 +2414,11 @@ func (p *Parser) parseCaseStatement() *ast.CaseStatement { } stmt.Choices = append(stmt.Choices, choice) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + break + } } return stmt