diff --git a/src/larktools/ebnf_grammar.py b/src/larktools/ebnf_grammar.py index d013d77..8450561 100644 --- a/src/larktools/ebnf_grammar.py +++ b/src/larktools/ebnf_grammar.py @@ -30,7 +30,24 @@ assignment: VARNAME "=" arith_expr - + logic_expr: logic_state | logic_operation | logic_comparison # removing logic_comparison here will make other logic tests succeed. + + logic_state: BOOLEAN + + logic_operation: logic_and | logic_or | logic_not + logic_and: logic_expr "and" logic_state + logic_or: logic_expr "or" logic_state + logic_not: "not" logic_expr + + logic_comparison: logic_greater_than | logic_greater_equal | logic_equal | logic_smaller_equal | logic_smaller_than | logic_unequal + logic_greater_than: arith_expr ">" arith_expr + logic_greater_equal: arith_expr ">=" arith_expr + logic_equal: arith_expr "==" arith_expr + logic_smaller_equal: arith_expr "<=" arith_expr + logic_smaller_than: arith_expr "<" arith_expr + logic_unequal: arith_expr "!=" arith_expr + + arith_expr: sum sum: product | addition | subtraction addition: sum "+" product @@ -53,6 +70,7 @@ SIGNED_INT: ["+"|"-"] INT DECIMAL: INT "." INT? | "." INT + _EXP: ("e"|"E") SIGNED_INT FLOAT: INT _EXP | DECIMAL _EXP? SIGNED_FLOAT: ["+"|"-"] FLOAT @@ -63,6 +81,8 @@ LETTER: UCASE_LETTER | LCASE_LETTER WORD: LETTER+ + BOOLEAN: "True" | "False" + // Whitespace characters are filtered out before parsing. // However, linebreaks are preserved. diff --git a/src/larktools/evaluation.py b/src/larktools/evaluation.py index 9a2199c..c28aa5c 100644 --- a/src/larktools/evaluation.py +++ b/src/larktools/evaluation.py @@ -77,9 +77,12 @@ def __call__(self, env): class NumberNode: def __init__(self, lark_node): node_name = get_name(lark_node) - self._value = { - "SIGNED_FLOAT": float, "INT": int, "INDEX": int, - }[node_name](get_value(lark_node)) + op_map = { + "SIGNED_FLOAT": float, + "INT": int, + "INDEX": int, + "BOOLEAN": lambda x: True if x == "True" else (False if x == "False" else None)} + self._value = op_map[node_name](get_value(lark_node)) def __call__(self, env): return self._value @@ -102,7 +105,8 @@ class UnaryOperatorNode(MappedOperatorNode): def __init__(self, lark_node): super().__init__( lark_node, - op_map={"neg_atom": lambda x: -x[0]} + op_map={"neg_atom": lambda x: -x[0], + "logic_not": lambda x: not x[0]} ) @@ -115,6 +119,14 @@ def __init__(self, lark_node): "subtraction": lambda x: x[0] - x[1], "multiplication": lambda x: x[0] * x[1], "division": lambda x: x[0] / x[1], + "logic_and": lambda x: bool(x[0]) and bool(x[1]), + "logic_or": lambda x: x[0] or x[1], + "logic_greater_than": lambda x: x[0] > x[1], + "logic_greater_equal": lambda x: x[0] >= x[1], + "logic_equal": lambda x: x[0] == x[1], + "logic_smaller_equal": lambda x: x[0] <= x[1], + "logic_smaller_than": lambda x: x[0] < x[1], + "logic_unequal": lambda x: x[0] != x[1] } ) @@ -122,10 +134,13 @@ def __init__(self, lark_node): NODE_MAP = { RootNode: ("multi_line_block",), AssignNode: ("assignment",), - UnaryOperatorNode: ("neg_atom",), - BinaryOperatorNode: ("addition", "subtraction", "multiplication", "division"), + UnaryOperatorNode: ("neg_atom", "logic_not"), + BinaryOperatorNode: ("addition", "subtraction", "multiplication", "division", + "logic_and", "logic_or", + "logic_greater_than","logic_greater_equal","logic_equal", + "logic_smaller_equal","logic_smaller_than","logic_unequal"), VariableNode: ("variable", "varname"), - NumberNode: ("INT", "SIGNED_INT", "FLOAT", "SIGNED_FLOAT", "INDEX"), + NumberNode: ("INT", "SIGNED_INT", "FLOAT", "SIGNED_FLOAT", "INDEX", "BOOLEAN"), } INV_NODE_MAP = {k: v for v in NODE_MAP for k in NODE_MAP[v]} diff --git a/tests/test_logic.py b/tests/test_logic.py new file mode 100644 index 0000000..bd11997 --- /dev/null +++ b/tests/test_logic.py @@ -0,0 +1,67 @@ +import pytest +from typing import Optional, Union + +from lark import Lark + +from larktools.ebnf_grammar import grammar +from larktools.evaluation import instantiate_eval_tree + + +class LogicParser: + def __init__(self): + self.parser = Lark(grammar, parser="lalr", start="logic_expr") + + def parse_and_eval(self, expression: str, env: Optional[dict] = None) -> Union[int, float]: + tree = self.parser.parse(expression) + eval_tree = instantiate_eval_tree(tree) + res = eval_tree({} if env is None else env) + return res + + +def _parse_and_assert(expression: str, expected: Union[int, float], env: Optional[dict] = None) -> None: + parser = LogicParser() + res = parser.parse_and_eval(expression, env) + assert expected == res + +def _parse_and_assert_collection(tests: list[str, Union[int, float]]) -> None: + for ipt, expected in tests: + _parse_and_assert(ipt, expected) + + +def test_comparison(): + _parse_and_assert("3 > 5", False) + _parse_and_assert("3 >= 5", False) + _parse_and_assert("3 >= 3", True) + _parse_and_assert("3 >= 3", True) + _parse_and_assert("5 == 3 + 2", True) + _parse_and_assert("2 + 3 == 3 + 2", True) + _parse_and_assert("5 == 3", False) + _parse_and_assert("3 <= 5", True) + _parse_and_assert("3 <= 3", True) + _parse_and_assert("3 <= 2", False) + _parse_and_assert("3 < 5", True) + _parse_and_assert("3 != 5", True) + _parse_and_assert("5 != 5", False) + + +def test_logic_states(): + _parse_and_assert("True", True) + _parse_and_assert("False", False) + +def test_logic_operations(): + _parse_and_assert("False or True", True) + _parse_and_assert("False or False", False) + _parse_and_assert("True or True", True) + _parse_and_assert("True or False", True) + + _parse_and_assert("False and True", False) + _parse_and_assert("False and False", False) + _parse_and_assert("True and True", True) + _parse_and_assert("True and False", False) + +def test_logic_negation(): + _parse_and_assert("not True", False) + _parse_and_assert("not False", True) + + +