44
55import liquidjava .rj_language .ast .BinaryExpression ;
66import liquidjava .rj_language .ast .Expression ;
7+ import liquidjava .rj_language .ast .Ite ;
78import liquidjava .rj_language .ast .LiteralBoolean ;
89import liquidjava .rj_language .ast .LiteralInt ;
910import liquidjava .rj_language .ast .UnaryExpression ;
1011import liquidjava .rj_language .ast .Var ;
1112import liquidjava .rj_language .opt .derivation_node .BinaryDerivationNode ;
1213import liquidjava .rj_language .opt .derivation_node .DerivationNode ;
14+ import liquidjava .rj_language .opt .derivation_node .IteDerivationNode ;
1315import liquidjava .rj_language .opt .derivation_node .UnaryDerivationNode ;
1416import liquidjava .rj_language .opt .derivation_node .ValDerivationNode ;
1517import liquidjava .rj_language .opt .derivation_node .VarDerivationNode ;
@@ -550,6 +552,76 @@ void testTransitive() {
550552 assertEquals ("a == 1" , result .getValue ().toString (), "Expected result to be a == 1" );
551553 }
552554
555+ @ Test
556+ void testIteTrueConditionSimplifiesToThenBranch () {
557+ // Given: true ? a : b
558+ // Expected: a
559+
560+ Expression expr = new Ite (new LiteralBoolean (true ), new Var ("a" ), new Var ("b" ));
561+
562+ // When
563+ ValDerivationNode result = ExpressionSimplifier .simplify (expr );
564+
565+ // Then
566+ assertNotNull (result , "Result should not be null" );
567+ assertEquals ("a" , result .getValue ().toString (), "Expected result to be a" );
568+
569+ ValDerivationNode conditionNode = new ValDerivationNode (new LiteralBoolean (true ), null );
570+ ValDerivationNode thenNode = new ValDerivationNode (new Var ("a" ), null );
571+ ValDerivationNode elseNode = new ValDerivationNode (new Var ("b" ), null );
572+ IteDerivationNode iteOrigin = new IteDerivationNode (conditionNode , thenNode , elseNode );
573+ ValDerivationNode expected = new ValDerivationNode (new Var ("a" ), iteOrigin );
574+
575+ assertDerivationEquals (expected , result , "" );
576+ }
577+
578+ @ Test
579+ void testIteFalseConditionSimplifiesToElseBranch () {
580+ // Given: false ? a : b
581+ // Expected: b
582+
583+ Expression expr = new Ite (new LiteralBoolean (false ), new Var ("a" ), new Var ("b" ));
584+
585+ // When
586+ ValDerivationNode result = ExpressionSimplifier .simplify (expr );
587+
588+ // Then
589+ assertNotNull (result , "Result should not be null" );
590+ assertEquals ("b" , result .getValue ().toString (), "Expected result to be b" );
591+
592+ ValDerivationNode conditionNode = new ValDerivationNode (new LiteralBoolean (false ), null );
593+ ValDerivationNode thenNode = new ValDerivationNode (new Var ("a" ), null );
594+ ValDerivationNode elseNode = new ValDerivationNode (new Var ("b" ), null );
595+ IteDerivationNode iteOrigin = new IteDerivationNode (conditionNode , thenNode , elseNode );
596+ ValDerivationNode expected = new ValDerivationNode (new Var ("b" ), iteOrigin );
597+
598+ assertDerivationEquals (expected , result , "" );
599+ }
600+
601+ @ Test
602+ void testIteEqualBranchesSimplifiesToBranch () {
603+ // Given: cond ? b : b
604+ // Expected: b
605+
606+ Expression branch = new Var ("b" );
607+ Expression expr = new Ite (new Var ("cond" ), branch , branch .clone ());
608+
609+ // When
610+ ValDerivationNode result = ExpressionSimplifier .simplify (expr );
611+
612+ // Then
613+ assertNotNull (result , "Result should not be null" );
614+ assertEquals ("b" , result .getValue ().toString (), "Expected result to be b" );
615+
616+ ValDerivationNode conditionNode = new ValDerivationNode (new Var ("cond" ), null );
617+ ValDerivationNode thenNode = new ValDerivationNode (new Var ("b" ), null );
618+ ValDerivationNode elseNode = new ValDerivationNode (new Var ("b" ), null );
619+ IteDerivationNode iteOrigin = new IteDerivationNode (conditionNode , thenNode , elseNode );
620+ ValDerivationNode expected = new ValDerivationNode (new Var ("b" ), iteOrigin );
621+
622+ assertDerivationEquals (expected , result , "" );
623+ }
624+
553625 /**
554626 * Helper method to compare two derivation nodes recursively
555627 */
@@ -576,6 +648,11 @@ private void assertDerivationEquals(DerivationNode expected, DerivationNode actu
576648 UnaryDerivationNode actualUnary = (UnaryDerivationNode ) actual ;
577649 assertEquals (expectedUnary .getOp (), actualUnary .getOp (), message + ": operators should match" );
578650 assertDerivationEquals (expectedUnary .getOperand (), actualUnary .getOperand (), message + " > operand" );
651+ } else if (expected instanceof IteDerivationNode expectedIte ) {
652+ IteDerivationNode actualIte = (IteDerivationNode ) actual ;
653+ assertDerivationEquals (expectedIte .getCondition (), actualIte .getCondition (), message + " > condition" );
654+ assertDerivationEquals (expectedIte .getThenBranch (), actualIte .getThenBranch (), message + " > then" );
655+ assertDerivationEquals (expectedIte .getElseBranch (), actualIte .getElseBranch (), message + " > else" );
579656 }
580657 }
581658}
0 commit comments