Skip to content

Commit 77a3fad

Browse files
committed
Unwrap Unhelpful Literals
1 parent e9efbc4 commit 77a3fad

File tree

2 files changed

+193
-22
lines changed

2 files changed

+193
-22
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import liquidjava.rj_language.ast.BinaryExpression;
44
import liquidjava.rj_language.ast.Expression;
55
import liquidjava.rj_language.ast.LiteralBoolean;
6+
import liquidjava.rj_language.ast.UnaryExpression;
67
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
78
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
9+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
810
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
911

1012
public class ExpressionSimplifier {
@@ -15,7 +17,8 @@ public class ExpressionSimplifier {
1517
*/
1618
public static ValDerivationNode simplify(Expression exp) {
1719
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp);
18-
return simplifyValDerivationNode(fixedPoint);
20+
ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint);
21+
return unwrapDerivedBooleans(simplified);
1922
}
2023

2124
/**
@@ -119,4 +122,61 @@ private static boolean isRedundant(Expression exp) {
119122
}
120123
return false;
121124
}
125+
126+
/**
127+
* Recursively traverses the derivation tree and replaces boolean literals with the expressions that produced them,
128+
* but only when at least one operand in the derivation is non-boolean. e.g. "x == true" where true came from "1 >
129+
* 0" becomes "x == 1 > 0"
130+
*/
131+
static ValDerivationNode unwrapDerivedBooleans(ValDerivationNode node) {
132+
Expression value = node.getValue();
133+
DerivationNode origin = node.getOrigin();
134+
135+
if (origin == null)
136+
return node;
137+
138+
// unwrap binary expressions
139+
if (value instanceof BinaryExpression binExp && origin instanceof BinaryDerivationNode binOrigin) {
140+
ValDerivationNode left = unwrapDerivedBooleans(binOrigin.getLeft());
141+
ValDerivationNode right = unwrapDerivedBooleans(binOrigin.getRight());
142+
if (left != binOrigin.getLeft() || right != binOrigin.getRight()) {
143+
Expression newValue = new BinaryExpression(left.getValue(), binExp.getOperator(), right.getValue());
144+
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
145+
}
146+
return node;
147+
}
148+
149+
// unwrap unary expressions
150+
if (value instanceof UnaryExpression unaryExp && origin instanceof UnaryDerivationNode unaryOrigin) {
151+
ValDerivationNode operand = unwrapDerivedBooleans(unaryOrigin.getOperand());
152+
if (operand != unaryOrigin.getOperand()) {
153+
Expression newValue = new UnaryExpression(unaryExp.getOp(), operand.getValue());
154+
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
155+
}
156+
return node;
157+
}
158+
159+
// boolean literal with binary origin: unwrap if at least one child is non-boolean
160+
if (value instanceof LiteralBoolean && origin instanceof BinaryDerivationNode binOrigin) {
161+
ValDerivationNode left = unwrapDerivedBooleans(binOrigin.getLeft());
162+
ValDerivationNode right = unwrapDerivedBooleans(binOrigin.getRight());
163+
if (!(left.getValue() instanceof LiteralBoolean) || !(right.getValue() instanceof LiteralBoolean)) {
164+
Expression newValue = new BinaryExpression(left.getValue(), binOrigin.getOp(), right.getValue());
165+
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
166+
}
167+
return node;
168+
}
169+
170+
// boolean literal with unary origin: unwrap if operand is non-boolean
171+
if (value instanceof LiteralBoolean && origin instanceof UnaryDerivationNode unaryOrigin) {
172+
ValDerivationNode operand = unwrapDerivedBooleans(unaryOrigin.getOperand());
173+
if (!(operand.getValue() instanceof LiteralBoolean)) {
174+
Expression newValue = new UnaryExpression(unaryOrigin.getOp(), operand.getValue());
175+
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
176+
}
177+
return node;
178+
}
179+
180+
return node;
181+
}
122182
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 132 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,13 @@ void testComplexArithmeticWithMultipleOperations() {
243243
// When
244244
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
245245

246-
// Then
246+
// Then: boolean literals are unwrapped to show the verified conditions
247247
assertNotNull(result, "Result should not be null");
248248
assertNotNull(result.getValue(), "Result value should not be null");
249-
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should be a boolean literal");
250-
assertTrue(result.getValue().isBooleanTrue(), "Expected result to be true");
249+
assertEquals("14 == 14 && 5 == 5 && 7 == 7 && 14 == 14", result.getValue().toString(),
250+
"All verified conditions should be visible instead of collapsed to true");
251251

252-
// 5 * 2 + 7 - 3
252+
// 5 * 2 + 7 - 3 = 14
253253
ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
254254
ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null);
255255
BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*");
@@ -266,39 +266,45 @@ void testComplexArithmeticWithMultipleOperations() {
266266
// 14 from variable c
267267
ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
268268

269-
// 14 == 14
269+
// 14 == 14 (unwrapped from true)
270270
BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "==");
271-
ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14);
271+
Expression expr14Eq14 = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
272+
ValDerivationNode compare14Node = new ValDerivationNode(expr14Eq14, compare14);
272273

