From 70c15119a5e0cb27d9999cef64d67137fa59d9a9 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sun, 22 Feb 2026 12:09:21 +0000 Subject: [PATCH 1/4] remove config for tasks --- .env.example | 5 +++++ .gitignore | 1 + app/commands/runtask/command.py | 14 +------------- app/lib/storage/postgres/config.py | 10 +++++----- app/tasks/interface.py | 2 +- configs/dev/tasks.yaml | 6 ------ main.py | 11 ++--------- tests/env_test.py | 9 --------- tests/regression/upload_simple_table.py | 4 ---- 9 files changed, 15 insertions(+), 47 deletions(-) create mode 100644 .env.example delete mode 100644 configs/dev/tasks.yaml diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..10dc7b72 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +STORAGE_ENDPOINT=localhost +STORAGE_PORT=6432 +STORAGE_DBNAME=hyperleda +STORAGE_USER=hyperleda +STORAGE_PASSWORD=password diff --git a/.gitignore b/.gitignore index fe0c64dd..e8bf489f 100644 --- a/.gitignore +++ b/.gitignore @@ -102,6 +102,7 @@ celerybeat.pid # Environments .env.* .env +!.env.example .venv env/ venv/ diff --git a/app/commands/runtask/command.py b/app/commands/runtask/command.py index 64f9b2c0..d57c7a96 100644 --- a/app/commands/runtask/command.py +++ b/app/commands/runtask/command.py @@ -3,7 +3,6 @@ from typing import Any, final import structlog -import yaml from app import tasks from app.lib import commands @@ -14,7 +13,6 @@ class RunTaskCommand(commands.Command): def __init__( self, task_name: str, - config_path: str, input_data_path: str | None = None, input_data: dict[str, Any] | None = None, task_args: tuple[str, ...] | None = None, @@ -26,7 +24,6 @@ def __init__( task_args = () self.task_name = task_name - self.config_path = config_path self.input_data_path = input_data_path self.input_data = input_data self.task_args = task_args @@ -40,7 +37,7 @@ def help(cls) -> str: """ def prepare(self): - cfg = parse_config(self.config_path) + cfg = tasks.Config() input_data = self.input_data @@ -86,12 +83,3 @@ def _parse_task_args(self, task_args: tuple[str, ...]) -> dict[str, Any]: i += 1 return args_dict - - -def parse_config(path: str) -> tasks.Config: - p = Path(path) - if not p.is_file(): - raise FileNotFoundError(f"Config file not found: '{path}'") - - data = yaml.safe_load(p.read_text()) - return tasks.Config(**data) diff --git a/app/lib/storage/postgres/config.py b/app/lib/storage/postgres/config.py index 2d7e3bc8..d2e943c1 100644 --- a/app/lib/storage/postgres/config.py +++ b/app/lib/storage/postgres/config.py @@ -6,11 +6,11 @@ class PgStorageConfig(config.ConfigSettings): model_config = settings.SettingsConfigDict(env_prefix="STORAGE_") - endpoint: str - port: int - dbname: str - user: str - password: str + endpoint: str = "localhost" + port: int = 6432 + dbname: str = "hyperleda" + user: str = "hyperleda" + password: str = "password" def get_dsn(self) -> str: # TODO: SSL and other options like transaction timeout diff --git a/app/tasks/interface.py b/app/tasks/interface.py index add62443..68c43839 100644 --- a/app/tasks/interface.py +++ b/app/tasks/interface.py @@ -5,7 +5,7 @@ class Config(config.ConfigSettings): - storage: postgres.PgStorageConfig + storage: postgres.PgStorageConfig = postgres.PgStorageConfig() class Task(abc.ABC): diff --git a/configs/dev/tasks.yaml b/configs/dev/tasks.yaml deleted file mode 100644 index dcf8270c..00000000 --- a/configs/dev/tasks.yaml +++ /dev/null @@ -1,6 +0,0 @@ -storage: - endpoint: localhost - port: 6432 - dbname: hyperleda - user: hyperleda - password: password diff --git a/main.py b/main.py index 0469be1b..b1b6a60c 100644 --- a/main.py +++ b/main.py @@ -44,13 +44,6 @@ def dataapi(config: str): required=True, type=str, ) -@click.option( - "-c", - "--config", - type=str, - default=lambda: os.environ.get("CONFIG", ""), - help="Path to configuration file", -) @click.option( "-i", "--input-data", @@ -64,8 +57,8 @@ def dataapi(config: str): help="Set the logging level", ) @click.argument("task_args", nargs=-1, type=click.UNPROCESSED) -def runtask(task_name: str, config: str, input_data: str | None, log_level: str, task_args: tuple[str, ...]): - commands.run(RunTaskCommand(task_name, config, input_data, None, task_args, log_level)) +def runtask(task_name: str, input_data: str | None, log_level: str, task_args: tuple[str, ...]): + commands.run(RunTaskCommand(task_name, input_data, None, task_args, log_level)) @cli.command(short_help=GenerateSpecCommand.help()) diff --git a/tests/env_test.py b/tests/env_test.py index bb6b605f..935bf596 100644 --- a/tests/env_test.py +++ b/tests/env_test.py @@ -5,7 +5,6 @@ import app.commands.adminapi.command as adminapi import app.commands.dataapi.command as dataapi -import app.commands.runtask.command as runtask MINIMAL_PYTHON_VERSION = (3, 10) @@ -38,11 +37,3 @@ def test_parse_adminapi_config(self, path): ) def test_parse_dataapi_config(self, path): _ = dataapi.parse_config(path) - - @parameterized.expand( - [ - ("configs/dev/tasks.yaml"), - ] - ) - def test_parse_runtask_config(self, path): - _ = runtask.parse_config(path) diff --git a/tests/regression/upload_simple_table.py b/tests/regression/upload_simple_table.py index 31b36a1f..e09931f5 100644 --- a/tests/regression/upload_simple_table.py +++ b/tests/regression/upload_simple_table.py @@ -144,7 +144,6 @@ def start_marking(table_name: str): commands.run( RunTaskCommand( "layer0-marking", - "configs/dev/tasks.yaml", input_data={"table_name": table_name, "batch_size": 200, "workers": 8}, log_level="warn", ), @@ -156,7 +155,6 @@ def start_crossmatch(table_name: str): commands.run( RunTaskCommand( "crossmatch", - "configs/dev/tasks.yaml", input_data={"table_name": table_name}, log_level="warn", ), @@ -323,7 +321,6 @@ def submit_crossmatch(table_name: str): commands.run( RunTaskCommand( "submit-crossmatch", - "configs/dev/tasks.yaml", input_data={"table_name": table_name, "batch_size": OBJECTS_NUM // 2}, log_level="warn", ), @@ -335,7 +332,6 @@ def layer2_import(): commands.run( RunTaskCommand( "layer2-import", - "configs/dev/tasks.yaml", input_data={"batch_size": OBJECTS_NUM // 5}, log_level="warn", ), From 638e1484c3ea20971c233d03d85ef0a76a398868 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sun, 22 Feb 2026 12:20:27 +0000 Subject: [PATCH 2/4] rename modifier_name to name and added example to PATCH table method --- app/domain/adminapi/table_upload.py | 2 +- app/presentation/adminapi/interface.py | 2 +- app/presentation/adminapi/server.py | 138 ++++++++++++++++++++++++- tests/integration/create_table_test.py | 4 +- 4 files changed, 141 insertions(+), 5 deletions(-) diff --git a/app/domain/adminapi/table_upload.py b/app/domain/adminapi/table_upload.py index 0414c3f1..b0d9ba9b 100644 --- a/app/domain/adminapi/table_upload.py +++ b/app/domain/adminapi/table_upload.py @@ -76,7 +76,7 @@ def patch_table(self, r: adminapi.PatchTableRequest) -> adminapi.PatchTableRespo self.layer0_repo.update_column_metadata(r.table_name, column_metadata) if spec.modifiers is not None: - modifiers = [model.Modifier(column_name, m.modifier_name, m.params) for m in spec.modifiers] + modifiers = [model.Modifier(column_name, m.name, m.params) for m in spec.modifiers] self.layer0_repo.set_modifiers(r.table_name, column_name, modifiers) return adminapi.PatchTableResponse() diff --git a/app/presentation/adminapi/interface.py b/app/presentation/adminapi/interface.py index dd81f98d..9256998a 100644 --- a/app/presentation/adminapi/interface.py +++ b/app/presentation/adminapi/interface.py @@ -129,7 +129,7 @@ class AddDataResponse(pydantic.BaseModel): class ModifierSpec(pydantic.BaseModel): - modifier_name: str + name: str params: dict[str, Any] = {} diff --git a/app/presentation/adminapi/server.py b/app/presentation/adminapi/server.py index 89846ea9..a586d310 100644 --- a/app/presentation/adminapi/server.py +++ b/app/presentation/adminapi/server.py @@ -138,7 +138,143 @@ def __init__( http.HTTPMethod.PATCH, api.patch_table, "Patch table schema", - "Patch the schema of the table", + """Patches the schema of the table. Allows updating column metadata (UCD, unit, description) and +setting column modifiers. Modifiers are transformations applied to column values during the unification process. + +Only provided fields will be updated; omitted fields will remain unchanged. + +**Example 1**: Update column metadata (UCD and unit): +```json +{ + "table_name": "my_table", + "columns": { + "ra": { + "ucd": "pos.eq.ra", + "unit": "hourangle" + }, + "dec": { + "ucd": "pos.eq.dec", + "unit": "deg" + } + } +} +``` + +**Example 2**: Add a column description: +```json +{ + "table_name": "my_table", + "columns": { + "vmag": { + "description": "Visual magnitude in the V band" + } + } +} +``` + +**Example 3**: Set a `map` modifier to convert categorical string values to numeric ones. +For instance, mapping morphological types to numeric codes: +```json +{ + "table_name": "my_table", + "columns": { + "morph_type": { + "modifiers": [ + { + "name": "map", + "params": { + "mapping": [ + {"from": "E", "to": -5}, + {"from": "S0", "to": 0}, + {"from": "Sa", "to": 1}, + {"from": "Sb", "to": 3} + ], + "default": null + } + } + ] + } + } +} +``` + +**Example 4**: Set a `format` modifier to reformat string values using a Python format pattern: +```json +{ + "table_name": "my_table", + "columns": { + "obj_id": { + "modifiers": [ + { + "name": "format", + "params": { + "pattern": "SDSS J{}" + } + } + ] + } + } +} +``` + +**Example 5**: Set an `add_unit` modifier to override the unit attached to a column during processing: +```json +{ + "table_name": "my_table", + "columns": { + "velocity": { + "modifiers": [ + { + "name": "add_unit", + "params": { + "unit": "km/s" + } + } + ] + } + } +} +``` + +**Example 6**: Set a `constant` modifier to replace all values in a column with a fixed value: +```json +{ + "table_name": "my_table", + "columns": { + "survey": { + "modifiers": [ + { + "name": "constant", + "params": { + "constant": "SDSS" + } + } + ] + } + } +} +``` + +**Example 7**: Combine metadata updates with modifiers in a single request: +```json +{ + "table_name": "my_table", + "columns": { + "ra": { + "ucd": "pos.eq.ra", + "unit": "deg", + "modifiers": [ + { + "name": "add_unit", + "params": { + "unit": "hourangle" + } + } + ] + } + } +} +```""", ), server.Route( "/v1/login", diff --git a/tests/integration/create_table_test.py b/tests/integration/create_table_test.py index 2cf84ab7..58f694a9 100644 --- a/tests/integration/create_table_test.py +++ b/tests/integration/create_table_test.py @@ -115,8 +115,8 @@ def test_create_table_with_patch_modifiers(self): columns={ "ra": presentation.PatchColumnSpec( modifiers=[ - presentation.ModifierSpec(modifier_name="constant", params={"constant": 1}), - presentation.ModifierSpec(modifier_name="add_unit", params={"unit": "deg"}), + presentation.ModifierSpec(name="constant", params={"constant": 1}), + presentation.ModifierSpec(name="add_unit", params={"unit": "deg"}), ] ), }, From 8e5a0bc9fc056afa8fdcfecf05c869b1643aa56f Mon Sep 17 00:00:00 2001 From: kraysent Date: Sun, 22 Feb 2026 16:14:01 +0000 Subject: [PATCH 3/4] Add modification time to return value for the /tables method --- app/data/model/table.py | 1 + app/data/repositories/layer0/tables.py | 4 +++- app/domain/adminapi/table_upload.py | 1 + app/presentation/adminapi/interface.py | 2 ++ tests/unit/data/layer0_repository_test.py | 2 ++ 5 files changed, 9 insertions(+), 1 deletion(-) diff --git a/app/data/model/table.py b/app/data/model/table.py index 1038f361..a90891db 100644 --- a/app/data/model/table.py +++ b/app/data/model/table.py @@ -41,6 +41,7 @@ class Layer0TableListItem: description: str num_entries: int num_fields: int + modification_dt: datetime.datetime @dataclass diff --git a/app/data/repositories/layer0/tables.py b/app/data/repositories/layer0/tables.py index 1ab6bdb9..93c52ed2 100644 --- a/app/data/repositories/layer0/tables.py +++ b/app/data/repositories/layer0/tables.py @@ -326,6 +326,7 @@ def search_tables( sql = """ SELECT t.table_name, + t.modification_dt, COALESCE(ti.param->>'description', '') AS description, COALESCE(ps.n_live_tup::bigint, 0)::int AS num_entries, ( @@ -341,7 +342,7 @@ def search_tables( LEFT JOIN pg_stat_user_tables ps ON ps.schemaname = %s AND ps.relname = t.table_name WHERE t.table_name ILIKE %s OR COALESCE(ti.param->>'description', '') ILIKE %s - ORDER BY t.table_name + ORDER BY t.modification_dt DESC LIMIT %s OFFSET %s """ params = [ @@ -361,6 +362,7 @@ def search_tables( description=row["description"] or "", num_entries=int(row["num_entries"]), num_fields=int(row["num_fields"]), + modification_dt=row["modification_dt"], ) for row in rows ] diff --git a/app/domain/adminapi/table_upload.py b/app/domain/adminapi/table_upload.py index b0d9ba9b..dbc3533b 100644 --- a/app/domain/adminapi/table_upload.py +++ b/app/domain/adminapi/table_upload.py @@ -169,6 +169,7 @@ def get_table_list(self, r: adminapi.GetTableListRequest) -> adminapi.GetTableLi description=item.description, num_entries=item.num_entries, num_fields=item.num_fields, + modification_dt=item.modification_dt, ) for item in items ] diff --git a/app/presentation/adminapi/interface.py b/app/presentation/adminapi/interface.py index 9256998a..de6f8a8f 100644 --- a/app/presentation/adminapi/interface.py +++ b/app/presentation/adminapi/interface.py @@ -1,4 +1,5 @@ import abc +import datetime import enum from typing import Annotated, Any @@ -62,6 +63,7 @@ class TableListItem(pydantic.BaseModel): description: str num_entries: int num_fields: int + modification_dt: datetime.datetime class GetTableListResponse(pydantic.BaseModel): diff --git a/tests/unit/data/layer0_repository_test.py b/tests/unit/data/layer0_repository_test.py index ef8dbf98..dea2eb55 100644 --- a/tests/unit/data/layer0_repository_test.py +++ b/tests/unit/data/layer0_repository_test.py @@ -1,3 +1,4 @@ +import datetime import unittest import uuid from unittest import mock @@ -63,6 +64,7 @@ def test_search_tables_calls_query_with_expected_structure(self): "description": "A test table", "num_entries": 100, "num_fields": 6, + "modification_dt": datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC), } ] From 8a189b0f61c974a03c63cd65b0e2d3f33a702158 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sun, 22 Feb 2026 16:32:30 +0000 Subject: [PATCH 4/4] remove sql injections --- app/data/repositories/layer0/records.py | 8 +- app/data/repositories/layer0/tables.py | 91 ++++++++++++++------ app/data/template.py | 25 +++--- app/lib/storage/postgres/postgres_storage.py | 17 ++-- tests/unit/data/layer0_repository_test.py | 27 +++--- 5 files changed, 111 insertions(+), 57 deletions(-) diff --git a/app/data/repositories/layer0/records.py b/app/data/repositories/layer0/records.py index efc0f6fa..8ad71741 100644 --- a/app/data/repositories/layer0/records.py +++ b/app/data/repositories/layer0/records.py @@ -1,6 +1,8 @@ import json from collections.abc import Sequence +from psycopg import sql + from app.data import model, template from app.data.repositories.layer0.common import RAWDATA_SCHEMA from app.lib import concurrency @@ -133,7 +135,11 @@ def get_table_statistics(self, table_name: str) -> model.TableStatistics: params=[table_id], ) total_original_rows_res = errgr.run( - self._storage.query_one, f'SELECT COUNT(1) AS cnt FROM {RAWDATA_SCHEMA}."{table_name}"' + self._storage.query_one, + sql.SQL("SELECT COUNT(1) AS cnt FROM {}.{}").format( + sql.Identifier(RAWDATA_SCHEMA), + sql.Identifier(table_name), + ), ) errgr.wait() diff --git a/app/data/repositories/layer0/tables.py b/app/data/repositories/layer0/tables.py index 93c52ed2..c665586f 100644 --- a/app/data/repositories/layer0/tables.py +++ b/app/data/repositories/layer0/tables.py @@ -7,6 +7,7 @@ import structlog from astropy import table from astropy import units as u +from psycopg import sql from app.data import model, repositories, template from app.data.repositories.layer0.common import INTERNAL_ID_COLUMN_NAME, RAWDATA_SCHEMA @@ -55,8 +56,7 @@ def create_table(self, data: model.Layer0TableMeta) -> model.Layer0CreationRespo table_id = int(row.get("id")) self._storage.exec( - template.render_query( - template.CREATE_TABLE, + template.build_create_table_query( schema=RAWDATA_SCHEMA, name=data.table_name, fields=fields, @@ -85,7 +85,7 @@ def insert_raw_data(self, data: model.Layer0RawData) -> None: log.warn("trying to insert 0 rows into the table", table_name=data.table_name) return - fields = data.data.columns + fields = list(data.data.columns) values = [] params = [] @@ -98,15 +98,16 @@ def insert_raw_data(self, data: model.Layer0RawData) -> None: params.append(value) - values.append(f"({','.join(['%s'] * len(fields))})") + values.append(sql.SQL("({})").format(sql.SQL(",").join([sql.Placeholder()] * len(fields)))) - fields = [f'"{field}"' for field in fields] + field_identifiers = sql.SQL(",").join([sql.Identifier(f) for f in fields]) - query = f""" - INSERT INTO rawdata."{data.table_name}" ({",".join(fields)}) - VALUES {",".join(values)} - ON CONFLICT DO NOTHING - """ + query = sql.SQL("INSERT INTO {}.{} ({}) VALUES {} ON CONFLICT DO NOTHING").format( + sql.Identifier(RAWDATA_SCHEMA), + sql.Identifier(data.table_name), + field_identifiers, + sql.SQL(",").join(values), + ) self._storage.exec(query, params=params) @@ -130,25 +131,39 @@ def fetch_table( meta = self.fetch_metadata_by_name(table_name) - columns_str = ",".join(columns or ["*"]) + if columns: + columns_sql = sql.SQL(",").join([sql.Identifier(c) for c in columns]) + else: + columns_sql = sql.SQL("*") params = [] - query = f""" - SELECT {columns_str} FROM {RAWDATA_SCHEMA}."{table_name}"\n - """ + parts: list[sql.Composable] = [ + sql.SQL("SELECT {} FROM {}.{}").format( + columns_sql, + sql.Identifier(RAWDATA_SCHEMA), + sql.Identifier(table_name), + ) + ] if offset is not None: - query += f"WHERE {repositories.INTERNAL_ID_COLUMN_NAME} > %s\n" + parts.append( + sql.SQL(" WHERE {} > %s").format( + sql.Identifier(repositories.INTERNAL_ID_COLUMN_NAME), + ) + ) params.append(offset) if order_column is not None: - query += f"ORDER BY {order_column} {order_direction}\n" + if order_direction not in ("asc", "desc"): + raise ValueError(f"invalid order direction: {order_direction}") + parts.append(sql.SQL(" ORDER BY {} ").format(sql.Identifier(order_column))) + parts.append(sql.SQL(order_direction)) if limit is not None: - query += "LIMIT %s\n" + parts.append(sql.SQL(" LIMIT %s")) params.append(limit) - rows = self._storage.query(query, params=params) + rows = self._storage.query(sql.Composed(parts), params=params) df = pandas.DataFrame(rows) tbl = table.Table() if len(df) == 0: @@ -198,31 +213,53 @@ def fetch_raw_data( if table_name is None: raise ValueError("either table_name or record_id must be provided") - columns_str = ",".join(columns or ["*"]) + if columns: + columns_sql = sql.SQL(",").join([sql.Identifier(c) for c in columns]) + else: + columns_sql = sql.SQL("*") params = [] - where_stmnt = [] + where_parts: list[sql.Composable] = [] if offset is not None: - where_stmnt.append(f"{repositories.INTERNAL_ID_COLUMN_NAME} > %s") + where_parts.append( + sql.SQL("{} > %s").format( + sql.Identifier(repositories.INTERNAL_ID_COLUMN_NAME), + ) + ) params.append(offset) if record_id is not None: - where_stmnt.append(f"{INTERNAL_ID_COLUMN_NAME} = %s") + where_parts.append( + sql.SQL("{} = %s").format( + sql.Identifier(INTERNAL_ID_COLUMN_NAME), + ) + ) params.append(record_id) - where_clause = f"WHERE {' AND '.join(where_stmnt)}" if where_stmnt else "" + parts: list[sql.Composable] = [ + sql.SQL("SELECT {} FROM {}.{}").format( + columns_sql, + sql.Identifier(RAWDATA_SCHEMA), + sql.Identifier(table_name), + ) + ] - query = f'SELECT {columns_str} FROM {RAWDATA_SCHEMA}."{table_name}" {where_clause}\n' + if where_parts: + parts.append(sql.SQL(" WHERE ")) + parts.append(sql.SQL(" AND ").join(where_parts)) if order_column is not None: - query += f"ORDER BY {order_column} {order_direction}\n" + if order_direction not in ("asc", "desc"): + raise ValueError(f"invalid order direction: {order_direction}") + parts.append(sql.SQL(" ORDER BY {} ").format(sql.Identifier(order_column))) + parts.append(sql.SQL(order_direction)) if limit is not None: - query += "LIMIT %s\n" + parts.append(sql.SQL(" LIMIT %s")) params.append(limit) - rows = self._storage.query(query, params=params) + rows = self._storage.query(sql.Composed(parts), params=params) return model.Layer0RawData(table_name, pandas.DataFrame(rows)) def _resolve_table_name(self, record_id: str) -> str | None: diff --git a/app/data/template.py b/app/data/template.py index dbcb1239..c166baf1 100644 --- a/app/data/template.py +++ b/app/data/template.py @@ -1,10 +1,19 @@ -import jinja2 +from psycopg import sql -def render_query(query_string: str, **kwargs) -> str: - tpl = jinja2.Environment(loader=jinja2.BaseLoader()).from_string(query_string) +def build_create_table_query(schema: str, name: str, fields: list[tuple[str, str, str]]) -> sql.Composed: + field_parts = [] + for field_name, field_type, constraint in fields: + parts = [sql.Identifier(field_name), sql.SQL(" "), sql.SQL(field_type)] + if constraint: + parts.extend([sql.SQL(" "), sql.SQL(constraint)]) + field_parts.append(sql.Composed(parts)) - return tpl.render(**kwargs) + return sql.SQL("CREATE TABLE {}.{} ({})").format( + sql.Identifier(schema), + sql.Identifier(name), + sql.SQL(", ").join(field_parts), + ) GET_SOURCE_BY_CODE = """ @@ -19,14 +28,6 @@ def render_query(query_string: str, **kwargs) -> str: LIMIT 1 """ -CREATE_TABLE = """ -CREATE TABLE {{ schema }}."{{ name }}" ( - {% for field_name, field_type, constraint in fields %} - "{{field_name}}" {{field_type}} {{constraint}}{% if not loop.last %},{% endif %} - {% endfor %} -) -""" - INSERT_TABLE_REGISTRY_ITEM = """ INSERT INTO layer0.tables (bib, table_name, datatype) VALUES (%s, %s, %s) diff --git a/app/lib/storage/postgres/postgres_storage.py b/app/lib/storage/postgres/postgres_storage.py index 582ea111..feb665e6 100644 --- a/app/lib/storage/postgres/postgres_storage.py +++ b/app/lib/storage/postgres/postgres_storage.py @@ -3,7 +3,7 @@ import numpy as np import psycopg import structlog -from psycopg import rows +from psycopg import rows, sql from psycopg.types import enum, numeric from app.lib.storage import enums @@ -84,24 +84,29 @@ def disconnect(self) -> None: self._connection.close() - def exec(self, query: str, *, params: list[Any] | None = None) -> None: + def _query_str(self, query: str | sql.SQL | sql.Composed) -> str: + if isinstance(query, str): + return query + return query.as_string(self._connection) + + def exec(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None = None) -> None: if params is None: params = [] if self._connection is None: raise RuntimeError("Unable to execute query: connection to Postgres was not established") - log.debug("SQL query", query=query.replace("\n", " "), args=params) + log.debug("SQL query", query=self._query_str(query).replace("\n", " "), args=params) cursor = self._connection.cursor() cursor.execute(query, params) - def query(self, query: str, *, params: list[Any] | None = None) -> list[rows.DictRow]: + def query(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None = None) -> list[rows.DictRow]: if params is None: params = [] if self._connection is None: raise RuntimeError("Unable to execute query: connection to Postgres was not established") - log.debug("SQL query", query=query.replace("\n", " "), args=params) + log.debug("SQL query", query=self._query_str(query).replace("\n", " "), args=params) cursor = self._connection.cursor() cursor.execute(query, params) @@ -111,7 +116,7 @@ def query(self, query: str, *, params: list[Any] | None = None) -> list[rows.Dic return result - def query_one(self, query: str, *, params: list[Any] | None = None) -> rows.DictRow: + def query_one(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None = None) -> rows.DictRow: result = self.query(query, params=params) if len(result) != 1: diff --git a/tests/unit/data/layer0_repository_test.py b/tests/unit/data/layer0_repository_test.py index dea2eb55..df320c0f 100644 --- a/tests/unit/data/layer0_repository_test.py +++ b/tests/unit/data/layer0_repository_test.py @@ -5,11 +5,18 @@ import structlog from parameterized import param, parameterized +from psycopg import sql from app.data.repositories import Layer0Repository from tests import lib +def normalize_query(s: str | sql.Composable) -> str: + if not isinstance(s, str): + s = s.as_string(None) + return " ".join(s.replace("\n", " ").replace(", ", ",").lower().split()) + + class Layer0RepositoryTest(unittest.TestCase): def setUp(self) -> None: self.storage_mock = mock.MagicMock() @@ -17,18 +24,18 @@ def setUp(self) -> None: @parameterized.expand( [ - param("no kwargs", {}, 'SELECT * FROM rawdata."ironman"'), - param("with columns", {"columns": ["one", "two"]}, 'SELECT one, two FROM rawdata."ironman"'), + param("no kwargs", {}, 'SELECT * FROM "rawdata"."ironman"'), + param("with columns", {"columns": ["one", "two"]}, 'SELECT "one","two" FROM "rawdata"."ironman"'), param( "with order by", {"order_column": "one", "order_direction": "desc"}, - 'SELECT * FROM rawdata."ironman" ORDER BY one DESC', + 'SELECT * FROM "rawdata"."ironman" ORDER BY "one" DESC', ), - param("with limit", {"limit": 10}, 'SELECT * FROM rawdata."ironman" LIMIT %s'), + param("with limit", {"limit": 10}, 'SELECT * FROM "rawdata"."ironman" LIMIT %s'), param( "with offset", {"offset": uuid.uuid4()}, - 'SELECT * FROM rawdata."ironman" WHERE hyperleda_internal_id > %s', + 'SELECT * FROM "rawdata"."ironman" WHERE "hyperleda_internal_id" > %s', ), param( "with all", @@ -39,7 +46,8 @@ def setUp(self) -> None: "offset": uuid.uuid4(), "limit": 10, }, - 'SELECT one, two FROM rawdata."ironman" WHERE hyperleda_internal_id > %s ORDER BY one DESC LIMIT %s', + 'SELECT "one","two" FROM "rawdata"."ironman" WHERE "hyperleda_internal_id" > %s' + ' ORDER BY "one" DESC LIMIT %s', ), ] ) @@ -49,11 +57,8 @@ def test_fetch_raw_data(self, name: str, kwargs: dict, expected_query: str): _ = self.repo.fetch_raw_data("ironman", **kwargs) args, _ = self.storage_mock.query.call_args - def transform(s): - return " ".join(s.replace("\n", " ").replace(", ", ",").lower().split()) - - actual = transform(args[0]) - expected = transform(expected_query) + actual = normalize_query(args[0]) + expected = normalize_query(expected_query) self.assertEqual(actual, expected)