From daf417013fb209b9ab86a21358709307c2b0e20c Mon Sep 17 00:00:00 2001 From: "SPRENGER Julia, NEA/SCI/DB" Date: Tue, 29 Oct 2024 18:28:08 +0100 Subject: [PATCH] Add first version of logic grammar --- src/larktools/ebnf_grammar.py | 22 ++++++- src/larktools/evaluation.py | 114 ++++++++++++++++++++++++++++++++++ tests/test_logic.py | 62 ++++++++++++++++++ 3 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 tests/test_logic.py diff --git a/src/larktools/ebnf_grammar.py b/src/larktools/ebnf_grammar.py index d013d77..8450561 100644 --- a/src/larktools/ebnf_grammar.py +++ b/src/larktools/ebnf_grammar.py @@ -30,7 +30,24 @@ assignment: VARNAME "=" arith_expr - + logic_expr: logic_state | logic_operation | logic_comparison # removing logic_comparison here will make other logic tests succeed. + + logic_state: BOOLEAN + + logic_operation: logic_and | logic_or | logic_not + logic_and: logic_expr "and" logic_state + logic_or: logic_expr "or" logic_state + logic_not: "not" logic_expr + + logic_comparison: logic_greater_than | logic_greater_equal | logic_equal | logic_smaller_equal | logic_smaller_than | logic_unequal + logic_greater_than: arith_expr ">" arith_expr + logic_greater_equal: arith_expr ">=" arith_expr + logic_equal: arith_expr "==" arith_expr + logic_smaller_equal: arith_expr "<=" arith_expr + logic_smaller_than: arith_expr "<" arith_expr + logic_unequal: arith_expr "!=" arith_expr + + arith_expr: sum sum: product | addition | subtraction addition: sum "+" product @@ -53,6 +70,7 @@ SIGNED_INT: ["+"|"-"] INT DECIMAL: INT "." INT? | "." INT + _EXP: ("e"|"E") SIGNED_INT FLOAT: INT _EXP | DECIMAL _EXP? SIGNED_FLOAT: ["+"|"-"] FLOAT @@ -63,6 +81,8 @@ LETTER: UCASE_LETTER | LCASE_LETTER WORD: LETTER+ + BOOLEAN: "True" | "False" + // Whitespace characters are filtered out before parsing. // However, linebreaks are preserved. diff --git a/src/larktools/evaluation.py b/src/larktools/evaluation.py index 54eebc2..390f071 100644 --- a/src/larktools/evaluation.py +++ b/src/larktools/evaluation.py @@ -153,3 +153,117 @@ def eval_variable(node, env): idx = int(get_value(ch)) value = value[idx] return value + +def eval_logic_expr(node, env): + child = get_children(node)[0] + child_name = get_name(child) + if child_name == 'logic_state': + return eval_logic_state(child, env) + if child_name == 'logic_operation': + return eval_logic_operation(child, env) + if child_name == 'logic_comparison': + return eval_logic_comparison(child, env) + raise ValueError('Unexpected child name') + +def eval_logic_state(node,env): + child = get_children(node)[0] + child_name = get_name(child) + if child_name == 'BOOLEAN': + value = get_value(child) + if value == 'True': + return True + if value == 'False': + return False + if child_name == 'variable': + return eval_variable(child, env) + +def eval_logic_operation(node, env): + child = get_children(node)[0] + child_name = get_name(child) + if child_name == 'logic_and': + return eval_logic_and(child, env) + if child_name == 'logic_or': + return eval_logic_or(child, env) + if child_name == 'logic_not': + return eval_logic_not(child, env) + +def eval_logic_and(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'logic_expr' + assert get_name(child2) == 'logic_state' + + return eval_logic_expr(child1, env) & eval_logic_state(child2, env) + +def eval_logic_or(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'logic_expr' + assert get_name(child2) == 'logic_state' + + return eval_logic_expr(child1, env) | eval_logic_state(child2, env) + +def eval_logic_not(node, env): + child = get_children(node)[0] + assert get_name(child) == 'logic_expr' + + return not eval_logic_expr(child, env) + +def eval_logic_comparison(node, env): + child = get_children(node)[0] + child_name = get_name(child) + + if child_name == 'logic_greater_than': + return eval_logic_greater_than(child, env) + if child_name == 'logic_greater_equal': + return eval_logic_greater_equal(child, env) + if child_name == 'logic_equal': + return eval_logic_equal(child, env) + if child_name == 'logic_smaller_equal': + return eval_logic_smaller_equal(child, env) + if child_name == 'logic_smaller_than': + return eval_logic_smaller_than(child, env) + if child_name == 'logic_unequal': + return eval_logic_unequal(child, env) + +def eval_logic_greater_than(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'arith_expr' + assert get_name(child2) == 'arith_expr' + + return eval_arith_expr(child1, env) > eval_arith_expr(child2, env) + + +def eval_logic_greater_equal(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'arith_expr' + assert get_name(child2) == 'arith_expr' + + return eval_arith_expr(child1, env) >= eval_arith_expr(child2, env) + + +def eval_logic_equal(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'arith_expr' + assert get_name(child2) == 'arith_expr' + + return eval_arith_expr(child1, env) == eval_arith_expr(child2, env) + +def eval_logic_smaller_equal(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'arith_expr' + assert get_name(child2) == 'arith_expr' + + return eval_arith_expr(child1, env) <= eval_arith_expr(child2, env) + +def eval_logic_smaller_than(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'arith_expr' + assert get_name(child2) == 'arith_expr' + + return eval_arith_expr(child1, env) <= eval_arith_expr(child2, env) + +def eval_logic_unequal(node, env): + child1, child2 = get_children(node) + assert get_name(child1) == 'arith_expr' + assert get_name(child2) == 'arith_expr' + + return eval_arith_expr(child1, env) != eval_arith_expr(child2, env) diff --git a/tests/test_logic.py b/tests/test_logic.py new file mode 100644 index 0000000..2753dca --- /dev/null +++ b/tests/test_logic.py @@ -0,0 +1,62 @@ +import pytest +from typing import Optional, Union + +from lark import Lark + +from larktools.ebnf_grammar import grammar +from larktools.evaluation import eval_logic_expr + + +class ArithParser: + def __init__(self): + self.parser = Lark(grammar, parser="lalr", start="logic_expr") + self.parse = self.parser.parse + + def parse_and_eval(self, expression: str, env: Optional[Union[None, dict]] = None) -> Union[int, float]: + tree = self.parse(expression) + res = eval_logic_expr(tree, {} if env is None else env) + return res + + +def _parse_and_assert(expression: str, expected: Union[int, float]) -> None: + parser = ArithParser() + res = parser.parse_and_eval(expression) + assert expected == res + +def test_comparison(): + _parse_and_assert("3 > 5", False) + _parse_and_assert("3 >= 5", False) + _parse_and_assert("3 >= 3", True) + _parse_and_assert("3 >= 3", True) + _parse_and_assert("5 == 3 + 2", True) + _parse_and_assert("2 + 3 == 3 + 2", True) + _parse_and_assert("5 == 3", False) + _parse_and_assert("3 <= 5", True) + _parse_and_assert("3 <= 3", True) + _parse_and_assert("3 <= 2", False) + _parse_and_assert("3 < 5", True) + _parse_and_assert("3 != 5", True) + _parse_and_assert("5 != 5", False) + + +def test_logic_states(): + _parse_and_assert("True", True) + _parse_and_assert("False", False) + +def test_logic_operations(): + _parse_and_assert("False or True", True) + _parse_and_assert("False or False", False) + _parse_and_assert("True or True", True) + _parse_and_assert("True or False", True) + + _parse_and_assert("False and True", False) + _parse_and_assert("False and False", False) + _parse_and_assert("True and True", True) + _parse_and_assert("True and False", False) + +def test_logic_negation(): + _parse_and_assert("not True", False) + _parse_and_assert("not False", True) + + +