273-
// a == 5 => true
274+
// a == 5 => 5 == 5 (unwrapped from true)
274275
ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
275276
ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null);
276277
BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "==");
277-
ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5);
278+
Expression expr5Eq5 = new BinaryExpression(new LiteralInt(5), "==", new LiteralInt(5));
279+
ValDerivationNode compare5Node = new ValDerivationNode(expr5Eq5, compareA5);
278280

279-
// b == 7 => true
281+
// b == 7 => 7 == 7 (unwrapped from true)
280282
ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
281283
ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null);
282284
BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "==");
283-
ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7);
285+
Expression expr7Eq7 = new BinaryExpression(new LiteralInt(7), "==", new LiteralInt(7));
286+
ValDerivationNode compare7Node = new ValDerivationNode(expr7Eq7, compareB7);
284287

285-
// (a == 5) && (b == 7) => true
286-
BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&");
287-
ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB);
288+
// (5 == 5) && (7 == 7) (unwrapped from true)
289+
BinaryDerivationNode andAB = new BinaryDerivationNode(compare5Node, compare7Node, "&&");
290+
Expression expr5And7 = new BinaryExpression(expr5Eq5, "&&", expr7Eq7);
291+
ValDerivationNode and5And7Node = new ValDerivationNode(expr5And7, andAB);
288292

289-
// c == 14 => true
293+
// c == 14 => 14 == 14 (unwrapped from true)
290294
ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
291295
ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null);
292296
BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "==");
293-
ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14);
297+
Expression expr14Eq14b = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
298+
ValDerivationNode compare14bNode = new ValDerivationNode(expr14Eq14b, compareC14);
294299

295-
// ((a == 5) && (b == 7)) && (c == 14) => true
296-
BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&");
297-
ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC);
300+
// ((5 == 5) && (7 == 7)) && (14 == 14) (unwrapped from true)
301+
BinaryDerivationNode andABC = new BinaryDerivationNode(and5And7Node, compare14bNode, "&&");
302+
Expression exprConditions = new BinaryExpression(expr5And7, "&&", expr14Eq14b);
303+
ValDerivationNode conditionsNode = new ValDerivationNode(exprConditions, andABC);
298304

299-
// 14 == 14 => true
300-
BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&");
301-
ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd);
305+
// (14 == 14) && ((5 == 5 && 7 == 7) && 14 == 14)
306+
BinaryDerivationNode finalAnd = new BinaryDerivationNode(compare14Node, conditionsNode, "&&");
307+
ValDerivationNode expected = new ValDerivationNode(result.getValue(), finalAnd);
302308

303309
// Compare the derivation trees
304310
assertDerivationEquals(expected, result, "");
@@ -580,6 +586,111 @@ void testShouldNotOversimplifyToTrue() {
580586
"Should stop simplification before collapsing to true");
581587
}
582588

