Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
STORAGE_ENDPOINT=localhost
STORAGE_PORT=6432
STORAGE_DBNAME=hyperleda
STORAGE_USER=hyperleda
STORAGE_PASSWORD=password
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ celerybeat.pid
# Environments
.env.*
.env
!.env.example
.venv
env/
venv/
Expand Down
14 changes: 1 addition & 13 deletions app/commands/runtask/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, final

import structlog
import yaml

from app import tasks
from app.lib import commands
Expand All @@ -14,7 +13,6 @@ class RunTaskCommand(commands.Command):
def __init__(
self,
task_name: str,
config_path: str,
input_data_path: str | None = None,
input_data: dict[str, Any] | None = None,
task_args: tuple[str, ...] | None = None,
Expand All @@ -26,7 +24,6 @@ def __init__(
task_args = ()

self.task_name = task_name
self.config_path = config_path
self.input_data_path = input_data_path
self.input_data = input_data
self.task_args = task_args
Expand All @@ -40,7 +37,7 @@ def help(cls) -> str:
"""

def prepare(self):
cfg = parse_config(self.config_path)
cfg = tasks.Config()

input_data = self.input_data

Expand Down Expand Up @@ -86,12 +83,3 @@ def _parse_task_args(self, task_args: tuple[str, ...]) -> dict[str, Any]:
i += 1

return args_dict


def parse_config(path: str) -> tasks.Config:
p = Path(path)
if not p.is_file():
raise FileNotFoundError(f"Config file not found: '{path}'")

data = yaml.safe_load(p.read_text())
return tasks.Config(**data)
1 change: 1 addition & 0 deletions app/data/model/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Layer0TableListItem:
description: str
num_entries: int
num_fields: int
modification_dt: datetime.datetime


@dataclass
Expand Down
8 changes: 7 additions & 1 deletion app/data/repositories/layer0/records.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from collections.abc import Sequence

from psycopg import sql

from app.data import model, template
from app.data.repositories.layer0.common import RAWDATA_SCHEMA
from app.lib import concurrency
Expand Down Expand Up @@ -133,7 +135,11 @@ def get_table_statistics(self, table_name: str) -> model.TableStatistics:
params=[table_id],
)
total_original_rows_res = errgr.run(
self._storage.query_one, f'SELECT COUNT(1) AS cnt FROM {RAWDATA_SCHEMA}."{table_name}"'
self._storage.query_one,
sql.SQL("SELECT COUNT(1) AS cnt FROM {}.{}").format(
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
),
)
errgr.wait()

Expand Down
95 changes: 67 additions & 28 deletions app/data/repositories/layer0/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import structlog
from astropy import table
from astropy import units as u
from psycopg import sql

from app.data import model, repositories, template
from app.data.repositories.layer0.common import INTERNAL_ID_COLUMN_NAME, RAWDATA_SCHEMA
Expand Down Expand Up @@ -55,8 +56,7 @@ def create_table(self, data: model.Layer0TableMeta) -> model.Layer0CreationRespo
table_id = int(row.get("id"))

self._storage.exec(
template.render_query(
template.CREATE_TABLE,
template.build_create_table_query(
schema=RAWDATA_SCHEMA,
name=data.table_name,
fields=fields,
Expand Down Expand Up @@ -85,7 +85,7 @@ def insert_raw_data(self, data: model.Layer0RawData) -> None:
log.warn("trying to insert 0 rows into the table", table_name=data.table_name)
return

fields = data.data.columns
fields = list(data.data.columns)

values = []
params = []
Expand All @@ -98,15 +98,16 @@ def insert_raw_data(self, data: model.Layer0RawData) -> None:

params.append(value)

values.append(f"({','.join(['%s'] * len(fields))})")
values.append(sql.SQL("({})").format(sql.SQL(",").join([sql.Placeholder()] * len(fields))))

fields = [f'"{field}"' for field in fields]
field_identifiers = sql.SQL(",").join([sql.Identifier(f) for f in fields])

query = f"""
INSERT INTO rawdata."{data.table_name}" ({",".join(fields)})
VALUES {",".join(values)}
ON CONFLICT DO NOTHING
"""
query = sql.SQL("INSERT INTO {}.{} ({}) VALUES {} ON CONFLICT DO NOTHING").format(
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(data.table_name),
field_identifiers,
sql.SQL(",").join(values),
)

self._storage.exec(query, params=params)

Expand All @@ -130,25 +131,39 @@ def fetch_table(

meta = self.fetch_metadata_by_name(table_name)

columns_str = ",".join(columns or ["*"])
if columns:
columns_sql = sql.SQL(",").join([sql.Identifier(c) for c in columns])
else:
columns_sql = sql.SQL("*")

params = []
query = f"""
SELECT {columns_str} FROM {RAWDATA_SCHEMA}."{table_name}"\n
"""
parts: list[sql.Composable] = [
sql.SQL("SELECT {} FROM {}.{}").format(
columns_sql,
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
)
]

if offset is not None:
query += f"WHERE {repositories.INTERNAL_ID_COLUMN_NAME} > %s\n"
parts.append(
sql.SQL(" WHERE {} > %s").format(
sql.Identifier(repositories.INTERNAL_ID_COLUMN_NAME),
)
)
params.append(offset)

if order_column is not None:
query += f"ORDER BY {order_column} {order_direction}\n"
if order_direction not in ("asc", "desc"):
raise ValueError(f"invalid order direction: {order_direction}")
parts.append(sql.SQL(" ORDER BY {} ").format(sql.Identifier(order_column)))
parts.append(sql.SQL(order_direction))

if limit is not None:
query += "LIMIT %s\n"
parts.append(sql.SQL(" LIMIT %s"))
params.append(limit)

rows = self._storage.query(query, params=params)
rows = self._storage.query(sql.Composed(parts), params=params)
df = pandas.DataFrame(rows)
tbl = table.Table()
if len(df) == 0:
Expand Down Expand Up @@ -198,31 +213,53 @@ def fetch_raw_data(
if table_name is None:
raise ValueError("either table_name or record_id must be provided")

columns_str = ",".join(columns or ["*"])
if columns:
columns_sql = sql.SQL(",").join([sql.Identifier(c) for c in columns])
else:
columns_sql = sql.SQL("*")

params = []
where_stmnt = []
where_parts: list[sql.Composable] = []

if offset is not None:
where_stmnt.append(f"{repositories.INTERNAL_ID_COLUMN_NAME} > %s")
where_parts.append(
sql.SQL("{} > %s").format(
sql.Identifier(repositories.INTERNAL_ID_COLUMN_NAME),
)
)
params.append(offset)

if record_id is not None:
where_stmnt.append(f"{INTERNAL_ID_COLUMN_NAME} = %s")
where_parts.append(
sql.SQL("{} = %s").format(
sql.Identifier(INTERNAL_ID_COLUMN_NAME),
)
)
params.append(record_id)

where_clause = f"WHERE {' AND '.join(where_stmnt)}" if where_stmnt else ""
parts: list[sql.Composable] = [
sql.SQL("SELECT {} FROM {}.{}").format(
columns_sql,
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
)
]

query = f'SELECT {columns_str} FROM {RAWDATA_SCHEMA}."{table_name}" {where_clause}\n'
if where_parts:
parts.append(sql.SQL(" WHERE "))
parts.append(sql.SQL(" AND ").join(where_parts))

if order_column is not None:
query += f"ORDER BY {order_column} {order_direction}\n"
if order_direction not in ("asc", "desc"):
raise ValueError(f"invalid order direction: {order_direction}")
parts.append(sql.SQL(" ORDER BY {} ").format(sql.Identifier(order_column)))
parts.append(sql.SQL(order_direction))

if limit is not None:
query += "LIMIT %s\n"
parts.append(sql.SQL(" LIMIT %s"))
params.append(limit)

rows = self._storage.query(query, params=params)
rows = self._storage.query(sql.Composed(parts), params=params)
return model.Layer0RawData(table_name, pandas.DataFrame(rows))

def _resolve_table_name(self, record_id: str) -> str | None:
Expand Down Expand Up @@ -326,6 +363,7 @@ def search_tables(
sql = """
SELECT
t.table_name,
t.modification_dt,
COALESCE(ti.param->>'description', '') AS description,
COALESCE(ps.n_live_tup::bigint, 0)::int AS num_entries,
(
Expand All @@ -341,7 +379,7 @@ def search_tables(
LEFT JOIN pg_stat_user_tables ps
ON ps.schemaname = %s AND ps.relname = t.table_name
WHERE t.table_name ILIKE %s OR COALESCE(ti.param->>'description', '') ILIKE %s
ORDER BY t.table_name
ORDER BY t.modification_dt DESC
LIMIT %s OFFSET %s
"""
params = [
Expand All @@ -361,6 +399,7 @@ def search_tables(
description=row["description"] or "",
num_entries=int(row["num_entries"]),
num_fields=int(row["num_fields"]),
modification_dt=row["modification_dt"],
)
for row in rows
]
Expand Down
25 changes: 13 additions & 12 deletions app/data/template.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import jinja2
from psycopg import sql


def render_query(query_string: str, **kwargs) -> str:
tpl = jinja2.Environment(loader=jinja2.BaseLoader()).from_string(query_string)
def build_create_table_query(schema: str, name: str, fields: list[tuple[str, str, str]]) -> sql.Composed:
field_parts = []
for field_name, field_type, constraint in fields:
parts = [sql.Identifier(field_name), sql.SQL(" "), sql.SQL(field_type)]
if constraint:
parts.extend([sql.SQL(" "), sql.SQL(constraint)])
field_parts.append(sql.Composed(parts))

return tpl.render(**kwargs)
return sql.SQL("CREATE TABLE {}.{} ({})").format(
sql.Identifier(schema),
sql.Identifier(name),
sql.SQL(", ").join(field_parts),
)


GET_SOURCE_BY_CODE = """
Expand All @@ -19,14 +28,6 @@ def render_query(query_string: str, **kwargs) -> str:
LIMIT 1
"""

CREATE_TABLE = """
CREATE TABLE {{ schema }}."{{ name }}" (
{% for field_name, field_type, constraint in fields %}
"{{field_name}}" {{field_type}} {{constraint}}{% if not loop.last %},{% endif %}
{% endfor %}
)
"""

INSERT_TABLE_REGISTRY_ITEM = """
INSERT INTO layer0.tables (bib, table_name, datatype)
VALUES (%s, %s, %s)
Expand Down
3 changes: 2 additions & 1 deletion app/domain/adminapi/table_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def patch_table(self, r: adminapi.PatchTableRequest) -> adminapi.PatchTableRespo
self.layer0_repo.update_column_metadata(r.table_name, column_metadata)

if spec.modifiers is not None:
modifiers = [model.Modifier(column_name, m.modifier_name, m.params) for m in spec.modifiers]
modifiers = [model.Modifier(column_name, m.name, m.params) for m in spec.modifiers]
self.layer0_repo.set_modifiers(r.table_name, column_name, modifiers)

return adminapi.PatchTableResponse()
Expand Down Expand Up @@ -169,6 +169,7 @@ def get_table_list(self, r: adminapi.GetTableListRequest) -> adminapi.GetTableLi
description=item.description,
num_entries=item.num_entries,
num_fields=item.num_fields,
modification_dt=item.modification_dt,
)
for item in items
]
Expand Down
10 changes: 5 additions & 5 deletions app/lib/storage/postgres/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
class PgStorageConfig(config.ConfigSettings):
model_config = settings.SettingsConfigDict(env_prefix="STORAGE_")

endpoint: str
port: int
dbname: str
user: str
password: str
endpoint: str = "localhost"
port: int = 6432
dbname: str = "hyperleda"
user: str = "hyperleda"
password: str = "password"

def get_dsn(self) -> str:
# TODO: SSL and other options like transaction timeout
Expand Down
Loading
Loading