Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.LiteralBoolean;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;

public class ExpressionSimplifier {
Expand All @@ -15,12 +17,13 @@ public class ExpressionSimplifier {
*/
public static ValDerivationNode simplify(Expression exp) {
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp);
return simplifyValDerivationNode(fixedPoint);
ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint);
return unwrapDerivedBooleans(simplified);
}

/**
* Recursively applies propagation and folding until the expression stops changing (fixed point) Stops early if the
* expression simplifies to 'true', which means we've simplified too much
* expression simplifies to a boolean literal, which means we've simplified too much
*/
private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp) {
// apply propagation and folding
Expand All @@ -34,6 +37,11 @@ private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current,
return current;
}

// prevent oversimplification
if (current != null && currExp instanceof LiteralBoolean && !(current.getValue() instanceof LiteralBoolean)) {
return current;
}

// continue simplifying
return simplifyToFixedPoint(simplified, simplified.getValue());
}
Expand Down Expand Up @@ -114,4 +122,61 @@ private static boolean isRedundant(Expression exp) {
}
return false;
}

/**
* Recursively traverses the derivation tree and replaces boolean literals with the expressions that produced them,
* but only when at least one operand in the derivation is non-boolean. e.g. "x == true" where true came from "1 >
* 0" becomes "x == 1 > 0"
*/
static ValDerivationNode unwrapDerivedBooleans(ValDerivationNode node) {
Expression value = node.getValue();
DerivationNode origin = node.getOrigin();

if (origin == null)
return node;

// unwrap binary expressions
if (value instanceof BinaryExpression binExp && origin instanceof BinaryDerivationNode binOrigin) {
ValDerivationNode left = unwrapDerivedBooleans(binOrigin.getLeft());
ValDerivationNode right = unwrapDerivedBooleans(binOrigin.getRight());
if (left != binOrigin.getLeft() || right != binOrigin.getRight()) {
Expression newValue = new BinaryExpression(left.getValue(), binExp.getOperator(), right.getValue());
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
}
return node;
}

// unwrap unary expressions
if (value instanceof UnaryExpression unaryExp && origin instanceof UnaryDerivationNode unaryOrigin) {
ValDerivationNode operand = unwrapDerivedBooleans(unaryOrigin.getOperand());
if (operand != unaryOrigin.getOperand()) {
Expression newValue = new UnaryExpression(unaryExp.getOp(), operand.getValue());
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
}
return node;
}

// boolean literal with binary origin: unwrap if at least one child is non-boolean
if (value instanceof LiteralBoolean && origin instanceof BinaryDerivationNode binOrigin) {
ValDerivationNode left = unwrapDerivedBooleans(binOrigin.getLeft());
ValDerivationNode right = unwrapDerivedBooleans(binOrigin.getRight());
if (!(left.getValue() instanceof LiteralBoolean) || !(right.getValue() instanceof LiteralBoolean)) {
Expression newValue = new BinaryExpression(left.getValue(), binOrigin.getOp(), right.getValue());
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
}
return node;
}

// boolean literal with unary origin: unwrap if operand is non-boolean
if (value instanceof LiteralBoolean && origin instanceof UnaryDerivationNode unaryOrigin) {
ValDerivationNode operand = unwrapDerivedBooleans(unaryOrigin.getOperand());
if (!(operand.getValue() instanceof LiteralBoolean)) {
Expression newValue = new UnaryExpression(unaryOrigin.getOp(), operand.getValue());
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
}
return node;
}

return node;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ void testComplexArithmeticWithMultipleOperations() {
// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
// Then: boolean literals are unwrapped to show the verified conditions
assertNotNull(result, "Result should not be null");
assertNotNull(result.getValue(), "Result value should not be null");
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should be a boolean literal");
assertTrue(result.getValue().isBooleanTrue(), "Expected result to be true");
assertEquals("14 == 14 && 5 == 5 && 7 == 7 && 14 == 14", result.getValue().toString(),
"All verified conditions should be visible instead of collapsed to true");

// 5 * 2 + 7 - 3
// 5 * 2 + 7 - 3 = 14
ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null);
BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*");
Expand All @@ -266,39 +266,45 @@ void testComplexArithmeticWithMultipleOperations() {
// 14 from variable c
ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));

// 14 == 14
// 14 == 14 (unwrapped from true)
BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "==");
ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14);
Expression expr14Eq14 = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
ValDerivationNode compare14Node = new ValDerivationNode(expr14Eq14, compare14);

// a == 5 => true
// a == 5 => 5 == 5 (unwrapped from true)
ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null);
BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "==");
ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5);
Expression expr5Eq5 = new BinaryExpression(new LiteralInt(5), "==", new LiteralInt(5));
ValDerivationNode compare5Node = new ValDerivationNode(expr5Eq5, compareA5);

// b == 7 => true
// b == 7 => 7 == 7 (unwrapped from true)
ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null);
BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "==");
ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7);
Expression expr7Eq7 = new BinaryExpression(new LiteralInt(7), "==", new LiteralInt(7));
ValDerivationNode compare7Node = new ValDerivationNode(expr7Eq7, compareB7);

// (a == 5) && (b == 7) => true
BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&");
ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB);
// (5 == 5) && (7 == 7) (unwrapped from true)
BinaryDerivationNode andAB = new BinaryDerivationNode(compare5Node, compare7Node, "&&");
Expression expr5And7 = new BinaryExpression(expr5Eq5, "&&", expr7Eq7);
ValDerivationNode and5And7Node = new ValDerivationNode(expr5And7, andAB);

