Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ type IfStatement struct {
type IfChoice struct {
Condition Expression
Body []Statement
NestedIf *IfStatement // non-nil when this choice is a nested/replicated IF
}

func (i *IfStatement) statementNode() {}
Expand Down
198 changes: 164 additions & 34 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ func (g *Generator) containsPar(stmt ast.Statement) bool {
}
case *ast.IfStatement:
for _, choice := range s.Choices {
if choice.NestedIf != nil {
if g.containsPar(choice.NestedIf) {
return true
}
}
for _, inner := range choice.Body {
if g.containsPar(inner) {
return true
Expand Down Expand Up @@ -289,6 +294,11 @@ func (g *Generator) containsPrint(stmt ast.Statement) bool {
}
case *ast.IfStatement:
for _, choice := range s.Choices {
if choice.NestedIf != nil {
if g.containsPrint(choice.NestedIf) {
return true
}
}
for _, inner := range choice.Body {
if g.containsPrint(inner) {
return true
Expand Down Expand Up @@ -360,6 +370,11 @@ func (g *Generator) containsTimer(stmt ast.Statement) bool {
}
case *ast.IfStatement:
for _, choice := range s.Choices {
if choice.NestedIf != nil {
if g.containsTimer(choice.NestedIf) {
return true
}
}
for _, inner := range choice.Body {
if g.containsTimer(inner) {
return true
Expand Down Expand Up @@ -428,6 +443,11 @@ func (g *Generator) containsStop(stmt ast.Statement) bool {
}
case *ast.IfStatement:
for _, choice := range s.Choices {
if choice.NestedIf != nil {
if g.containsStop(choice.NestedIf) {
return true
}
}
for _, inner := range choice.Body {
if g.containsStop(inner) {
return true
Expand Down Expand Up @@ -504,6 +524,11 @@ func (g *Generator) containsMostExpr(stmt ast.Statement) bool {
}
case *ast.IfStatement:
for _, choice := range s.Choices {
if choice.NestedIf != nil {
if g.containsMostExpr(choice.NestedIf) {
return true
}
}
if g.exprNeedsMath(choice.Condition) {
return true
}
Expand Down Expand Up @@ -943,6 +968,9 @@ func (g *Generator) collectChanProtocols(stmt ast.Statement) {
}
case *ast.IfStatement:
for _, choice := range s.Choices {
if choice.NestedIf != nil {
g.collectChanProtocols(choice.NestedIf)
}
for _, inner := range choice.Body {
g.collectChanProtocols(inner)
}
Expand Down Expand Up @@ -999,6 +1027,9 @@ func (g *Generator) collectRecordVars(stmt ast.Statement) {
}
case *ast.IfStatement:
for _, choice := range s.Choices {
if choice.NestedIf != nil {
g.collectRecordVars(choice.NestedIf)
}
for _, inner := range choice.Body {
g.collectRecordVars(inner)
}
Expand Down Expand Up @@ -1551,35 +1582,104 @@ func (g *Generator) generateWhileLoop(loop *ast.WhileLoop) {
func (g *Generator) generateIfStatement(stmt *ast.IfStatement) {
if stmt.Replicator != nil {
// Replicated IF: IF i = start FOR count → for loop with break on first match
v := stmt.Replicator.Variable
if stmt.Replicator.Step != nil {
counter := "_repl_" + v
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write(fmt.Sprintf("for %s := 0; %s < ", counter, counter))
g.generateExpression(stmt.Replicator.Count)
g.write(fmt.Sprintf("; %s++ {\n", counter))
g.indent++
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write(fmt.Sprintf("%s := ", v))
g.generateExpression(stmt.Replicator.Start)
g.write(fmt.Sprintf(" + %s * ", counter))
g.generateExpression(stmt.Replicator.Step)
g.write("\n")
g.generateReplicatedIfLoop(stmt, false)
} else {
// Flatten non-replicated nested IFs into the parent choice list
choices := g.flattenIfChoices(stmt.Choices)
g.generateIfChoiceChain(choices, true)
}
}

// flattenIfChoices inlines choices from non-replicated nested IFs into a flat list.
// Replicated nested IFs are preserved as-is (they need special loop codegen).
func (g *Generator) flattenIfChoices(choices []ast.IfChoice) []ast.IfChoice {
var flat []ast.IfChoice
for _, c := range choices {
if c.NestedIf != nil && c.NestedIf.Replicator == nil {
// Non-replicated nested IF: inline its choices recursively
flat = append(flat, g.flattenIfChoices(c.NestedIf.Choices)...)
} else {
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write(fmt.Sprintf("for %s := ", v))
g.generateExpression(stmt.Replicator.Start)
g.write(fmt.Sprintf("; %s < ", v))
g.generateExpression(stmt.Replicator.Start)
g.write(" + ")
g.generateExpression(stmt.Replicator.Count)
g.write(fmt.Sprintf("; %s++ {\n", v))
g.indent++
flat = append(flat, c)
}
}
return flat
}

for i, choice := range stmt.Choices {
// generateReplicatedIfLoop emits a for loop that breaks on first matching choice.
// When withinFlag is true, it sets _ifmatched = true before breaking.
func (g *Generator) generateReplicatedIfLoop(stmt *ast.IfStatement, withinFlag bool) {
repl := stmt.Replicator
v := repl.Variable
if repl.Step != nil {
counter := "_repl_" + v
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write(fmt.Sprintf("for %s := 0; %s < ", counter, counter))
g.generateExpression(repl.Count)
g.write(fmt.Sprintf("; %s++ {\n", counter))
g.indent++
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write(fmt.Sprintf("%s := ", v))
g.generateExpression(repl.Start)
g.write(fmt.Sprintf(" + %s * ", counter))
g.generateExpression(repl.Step)
g.write("\n")
} else {
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write(fmt.Sprintf("for %s := ", v))
g.generateExpression(repl.Start)
g.write(fmt.Sprintf("; %s < ", v))
g.generateExpression(repl.Start)
g.write(" + ")
g.generateExpression(repl.Count)
g.write(fmt.Sprintf("; %s++ {\n", v))
g.indent++
}

for i, choice := range stmt.Choices {
g.builder.WriteString(strings.Repeat("\t", g.indent))
if i == 0 {
g.write("if ")
} else {
g.write("} else if ")
}
g.generateExpression(choice.Condition)
g.write(" {\n")
g.indent++

for _, s := range choice.Body {
g.generateStatement(s)
}
if withinFlag {
g.writeLine("_ifmatched = true")
}
g.writeLine("break")

g.indent--
}
g.writeLine("}")

g.indent--
g.writeLine("}")
}

// generateIfChoiceChain emits a chain of if/else-if for the given choices.
// When a replicated nested IF is encountered, it splits the chain and uses
// a _ifmatched flag to determine whether remaining choices should be tried.
func (g *Generator) generateIfChoiceChain(choices []ast.IfChoice, isFirst bool) {
// Find first replicated nested IF
replIdx := -1
for i, c := range choices {
if c.NestedIf != nil && c.NestedIf.Replicator != nil {
replIdx = i
break
}
}

if replIdx == -1 {
// No replicated nested IFs — simple if/else-if chain
for i, choice := range choices {
g.builder.WriteString(strings.Repeat("\t", g.indent))
if i == 0 {
if i == 0 && isFirst {
g.write("if ")
} else {
g.write("} else if ")
Expand All @@ -1591,32 +1691,62 @@ func (g *Generator) generateIfStatement(stmt *ast.IfStatement) {
for _, s := range choice.Body {
g.generateStatement(s)
}
g.writeLine("break")

g.indent--
}
g.writeLine("}")
if len(choices) > 0 {
g.writeLine("}")
}
return
}

g.indent--
g.writeLine("}")
} else {
for i, choice := range stmt.Choices {
// Split at the replicated nested IF
before := choices[:replIdx]
replChoice := choices[replIdx]
after := choices[replIdx+1:]

// Emit choices before the replicated IF as a normal if-else chain
if len(before) > 0 {
for i, choice := range before {
g.builder.WriteString(strings.Repeat("\t", g.indent))
if i == 0 {
if i == 0 && isFirst {
g.write("if ")
} else {
g.write("} else if ")
}
g.generateExpression(choice.Condition)
g.write(" {\n")
g.indent++

for _, s := range choice.Body {
g.generateStatement(s)
}

g.indent--
}
// Open else block for the replicated IF + remaining choices
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write("} else {\n")
g.indent++
}

// Emit the replicated nested IF with a flag
needFlag := len(after) > 0
if needFlag {
g.writeLine("_ifmatched := false")
}
g.generateReplicatedIfLoop(replChoice.NestedIf, needFlag)

// Emit remaining choices inside if !_ifmatched (recursive for multiple)
if len(after) > 0 {
g.builder.WriteString(strings.Repeat("\t", g.indent))
g.write("if !_ifmatched {\n")
g.indent++
g.generateIfChoiceChain(after, true) // recursive for remaining
g.indent--
g.writeLine("}")
}

if len(before) > 0 {
g.indent--
g.writeLine("}")
}
}
Expand Down
93 changes: 93 additions & 0 deletions codegen/e2e_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,99 @@ func TestE2E_MultiStatementWhileBody(t *testing.T) {
}
}

func TestE2E_NestedReplicatedIfWithDefault(t *testing.T) {
// Replicated IF as a choice within outer IF, with TRUE default
occam := `SEQ
[5]INT arr:
INT result:
SEQ i = 0 FOR 5
arr[i] := i * 10
IF
IF i = 0 FOR 5
arr[i] > 25
result := arr[i]
TRUE
result := -1
print.int(result)
`
output := transpileCompileRun(t, occam)
expected := "30\n"
if output != expected {
t.Errorf("expected %q, got %q", expected, output)
}
}

func TestE2E_NestedReplicatedIfNoMatch(t *testing.T) {
// Replicated IF where no choice matches, falls through to TRUE
occam := `SEQ
[3]INT arr:
INT result:
SEQ i = 0 FOR 3
arr[i] := i
IF
IF i = 0 FOR 3
arr[i] > 100
result := arr[i]
TRUE
result := -1
print.int(result)
`
output := transpileCompileRun(t, occam)
expected := "-1\n"
if output != expected {
t.Errorf("expected %q, got %q", expected, output)
}
}

func TestE2E_NestedReplicatedIfWithPrecedingChoice(t *testing.T) {
// Normal choice before replicated IF, then default
occam := `SEQ
[3]INT arr:
INT result:
SEQ i = 0 FOR 3
arr[i] := i
INT x:
x := 99
IF
x > 100
result := x
IF i = 0 FOR 3
arr[i] = 2
result := arr[i]
TRUE
result := -1
print.int(result)
`
output := transpileCompileRun(t, occam)
expected := "2\n"
if output != expected {
t.Errorf("expected %q, got %q", expected, output)
}
}

func TestE2E_NestedNonReplicatedIf(t *testing.T) {
// Non-replicated nested IF (choices inlined into parent)
occam := `SEQ
INT x:
INT result:
x := 5
IF
IF
x > 10
result := 1
x > 3
result := 2
TRUE
result := 0
print.int(result)
`
output := transpileCompileRun(t, occam)
expected := "2\n"
if output != expected {
t.Errorf("expected %q, got %q", expected, output)
}
}

func TestE2E_ChannelDirAtCallSite(t *testing.T) {
occam := `PROC worker(CHAN OF INT in?, CHAN OF INT out!)
INT x:
Expand Down
Loading
Loading