589+
@Test
590+
void testShouldUnwrapBooleanInEquality() {
591+
// Given: x == (1 > 0)
592+
// Without unwrapping: x == true (unhelpful - hides what "true" came from)
593+
// Expected: x == 1 > 0 (unwrapped to show the original comparison)
594+
595+
Expression varX = new Var("x");
596+
Expression one = new LiteralInt(1);
597+
Expression zero = new LiteralInt(0);
598+
Expression oneGreaterZero = new BinaryExpression(one, ">", zero);
599+
Expression fullExpression = new BinaryExpression(varX, "==", oneGreaterZero);
600+
601+
// When
602+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
603+
604+
// Then
605+
assertNotNull(result, "Result should not be null");
606+
assertEquals("x == 1 > 0", result.getValue().toString(),
607+
"Boolean in equality should be unwrapped to show the original comparison");
608+
}
609+
610+
@Test
611+
void testShouldUnwrapBooleanInEqualityWithPropagation() {
612+
// Given: x == (a > b) && a == 3 && b == 1
613+
// Without unwrapping: x == true (unhelpful)
614+
// Expected: x == 3 > 1 (unwrapped and propagated)
615+
616+
Expression varX = new Var("x");
617+
Expression varA = new Var("a");
618+
Expression varB = new Var("b");
619+
Expression aGreaterB = new BinaryExpression(varA, ">", varB);
620+
Expression xEqualsComp = new BinaryExpression(varX, "==", aGreaterB);
621+
622+
Expression three = new LiteralInt(3);
623+
Expression aEquals3 = new BinaryExpression(varA, "==", three);
624+
Expression one = new LiteralInt(1);
625+
Expression bEquals1 = new BinaryExpression(varB, "==", one);
626+
627+
Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals1);
628+
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);
629+
630+
// When
631+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
632+
633+
// Then
634+
assertNotNull(result, "Result should not be null");
635+
assertEquals("x == 3 > 1", result.getValue().toString(),
636+
"Boolean in equality should be unwrapped after propagation");
637+
}
638+
639+
@Test
640+
void testShouldNotUnwrapBooleanWithBooleanChildren() {
641+
// Given: (y || true) && !true && y == false
642+
// Expected: false (both children of the fold are boolean, so no unwrapping needed)
643+
644+
Expression varY = new Var("y");
645+
Expression trueExp = new LiteralBoolean(true);
646+
Expression yOrTrue = new BinaryExpression(varY, "||", trueExp);
647+
Expression notTrue = new UnaryExpression("!", trueExp);
648+
Expression falseExp = new LiteralBoolean(false);
649+
Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp);
650+
651+
Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue);
652+
Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse);
653+
654+
// When
655+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
656+
657+
// Then: false stays as false since both sides in the derivation are booleans
658+
assertNotNull(result, "Result should not be null");
659+
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should remain a boolean");
660+
assertFalse(result.getValue().isBooleanTrue(), "Expected result to be false");
661+
}
662+
663+
@Test
664+
void testShouldUnwrapNestedBooleanInEquality() {
665+
// Given: x == (a + b > 10) && a == 3 && b == 5
666+
// Without unwrapping: x == true (unhelpful)
667+
// Expected: x == 8 > 10 (shows the actual comparison that produced the boolean)
668+
669+
Expression varX = new Var("x");
670+
Expression varA = new Var("a");
671+
Expression varB = new Var("b");
672+
Expression aPlusB = new BinaryExpression(varA, "+", varB);
673+
Expression ten = new LiteralInt(10);
674+
Expression comparison = new BinaryExpression(aPlusB, ">", ten);
675+
Expression xEqualsComp = new BinaryExpression(varX, "==", comparison);
676+
677+
Expression three = new LiteralInt(3);
678+
Expression aEquals3 = new BinaryExpression(varA, "==", three);
679+
Expression five = new LiteralInt(5);
680+
Expression bEquals5 = new BinaryExpression(varB, "==", five);
681+
682+
Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5);
683+
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);
684+
685+
// When
686+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
687+
688+
// Then
689+
assertNotNull(result, "Result should not be null");
690+
assertEquals("x == 8 > 10", result.getValue().toString(),
691+
"Boolean in equality should be unwrapped to show the computed comparison");
692+
}
693+
583694
/**
584695
* Helper method to compare two derivation nodes recursively
585696
*/

0 commit comments

Comments
 (0)