diff --git a/ast/ast.go b/ast/ast.go index 370a898..3363ffd 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -205,6 +205,7 @@ type IfStatement struct { type IfChoice struct { Condition Expression Body []Statement + NestedIf *IfStatement // non-nil when this choice is a nested/replicated IF } func (i *IfStatement) statementNode() {} diff --git a/codegen/codegen.go b/codegen/codegen.go index b56ba9b..0e8c8a3 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -221,6 +221,11 @@ func (g *Generator) containsPar(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsPar(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsPar(inner) { return true @@ -289,6 +294,11 @@ func (g *Generator) containsPrint(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsPrint(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsPrint(inner) { return true @@ -360,6 +370,11 @@ func (g *Generator) containsTimer(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsTimer(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsTimer(inner) { return true @@ -428,6 +443,11 @@ func (g *Generator) containsStop(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsStop(choice.NestedIf) { + return true + } + } for _, inner := range choice.Body { if g.containsStop(inner) { return true @@ -504,6 +524,11 @@ func (g *Generator) containsMostExpr(stmt ast.Statement) bool { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + if g.containsMostExpr(choice.NestedIf) { + return true + } + } if g.exprNeedsMath(choice.Condition) { return true } @@ -943,6 +968,9 @@ func (g *Generator) collectChanProtocols(stmt ast.Statement) { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + g.collectChanProtocols(choice.NestedIf) + } for _, inner := range choice.Body { g.collectChanProtocols(inner) } @@ -999,6 +1027,9 @@ func (g *Generator) collectRecordVars(stmt ast.Statement) { } case *ast.IfStatement: for _, choice := range s.Choices { + if choice.NestedIf != nil { + g.collectRecordVars(choice.NestedIf) + } for _, inner := range choice.Body { g.collectRecordVars(inner) } @@ -1551,35 +1582,104 @@ func (g *Generator) generateWhileLoop(loop *ast.WhileLoop) { func (g *Generator) generateIfStatement(stmt *ast.IfStatement) { if stmt.Replicator != nil { // Replicated IF: IF i = start FOR count → for loop with break on first match - v := stmt.Replicator.Variable - if stmt.Replicator.Step != nil { - counter := "_repl_" + v - g.builder.WriteString(strings.Repeat("\t", g.indent)) - g.write(fmt.Sprintf("for %s := 0; %s < ", counter, counter)) - g.generateExpression(stmt.Replicator.Count) - g.write(fmt.Sprintf("; %s++ {\n", counter)) - g.indent++ - g.builder.WriteString(strings.Repeat("\t", g.indent)) - g.write(fmt.Sprintf("%s := ", v)) - g.generateExpression(stmt.Replicator.Start) - g.write(fmt.Sprintf(" + %s * ", counter)) - g.generateExpression(stmt.Replicator.Step) - g.write("\n") + g.generateReplicatedIfLoop(stmt, false) + } else { + // Flatten non-replicated nested IFs into the parent choice list + choices := g.flattenIfChoices(stmt.Choices) + g.generateIfChoiceChain(choices, true) + } +} + +// flattenIfChoices inlines choices from non-replicated nested IFs into a flat list. +// Replicated nested IFs are preserved as-is (they need special loop codegen). +func (g *Generator) flattenIfChoices(choices []ast.IfChoice) []ast.IfChoice { + var flat []ast.IfChoice + for _, c := range choices { + if c.NestedIf != nil && c.NestedIf.Replicator == nil { + // Non-replicated nested IF: inline its choices recursively + flat = append(flat, g.flattenIfChoices(c.NestedIf.Choices)...) } else { - g.builder.WriteString(strings.Repeat("\t", g.indent)) - g.write(fmt.Sprintf("for %s := ", v)) - g.generateExpression(stmt.Replicator.Start) - g.write(fmt.Sprintf("; %s < ", v)) - g.generateExpression(stmt.Replicator.Start) - g.write(" + ") - g.generateExpression(stmt.Replicator.Count) - g.write(fmt.Sprintf("; %s++ {\n", v)) - g.indent++ + flat = append(flat, c) } + } + return flat +} - for i, choice := range stmt.Choices { +// generateReplicatedIfLoop emits a for loop that breaks on first matching choice. +// When withinFlag is true, it sets _ifmatched = true before breaking. +func (g *Generator) generateReplicatedIfLoop(stmt *ast.IfStatement, withinFlag bool) { + repl := stmt.Replicator + v := repl.Variable + if repl.Step != nil { + counter := "_repl_" + v + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write(fmt.Sprintf("for %s := 0; %s < ", counter, counter)) + g.generateExpression(repl.Count) + g.write(fmt.Sprintf("; %s++ {\n", counter)) + g.indent++ + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write(fmt.Sprintf("%s := ", v)) + g.generateExpression(repl.Start) + g.write(fmt.Sprintf(" + %s * ", counter)) + g.generateExpression(repl.Step) + g.write("\n") + } else { + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write(fmt.Sprintf("for %s := ", v)) + g.generateExpression(repl.Start) + g.write(fmt.Sprintf("; %s < ", v)) + g.generateExpression(repl.Start) + g.write(" + ") + g.generateExpression(repl.Count) + g.write(fmt.Sprintf("; %s++ {\n", v)) + g.indent++ + } + + for i, choice := range stmt.Choices { + g.builder.WriteString(strings.Repeat("\t", g.indent)) + if i == 0 { + g.write("if ") + } else { + g.write("} else if ") + } + g.generateExpression(choice.Condition) + g.write(" {\n") + g.indent++ + + for _, s := range choice.Body { + g.generateStatement(s) + } + if withinFlag { + g.writeLine("_ifmatched = true") + } + g.writeLine("break") + + g.indent-- + } + g.writeLine("}") + + g.indent-- + g.writeLine("}") +} + +// generateIfChoiceChain emits a chain of if/else-if for the given choices. +// When a replicated nested IF is encountered, it splits the chain and uses +// a _ifmatched flag to determine whether remaining choices should be tried. +func (g *Generator) generateIfChoiceChain(choices []ast.IfChoice, isFirst bool) { + // Find first replicated nested IF + replIdx := -1 + for i, c := range choices { + if c.NestedIf != nil && c.NestedIf.Replicator != nil { + replIdx = i + break + } + } + + if replIdx == -1 { + // No replicated nested IFs — simple if/else-if chain + for i, choice := range choices { g.builder.WriteString(strings.Repeat("\t", g.indent)) - if i == 0 { + if i == 0 && isFirst { g.write("if ") } else { g.write("} else if ") @@ -1591,18 +1691,25 @@ func (g *Generator) generateIfStatement(stmt *ast.IfStatement) { for _, s := range choice.Body { g.generateStatement(s) } - g.writeLine("break") g.indent-- } - g.writeLine("}") + if len(choices) > 0 { + g.writeLine("}") + } + return + } - g.indent-- - g.writeLine("}") - } else { - for i, choice := range stmt.Choices { + // Split at the replicated nested IF + before := choices[:replIdx] + replChoice := choices[replIdx] + after := choices[replIdx+1:] + + // Emit choices before the replicated IF as a normal if-else chain + if len(before) > 0 { + for i, choice := range before { g.builder.WriteString(strings.Repeat("\t", g.indent)) - if i == 0 { + if i == 0 && isFirst { g.write("if ") } else { g.write("} else if ") @@ -1610,13 +1717,36 @@ func (g *Generator) generateIfStatement(stmt *ast.IfStatement) { g.generateExpression(choice.Condition) g.write(" {\n") g.indent++ - for _, s := range choice.Body { g.generateStatement(s) } - g.indent-- } + // Open else block for the replicated IF + remaining choices + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write("} else {\n") + g.indent++ + } + + // Emit the replicated nested IF with a flag + needFlag := len(after) > 0 + if needFlag { + g.writeLine("_ifmatched := false") + } + g.generateReplicatedIfLoop(replChoice.NestedIf, needFlag) + + // Emit remaining choices inside if !_ifmatched (recursive for multiple) + if len(after) > 0 { + g.builder.WriteString(strings.Repeat("\t", g.indent)) + g.write("if !_ifmatched {\n") + g.indent++ + g.generateIfChoiceChain(after, true) // recursive for remaining + g.indent-- + g.writeLine("}") + } + + if len(before) > 0 { + g.indent-- g.writeLine("}") } } diff --git a/codegen/e2e_control_test.go b/codegen/e2e_control_test.go index ccedf8a..e07f032 100644 --- a/codegen/e2e_control_test.go +++ b/codegen/e2e_control_test.go @@ -259,6 +259,99 @@ func TestE2E_MultiStatementWhileBody(t *testing.T) { } } +func TestE2E_NestedReplicatedIfWithDefault(t *testing.T) { + // Replicated IF as a choice within outer IF, with TRUE default + occam := `SEQ + [5]INT arr: + INT result: + SEQ i = 0 FOR 5 + arr[i] := i * 10 + IF + IF i = 0 FOR 5 + arr[i] > 25 + result := arr[i] + TRUE + result := -1 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "30\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + +func TestE2E_NestedReplicatedIfNoMatch(t *testing.T) { + // Replicated IF where no choice matches, falls through to TRUE + occam := `SEQ + [3]INT arr: + INT result: + SEQ i = 0 FOR 3 + arr[i] := i + IF + IF i = 0 FOR 3 + arr[i] > 100 + result := arr[i] + TRUE + result := -1 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "-1\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + +func TestE2E_NestedReplicatedIfWithPrecedingChoice(t *testing.T) { + // Normal choice before replicated IF, then default + occam := `SEQ + [3]INT arr: + INT result: + SEQ i = 0 FOR 3 + arr[i] := i + INT x: + x := 99 + IF + x > 100 + result := x + IF i = 0 FOR 3 + arr[i] = 2 + result := arr[i] + TRUE + result := -1 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "2\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + +func TestE2E_NestedNonReplicatedIf(t *testing.T) { + // Non-replicated nested IF (choices inlined into parent) + occam := `SEQ + INT x: + INT result: + x := 5 + IF + IF + x > 10 + result := 1 + x > 3 + result := 2 + TRUE + result := 0 + print.int(result) +` + output := transpileCompileRun(t, occam) + expected := "2\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + func TestE2E_ChannelDirAtCallSite(t *testing.T) { occam := `PROC worker(CHAN OF INT in?, CHAN OF INT out!) INT x: diff --git a/parser/parser.go b/parser/parser.go index 3340c70..accf746 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1238,12 +1238,17 @@ func (p *Parser) parseVariantReceive(channel string, token lexer.Token) *ast.Var break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + // Parse a variant case: tag [; var]* \n INDENT body vc := ast.VariantCase{} if !p.curTokenIs(lexer.IDENT) { p.addError(fmt.Sprintf("expected variant tag name, got %s", p.curToken.Type)) - return stmt + p.nextToken() // skip unrecognized token to avoid infinite loop + continue } vc.Tag = p.curToken.Literal @@ -1273,6 +1278,14 @@ func (p *Parser) parseVariantReceive(channel string, token lexer.Token) *ast.Var } stmt.Cases = append(stmt.Cases, vc) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + p.nextToken() // force progress + if p.curToken == prevToken { + break + } + } } return stmt @@ -1323,11 +1336,16 @@ func (p *Parser) parseVariantReceiveWithIndex(channel string, channelIndex ast.E break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + vc := ast.VariantCase{} if !p.curTokenIs(lexer.IDENT) { p.addError(fmt.Sprintf("expected variant tag name, got %s", p.curToken.Type)) - return stmt + p.nextToken() // skip unrecognized token to avoid infinite loop + continue } vc.Tag = p.curToken.Literal @@ -1354,6 +1372,14 @@ func (p *Parser) parseVariantReceiveWithIndex(channel string, channelIndex ast.E } stmt.Cases = append(stmt.Cases, vc) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + p.nextToken() // force progress + if p.curToken == prevToken { + break + } + } } return stmt @@ -1517,11 +1543,20 @@ func (p *Parser) parseAltCases() []ast.AltCase { break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + // Parse an ALT case: [guard &] channel ? var altCase := p.parseAltCase() if altCase != nil { cases = append(cases, *altCase) } + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + break + } } return cases @@ -2269,21 +2304,37 @@ func (p *Parser) parseIfStatement() *ast.IfStatement { break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + choice := ast.IfChoice{} - choice.Condition = p.parseExpression(LOWEST) - // Skip newlines and expect INDENT for body - for p.peekTokenIs(lexer.NEWLINE) { - p.nextToken() - } + // Nested IF (plain or replicated) used as a choice within this IF + if p.curTokenIs(lexer.IF) { + nestedIf := p.parseIfStatement() + choice.NestedIf = nestedIf + } else { + choice.Condition = p.parseExpression(LOWEST) - if p.peekTokenIs(lexer.INDENT) { - p.nextToken() // consume INDENT - p.nextToken() // move to body - choice.Body = p.parseBodyStatements() + // Skip newlines and expect INDENT for body + for p.peekTokenIs(lexer.NEWLINE) { + p.nextToken() + } + + if p.peekTokenIs(lexer.INDENT) { + p.nextToken() // consume INDENT + p.nextToken() // move to body + choice.Body = p.parseBodyStatements() + } } stmt.Choices = append(stmt.Choices, choice) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + break + } } return stmt @@ -2338,6 +2389,10 @@ func (p *Parser) parseCaseStatement() *ast.CaseStatement { break } + // Safety guard: record position before parsing to detect no-progress + prevToken := p.curToken + prevPeek := p.peekToken + choice := ast.CaseChoice{} if p.curTokenIs(lexer.ELSE) { @@ -2359,6 +2414,11 @@ func (p *Parser) parseCaseStatement() *ast.CaseStatement { } stmt.Choices = append(stmt.Choices, choice) + + // No-progress guard: if we haven't moved, break to prevent infinite loop + if p.curToken == prevToken && p.peekToken == prevPeek { + break + } } return stmt