diff --git a/app/domain/dataapi/actions.py b/app/domain/dataapi/actions.py index 7a4d96b5..d4ed75fc 100644 --- a/app/domain/dataapi/actions.py +++ b/app/domain/dataapi/actions.py @@ -1,12 +1,8 @@ from typing import final -import astropy.units as u -from astropy import coordinates as coords - from app.data import model, repositories -from app.data.repositories import layer2 -from app.domain import expressions, responders -from app.domain.dataapi import parameterized_query +from app.domain import responders +from app.domain.dataapi import parameterized_query, search_parsers from app.presentation import dataapi ENABLED_CATALOGS = [ @@ -29,9 +25,7 @@ def __init__( ) def query(self, query: dataapi.QueryRequest) -> dataapi.QueryResponse: - expression = expressions.parse_expression(query.q) - filters, search_params = expression_to_filter(expression) - + filters, search_params = search_parsers.query_to_filters(query.q, search_parsers.DEFAULT_PARSERS) objects = self.layer2_repo.query( ENABLED_CATALOGS, filters, @@ -39,7 +33,6 @@ def query(self, query: dataapi.QueryRequest) -> dataapi.QueryResponse: query.page_size, query.page, ) - responder = responders.JSONResponder() pgc_objects = responder.build_response(objects) return dataapi.QueryResponse(objects=pgc_objects) @@ -49,55 +42,3 @@ def query_fits(self, query: dataapi.FITSRequest) -> bytes: def query_simple(self, query: dataapi.QuerySimpleRequest) -> dataapi.QuerySimpleResponse: return self.parameterized_query_manager.query_simple(query) - - -def parse_coordinates(coord_str: str) -> coords.SkyCoord: - try: - if coord_str.startswith(("J", "B")): - return coords.SkyCoord(coord_str[1:], unit=(u.Unit("hourangle"), u.Unit("deg"))) - - if coord_str.startswith("G"): - long, lat = map(float, coord_str[1:].split("+")) - coord = coords.SkyCoord(l=long * u.Unit("deg"), b=lat * u.Unit("deg"), frame="galactic") - return coord.transform_to("icrs") - - return coords.SkyCoord(coord_str, unit=(u.Unit("hourangle"), u.Unit("deg"))) - except Exception as e: - raise ValueError(f"Invalid coordinate format: {coord_str}") from e - - -def parse_function_node(node: expressions.FunctionNode) -> tuple[layer2.Filter, layer2.SearchParams]: - if node.function == expressions.FunctionName.PGC: - try: - pgc = int(node.value) - except ValueError as e: - raise ValueError(f"Invalid PGC value: '{node.value}'") from e - - return layer2.PGCOneOfFilter([pgc]), layer2.CombinedSearchParams([]) - if node.function == expressions.FunctionName.NAME: - return layer2.DesignationCloseFilter(2), layer2.DesignationSearchParams(node.value) - if node.function == expressions.FunctionName.POS: - return layer2.ICRSCoordinatesInRadiusFilter(1 * u.Unit("arcsec")), layer2.ICRSSearchParams( - coords=parse_coordinates(node.value) - ) - - raise ValueError(f"Unsupported function: {node.function}") - - -def expression_to_filter(expr: expressions.Node) -> tuple[layer2.Filter, layer2.SearchParams]: - if isinstance(expr, expressions.AndNode): - left_filter, left_search_params = expression_to_filter(expr.left) - right_filter, right_search_params = expression_to_filter(expr.right) - - return layer2.AndFilter([left_filter, right_filter]), layer2.CombinedSearchParams( - [left_search_params, right_search_params] - ) - if isinstance(expr, expressions.OrNode): - left_filter, left_search_params = expression_to_filter(expr.left) - right_filter, right_search_params = expression_to_filter(expr.right) - - return layer2.OrFilter([left_filter, right_filter]), layer2.CombinedSearchParams( - [left_search_params, right_search_params] - ) - - return parse_function_node(expr) diff --git a/app/domain/dataapi/search_parsers.py b/app/domain/dataapi/search_parsers.py new file mode 100644 index 00000000..76d7ee1e --- /dev/null +++ b/app/domain/dataapi/search_parsers.py @@ -0,0 +1,107 @@ +import abc +import re +from typing import final + +import astropy.units as u +from astropy import coordinates as coords + +from app.data.repositories import layer2 + +RADIUS_ARCSEC = 1 * u.Unit("arcsec") + +HMS_DMS_PATTERN = re.compile(r"^(\d+h\d+m[\d.]+s)([+-]\d+d\d+m[\d.]+s)$", re.IGNORECASE) +J_COORD_PATTERN = re.compile(r"^J(\d{2})(\d{2})([\d.]+)([+-])(\d{2})(\d{2})([\d.]+)$", re.IGNORECASE) + + +class SearchParser(abc.ABC): + @abc.abstractmethod + def parse(self, query: str) -> tuple[layer2.Filter, layer2.SearchParams] | None: + pass + + +@final +class NameSearchParser(SearchParser): + def parse(self, query: str) -> tuple[layer2.Filter, layer2.SearchParams] | None: + return ( + layer2.DesignationLikeFilter(), + layer2.DesignationSearchParams(query.strip()), + ) + + +@final +class HMSDMSCoordinateParser(SearchParser): + def parse(self, query: str) -> tuple[layer2.Filter, layer2.SearchParams] | None: + query = query.strip() + m = HMS_DMS_PATTERN.match(query) + if m is None: + return None + try: + ra_str = m.group(1).replace("h", ":").replace("m", ":").replace("s", "") + dec_str = m.group(2).replace("d", ":").replace("m", ":").replace("s", "") + sky_coord = coords.SkyCoord( + ra_str + " " + dec_str, + unit=(u.Unit("hourangle"), u.Unit("deg")), + ) + except Exception: + return None + return ( + layer2.ICRSCoordinatesInRadiusFilter(RADIUS_ARCSEC), + layer2.ICRSSearchParams(coords=sky_coord), + ) + + +def _parse_j_coord_to_skycoord(query: str) -> coords.SkyCoord: + m = J_COORD_PATTERN.match(query.strip()) + if m is None: + raise ValueError("Does not match J coordinate format") + ra_h, ra_m, ra_s, dec_sign, dec_d, dec_m, dec_s = m.groups() + ra_str = f"{ra_h}:{ra_m}:{ra_s}" + dec_str = f"{dec_sign}{dec_d}:{dec_m}:{dec_s}" + return coords.SkyCoord( + ra_str + " " + dec_str, + unit=(u.Unit("hourangle"), u.Unit("deg")), + ) + + +@final +class JCoordinateParser(SearchParser): + def parse(self, query: str) -> tuple[layer2.Filter, layer2.SearchParams] | None: + query = query.strip() + if not query.upper().startswith("J"): + return None + try: + sky_coord = _parse_j_coord_to_skycoord(query) + except (ValueError, Exception): + return None + return ( + layer2.ICRSCoordinatesInRadiusFilter(RADIUS_ARCSEC), + layer2.ICRSSearchParams(coords=sky_coord), + ) + + +DEFAULT_PARSERS: list[SearchParser] = [ + NameSearchParser(), + HMSDMSCoordinateParser(), + JCoordinateParser(), +] + + +def query_to_filters( + query: str, + parsers: list[SearchParser], +) -> tuple[layer2.Filter, layer2.SearchParams]: + results: list[tuple[layer2.Filter, layer2.SearchParams]] = [] + for parser in parsers: + parsed = parser.parse(query) + if parsed is not None: + results.append(parsed) + if not results: + return ( + layer2.DesignationLikeFilter(), + layer2.DesignationSearchParams(query.strip()), + ) + if len(results) == 1: + return results[0] + filters = [f for f, _ in results] + params = [p for _, p in results] + return layer2.OrFilter(filters), layer2.CombinedSearchParams(params) diff --git a/app/domain/expressions/__init__.py b/app/domain/expressions/__init__.py deleted file mode 100644 index 3559eba4..00000000 --- a/app/domain/expressions/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from app.domain.expressions.parser import AndNode, FunctionNode, Node, OrNode, parse_expression -from app.domain.expressions.tokenizer import FunctionName - -__all__ = ["FunctionName", "parse_expression", "Node", "OrNode", "AndNode", "FunctionNode"] diff --git a/app/domain/expressions/parser.py b/app/domain/expressions/parser.py deleted file mode 100644 index ae3d4b6c..00000000 --- a/app/domain/expressions/parser.py +++ /dev/null @@ -1,100 +0,0 @@ -from dataclasses import dataclass - -from app.domain.expressions import tokenizer - - -@dataclass -class AndNode: - left: "Node" - right: "Node" - - -@dataclass -class OrNode: - left: "Node" - right: "Node" - - -@dataclass -class FunctionNode: - function: tokenizer.FunctionName - value: str - - -Node = AndNode | OrNode | FunctionNode - -PRECEDENCE = { - tokenizer.OperatorName.AND: 1, - tokenizer.OperatorName.OR: 0, -} - - -def peek[T](stack: list[T]) -> T | None: - if stack: - return stack[-1] - - return None - - -def infix_to_postfix(tokens: list[tokenizer.Token]) -> list[tokenizer.Token]: - holding_stack: list[tokenizer.Token] = [] - output: list[tokenizer.Token] = [] - - for token in tokens: - if isinstance(token, tokenizer.FunctionToken): - output.append(token) - elif isinstance(token, tokenizer.OperatorToken): - last_item = peek(holding_stack) - while ( - last_item is not None - and isinstance(last_item, tokenizer.OperatorToken) - and PRECEDENCE[last_item.name] >= PRECEDENCE[token.name] - ): - output.append(holding_stack.pop()) - last_item = peek(holding_stack) - - holding_stack.append(token) - elif isinstance(token, tokenizer.LParenToken): - holding_stack.append(token) - else: # Right paren - last_item = peek(holding_stack) - while last_item is not None and not isinstance(last_item, tokenizer.LParenToken): - output.append(holding_stack.pop()) - last_item = peek(holding_stack) - - if isinstance(last_item, tokenizer.LParenToken): - holding_stack.pop() - else: - raise RuntimeError("Mismatched parentheses") - - while len(holding_stack) > 0: - last_item = holding_stack.pop() - if isinstance(last_item, tokenizer.LParenToken): - raise RuntimeError("Mismatched parentheses") - output.append(last_item) - - return output - - -def solve_postfix(tokens: list[tokenizer.Token]) -> Node: - stack: list[Node] = [] - - for token in tokens: - if isinstance(token, tokenizer.FunctionToken): - stack.append(FunctionNode(token.name, token.value)) - elif isinstance(token, tokenizer.OperatorToken): - right = stack.pop() - left = stack.pop() - - if token.name == tokenizer.OperatorName.AND: - stack.append(AndNode(left, right)) - elif token.name == tokenizer.OperatorName.OR: - stack.append(OrNode(left, right)) - - return stack[0] - - -def parse_expression(s: str) -> Node: - tokens = tokenizer.tokenize(s) - postfix_tokens = infix_to_postfix(tokens) - return solve_postfix(postfix_tokens) diff --git a/app/domain/expressions/tokenizer.py b/app/domain/expressions/tokenizer.py deleted file mode 100644 index f03cd5a3..00000000 --- a/app/domain/expressions/tokenizer.py +++ /dev/null @@ -1,125 +0,0 @@ -import enum -import re -from dataclasses import dataclass - -# Use https://regex101.com/ to explain these regex -function_call_pattern = r'^([a-z0-9-]+):((?:[a-zA-Z0-9_+.]+)|(?:"[^"]+"))' -operator_pattern = r"^([a-z-]+)\s+" - - -class OperatorName(enum.Enum): - AND = "and" - OR = "or" - - -@dataclass -class OperatorToken: - name: OperatorName - - -class FunctionName(enum.Enum): - NAME = "name" - POS = "pos" - PGC = "pgc" - - -@dataclass -class FunctionToken: - name: FunctionName - value: str - - -@dataclass -class LParenToken: - pass - - -@dataclass -class RParenToken: - pass - - -Token = OperatorToken | FunctionToken | LParenToken | RParenToken - - -def parse_function_call(s: str) -> tuple[FunctionToken, int] | None: - match = re.match(function_call_pattern, s) - if match is None: - return None - - groups = match.groups() - if len(groups) != 2: - raise RuntimeError(f"Unable to parse string: {s}") - - function_name_str, parameter = groups - function_name = None - - chars_consumed = match.end() - - if function_name_str == "pos": - function_name = FunctionName.POS - elif function_name_str == "name": - function_name = FunctionName.NAME - elif function_name_str == "pgc": - function_name = FunctionName.PGC - else: - raise RuntimeError(f"Unknown function: {function_name_str}") - - parameter = parameter.strip('"') - - return FunctionToken(function_name, parameter), chars_consumed - - -def parse_operator(s: str) -> tuple[OperatorToken, int] | None: - match = re.match(operator_pattern, s) - if match is None: - return None - - operator_str = match.group(1) - chars_consumed = match.end() - - if operator_str == "and": - return OperatorToken(OperatorName.AND), chars_consumed - - if operator_str == "or": - return OperatorToken(OperatorName.OR), 2 - - raise RuntimeError(f"Unknown operator: {operator_str}") - - -def tokenize(s: str) -> list[Token]: - tokens: list[Token] = [] - i = 0 - - while i < len(s): - if s[i].isspace(): - i += 1 - continue - - if s[i] == "(": - tokens.append(LParenToken()) - i += 1 - continue - - if s[i] == ")": - tokens.append(RParenToken()) - i += 1 - continue - - parsed_operator = parse_operator(s[i:]) - if parsed_operator is not None: - token, offset = parsed_operator - tokens.append(token) - i += offset - continue - - parsed_func = parse_function_call(s[i:]) - if parsed_func is not None: - token, offset = parsed_func - tokens.append(token) - i += offset - continue - - raise RuntimeError(f"Invalid syntax at position {i}: {s[i:]}") - - return tokens diff --git a/postgres/migrations/V019__designation_trgm_index.sql b/postgres/migrations/V019__designation_trgm_index.sql new file mode 100644 index 00000000..707a0c3d --- /dev/null +++ b/postgres/migrations/V019__designation_trgm_index.sql @@ -0,0 +1,6 @@ +/* pgmigrate-encoding: utf-8 */ + +CREATE EXTENSION IF NOT EXISTS pg_trgm; + +CREATE INDEX layer2_designation_design_trgm_idx + ON layer2.designation USING GIN (design gin_trgm_ops); diff --git a/tests/unit/domain/parse_coords_test.py b/tests/unit/domain/parse_coords_test.py index 6e6a2d0e..d581dde6 100644 --- a/tests/unit/domain/parse_coords_test.py +++ b/tests/unit/domain/parse_coords_test.py @@ -4,43 +4,132 @@ from astropy import coordinates as coords from parameterized import param, parameterized -from app.domain.dataapi.actions import parse_coordinates +from app.data.repositories import layer2 +from app.domain.dataapi.search_parsers import ( + HMSDMSCoordinateParser, + JCoordinateParser, + _parse_j_coord_to_skycoord, +) -class TestParseCoordinates(unittest.TestCase): +class TestJCoordinateParser(unittest.TestCase): def setUp(self): - pass + self.parser = JCoordinateParser() @parameterized.expand( [ param( - "J12:34:56+12:34:56", coords.SkyCoord("12:34:56 +12:34:56", unit=(u.Unit("hourangle"), u.Unit("deg"))) + "J123456+123456", + coords.SkyCoord( + "12:34:56 +12:34:56", + unit=(u.Unit("hourangle"), u.Unit("deg")), + ), ), param( - "B12:34:56-12:34:56", coords.SkyCoord("12:34:56 -12:34:56", unit=(u.Unit("hourangle"), u.Unit("deg"))) + "J001122.33+443322.1", + coords.SkyCoord( + "00:11:22.33 +44:33:22.1", + unit=(u.Unit("hourangle"), u.Unit("deg")), + ), ), param( - "G180.5+45.8", - coords.SkyCoord(l=180.5 * u.Unit("deg"), b=45.8 * u.Unit("deg"), frame="galactic").transform_to("icrs"), + "j001122+443322", + coords.SkyCoord( + "00:11:22 +44:33:22", + unit=(u.Unit("hourangle"), u.Unit("deg")), + ), ), + ] + ) + def test_accepts_valid_j_format(self, input_str, expected): + result = self.parser.parse(input_str) + assert result is not None + filter_obj, search_params = result + assert isinstance(filter_obj, layer2.ICRSCoordinatesInRadiusFilter) + assert isinstance(search_params, layer2.ICRSSearchParams) + assert search_params.get_params()["ra"] == expected.ra.deg + assert search_params.get_params()["dec"] == expected.dec.deg + + @parameterized.expand( + [ + "M33", + "12h30m49s+12d22m33s", + "J12:34:56+12:34:56", + "J123", + "J12345+123", + ] + ) + def test_rejects_invalid(self, input_str): + result = self.parser.parse(input_str) + assert result is None + + +class TestJCoordToSkyCoord(unittest.TestCase): + @parameterized.expand( + [ param( - "12:34:56 +12:34:56", coords.SkyCoord("12:34:56 +12:34:56", unit=(u.Unit("hourangle"), u.Unit("deg"))) + "J001122.33+443322.1", + coords.SkyCoord( + "00:11:22.33 +44:33:22.1", + unit=(u.Unit("hourangle"), u.Unit("deg")), + ), ), ] ) - def test_valid_coordinates(self, input_str, expected): - result = parse_coordinates(input_str) - self.assertIsInstance(result, coords.SkyCoord) - assert result.ra is not None - assert result.dec is not None - self.assertEqual(result.ra.deg, expected.ra.deg) - self.assertEqual(result.dec.deg, expected.dec.deg) - - @parameterized.expand(["invalid", "J12:34", "G180", "X12:34:56 +12:34:56"]) - def test_invalid_coordinates(self, input_str): - with self.assertRaisesRegex(ValueError, "Invalid coordinate format"): - parse_coordinates(input_str) - - -if __name__ == "__main__": - unittest.main() + def test_parsed_coordinates_match(self, input_str, expected): + result = _parse_j_coord_to_skycoord(input_str) + assert isinstance(result, coords.SkyCoord) + assert abs(result.ra.deg - expected.ra.deg) < 1e-5 + assert abs(result.dec.deg - expected.dec.deg) < 1e-5 + + def test_invalid_raises(self): + try: + _parse_j_coord_to_skycoord("not-j-format") + raise AssertionError("expected ValueError") + except ValueError: + pass + + +class TestHMSDMSCoordinateParser(unittest.TestCase): + def setUp(self): + self.parser = HMSDMSCoordinateParser() + + @parameterized.expand( + [ + param( + "12h32m22s+15d22m45s", + coords.SkyCoord( + "12:32:22 +15:22:45", + unit=(u.Unit("hourangle"), u.Unit("deg")), + ), + ), + param( + "12h30m49.32s+12d22m33.2s", + coords.SkyCoord( + "12:30:49.32 +12:22:33.2", + unit=(u.Unit("hourangle"), u.Unit("deg")), + ), + ), + ] + ) + def test_accepts_valid_hms_dms_format(self, input_str, expected): + result = self.parser.parse(input_str) + assert result is not None + filter_obj, search_params = result + assert isinstance(filter_obj, layer2.ICRSCoordinatesInRadiusFilter) + assert isinstance(search_params, layer2.ICRSSearchParams) + assert abs(search_params.get_params()["ra"] - expected.ra.deg) < 1e-5 + assert abs(search_params.get_params()["dec"] - expected.dec.deg) < 1e-5 + + @parameterized.expand( + [ + "M33", + "J123049.32+122233.2", + "12:30:49 +12:22:33", + "12h30m49s", + "invalid", + ] + ) + def test_rejects_invalid(self, input_str): + result = self.parser.parse(input_str) + assert result is None diff --git a/tests/unit/domain/parser_test.py b/tests/unit/domain/parser_test.py deleted file mode 100644 index 98d9be03..00000000 --- a/tests/unit/domain/parser_test.py +++ /dev/null @@ -1,138 +0,0 @@ -import unittest - -from parameterized import param, parameterized - -from app.domain.expressions import parser, tokenizer - - -class InfixToPostfixTest(unittest.TestCase): - @parameterized.expand( - [ - param( - "no operators", - [ - tokenizer.LParenToken(), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.RParenToken(), - ], - [tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33")], - ), - param( - "single operator", - [ - tokenizer.LParenToken(), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.RParenToken(), - ], - [ - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - ], - ), - param( - "multiple operators", - [ - tokenizer.LParenToken(), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.OperatorToken(tokenizer.OperatorName.OR), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M35"), - tokenizer.RParenToken(), - ], - [ - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M35"), - tokenizer.OperatorToken(tokenizer.OperatorName.OR), - ], - ), - param( - "nested operators", - [ - tokenizer.LParenToken(), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - tokenizer.LParenToken(), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.OperatorToken(tokenizer.OperatorName.OR), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M35"), - tokenizer.RParenToken(), - tokenizer.RParenToken(), - ], - [ - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M35"), - tokenizer.OperatorToken(tokenizer.OperatorName.OR), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - ], - ), - ] - ) - def test_happy(self, name, input_tokens, expected): - actual = parser.infix_to_postfix(input_tokens) - self.assertEqual(actual, expected) - - @parameterized.expand( - [ - param( - "mismatched paren", - [tokenizer.LParenToken(), tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33")], - "Mismatched parentheses", - ), - ] - ) - def test_fails(self, name, input_tokens, expected_err): - with self.assertRaises(RuntimeError) as err: - _ = parser.infix_to_postfix(input_tokens) - - self.assertIn(expected_err, str(err.exception)) - - -class SolvePostfixTest(unittest.TestCase): - @parameterized.expand( - [ - param( - "single function", - [tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33")], - parser.FunctionNode(tokenizer.FunctionName.NAME, "M33"), - ), - param( - "single operator", - [ - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - ], - parser.AndNode( - left=parser.FunctionNode(tokenizer.FunctionName.NAME, "M33"), - right=parser.FunctionNode(tokenizer.FunctionName.NAME, "M34"), - ), - ), - param( - "multiple operators", - [ - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M33"), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M34"), - tokenizer.OperatorToken(tokenizer.OperatorName.AND), - tokenizer.FunctionToken(tokenizer.FunctionName.NAME, "M35"), - tokenizer.OperatorToken(tokenizer.OperatorName.OR), - ], - parser.OrNode( - left=parser.AndNode( - left=parser.FunctionNode(tokenizer.FunctionName.NAME, "M33"), - right=parser.FunctionNode(tokenizer.FunctionName.NAME, "M34"), - ), - right=parser.FunctionNode(tokenizer.FunctionName.NAME, "M35"), - ), - ), - ] - ) - def test_happy(self, name, input_tokens, expected): - actual = parser.solve_postfix(input_tokens) - self.assertEqual(actual, expected) diff --git a/tests/unit/domain/search_handler_test.py b/tests/unit/domain/search_handler_test.py new file mode 100644 index 00000000..01e246b7 --- /dev/null +++ b/tests/unit/domain/search_handler_test.py @@ -0,0 +1,41 @@ +import unittest + +from app.data.repositories import layer2 +from app.domain.dataapi.search_parsers import ( + DEFAULT_PARSERS, + query_to_filters, +) + + +class TestQueryToFiltersPipeline(unittest.TestCase): + def test_plain_name_produces_designation_like_only(self): + filters, search_params = query_to_filters("M33", DEFAULT_PARSERS) + assert isinstance(filters, layer2.DesignationLikeFilter) + assert isinstance(search_params, layer2.DesignationSearchParams) + assert search_params.get_params()["design"] == "M33" + + def test_j_coord_produces_or_filter_with_name_and_coords(self): + filters, search_params = query_to_filters("J123049.32+122233.2", DEFAULT_PARSERS) + assert isinstance(filters, layer2.OrFilter) + assert isinstance(search_params, layer2.CombinedSearchParams) + params = search_params.get_params() + assert "design" in params + assert "ra" in params + assert "dec" in params + assert " OR " in filters.get_query() + + def test_hms_dms_coord_produces_or_filter_with_name_and_coords(self): + filters, search_params = query_to_filters("12h30m49.32s+12d22m33.2s", DEFAULT_PARSERS) + assert isinstance(filters, layer2.OrFilter) + assert isinstance(search_params, layer2.CombinedSearchParams) + params = search_params.get_params() + assert "design" in params + assert "ra" in params + assert "dec" in params + assert " OR " in filters.get_query() + + def test_non_coord_string_produces_only_designation_like(self): + filters, search_params = query_to_filters("some random object name", DEFAULT_PARSERS) + assert isinstance(filters, layer2.DesignationLikeFilter) + assert isinstance(search_params, layer2.DesignationSearchParams) + assert search_params.get_params()["design"] == "some random object name" diff --git a/tests/unit/domain/tokenizer_test.py b/tests/unit/domain/tokenizer_test.py deleted file mode 100644 index 037e9d22..00000000 --- a/tests/unit/domain/tokenizer_test.py +++ /dev/null @@ -1,135 +0,0 @@ -import unittest - -from parameterized import param, parameterized - -from app.domain.expressions.tokenizer import ( - FunctionName, - FunctionToken, - LParenToken, - OperatorName, - OperatorToken, - RParenToken, - parse_function_call, - parse_operator, - tokenize, -) - - -class ParserTest(unittest.TestCase): - @parameterized.expand( - [ - param("name:M33", FunctionToken(FunctionName.NAME, "M33")), - param("pos:J123049.32+122233.2", FunctionToken(FunctionName.POS, "J123049.32+122233.2")), - param( - "pos:\"12h 30m 49.32s +12d 23' 33.2''\"", - FunctionToken(FunctionName.POS, "12h 30m 49.32s +12d 23' 33.2''"), - ), - param("pos:B122817.46+123907.1", FunctionToken(FunctionName.POS, "B122817.46+123907.1")), - param("pos:G283.79325+74.47647", FunctionToken(FunctionName.POS, "G283.79325+74.47647")), - param("pos:M33", FunctionToken(FunctionName.POS, "M33")), - param("pgc:123456", FunctionToken(FunctionName.PGC, "123456")), - param( - "pos:\"12h 30m 49.32s +12d 23' 33.2''\" there is some additional text", - FunctionToken(FunctionName.POS, "12h 30m 49.32s +12d 23' 33.2''"), - ), - param("pgc:123456 some text", FunctionToken(FunctionName.PGC, "123456")), - param("totally not function", None), - param("and (name:M33)", None), - ], - ) - def test_parse_function_call_happy(self, s, expected): - actual = parse_function_call(s) - - if actual is None: - self.assertEqual(actual, expected) - else: - actual_token, _ = actual - self.assertEqual(actual_token, expected) - - @parameterized.expand( - [ - param("nonexistingfunc:M33", "Unknown function"), - ], - ) - def test_parse_function_call_errors(self, s, err_substr): - with self.assertRaises(RuntimeError) as err: - _ = parse_function_call(s) - - self.assertIn(err_substr, str(err.exception)) - - @parameterized.expand( - [ - param("and ", OperatorToken(OperatorName.AND)), - param("and name:M33 or name:M44", OperatorToken(OperatorName.AND)), - param("or (((", OperatorToken(OperatorName.OR)), - param("not:an:operator", None), - ] - ) - def test_parse_operator_happy(self, s, expected): - actual = parse_operator(s) - - if actual is None: - self.assertEqual(actual, expected) - else: - actual_token, _ = actual - self.assertEqual(actual_token, expected) - - @parameterized.expand( - [ - param("nonexistentoperator ", "Unknown operator"), - ], - ) - def test_parse_operator_errors(self, s, err_substr): - with self.assertRaises(RuntimeError) as err: - _ = parse_operator(s) - - self.assertIn(err_substr, str(err.exception)) - - -class TokenizerTest(unittest.TestCase): - @parameterized.expand( - [ - param("name:M33", [FunctionToken(FunctionName.NAME, "M33")]), - param("(name:M33)", [LParenToken(), FunctionToken(FunctionName.NAME, "M33"), RParenToken()]), - param( - "and (name:M33)", - [ - OperatorToken(OperatorName.AND), - LParenToken(), - FunctionToken(FunctionName.NAME, "M33"), - RParenToken(), - ], - ), - param( - "(name:M33) or (pgc:123456)", - [ - LParenToken(), - FunctionToken(FunctionName.NAME, "M33"), - RParenToken(), - OperatorToken(OperatorName.OR), - LParenToken(), - FunctionToken(FunctionName.PGC, "123456"), - RParenToken(), - ], - ), - param( - "((name:M33) or (pgc:123456)) and name:M44", - [ - LParenToken(), - LParenToken(), - FunctionToken(FunctionName.NAME, "M33"), - RParenToken(), - OperatorToken(OperatorName.OR), - LParenToken(), - FunctionToken(FunctionName.PGC, "123456"), - RParenToken(), - RParenToken(), - OperatorToken(OperatorName.AND), - FunctionToken(FunctionName.NAME, "M44"), - ], - ), - ] - ) - def test_happy(self, s, expected): - actual = tokenize(s) - self.assertEqual(actual, expected)