// c == 14 => true
// c == 14 => 14 == 14 (unwrapped from true)
ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null);
BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "==");
ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14);
Expression expr14Eq14b = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
ValDerivationNode compare14bNode = new ValDerivationNode(expr14Eq14b, compareC14);

// ((a == 5) && (b == 7)) && (c == 14) => true
BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&");
ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC);
// ((5 == 5) && (7 == 7)) && (14 == 14) (unwrapped from true)
BinaryDerivationNode andABC = new BinaryDerivationNode(and5And7Node, compare14bNode, "&&");
Expression exprConditions = new BinaryExpression(expr5And7, "&&", expr14Eq14b);
ValDerivationNode conditionsNode = new ValDerivationNode(exprConditions, andABC);

// 14 == 14 => true
BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&");
ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd);
// (14 == 14) && ((5 == 5 && 7 == 7) && 14 == 14)
BinaryDerivationNode finalAnd = new BinaryDerivationNode(compare14Node, conditionsNode, "&&");
ValDerivationNode expected = new ValDerivationNode(result.getValue(), finalAnd);

// Compare the derivation trees
assertDerivationEquals(expected, result, "");
Expand Down Expand Up @@ -550,6 +556,141 @@ void testTransitive() {
assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1");
}

@Test
void testShouldNotOversimplifyToTrue() {
// Given: x > 5 && x == y && y == 10
// Iteration 1: resolves y == 10, substitutes y -> 10: x > 5 && x == 10
// Iteration 2: resolves x == 10, substitutes x -> 10: 10 > 5 && 10 == 10 -> true
// Expected: x > 5 && x == 10 (should NOT simplify to true)

Expression varX = new Var("x");
Expression varY = new Var("y");
Expression five = new LiteralInt(5);
Expression ten = new LiteralInt(10);

Expression xGreater5 = new BinaryExpression(varX, ">", five);
Expression xEqualsY = new BinaryExpression(varX, "==", varY);
Expression yEquals10 = new BinaryExpression(varY, "==", ten);

Expression firstAnd = new BinaryExpression(xGreater5, "&&", xEqualsY);
Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEquals10);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertFalse(result.getValue() instanceof LiteralBoolean,
"Should not oversimplify to a boolean literal, but got: " + result.getValue());
assertEquals("x > 5 && x == 10", result.getValue().toString(),
"Should stop simplification before collapsing to true");
}

@Test
void testShouldUnwrapBooleanInEquality() {
// Given: x == (1 > 0)
// Without unwrapping: x == true (unhelpful - hides what "true" came from)
// Expected: x == 1 > 0 (unwrapped to show the original comparison)

Expression varX = new Var("x");
Expression one = new LiteralInt(1);
Expression zero = new LiteralInt(0);
Expression oneGreaterZero = new BinaryExpression(one, ">", zero);
Expression fullExpression = new BinaryExpression(varX, "==", oneGreaterZero);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("x == 1 > 0", result.getValue().toString(),
"Boolean in equality should be unwrapped to show the original comparison");
}

@Test
void testShouldUnwrapBooleanInEqualityWithPropagation() {
// Given: x == (a > b) && a == 3 && b == 1
// Without unwrapping: x == true (unhelpful)
// Expected: x == 3 > 1 (unwrapped and propagated)

Expression varX = new Var("x");
Expression varA = new Var("a");
Expression varB = new Var("b");
Expression aGreaterB = new BinaryExpression(varA, ">", varB);
Expression xEqualsComp = new BinaryExpression(varX, "==", aGreaterB);

Expression three = new LiteralInt(3);
Expression aEquals3 = new BinaryExpression(varA, "==", three);
Expression one = new LiteralInt(1);
Expression bEquals1 = new BinaryExpression(varB, "==", one);

Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals1);
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("x == 3 > 1", result.getValue().toString(),
"Boolean in equality should be unwrapped after propagation");
}

@Test
void testShouldNotUnwrapBooleanWithBooleanChildren() {
// Given: (y || true) && !true && y == false
// Expected: false (both children of the fold are boolean, so no unwrapping needed)

Expression varY = new Var("y");
Expression trueExp = new LiteralBoolean(true);
Expression yOrTrue = new BinaryExpression(varY, "||", trueExp);
Expression notTrue = new UnaryExpression("!", trueExp);
Expression falseExp = new LiteralBoolean(false);
Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp);

Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue);
Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then: false stays as false since both sides in the derivation are booleans
assertNotNull(result, "Result should not be null");
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should remain a boolean");
assertFalse(result.getValue().isBooleanTrue(), "Expected result to be false");
}

@Test
void testShouldUnwrapNestedBooleanInEquality() {
// Given: x == (a + b > 10) && a == 3 && b == 5
// Without unwrapping: x == true (unhelpful)
// Expected: x == 8 > 10 (shows the actual comparison that produced the boolean)

Expression varX = new Var("x");
Expression varA = new Var("a");
Expression varB = new Var("b");
Expression aPlusB = new BinaryExpression(varA, "+", varB);
Expression ten = new LiteralInt(10);
Expression comparison = new BinaryExpression(aPlusB, ">", ten);
Expression xEqualsComp = new BinaryExpression(varX, "==", comparison);

Expression three = new LiteralInt(3);
Expression aEquals3 = new BinaryExpression(varA, "==", three);
Expression five = new LiteralInt(5);
Expression bEquals5 = new BinaryExpression(varB, "==", five);

Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5);
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("x == 8 > 10", result.getValue().toString(),
"Boolean in equality should be unwrapped to show the computed comparison");
}

/**
* Helper method to compare two derivation nodes recursively
*/
Expand Down
Loading