Skip to content

Commit 445664d

Browse files
committed
Simplify If Expressions
1 parent f8d3f3e commit 445664d

File tree

4 files changed

+154
-0
lines changed

4 files changed

+154
-0
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import liquidjava.rj_language.ast.BinaryExpression;
44
import liquidjava.rj_language.ast.Expression;
55
import liquidjava.rj_language.ast.GroupExpression;
6+
import liquidjava.rj_language.ast.Ite;
67
import liquidjava.rj_language.ast.LiteralBoolean;
78
import liquidjava.rj_language.ast.LiteralInt;
89
import liquidjava.rj_language.ast.LiteralReal;
910
import liquidjava.rj_language.ast.UnaryExpression;
1011
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
1112
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
13+
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
1214
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1315
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1416

@@ -26,6 +28,9 @@ public static ValDerivationNode fold(ValDerivationNode node) {
2628
if (exp instanceof UnaryExpression)
2729
return foldUnary(node);
2830

31+
if (exp instanceof Ite)
32+
return foldIte(node);
33+
2934
if (exp instanceof GroupExpression group) {
3035
if (group.getChildren().size() == 1) {
3136
return fold(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()));
@@ -191,4 +196,45 @@ private static ValDerivationNode foldUnary(ValDerivationNode node) {
191196
DerivationNode origin = operandNode.getOrigin() != null ? new UnaryDerivationNode(operandNode, operator) : null;
192197
return new ValDerivationNode(unaryExp, origin);
193198
}
199+
200+
/**
201+
* Folds ternary expressions by checking if condition is a boolean literal or both branches are the same
202+
*/
203+
private static ValDerivationNode foldIte(ValDerivationNode node) {
204+
Ite iteExp = (Ite) node.getValue();
205+
206+
ValDerivationNode condNode = fold(new ValDerivationNode(iteExp.getCondition(), null));
207+
ValDerivationNode thenNode = fold(new ValDerivationNode(iteExp.getThen(), null));
208+
ValDerivationNode elseNode = fold(new ValDerivationNode(iteExp.getElse(), null));
209+
210+
Expression condition = condNode.getValue();
211+
Expression thenExp = thenNode.getValue();
212+
Expression elseExp = elseNode.getValue();
213+
214+
iteExp.setChild(0, condition);
215+
iteExp.setChild(1, thenExp);
216+
iteExp.setChild(2, elseExp);
217+
218+
// if condition is a boolean literal, select the corresponding branch: true ? a : b => a, false ? a : b => b
219+
if (condition instanceof LiteralBoolean boolCond) {
220+
Expression selected = boolCond.isBooleanTrue() ? thenExp : elseExp;
221+
DerivationNode origin = new IteDerivationNode(condNode, thenNode, elseNode);
222+
return new ValDerivationNode(selected, origin);
223+
}
224+
225+
// if both branches are the same, return one of them (e.g. cond ? b : b => b)
226+
if (thenExp.equals(elseExp)) {
227+
DerivationNode origin = new IteDerivationNode(condNode, thenNode, elseNode);
228+
return new ValDerivationNode(thenExp, origin);
229+
}
230+
231+
// no folding, but keep track of the folding steps in the origin
232+
DerivationNode origin = hasIteChildOrigin(condNode, thenNode, elseNode)
233+
? new IteDerivationNode(condNode, thenNode, elseNode) : node.getOrigin();
234+
return new ValDerivationNode(iteExp, origin);
235+
}
236+
237+
private static boolean hasIteChildOrigin(ValDerivationNode cond, ValDerivationNode then, ValDerivationNode els) {
238+
return cond.getOrigin() != null || then.getOrigin() != null || els.getOrigin() != null;
239+
}
194240
}

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import liquidjava.rj_language.ast.Var;
77
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
88
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
9+
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
910
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1011
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1112
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
@@ -134,6 +135,10 @@ private static void extractVarOrigins(ValDerivationNode node, Map<String, Deriva
134135
extractVarOrigins(binOrigin.getRight(), varOrigins);
135136
} else if (origin instanceof UnaryDerivationNode unaryOrigin) {
136137
extractVarOrigins(unaryOrigin.getOperand(), varOrigins);
138+
} else if (origin instanceof IteDerivationNode iteOrigin) {
139+
extractVarOrigins(iteOrigin.getCondition(), varOrigins);
140+
extractVarOrigins(iteOrigin.getThenBranch(), varOrigins);
141+
extractVarOrigins(iteOrigin.getElseBranch(), varOrigins);
137142
} else if (origin instanceof ValDerivationNode valOrigin) {
138143
extractVarOrigins(valOrigin, varOrigins);
139144
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package liquidjava.rj_language.opt.derivation_node;
2+
3+
public class IteDerivationNode extends DerivationNode {
4+
5+
private final ValDerivationNode condition;
6+
private final ValDerivationNode thenBranch;
7+
private final ValDerivationNode elseBranch;
8+
9+
public IteDerivationNode(ValDerivationNode condition, ValDerivationNode thenBranch, ValDerivationNode elseBranch) {
10+
this.condition = condition;
11+
this.thenBranch = thenBranch;
12+
this.elseBranch = elseBranch;
13+
}
14+
15+
public ValDerivationNode getCondition() {
16+
return condition;
17+
}
18+
19+
public ValDerivationNode getThenBranch() {
20+
return thenBranch;
21+
}
22+
23+
public ValDerivationNode getElseBranch() {
24+
return elseBranch;
25+
}
26+
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
import liquidjava.rj_language.ast.BinaryExpression;
66
import liquidjava.rj_language.ast.Expression;
7+
import liquidjava.rj_language.ast.Ite;
78
import liquidjava.rj_language.ast.LiteralBoolean;
89
import liquidjava.rj_language.ast.LiteralInt;
910
import liquidjava.rj_language.ast.UnaryExpression;
1011
import liquidjava.rj_language.ast.Var;
1112
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
1213
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
14+
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
1315
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1416
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1517
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
@@ -550,6 +552,76 @@ void testTransitive() {
550552
assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1");
551553
}
552554

555+
@Test
556+
void testIteTrueConditionSimplifiesToThenBranch() {
557+
// Given: true ? a : b
558+
// Expected: a
559+
560+
Expression expr = new Ite(new LiteralBoolean(true), new Var("a"), new Var("b"));
561+
562+
// When
563+
ValDerivationNode result = ExpressionSimplifier.simplify(expr);
564+
565+
// Then
566+
assertNotNull(result, "Result should not be null");
567+
assertEquals("a", result.getValue().toString(), "Expected result to be a");
568+
569+
ValDerivationNode conditionNode = new ValDerivationNode(new LiteralBoolean(true), null);
570+
ValDerivationNode thenNode = new ValDerivationNode(new Var("a"), null);
571+
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
572+
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
573+
ValDerivationNode expected = new ValDerivationNode(new Var("a"), iteOrigin);
574+
575+
assertDerivationEquals(expected, result, "");
576+
}
577+
578+
@Test
579+
void testIteFalseConditionSimplifiesToElseBranch() {
580+
// Given: false ? a : b
581+
// Expected: b
582+
583+
Expression expr = new Ite(new LiteralBoolean(false), new Var("a"), new Var("b"));
584+
585+
// When
586+
ValDerivationNode result = ExpressionSimplifier.simplify(expr);
587+
588+
// Then
589+
assertNotNull(result, "Result should not be null");
590+
assertEquals("b", result.getValue().toString(), "Expected result to be b");
591+
592+
ValDerivationNode conditionNode = new ValDerivationNode(new LiteralBoolean(false), null);
593+
ValDerivationNode thenNode = new ValDerivationNode(new Var("a"), null);
594+
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
595+
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
596+
ValDerivationNode expected = new ValDerivationNode(new Var("b"), iteOrigin);
597+
598+
assertDerivationEquals(expected, result, "");
599+
}
600+
601+
@Test
602+
void testIteEqualBranchesSimplifiesToBranch() {
603+
// Given: cond ? b : b
604+
// Expected: b
605+
606+
Expression branch = new Var("b");
607+
Expression expr = new Ite(new Var("cond"), branch, branch.clone());
608+
609+
// When
610+
ValDerivationNode result = ExpressionSimplifier.simplify(expr);
611+
612+
// Then
613+
assertNotNull(result, "Result should not be null");
614+
assertEquals("b", result.getValue().toString(), "Expected result to be b");
615+
616+
ValDerivationNode conditionNode = new ValDerivationNode(new Var("cond"), null);
617+
ValDerivationNode thenNode = new ValDerivationNode(new Var("b"), null);
618+
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
619+
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
620+
ValDerivationNode expected = new ValDerivationNode(new Var("b"), iteOrigin);
621+
622+
assertDerivationEquals(expected, result, "");
623+
}
624+
553625
/**
554626
* Helper method to compare two derivation nodes recursively
555627
*/
@@ -576,6 +648,11 @@ private void assertDerivationEquals(DerivationNode expected, DerivationNode actu
576648
UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual;
577649
assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match");
578650
assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand");
651+
} else if (expected instanceof IteDerivationNode expectedIte) {
652+
IteDerivationNode actualIte = (IteDerivationNode) actual;
653+
assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition");
654+
assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then");
655+
assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else");
579656
}
580657
}
581658
}

0 commit comments

Comments
 (0)