From 8609c761d20dba74e5354485d833cf0acbc312ee Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 10:15:40 +0100 Subject: [PATCH 01/18] Add pydantic dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 31b1501..0956753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "email-validator>=1.1", "imas-python", "numpy>=1.14", + "pydantic>=2.10.6", "python-dateutil>=2.6", "pyuda>=2.9.2", "pyyaml>=3.13", From 6f9181f65ee64a2271066e69aa5e67fb82b3db2c Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 11:35:35 +0100 Subject: [PATCH 02/18] Add tests for simulations post --- tests/remote/test_api.py | 82 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index 8fc9dfe..ca9a172 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -1,7 +1,10 @@ import base64 import importlib import os +import shutil import tempfile +import uuid +from datetime import datetime, timezone from pathlib import Path import pytest @@ -28,10 +31,13 @@ def client(): config = Config() config.load() db_fd, db_file = tempfile.mkstemp() + upload_dir = tempfile.mkdtemp() config.set_option("database.type", "sqlite") config.set_option("database.file", db_file) config.set_option("server.admin_password", TEST_PASSWORD) + config.set_option("server.upload_folder", upload_dir) config.set_option("authentication.type", "None") + config.set_option("server.copy_files", False) app = create_app(config=config, testing=True, debug=True) app.testing = True @@ -47,6 +53,7 @@ def client(): os.close(db_fd) Path(app.simdb_config.get_option("database.file")).unlink() + shutil.rmtree(upload_dir) @pytest.mark.skipif(not has_flask, reason="requires flask library") @@ -72,3 +79,78 @@ def test_get_simulations(client): assert rv.json["count"] == 100 assert len(rv.json["results"]) == len(SIMULATIONS) assert rv.status_code == 200 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations(client): + """Test POST endpoint for creating a new simulation.""" + # Create a new simulation data structure + sim_uuid = uuid.uuid4() + sim_uuid_hex = sim_uuid.hex + input_uuid = uuid.uuid4() + output_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_hex}, + "alias": "test-simulation", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [ + { + "uuid": {"_type": "uuid.UUID", "hex": input_uuid.hex}, + "type": "FILE", + "uri": "file:///path/to/input/data.txt", + "checksum": "abc123def456", + "datetime": datetime.now(timezone.utc).isoformat(), + "usage": "input_data", + "purpose": "test input file", + "sensitivity": "public", + "access": "open", + "embargo": None, + } + ], + "outputs": [ + { + "uuid": {"_type": "uuid.UUID", "hex": output_uuid.hex}, + "type": "FILE", + "uri": "file:///path/to/output/results.txt", + "checksum": "xyz789abc012", + "datetime": datetime.now(timezone.utc).isoformat(), + "usage": "output_data", + "purpose": "test output file", + "sensitivity": "public", + "access": "open", + "embargo": None, + } + ], + "metadata": [ + {"element": "machine", "value": "test-machine"}, + {"element": "code", "value": "test-code"}, + {"element": "description", "value": "Test simulation"}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + # POST the simulation + rv = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + + # Verify the response + if rv.status_code != 200: + print(f"Response status: {rv.status_code}") + print(f"Response data: {rv.data}") + print(f"Response json: {rv.json if rv.is_json else 'Not JSON'}") + + assert "ingested" in rv.json + assert rv.json["ingested"] == sim_uuid_hex + + # Verify the simulation was created by fetching it + rv_get = client.get(f"/v1.2/simulation/{sim_uuid_hex}", headers=HEADERS) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == "test-simulation" From 217ef6e2fcad927e462f6578b6dae46e235e84f1 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 11:37:04 +0100 Subject: [PATCH 03/18] Use pydantic for post data validation --- src/simdb/database/models/file.py | 18 +++++++++++ src/simdb/database/models/metadata.py | 6 ++++ src/simdb/database/models/simulation.py | 13 ++++++++ src/simdb/remote/apis/v1_2/simulations.py | 30 ++++++++--------- src/simdb/remote/models.py | 39 +++++++++++++++++++++++ 5 files changed, 90 insertions(+), 16 deletions(-) create mode 100644 src/simdb/remote/models.py diff --git a/src/simdb/database/models/file.py b/src/simdb/database/models/file.py index a05aaa5..a6894a6 100644 --- a/src/simdb/database/models/file.py +++ b/src/simdb/database/models/file.py @@ -14,6 +14,7 @@ from simdb.docstrings import inherit_docstrings from simdb.imas.checksum import checksum as imas_checksum from simdb.imas.utils import imas_timestamp +from simdb.remote.models import FileData from simdb.uda.checksum import checksum as uda_checksum from .base import Base @@ -125,6 +126,23 @@ def from_data(cls, data: Dict) -> "File": file.datetime = date_parser.parse(checked_get(data, "datetime", str)) return file + @classmethod + def from_data_model(cls, data: FileData) -> "File": + data_type = data.type + uri = data.uri + file = File( + DataObject.Type[data_type], urilib.URI(uri), perform_integrity_check=False + ) + file.uuid = data.uuid + file.usage = data.usage + file.checksum = data.checksum + file.purpose = data.purpose + file.sensitivity = data.sensitivity + file.access = data.access + file.embargo = data.embargo + file.datetime = data.datetime + return file + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "uuid": self.uuid, diff --git a/src/simdb/database/models/metadata.py b/src/simdb/database/models/metadata.py index 628f158..7b975ba 100644 --- a/src/simdb/database/models/metadata.py +++ b/src/simdb/database/models/metadata.py @@ -4,6 +4,7 @@ from sqlalchemy import types as sql_types from simdb.docstrings import inherit_docstrings +from simdb.remote.models import MetadataData from .base import Base @@ -32,6 +33,11 @@ def from_data(cls, data: Dict) -> "MetaData": meta = MetaData(data["element"], data["value"]) return meta + @classmethod + def from_data_model(cls, data: MetadataData) -> "MetaData": + meta = MetaData(data.element, data.value) + return meta + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "element": self.element, diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 1ee21ab..cf92c87 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union +from simdb.remote.models import SimulationData + if sys.version_info < (3, 11): from backports.datetime_fromisoformat import MonkeyPatch @@ -336,6 +338,17 @@ def from_data(cls, data: Dict[str, Union[str, Dict, List]]) -> "Simulation": simulation.meta.append(MetaData.from_data(el)) return simulation + @classmethod + def from_data_model(cls, data: SimulationData) -> "Simulation": + simulation = Simulation(None) + simulation.uuid = data.uuid + simulation.alias = data.alias + simulation.datetime = data.datetime + simulation.inputs = [File.from_data_model(el) for el in data.inputs] + simulation.outputs = [File.from_data_model(el) for el in data.outputs] + simulation.meta = [MetaData.from_data_model(el) for el in data.metadata] + return simulation + def data( self, recurse: bool = False, meta_keys: Optional[List[str]] = None ) -> Dict[str, Union[str, List]]: diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 01351e4..d496e6d 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, cast +import pydantic from flask import json as flask_json # fallback from flask import jsonify, request, send_file from flask_restx import Namespace, Resource @@ -24,6 +25,7 @@ from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path from simdb.remote.core.typing import current_app +from simdb.remote.models import SimulationPostData from simdb.uri import URI from simdb.validation import ValidationError, Validator from simdb.validation.file import find_file_validator @@ -155,7 +157,9 @@ def _build_trace(sim_id: str) -> Dict[str, Any]: return data -def _get_json_aware(force: bool = False, silent: bool = False): +def _get_json_aware( + force: bool = False, silent: bool = False +) -> Optional[dict[str, Any]]: """ Parse JSON like Flask's request.get_json, but handle Content-Encoding: gzip. - force/silent mimic request.get_json behavior. @@ -286,37 +290,31 @@ def post(self, user: User): # It returns None if the content type is not application/json. # If silent=True, it returns None instead of raising an error. # If force=True, it ignores the content type check. - data = _get_json_aware() - if not data: - return error("Invalid or missing JSON data") + d = SimulationPostData.model_validate(_get_json_aware()) - if "simulation" not in data: - return error("Simulation data not provided") - - add_watcher = data.get("add_watcher", True) - - simulation = models_sim.Simulation.from_data(data["simulation"]) + simulation = models_sim.Simulation.from_data_model(d.simulation) # Simulation Upload (Push) Date simulation.datetime = datetime.datetime.now() - if data["uploaded_by"] is not None: - simulation.set_meta("uploaded_by", data["uploaded_by"]) + if d.uploaded_by is not None: + simulation.set_meta("uploaded_by", d.uploaded_by) elif user.email is not None: simulation.set_meta("uploaded_by", user.email) elif user.name is not None: simulation.set_meta("uploaded_by", user.name) else: simulation.set_meta("uploaded_by", "anonymous") - if add_watcher: + + if d.add_watcher: simulation.watchers.append( models_watcher.Watcher( user.name, user.email, models_watcher.Notification.ALL ) ) - if "alias" in data["simulation"]: - alias = data["simulation"]["alias"] + if d.simulation.alias is not None: + alias = d.simulation.alias if alias is not None: (updated_alias, next_id) = _set_alias(alias) if updated_alias: @@ -430,7 +428,7 @@ def post(self, user: User): create_alias_dir(simulation) return jsonify(result) - except (DatabaseError, ValueError) as err: + except (DatabaseError, ValueError, pydantic.ValidationError) as err: return error(str(err)) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py new file mode 100644 index 0000000..6df923e --- /dev/null +++ b/src/simdb/remote/models.py @@ -0,0 +1,39 @@ +from datetime import datetime as dt +from datetime import timezone +from typing import Any, List, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +class FileData(BaseModel): + type: str + uri: str + uuid: UUID + checksum: str + datetime: dt + usage: Optional[str] + purpose: Optional[str] + sensitivity: Optional[str] + access: Optional[str] + embargo: Optional[str] + + +class MetadataData(BaseModel): + element: str + value: Any + + +class SimulationData(BaseModel): + uuid: UUID + alias: Optional[str] + datetime: dt = Field(default_factory=lambda: dt.now(timezone.utc)) + inputs: List[FileData] + outputs: List[FileData] + metadata: List[MetadataData] + + +class SimulationPostData(BaseModel): + simulation: SimulationData + add_watcher: bool + uploaded_by: Optional[str] From cf4dae0db5a41443a1052917c63d85441b7b0e92 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 13:37:35 +0100 Subject: [PATCH 04/18] Use response model --- src/simdb/remote/apis/v1_2/simulations.py | 55 +++++++++++------------ src/simdb/remote/models.py | 17 ++++++- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index d496e6d..15197c2 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -25,7 +25,11 @@ from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path from simdb.remote.core.typing import current_app -from simdb.remote.models import SimulationPostData +from simdb.remote.models import ( + SimulationPostData, + SimulationPostResponse, + ValidationResult, +) from simdb.uri import URI from simdb.validation import ValidationError, Validator from simdb.validation.file import find_file_validator @@ -57,7 +61,7 @@ def _update_simulation_status( server.send_message(f"Simulation {simulation.alias}", msg, to_addresses) -def _validate(simulation, user) -> Dict: +def _validate(simulation, user) -> ValidationResult: schemas = Validator.validation_schemas(current_app.simdb_config, simulation) try: for schema in schemas: @@ -67,10 +71,7 @@ def _validate(simulation, user) -> Dict: ) except ValidationError as err: _update_simulation_status(simulation, models_sim.Simulation.Status.FAILED, user) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) file_validator_type = current_app.simdb_config.get_string_option( "file_validation.type", default=None @@ -90,16 +91,11 @@ def _validate(simulation, user) -> Dict: _update_simulation_status( simulation, models_sim.Simulation.Status.FAILED, user ) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) else: error("Invalid file validator specified in configuration") - return { - "passed": True, - } + return ValidationResult(passed=True, error=None) def _set_alias(alias: str): @@ -375,29 +371,32 @@ def post(self, user: User): path = Path(sim_file.uri.query["path"]) sim_file.uri = convert_uri(sim_file.uri, path, config) - result = { - "ingested": simulation.uuid.hex, - } + result = SimulationPostResponse( + ingested=simulation.uuid, error=None, validation=None + ) + + error_on_fail = current_app.simdb_config.get_option( + "validation.error_on_fail", default=False + ) if current_app.simdb_config.get_option( "validation.auto_validate", default=False ): - result["validation"] = _validate(simulation, user) + result.validation = _validate(simulation, user) - if current_app.simdb_config.get_option( - "validation.error_on_fail", default=False - ): - if simulation.status == models_sim.Simulation.Status.NOT_VALIDATED: - raise Exception( - "Validation config option error_on_fail=True without " - "auto_validate=True." + if not result.validation.passed and error_on_fail: + result.error = ( + f"Simulation validation failed and server has " + f"error_on_fail=True.\n{result.validation.error}" ) - elif simulation.status == models_sim.Simulation.Status.FAILED: - result["error"] = f"""Simulation validation failed and server has - error_on_fail=True.\n{result["validation"]["error"]}""" response = jsonify(result) response.status_code = 400 return response + elif error_on_fail: + raise RuntimeError( + "Validation config option error_on_fail=True without " + "auto_validate=True." + ) replaces = simulation.find_meta("replaces") if ( @@ -427,7 +426,7 @@ def post(self, user: User): with contextlib.suppress(OSError): create_alias_dir(simulation) - return jsonify(result) + return jsonify(result.model_dump(mode="json")) except (DatabaseError, ValueError, pydantic.ValidationError) as err: return error(str(err)) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 6df923e..777b6fd 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,9 +1,11 @@ from datetime import datetime as dt from datetime import timezone -from typing import Any, List, Optional +from typing import Annotated, Any, List, Optional from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PlainSerializer + +HexUUID = Annotated[UUID, PlainSerializer(lambda x: x.hex, return_type=str)] class FileData(BaseModel): @@ -37,3 +39,14 @@ class SimulationPostData(BaseModel): simulation: SimulationData add_watcher: bool uploaded_by: Optional[str] + + +class ValidationResult(BaseModel): + passed: bool + error: Optional[str] + + +class SimulationPostResponse(BaseModel): + ingested: HexUUID + error: Optional[str] + validation: Optional[ValidationResult] From 54aeeafe00b2bfdf1e4f0857facd9a033c2e98f3 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 13:37:43 +0100 Subject: [PATCH 05/18] Small cleanup --- src/simdb/remote/apis/v1_2/simulations.py | 70 ++++++++--------------- 1 file changed, 23 insertions(+), 47 deletions(-) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 15197c2..6636b84 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -293,14 +293,9 @@ def post(self, user: User): # Simulation Upload (Push) Date simulation.datetime = datetime.datetime.now() - if d.uploaded_by is not None: - simulation.set_meta("uploaded_by", d.uploaded_by) - elif user.email is not None: - simulation.set_meta("uploaded_by", user.email) - elif user.name is not None: - simulation.set_meta("uploaded_by", user.name) - else: - simulation.set_meta("uploaded_by", "anonymous") + uploaded_by = d.uploaded_by or user.email or user.name or "anonymous" + + simulation.set_meta("uploaded_by", uploaded_by) if d.add_watcher: simulation.watchers.append( @@ -310,16 +305,12 @@ def post(self, user: User): ) if d.simulation.alias is not None: - alias = d.simulation.alias - if alias is not None: - (updated_alias, next_id) = _set_alias(alias) - if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) - simulation.alias = updated_alias - else: - simulation.alias = alias + (updated_alias, next_id) = _set_alias(d.simulation.alias) + if updated_alias: + simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.alias = updated_alias else: - simulation.alias = simulation.uuid.hex + simulation.alias = d.simulation.alias else: simulation.alias = simulation.uuid.hex @@ -328,15 +319,19 @@ def post(self, user: User): common_root = find_common_root(sim_file_paths) config = current_app.simdb_config + copy_files = config.get_option("server.copy_files", default=True) + imas_remote_host = config.get_option( + "server.imas_remote_host", default=None + ) - if config.get_option("server.copy_files", default=True): + if copy_files or imas_remote_host: staging_dir = ( Path(config.get_string_option("server.upload_folder")) / simulation.uuid.hex ) for sim_file in files: - if sim_file.uri.scheme == "file": + if copy_files and sim_file.uri.scheme == "file": path = secure_path(sim_file.uri.path, common_root, staging_dir) if not path.exists(): raise ValueError( @@ -344,32 +339,16 @@ def post(self, user: User): ) sim_file.uri = URI(scheme="file", path=path) elif sim_file.uri.scheme == "imas": - path = secure_path( - Path(sim_file.uri.query["path"]), - common_root, - staging_dir, - is_file=common_root is not None, - ) - sim_file.uri = convert_uri(sim_file.uri, path, config) - elif config.get_option("server.imas_remote_host", default=None): - staging_dir = ( - Path(config.get_string_option("server.upload_folder")) - / simulation.uuid.hex - ) - - for sim_file in files: - if sim_file.uri.scheme == "imas": - if config.get_option("server.copy_files", default=True): + if copy_files: path = secure_path( Path(sim_file.uri.query["path"]), common_root, staging_dir, is_file=common_root is not None, ) - sim_file.uri = convert_uri(sim_file.uri, path, config) else: path = Path(sim_file.uri.query["path"]) - sim_file.uri = convert_uri(sim_file.uri, path, config) + sim_file.uri = convert_uri(sim_file.uri, path, config) result = SimulationPostResponse( ingested=simulation.uuid, error=None, validation=None @@ -398,22 +377,19 @@ def post(self, user: User): "auto_validate=True." ) + disable_replaces = config.get_option( + "development.disable_replaces", default=False + ) replaces = simulation.find_meta("replaces") - if ( - not current_app.simdb_config.get_option( - "development.disable_replaces", default=False - ) - and replaces - and replaces[0].value - ): + + if not disable_replaces and replaces and replaces[0].value: sim_id = replaces[0].value try: replaces_sim = current_app.db.get_simulation(sim_id) except DatabaseError: replaces_sim = None - if replaces_sim is None: - pass - else: + + if replaces_sim is not None: _update_simulation_status( replaces_sim, models_sim.Simulation.Status.DEPRECATED, user ) From e642078f9e01e1cdecf8bcc59a5a42dc9c96bbae Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 13:52:23 +0100 Subject: [PATCH 06/18] Add edge case tests to post simulations --- tests/remote/test_api.py | 423 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 423 insertions(+) diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index ca9a172..7dcbe4d 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -154,3 +154,426 @@ def test_post_simulations(client): rv_get = client.get(f"/v1.2/simulation/{sim_uuid_hex}", headers=HEADERS) assert rv_get.status_code == 200 assert rv_get.json["alias"] == "test-simulation" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_with_alias_dash(client): + """Test POST endpoint with alias ending in dash (auto-increment).""" + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "dashtest-", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "test", "value": "dash"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + + assert rv.status_code == 200 + assert "ingested" in rv.json + + rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == "dashtest-1" + + # Check seqid metadata was added + metadata = rv_get.json["metadata"] + seqid_meta = [m for m in metadata if m["element"] == "seqid"] + assert len(seqid_meta) == 1 + assert seqid_meta[0]["value"] == 1 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_with_alias_hash(client): + """Test POST endpoint with alias ending in hash (auto-increment).""" + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "hashtest#", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "test", "value": "hash"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + + assert rv.status_code == 200 + assert "ingested" in rv.json + + rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == "hashtest#1" + + # Check seqid metadata was added + metadata = rv_get.json["metadata"] + seqid_meta = [m for m in metadata if m["element"] == "seqid"] + assert len(seqid_meta) == 1 + assert seqid_meta[0]["value"] == 1 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_alias_increment_sequence(client): + """Test multiple simulations with incrementing dash alias.""" + # Create first simulation with dash alias + sim_uuid_1 = uuid.uuid4() + simulation_data_1 = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_1.hex}, + "alias": "sequence-", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv1 = client.post( + "/v1.2/simulations", + json=simulation_data_1, + headers=HEADERS, + content_type="application/json", + ) + assert rv1.status_code == 200 + + # Create second simulation with same dash alias + sim_uuid_2 = uuid.uuid4() + simulation_data_2 = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_2.hex}, + "alias": "sequence-", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv2 = client.post( + "/v1.2/simulations", + json=simulation_data_2, + headers=HEADERS, + content_type="application/json", + ) + assert rv2.status_code == 200 + + # Verify aliases were incremented + rv_get1 = client.get(f"/v1.2/simulation/{sim_uuid_1.hex}", headers=HEADERS) + assert rv_get1.json["alias"] == "sequence-1" + + rv_get2 = client.get(f"/v1.2/simulation/{sim_uuid_2.hex}", headers=HEADERS) + assert rv_get2.json["alias"] == "sequence-2" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_no_alias(client): + """Test POST endpoint with no alias provided (should use uuid.hex).""" + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": None, + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + + assert rv.status_code == 200 + rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == sim_uuid.hex + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_with_replaces(client): + """Test POST endpoint with replaces metadata (deprecates old simulation).""" + # Create initial simulation + old_sim_uuid = uuid.uuid4() + old_simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": old_sim_uuid.hex}, + "alias": "original-simulation", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "version", "value": "1.0"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_old = client.post( + "/v1.2/simulations", + json=old_simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_old.status_code == 200 + + # Create new simulation that replaces the old one + new_sim_uuid = uuid.uuid4() + new_simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": new_sim_uuid.hex}, + "alias": "updated-simulation", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "version", "value": "2.0"}, + {"element": "replaces", "value": old_sim_uuid.hex}, + {"element": "replaces_reason", "value": "Bug fixes and improvements"}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_new = client.post( + "/v1.2/simulations", + json=new_simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_new.status_code == 200 + + # Verify the old simulation is marked as DEPRECATED + rv_old_get = client.get(f"/v1.2/simulation/{old_sim_uuid.hex}", headers=HEADERS) + assert rv_old_get.status_code == 200 + old_metadata = rv_old_get.json["metadata"] + + status_meta = [m for m in old_metadata if m["element"] == "status"] + assert len(status_meta) == 1 + assert status_meta[0]["value"].lower() == "deprecated" + + # Check replaced_by metadata was added + replaced_by_meta = [m for m in old_metadata if m["element"] == "replaced_by"] + assert len(replaced_by_meta) == 1 + assert replaced_by_meta[0]["value"] == new_sim_uuid + + # Verify the new simulation has replaces metadata + rv_new_get = client.get(f"/v1.2/simulation/{new_sim_uuid.hex}", headers=HEADERS) + assert rv_new_get.status_code == 200 + new_metadata = rv_new_get.json["metadata"] + + replaces_meta = [m for m in new_metadata if m["element"] == "replaces"] + assert len(replaces_meta) == 1 + assert replaces_meta[0]["value"] == old_sim_uuid.hex + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_replaces_nonexistent(client): + """Test POST endpoint with replaces pointing to non-existent simulation.""" + # Create simulation that tries to replace a non-existent simulation + sim_uuid = uuid.uuid4() + fake_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "replaces-nothing", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "replaces", "value": fake_uuid.hex}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + # Should still succeed (old simulation just doesn't exist to deprecate) + rv = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv.status_code == 200 + + # Verify the new simulation was created + rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == "replaces-nothing" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_with_watcher(client): + """Test POST endpoint with add_watcher set to true.""" + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "watched-simulation", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [], + }, + "add_watcher": True, + "uploaded_by": "watcher-user", + } + + rv = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv.status_code == 200 + + # Verify the simulation was created + rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + assert rv_get.status_code == 200 + + # Note: We can't easily verify watchers were added without accessing the db directly + # but we can verify the request was successful and uploaded_by metadata is present + metadata = rv_get.json["metadata"] + uploaded_by_meta = [m for m in metadata if m["element"] == "uploaded_by"] + assert len(uploaded_by_meta) == 1 + assert uploaded_by_meta[0]["value"] == "watcher-user" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_uploaded_by(client): + """Test POST endpoint with uploaded_by field.""" + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "upload-test", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [], + }, + "add_watcher": False, + "uploaded_by": "specific-user@example.com", + } + + rv = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv.status_code == 200 + + # Verify uploaded_by metadata + rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + assert rv_get.status_code == 200 + metadata = rv_get.json["metadata"] + uploaded_by_meta = [m for m in metadata if m["element"] == "uploaded_by"] + assert len(uploaded_by_meta) == 1 + assert uploaded_by_meta[0]["value"] == "specific-user@example.com" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_trace_with_replaces(client): + """Test the trace endpoint with a simulation that replaces another.""" + # Create original simulation + old_sim_uuid = uuid.uuid4() + old_simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": old_sim_uuid.hex}, + "alias": "trace-original", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "version", "value": "1.0"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_old = client.post( + "/v1.2/simulations", + json=old_simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_old.status_code == 200 + + # Create new simulation that replaces it + new_sim_uuid = uuid.uuid4() + new_simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": new_sim_uuid.hex}, + "alias": "trace-updated", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "version", "value": "2.0"}, + {"element": "replaces", "value": old_sim_uuid.hex}, + {"element": "replaces_reason", "value": "New features"}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_new = client.post( + "/v1.2/simulations", + json=new_simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_new.status_code == 200 + + # Get trace for the new simulation + rv_trace = client.get(f"/v1.2/trace/{new_sim_uuid.hex}", headers=HEADERS) + assert rv_trace.status_code == 200 + trace_data = rv_trace.json + + # Verify trace includes replaces information + assert "replaces" in trace_data + + replaces_uuid = trace_data["replaces"]["uuid"] + assert replaces_uuid == old_sim_uuid + assert "replaces_reason" in trace_data + assert trace_data["replaces_reason"] == "New features" + + with pytest.xfail("Deprecated on is not set, because replaced_on is never set"): + assert "deprecated_on" in trace_data["replaces"] From 8aa6eb4c718d7824f0320b78d58b1fec4e00f514 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 13:52:28 +0100 Subject: [PATCH 07/18] Fix failing tests --- src/simdb/remote/apis/v1_2/simulations.py | 7 ++--- tests/remote/test_api.py | 32 +++++++++++------------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 6636b84..a9edd0c 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -42,7 +42,8 @@ def _update_simulation_status( ) -> None: old_status = simulation.status simulation.status = status - if status != old_status and len(simulation.watchers) > 0: + watchers_list = list(simulation.watchers) + if status != old_status and len(watchers_list) > 0: server = EmailServer(current_app.simdb_config) msg = f"""\ Simulation status changed from {old_status} to {status}. @@ -51,7 +52,7 @@ def _update_simulation_status( Note: please don't reply to this email, replies to this address are not monitored. """ - to_addresses = [w.email for w in simulation.watchers] + to_addresses = [w.email for w in watchers_list] if to_addresses: if simulation.alias is None or simulation.alias == "": server.send_message( @@ -297,7 +298,7 @@ def post(self, user: User): simulation.set_meta("uploaded_by", uploaded_by) - if d.add_watcher: + if d.add_watcher and user.email: simulation.watchers.append( models_watcher.Watcher( user.name, user.email, models_watcher.Notification.ALL diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index 7dcbe4d..cf4744b 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -160,7 +160,7 @@ def test_post_simulations(client): def test_post_simulations_with_alias_dash(client): """Test POST endpoint with alias ending in dash (auto-increment).""" sim_uuid = uuid.uuid4() - + simulation_data = { "simulation": { "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, @@ -187,7 +187,7 @@ def test_post_simulations_with_alias_dash(client): rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) assert rv_get.status_code == 200 assert rv_get.json["alias"] == "dashtest-1" - + # Check seqid metadata was added metadata = rv_get.json["metadata"] seqid_meta = [m for m in metadata if m["element"] == "seqid"] @@ -199,7 +199,7 @@ def test_post_simulations_with_alias_dash(client): def test_post_simulations_with_alias_hash(client): """Test POST endpoint with alias ending in hash (auto-increment).""" sim_uuid = uuid.uuid4() - + simulation_data = { "simulation": { "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, @@ -226,7 +226,7 @@ def test_post_simulations_with_alias_hash(client): rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) assert rv_get.status_code == 200 assert rv_get.json["alias"] == "hashtest#1" - + # Check seqid metadata was added metadata = rv_get.json["metadata"] seqid_meta = [m for m in metadata if m["element"] == "seqid"] @@ -286,7 +286,7 @@ def test_post_simulations_alias_increment_sequence(client): # Verify aliases were incremented rv_get1 = client.get(f"/v1.2/simulation/{sim_uuid_1.hex}", headers=HEADERS) assert rv_get1.json["alias"] == "sequence-1" - + rv_get2 = client.get(f"/v1.2/simulation/{sim_uuid_2.hex}", headers=HEADERS) assert rv_get2.json["alias"] == "sequence-2" @@ -295,7 +295,7 @@ def test_post_simulations_alias_increment_sequence(client): def test_post_simulations_no_alias(client): """Test POST endpoint with no alias provided (should use uuid.hex).""" sim_uuid = uuid.uuid4() - + simulation_data = { "simulation": { "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, @@ -379,11 +379,11 @@ def test_post_simulations_with_replaces(client): rv_old_get = client.get(f"/v1.2/simulation/{old_sim_uuid.hex}", headers=HEADERS) assert rv_old_get.status_code == 200 old_metadata = rv_old_get.json["metadata"] - + status_meta = [m for m in old_metadata if m["element"] == "status"] assert len(status_meta) == 1 assert status_meta[0]["value"].lower() == "deprecated" - + # Check replaced_by metadata was added replaced_by_meta = [m for m in old_metadata if m["element"] == "replaced_by"] assert len(replaced_by_meta) == 1 @@ -393,7 +393,7 @@ def test_post_simulations_with_replaces(client): rv_new_get = client.get(f"/v1.2/simulation/{new_sim_uuid.hex}", headers=HEADERS) assert rv_new_get.status_code == 200 new_metadata = rv_new_get.json["metadata"] - + replaces_meta = [m for m in new_metadata if m["element"] == "replaces"] assert len(replaces_meta) == 1 assert replaces_meta[0]["value"] == old_sim_uuid.hex @@ -405,7 +405,7 @@ def test_post_simulations_replaces_nonexistent(client): # Create simulation that tries to replace a non-existent simulation sim_uuid = uuid.uuid4() fake_uuid = uuid.uuid4() - + simulation_data = { "simulation": { "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, @@ -429,7 +429,7 @@ def test_post_simulations_replaces_nonexistent(client): content_type="application/json", ) assert rv.status_code == 200 - + # Verify the new simulation was created rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) assert rv_get.status_code == 200 @@ -440,7 +440,7 @@ def test_post_simulations_replaces_nonexistent(client): def test_post_simulations_with_watcher(client): """Test POST endpoint with add_watcher set to true.""" sim_uuid = uuid.uuid4() - + simulation_data = { "simulation": { "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, @@ -465,7 +465,7 @@ def test_post_simulations_with_watcher(client): # Verify the simulation was created rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) assert rv_get.status_code == 200 - + # Note: We can't easily verify watchers were added without accessing the db directly # but we can verify the request was successful and uploaded_by metadata is present metadata = rv_get.json["metadata"] @@ -478,7 +478,7 @@ def test_post_simulations_with_watcher(client): def test_post_simulations_uploaded_by(client): """Test POST endpoint with uploaded_by field.""" sim_uuid = uuid.uuid4() - + simulation_data = { "simulation": { "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, @@ -566,7 +566,7 @@ def test_post_simulations_trace_with_replaces(client): rv_trace = client.get(f"/v1.2/trace/{new_sim_uuid.hex}", headers=HEADERS) assert rv_trace.status_code == 200 trace_data = rv_trace.json - + # Verify trace includes replaces information assert "replaces" in trace_data @@ -574,6 +574,6 @@ def test_post_simulations_trace_with_replaces(client): assert replaces_uuid == old_sim_uuid assert "replaces_reason" in trace_data assert trace_data["replaces_reason"] == "New features" - + with pytest.xfail("Deprecated on is not set, because replaced_on is never set"): assert "deprecated_on" in trace_data["replaces"] From 473994a11027f931c143bf94424b8e14b9a2da2a Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 14:54:34 +0100 Subject: [PATCH 08/18] Add more test for simulations get --- tests/remote/test_api.py | 454 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 446 insertions(+), 8 deletions(-) diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index cf4744b..d4e101e 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -73,14 +73,6 @@ def test_get_api_root(client): assert rv.status_code == 308 -@pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_get_simulations(client): - rv = client.get("/v1.2/simulations", headers=HEADERS) - assert rv.json["count"] == 100 - assert len(rv.json["results"]) == len(SIMULATIONS) - assert rv.status_code == 200 - - @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations(client): """Test POST endpoint for creating a new simulation.""" @@ -577,3 +569,449 @@ def test_post_simulations_trace_with_replaces(client): with pytest.xfail("Deprecated on is not set, because replaced_on is never set"): assert "deprecated_on" in trace_data["replaces"] + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_basic(client): + """Test basic GET request to /v1.2/simulations endpoint.""" + rv = client.get("/v1.2/simulations", headers=HEADERS) + + assert rv.status_code == 200 + assert rv.is_json + + data = rv.json + assert "count" in data + assert "page" in data + assert "limit" in data + assert "results" in data + + # Should return paginated results + assert data["page"] == 1 + assert data["limit"] == 100 + assert isinstance(data["results"], list) + assert data["count"] >= 100 # At least the 100 fixture simulations + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_pagination_limit(client): + """Test GET request with custom limit.""" + custom_limit = 10 + headers_with_limit = {**HEADERS, "simdb-result-limit": str(custom_limit)} + + rv = client.get("/v1.2/simulations", headers=headers_with_limit) + + assert rv.status_code == 200 + data = rv.json + + assert data["limit"] == custom_limit + assert len(data["results"]) <= custom_limit + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_pagination_page(client): + """Test GET request with custom page number.""" + headers_page_2 = {**HEADERS, "simdb-result-limit": "10", "simdb-page": "2"} + + rv = client.get("/v1.2/simulations", headers=headers_page_2) + + assert rv.status_code == 200 + data = rv.json + + assert data["page"] == 2 + assert data["limit"] == 10 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_pagination_multiple_pages(client): + """Test pagination across multiple pages.""" + limit = 20 + + # Get first page + headers_page_1 = {**HEADERS, "simdb-result-limit": str(limit), "simdb-page": "1"} + rv1 = client.get("/v1.2/simulations", headers=headers_page_1) + assert rv1.status_code == 200 + page1_data = rv1.json + + # Get second page + headers_page_2 = {**HEADERS, "simdb-result-limit": str(limit), "simdb-page": "2"} + rv2 = client.get("/v1.2/simulations", headers=headers_page_2) + assert rv2.status_code == 200 + page2_data = rv2.json + + # Both should have same count and limit + assert page1_data["count"] == page2_data["count"] + assert page1_data["limit"] == page2_data["limit"] == limit + + # Pages should be different + assert page1_data["page"] == 1 + assert page2_data["page"] == 2 + + # Results should be different (assuming we have enough data) + if page1_data["count"] > limit: + page1_uuids = {item["uuid"] for item in page1_data["results"]} + page2_uuids = {item["uuid"] for item in page2_data["results"]} + assert page1_uuids != page2_uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_by_alias(client): + """Test filtering simulations by alias.""" + # First create a simulation with a known alias + sim_uuid = uuid.uuid4() + test_alias = "filter-test-alias" + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": test_alias, + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "test_key", "value": "test_value"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Now filter by alias + rv = client.get(f"/v1.2/simulations?alias={test_alias}", headers=HEADERS) + + assert rv.status_code == 200 + data = rv.json + + assert data["count"] >= 1 + # Check that the filtered result contains our simulation + aliases = [item.get("alias") for item in data["results"]] + assert test_alias in aliases + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_by_uuid(client): + """Test filtering simulations by UUID.""" + # Create a simulation with a known UUID + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "uuid-filter-test", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Filter by UUID + rv = client.get(f"/v1.2/simulations?uuid={sim_uuid.hex}", headers=HEADERS) + + assert rv.status_code == 200 + data = rv.json + + assert data["count"] >= 1 + # Check that the filtered result contains our simulation + uuids = [item.get("uuid") for item in data["results"]] + assert sim_uuid in uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_by_metadata(client): + """Test filtering simulations by metadata.""" + # Create simulations with specific metadata + sim_uuid_1 = uuid.uuid4() + sim_uuid_2 = uuid.uuid4() + test_machine = "test-machine-xyz" + + simulation_data_1 = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_1.hex}, + "alias": "metadata-filter-1", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "machine", "value": test_machine}, + {"element": "code", "value": "test-code"}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + simulation_data_2 = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_2.hex}, + "alias": "metadata-filter-2", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "machine", "value": test_machine}, + {"element": "code", "value": "different-code"}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post_1 = client.post( + "/v1.2/simulations", + json=simulation_data_1, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post_1.status_code == 200 + + rv_post_2 = client.post( + "/v1.2/simulations", + json=simulation_data_2, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post_2.status_code == 200 + + # Filter by machine metadata + rv = client.get(f"/v1.2/simulations?machine={test_machine}", headers=HEADERS) + + assert rv.status_code == 200 + data = rv.json + + assert data["count"] >= 2 + + # Check that both simulations are in the results + results_uuids = [item.get("uuid") for item in data["results"]] + assert sim_uuid_1 in results_uuids + assert sim_uuid_2 in results_uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_multiple_metadata(client): + """Test filtering simulations by multiple metadata fields.""" + # Create a simulation with multiple metadata fields + sim_uuid = uuid.uuid4() + test_machine = "multi-filter-machine" + test_code = "multi-filter-code" + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "multi-metadata-filter", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "machine", "value": test_machine}, + {"element": "code", "value": test_code}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Filter by both machine and code + rv = client.get( + f"/v1.2/simulations?machine={test_machine}&code={test_code}", headers=HEADERS + ) + + assert rv.status_code == 200 + data = rv.json + + assert data["count"] >= 1 + results_uuids = [item.get("uuid") for item in data["results"]] + assert sim_uuid in results_uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_sorting_asc(client): + """Test sorting simulations in ascending order.""" + # Create simulations with sortable aliases + for i in range(3): + sim_uuid = uuid.uuid4() + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": f"sort-test-{i:03d}", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Get simulations sorted by alias ascending + headers_sorted = {**HEADERS, "simdb-sort-by": "alias", "simdb-sort-asc": "true"} + + rv = client.get("/v1.2/simulations?alias=sort-test-%", headers=headers_sorted) + + assert rv.status_code == 200 + data = rv.json + + # Filter to only our test simulations + test_sims = [ + item + for item in data["results"] + if item.get("alias", "").startswith("sort-test-") + ] + + if len(test_sims) >= 2: + # Check that results are sorted in ascending order + aliases = [item.get("alias") for item in test_sims] + assert aliases == sorted(aliases) + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_sorting_desc(client): + """Test sorting simulations in descending order.""" + # Get simulations sorted by alias descending + headers_sorted = {**HEADERS, "simdb-sort-by": "alias", "simdb-sort-asc": "false"} + + rv = client.get("/v1.2/simulations", headers=headers_sorted) + + assert rv.status_code == 200 + data = rv.json + + # Just verify the request succeeded and returned data + assert "results" in data + assert isinstance(data["results"], list) + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_empty_result(client): + """Test GET request with filters that return no results.""" + # Use a filter that shouldn't match anything + rv = client.get( + "/v1.2/simulations?alias=non-existent-simulation-12345xyz", headers=HEADERS + ) + + assert rv.status_code == 200 + data = rv.json + + assert data["count"] == 0 + assert data["results"] == [] + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_with_metadata_keys(client): + """Test requesting specific metadata keys in results.""" + # Create a simulation with known metadata + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "meta-keys-test", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "machine", "value": "machine-x"}, + {"element": "code", "value": "code-y"}, + {"element": "description", "value": "test description"}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Request simulations with specific metadata keys + rv = client.get( + "/v1.2/simulations?alias=meta-keys-test&machine&code", headers=HEADERS + ) + + assert rv.status_code == 200 + data = rv.json + + assert data["count"] >= 1 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_combined_pagination_sorting_filtering(client): + """Test GET request with pagination, sorting, and filtering combined.""" + # Create multiple simulations for testing + test_prefix = "combined-test" + for i in range(5): + sim_uuid = uuid.uuid4() + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": f"{test_prefix}-{i:02d}", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "test_group", "value": "combined"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Request with all features combined + headers_combined = { + **HEADERS, + "simdb-result-limit": "3", + "simdb-page": "1", + "simdb-sort-by": "alias", + "simdb-sort-asc": "true", + } + + rv = client.get( + f"/v1.2/simulations?alias={test_prefix}-%", headers=headers_combined + ) + + assert rv.status_code == 200 + data = rv.json + + assert data["page"] == 1 + assert data["limit"] == 3 + assert len(data["results"]) <= 3 From b2450c2b87945b79c007c2ee339fb5e7c4fa8539 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 14:59:01 +0100 Subject: [PATCH 09/18] Use pydantic model for get response --- src/simdb/remote/apis/v1_2/simulations.py | 10 +++++- src/simdb/remote/models.py | 43 ++++++++++++++++++++--- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index a9edd0c..280f040 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -26,6 +26,8 @@ from simdb.remote.core.path import find_common_root, secure_path from simdb.remote.core.typing import current_app from simdb.remote.models import ( + PaginatedResponse, + SimulationListItem, SimulationPostData, SimulationPostResponse, ValidationResult, @@ -277,7 +279,13 @@ def get(self, user: User): sort_asc=sort_asc, ) - return jsonify({"count": count, "page": page, "limit": limit, "results": data}) + serialized_data = [SimulationListItem.model_validate(item) for item in data] + + return jsonify( + PaginatedResponse( + count=count, page=page, limit=limit, results=serialized_data + ).model_dump(mode="json") + ) @requires_auth() def post(self, user: User): diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 777b6fd..20f4a5c 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,17 +1,35 @@ from datetime import datetime as dt from datetime import timezone -from typing import Annotated, Any, List, Optional +from typing import Annotated, Any, Generic, List, Optional, TypeVar from uuid import UUID -from pydantic import BaseModel, Field, PlainSerializer +from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer HexUUID = Annotated[UUID, PlainSerializer(lambda x: x.hex, return_type=str)] +def _deserialize_custom_uuid(v: Any) -> UUID: + """Deserialize CustomUUID format back to UUID.""" + if isinstance(v, UUID): + return v + if isinstance(v, dict) and "hex" in v: + return UUID(hex=v["hex"]) + if isinstance(v, str): + return UUID(v) + raise ValueError(f"Cannot deserialize {v} to UUID") + + +CustomUUID = Annotated[ + UUID, + BeforeValidator(_deserialize_custom_uuid), + PlainSerializer(lambda x: {"_type": "uuid.UUID", "hex": x.hex}), +] + + class FileData(BaseModel): type: str uri: str - uuid: UUID + uuid: CustomUUID checksum: str datetime: dt usage: Optional[str] @@ -27,7 +45,7 @@ class MetadataData(BaseModel): class SimulationData(BaseModel): - uuid: UUID + uuid: CustomUUID alias: Optional[str] datetime: dt = Field(default_factory=lambda: dt.now(timezone.utc)) inputs: List[FileData] @@ -50,3 +68,20 @@ class SimulationPostResponse(BaseModel): ingested: HexUUID error: Optional[str] validation: Optional[ValidationResult] + + +class SimulationListItem(BaseModel): + uuid: CustomUUID + alias: Optional[str] + datetime: str + metadata: Optional[List[MetadataData]] = None + + +T = TypeVar("T") + + +class PaginatedResponse(BaseModel, Generic[T]): + count: int + page: int + limit: int + results: T From e4f94d17f60a3776b9c354bf0d10d4e8500068d9 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 16:13:03 +0100 Subject: [PATCH 10/18] Add tests for get simulation by id --- tests/remote/test_api.py | 317 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 317 insertions(+) diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index d4e101e..d8a8efc 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -1015,3 +1015,320 @@ def test_get_simulations_combined_pagination_sorting_filtering(client): assert data["page"] == 1 assert data["limit"] == 3 assert len(data["results"]) <= 3 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_by_uuid(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - retrieve by UUID.""" + # Create a simulation with known properties + sim_uuid = uuid.uuid4() + input_uuid = uuid.uuid4() + output_uuid = uuid.uuid4() + test_alias = "get-test-simulation" + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": test_alias, + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [ + { + "uuid": {"_type": "uuid.UUID", "hex": input_uuid.hex}, + "type": "FILE", + "uri": "file:///test/input.dat", + "checksum": "input123", + "datetime": datetime.now(timezone.utc).isoformat(), + "usage": "input_data", + "purpose": "test input", + "sensitivity": "public", + "access": "open", + "embargo": None, + } + ], + "outputs": [ + { + "uuid": {"_type": "uuid.UUID", "hex": output_uuid.hex}, + "type": "FILE", + "uri": "file:///test/output.dat", + "checksum": "output456", + "datetime": datetime.now(timezone.utc).isoformat(), + "usage": "output_data", + "purpose": "test output", + "sensitivity": "public", + "access": "open", + "embargo": None, + } + ], + "metadata": [ + {"element": "machine", "value": "test-machine"}, + {"element": "code", "value": "test-code"}, + {"element": "description", "value": "Test simulation for GET"}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + # Create the simulation + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Test GET by UUID + rv = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + + assert rv.status_code == 200 + assert rv.is_json + + data = rv.json + + # Verify basic fields + assert "uuid" in data + assert data["uuid"] == sim_uuid + assert data["alias"] == test_alias + + # Verify datetime field exists + assert "datetime" in data + + # Verify inputs and outputs + assert "inputs" in data + assert len(data["inputs"]) == 1 + assert data["inputs"][0]["uuid"] == input_uuid + assert data["inputs"][0]["uri"] == "file:/test/input.dat" + assert data["inputs"][0]["checksum"] == "input123" + + assert "outputs" in data + assert len(data["outputs"]) == 1 + assert data["outputs"][0]["uuid"] == output_uuid + assert data["outputs"][0]["uri"] == "file:/test/output.dat" + assert data["outputs"][0]["checksum"] == "output456" + + # Verify metadata + assert "metadata" in data + assert len(data["metadata"]) >= 3 # At least our 3 metadata items + metadata_dict = {m["element"]: m["value"] for m in data["metadata"]} + assert metadata_dict["machine"] == "test-machine" + assert metadata_dict["code"] == "test-code" + assert metadata_dict["description"] == "Test simulation for GET" + + # Verify children and parents fields exist + assert "children" in data + assert "parents" in data + assert isinstance(data["children"], list) + assert isinstance(data["parents"], list) + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_by_alias(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - retrieve by alias.""" + # Create a simulation with a unique alias + sim_uuid = uuid.uuid4() + test_alias = "get-by-alias-test" + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": test_alias, + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "test", "value": "alias retrieval"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Test GET by alias + rv = client.get(f"/v1.2/simulation/{test_alias}", headers=HEADERS) + + assert rv.status_code == 200 + data = rv.json + + assert data["uuid"] == sim_uuid + assert data["alias"] == test_alias + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_not_found(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - non-existent simulation.""" + # Try to get a non-existent simulation + fake_uuid = uuid.uuid4() + + rv = client.get(f"/v1.2/simulation/{fake_uuid.hex}", headers=HEADERS) + + assert rv.status_code == 400 + data = rv.json + + # Should contain an error message + assert "error" in data or data.get("message") == "Simulation not found" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_with_parents_and_children(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - verify parents/children.""" + # Create parent simulation + parent_uuid = uuid.uuid4() + parent_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": parent_uuid.hex}, + "alias": "parent-simulation", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [{"element": "role", "value": "parent"}], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_parent = client.post( + "/v1.2/simulations", + json=parent_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_parent.status_code == 200 + + # Create child simulation that references parent + child_uuid = uuid.uuid4() + child_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": child_uuid.hex}, + "alias": "child-simulation", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [], + "outputs": [], + "metadata": [ + {"element": "role", "value": "child"}, + {"element": "parent", "value": parent_uuid.hex}, + ], + }, + "add_watcher": False, + "uploaded_by": "test-user", + } + + rv_child = client.post( + "/v1.2/simulations", + json=child_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_child.status_code == 200 + + # Get child simulation and verify parents field + rv = client.get(f"/v1.2/simulation/{child_uuid.hex}", headers=HEADERS) + + assert rv.status_code == 200 + data = rv.json + + # Verify the parents/children structure + assert "parents" in data + assert "children" in data + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_full_response_structure(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - verify complete response + structure.""" + # Create a comprehensive simulation + sim_uuid = uuid.uuid4() + + simulation_data = { + "simulation": { + "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, + "alias": "complete-structure-test", + "datetime": datetime.now(timezone.utc).isoformat(), + "inputs": [ + { + "uuid": {"_type": "uuid.UUID", "hex": uuid.uuid4().hex}, + "type": "FILE", + "uri": "file:///complete/input.dat", + "checksum": "complete123", + "datetime": datetime.now(timezone.utc).isoformat(), + "usage": "input", + "purpose": "complete input", + "sensitivity": "public", + "access": "open", + "embargo": None, + } + ], + "outputs": [ + { + "uuid": {"_type": "uuid.UUID", "hex": uuid.uuid4().hex}, + "type": "FILE", + "uri": "file:///complete/output.dat", + "checksum": "complete456", + "datetime": datetime.now(timezone.utc).isoformat(), + "usage": "output", + "purpose": "complete output", + "sensitivity": "public", + "access": "open", + "embargo": None, + } + ], + "metadata": [ + {"element": "machine", "value": "complete-machine"}, + {"element": "version", "value": "1.0"}, + ], + }, + "add_watcher": False, + "uploaded_by": "complete-user", + } + + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data, + headers=HEADERS, + content_type="application/json", + ) + assert rv_post.status_code == 200 + + # Get the simulation + rv = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + + assert rv.status_code == 200 + data = rv.json + + # Verify all required top-level fields are present + required_fields = [ + "uuid", + "alias", + "datetime", + "inputs", + "outputs", + "metadata", + "children", + "parents", + ] + for field in required_fields: + assert field in data, f"Required field '{field}' missing from response" + + # Verify inputs structure + assert len(data["inputs"]) == 1 + input_required_fields = ["uuid", "type", "uri", "checksum", "datetime"] + for field in input_required_fields: + assert field in data["inputs"][0], f"Required input field '{field}' missing" + + # Verify outputs structure + assert len(data["outputs"]) == 1 + output_required_fields = ["uuid", "type", "uri", "checksum", "datetime"] + for field in output_required_fields: + assert field in data["outputs"][0], f"Required output field '{field}' missing" + + # Verify metadata structure + assert len(data["metadata"]) >= 2 + for meta in data["metadata"]: + assert "element" in meta + assert "value" in meta From d97f044b895571cd233f1b96e8686821a09df03e Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 10 Feb 2026 16:13:21 +0100 Subject: [PATCH 11/18] Use pydantic for get simulation by id --- src/simdb/database/database.py | 37 +++++++++++++++++++++++ src/simdb/database/models/file.py | 14 +++++++++ src/simdb/database/models/metadata.py | 3 ++ src/simdb/database/models/simulation.py | 21 +++++++++++++ src/simdb/remote/apis/v1_2/simulations.py | 13 +++++--- src/simdb/remote/models.py | 16 +++++++--- 6 files changed, 95 insertions(+), 9 deletions(-) diff --git a/src/simdb/database/database.py b/src/simdb/database/database.py index 37ba007..20a9b0c 100644 --- a/src/simdb/database/database.py +++ b/src/simdb/database/database.py @@ -16,6 +16,7 @@ from simdb.config import Config from simdb.query import QueryType, query_compare +from simdb.remote.models import SimulationReference from .models import Base from .models.file import File @@ -571,6 +572,24 @@ def get_simulation_parents(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_parents_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.input_for.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.outputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_simulation_children(self, simulation: "Simulation") -> List[dict]: subquery = ( self.session.query(File.checksum) @@ -587,6 +606,24 @@ def get_simulation_children(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_children_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.output_of.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.inputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_file(self, file_uuid_str: str) -> "File": """ Get the specified file from the database. diff --git a/src/simdb/database/models/file.py b/src/simdb/database/models/file.py index a6894a6..5b6a14d 100644 --- a/src/simdb/database/models/file.py +++ b/src/simdb/database/models/file.py @@ -157,3 +157,17 @@ def data(self, recurse: bool = False) -> Dict[str, str]: "datetime": self.datetime.isoformat(), } return data + + def to_model(self) -> FileData: + return FileData( + type=self.type.name, + uri=str(self.uri), + uuid=self.uuid, + checksum=self.checksum, + datetime=self.datetime, + usage=self.usage, + purpose=self.purpose, + sensitivity=self.sensitivity, + access=self.access, + embargo=self.embargo, + ) diff --git a/src/simdb/database/models/metadata.py b/src/simdb/database/models/metadata.py index 7b975ba..81bcde9 100644 --- a/src/simdb/database/models/metadata.py +++ b/src/simdb/database/models/metadata.py @@ -45,5 +45,8 @@ def data(self, recurse: bool = False) -> Dict[str, str]: } return data + def to_model(self) -> MetadataData: + return MetadataData(element=self.element, value=self.value) + Index("metadata_index", MetaData.sim_id, MetaData.element, unique=True) diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index cf92c87..69ab660 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -367,6 +367,27 @@ def data( ] return data + def to_model( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationData: + inputs = [] + outputs = [] + metadata = [] + if recurse: + inputs = [f.to_model() for f in self.inputs] + outputs = [f.to_model() for f in self.outputs] + metadata = [m.to_model() for m in self.meta] + elif meta_keys: + metadata = [m.to_model() for m in self.meta if m.element in meta_keys] + return SimulationData( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + ) + def meta_dict(self) -> Dict[str, Union[Dict, Any]]: meta = {m.element: m.value for m in self.meta} return unflatten_dict(meta) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 280f040..f0960f7 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -27,6 +27,7 @@ from simdb.remote.core.typing import current_app from simdb.remote.models import ( PaginatedResponse, + SimulationDataResponse, SimulationListItem, SimulationPostData, SimulationPostResponse, @@ -424,12 +425,14 @@ def get(self, sim_id: str, user: User): try: simulation = current_app.db.get_simulation(sim_id) if simulation: - sim_data = simulation.data(recurse=True) - sim_data["children"] = current_app.db.get_simulation_children( - simulation + sim_data = simulation.to_model(recurse=True) + children = current_app.db.get_simulation_children_ref(simulation) + parents = current_app.db.get_simulation_parents_ref(simulation) + return jsonify( + SimulationDataResponse( + **sim_data.model_dump(), children=children, parents=parents + ).model_dump(mode="json") ) - sim_data["parents"] = current_app.db.get_simulation_parents(simulation) - return jsonify(sim_data) return error("Simulation not found") except DatabaseError as err: return error(str(err)) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 20f4a5c..2c78713 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,6 +1,6 @@ from datetime import datetime as dt from datetime import timezone -from typing import Annotated, Any, Generic, List, Optional, TypeVar +from typing import Annotated, Any, Generic, List, Optional, TypeVar, Union from uuid import UUID from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer @@ -14,8 +14,6 @@ def _deserialize_custom_uuid(v: Any) -> UUID: return v if isinstance(v, dict) and "hex" in v: return UUID(hex=v["hex"]) - if isinstance(v, str): - return UUID(v) raise ValueError(f"Cannot deserialize {v} to UUID") @@ -41,7 +39,12 @@ class FileData(BaseModel): class MetadataData(BaseModel): element: str - value: Any + value: Union[CustomUUID, Any] + + +class SimulationReference(BaseModel): + uuid: CustomUUID + alias: Optional[str] class SimulationData(BaseModel): @@ -53,6 +56,11 @@ class SimulationData(BaseModel): metadata: List[MetadataData] +class SimulationDataResponse(SimulationData): + parents: List[SimulationReference] + children: List[SimulationReference] + + class SimulationPostData(BaseModel): simulation: SimulationData add_watcher: bool From a8bda2f01a6587c3cb6492a185914c6c559e644c Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Fri, 13 Feb 2026 10:49:14 +0100 Subject: [PATCH 12/18] Set defaults for models --- src/simdb/remote/models.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 2c78713..9669fe5 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,7 +1,7 @@ from datetime import datetime as dt from datetime import timezone from typing import Annotated, Any, Generic, List, Optional, TypeVar, Union -from uuid import UUID +from uuid import UUID, uuid1 from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer @@ -27,14 +27,14 @@ def _deserialize_custom_uuid(v: Any) -> UUID: class FileData(BaseModel): type: str uri: str - uuid: CustomUUID + uuid: CustomUUID = Field(default_factory=lambda: uuid1()) checksum: str datetime: dt - usage: Optional[str] - purpose: Optional[str] - sensitivity: Optional[str] - access: Optional[str] - embargo: Optional[str] + usage: Optional[str] = None + purpose: Optional[str] = None + sensitivity: Optional[str] = None + access: Optional[str] = None + embargo: Optional[str] = None class MetadataData(BaseModel): @@ -44,16 +44,16 @@ class MetadataData(BaseModel): class SimulationReference(BaseModel): uuid: CustomUUID - alias: Optional[str] + alias: Optional[str] = None class SimulationData(BaseModel): - uuid: CustomUUID - alias: Optional[str] + uuid: CustomUUID = Field(default_factory=lambda: uuid1()) + alias: Optional[str] = None datetime: dt = Field(default_factory=lambda: dt.now(timezone.utc)) - inputs: List[FileData] - outputs: List[FileData] - metadata: List[MetadataData] + inputs: List[FileData] = [] + outputs: List[FileData] = [] + metadata: List[MetadataData] = [] class SimulationDataResponse(SimulationData): @@ -64,23 +64,23 @@ class SimulationDataResponse(SimulationData): class SimulationPostData(BaseModel): simulation: SimulationData add_watcher: bool - uploaded_by: Optional[str] + uploaded_by: Optional[str] = None class ValidationResult(BaseModel): passed: bool - error: Optional[str] + error: Optional[str] = None class SimulationPostResponse(BaseModel): ingested: HexUUID - error: Optional[str] - validation: Optional[ValidationResult] + error: Optional[str] = None + validation: Optional[ValidationResult] = None class SimulationListItem(BaseModel): uuid: CustomUUID - alias: Optional[str] + alias: Optional[str] = None datetime: str metadata: Optional[List[MetadataData]] = None From 4766c977b4bd9cc43c8d67ef6165113881158696 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Fri, 13 Feb 2026 10:54:11 +0100 Subject: [PATCH 13/18] Cleanup simulation post tests --- tests/remote/test_api.py | 420 +++++++++++++-------------------------- 1 file changed, 133 insertions(+), 287 deletions(-) diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index d8a8efc..5c31043 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -13,6 +13,12 @@ from simdb.config import Config from simdb.database.models import Simulation from simdb.remote.app import create_app +from simdb.remote.models import ( + FileData, + MetadataData, + SimulationData, + SimulationPostData, +) has_flask = importlib.util.find_spec("flask") is not None @@ -56,6 +62,25 @@ def client(): shutil.rmtree(upload_dir) +def generate_simulation_data( + add_watcher=False, uploaded_by=None, **overrides +) -> SimulationPostData: + simulation_data = SimulationData(**overrides) + data = SimulationPostData( + simulation=simulation_data, add_watcher=add_watcher, uploaded_by=uploaded_by + ) + return data + + +def generate_simulation_file() -> FileData: + return FileData( + type="FILE", + uri="file:///path/to/file", + checksum="fake_checksum", + datetime=datetime.now(timezone.utc), + ) + + @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_root(client): rv = client.get("/") @@ -76,148 +101,56 @@ def test_get_api_root(client): @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations(client): """Test POST endpoint for creating a new simulation.""" - # Create a new simulation data structure - sim_uuid = uuid.uuid4() - sim_uuid_hex = sim_uuid.hex - input_uuid = uuid.uuid4() - output_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_hex}, - "alias": "test-simulation", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [ - { - "uuid": {"_type": "uuid.UUID", "hex": input_uuid.hex}, - "type": "FILE", - "uri": "file:///path/to/input/data.txt", - "checksum": "abc123def456", - "datetime": datetime.now(timezone.utc).isoformat(), - "usage": "input_data", - "purpose": "test input file", - "sensitivity": "public", - "access": "open", - "embargo": None, - } - ], - "outputs": [ - { - "uuid": {"_type": "uuid.UUID", "hex": output_uuid.hex}, - "type": "FILE", - "uri": "file:///path/to/output/results.txt", - "checksum": "xyz789abc012", - "datetime": datetime.now(timezone.utc).isoformat(), - "usage": "output_data", - "purpose": "test output file", - "sensitivity": "public", - "access": "open", - "embargo": None, - } - ], - "metadata": [ - {"element": "machine", "value": "test-machine"}, - {"element": "code", "value": "test-code"}, - {"element": "description", "value": "Test simulation"}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + simulation_data = generate_simulation_data( + alias="test-simulation", + inputs=[generate_simulation_file()], + outputs=[generate_simulation_file()], + ) # POST the simulation rv = client.post( "/v1.2/simulations", - json=simulation_data, + json=simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) # Verify the response - if rv.status_code != 200: - print(f"Response status: {rv.status_code}") - print(f"Response data: {rv.data}") - print(f"Response json: {rv.json if rv.is_json else 'Not JSON'}") - - assert "ingested" in rv.json - assert rv.json["ingested"] == sim_uuid_hex + assert rv.status_code == 200 + assert rv.json["ingested"] == simulation_data.simulation.uuid.hex # Verify the simulation was created by fetching it - rv_get = client.get(f"/v1.2/simulation/{sim_uuid_hex}", headers=HEADERS) + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get.status_code == 200 assert rv_get.json["alias"] == "test-simulation" @pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_post_simulations_with_alias_dash(client): - """Test POST endpoint with alias ending in dash (auto-increment).""" - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "dashtest-", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "test", "value": "dash"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - rv = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", +@pytest.mark.parametrize("suffix", ["-", "#"]) +def test_post_simulations_with_alias_auto_increment(client, suffix): + """Test POST endpoint with alias ending in dash or hashtag (auto-increment).""" + random_name = uuid.uuid4().hex + simulation_data = generate_simulation_data( + alias=f"{random_name}{suffix}", ) - assert rv.status_code == 200 - assert "ingested" in rv.json - - rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) - assert rv_get.status_code == 200 - assert rv_get.json["alias"] == "dashtest-1" - - # Check seqid metadata was added - metadata = rv_get.json["metadata"] - seqid_meta = [m for m in metadata if m["element"] == "seqid"] - assert len(seqid_meta) == 1 - assert seqid_meta[0]["value"] == 1 - - -@pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_post_simulations_with_alias_hash(client): - """Test POST endpoint with alias ending in hash (auto-increment).""" - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "hashtest#", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "test", "value": "hash"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - rv = client.post( "/v1.2/simulations", - json=simulation_data, + json=simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv.status_code == 200 - assert "ingested" in rv.json + assert rv.json["ingested"] == simulation_data.simulation.uuid.hex - rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get.status_code == 200 - assert rv_get.json["alias"] == "hashtest#1" + assert rv_get.json["alias"] == f"{random_name}{suffix}1" # Check seqid metadata was added metadata = rv_get.json["metadata"] @@ -230,145 +163,99 @@ def test_post_simulations_with_alias_hash(client): def test_post_simulations_alias_increment_sequence(client): """Test multiple simulations with incrementing dash alias.""" # Create first simulation with dash alias - sim_uuid_1 = uuid.uuid4() - simulation_data_1 = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_1.hex}, - "alias": "sequence-", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + simulation_data_1 = generate_simulation_data( + alias="sequence-", + ) rv1 = client.post( "/v1.2/simulations", - json=simulation_data_1, + json=simulation_data_1.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv1.status_code == 200 - # Create second simulation with same dash alias - sim_uuid_2 = uuid.uuid4() - simulation_data_2 = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_2.hex}, - "alias": "sequence-", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + simulation_data_2 = generate_simulation_data( + alias="sequence-", + ) rv2 = client.post( "/v1.2/simulations", - json=simulation_data_2, + json=simulation_data_2.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv2.status_code == 200 # Verify aliases were incremented - rv_get1 = client.get(f"/v1.2/simulation/{sim_uuid_1.hex}", headers=HEADERS) + rv_get1 = client.get( + f"/v1.2/simulation/{simulation_data_1.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get1.json["alias"] == "sequence-1" - rv_get2 = client.get(f"/v1.2/simulation/{sim_uuid_2.hex}", headers=HEADERS) + rv_get2 = client.get( + f"/v1.2/simulation/{simulation_data_2.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get2.json["alias"] == "sequence-2" @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations_no_alias(client): """Test POST endpoint with no alias provided (should use uuid.hex).""" - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": None, - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + simulation_data = generate_simulation_data() rv = client.post( "/v1.2/simulations", - json=simulation_data, + json=simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv.status_code == 200 - rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get.status_code == 200 - assert rv_get.json["alias"] == sim_uuid.hex + assert rv_get.json["alias"] == simulation_data.simulation.uuid.hex @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations_with_replaces(client): """Test POST endpoint with replaces metadata (deprecates old simulation).""" # Create initial simulation - old_sim_uuid = uuid.uuid4() - old_simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": old_sim_uuid.hex}, - "alias": "original-simulation", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "version", "value": "1.0"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + old_simulation_data = generate_simulation_data(alias="old_simulation") rv_old = client.post( "/v1.2/simulations", - json=old_simulation_data, + json=old_simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv_old.status_code == 200 # Create new simulation that replaces the old one - new_sim_uuid = uuid.uuid4() - new_simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": new_sim_uuid.hex}, - "alias": "updated-simulation", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "version", "value": "2.0"}, - {"element": "replaces", "value": old_sim_uuid.hex}, - {"element": "replaces_reason", "value": "Bug fixes and improvements"}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + new_simulation_data = generate_simulation_data( + alias="updated-simulation", + metadata=[ + MetadataData( + element="replaces", value=old_simulation_data.simulation.uuid.hex + ), + MetadataData(element="replaces_reason", value="Test replacement"), + ], + ) rv_new = client.post( "/v1.2/simulations", - json=new_simulation_data, + json=new_simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv_new.status_code == 200 # Verify the old simulation is marked as DEPRECATED - rv_old_get = client.get(f"/v1.2/simulation/{old_sim_uuid.hex}", headers=HEADERS) + rv_old_get = client.get( + f"/v1.2/simulation/{old_simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_old_get.status_code == 200 old_metadata = rv_old_get.json["metadata"] @@ -379,51 +266,45 @@ def test_post_simulations_with_replaces(client): # Check replaced_by metadata was added replaced_by_meta = [m for m in old_metadata if m["element"] == "replaced_by"] assert len(replaced_by_meta) == 1 - assert replaced_by_meta[0]["value"] == new_sim_uuid + assert replaced_by_meta[0]["value"] == new_simulation_data.simulation.uuid # Verify the new simulation has replaces metadata - rv_new_get = client.get(f"/v1.2/simulation/{new_sim_uuid.hex}", headers=HEADERS) + rv_new_get = client.get( + f"/v1.2/simulation/{new_simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_new_get.status_code == 200 new_metadata = rv_new_get.json["metadata"] replaces_meta = [m for m in new_metadata if m["element"] == "replaces"] assert len(replaces_meta) == 1 - assert replaces_meta[0]["value"] == old_sim_uuid.hex + assert replaces_meta[0]["value"] == old_simulation_data.simulation.uuid.hex @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations_replaces_nonexistent(client): """Test POST endpoint with replaces pointing to non-existent simulation.""" # Create simulation that tries to replace a non-existent simulation - sim_uuid = uuid.uuid4() - fake_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "replaces-nothing", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "replaces", "value": fake_uuid.hex}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + simulation_data = generate_simulation_data( + alias="replaces-nothing", + metadata=[ + MetadataData(element="replaces", value=uuid.uuid1().hex), + MetadataData(element="replaces_reason", value="Test replacement"), + ], + ) # Should still succeed (old simulation just doesn't exist to deprecate) rv = client.post( "/v1.2/simulations", - json=simulation_data, + json=simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv.status_code == 200 # Verify the new simulation was created - rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get.status_code == 200 assert rv_get.json["alias"] == "replaces-nothing" @@ -431,31 +312,22 @@ def test_post_simulations_replaces_nonexistent(client): @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations_with_watcher(client): """Test POST endpoint with add_watcher set to true.""" - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "watched-simulation", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [], - }, - "add_watcher": True, - "uploaded_by": "watcher-user", - } + simulation_data = generate_simulation_data( + add_watcher=True, uploaded_by="watcher-user" + ) rv = client.post( "/v1.2/simulations", - json=simulation_data, + json=simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv.status_code == 200 # Verify the simulation was created - rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get.status_code == 200 # Note: We can't easily verify watchers were added without accessing the db directly @@ -469,93 +341,67 @@ def test_post_simulations_with_watcher(client): @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations_uploaded_by(client): """Test POST endpoint with uploaded_by field.""" - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "upload-test", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [], - }, - "add_watcher": False, - "uploaded_by": "specific-user@example.com", - } + """Test POST endpoint with add_watcher set to true.""" + simulation_data = generate_simulation_data(uploaded_by="test-user") rv = client.post( "/v1.2/simulations", - json=simulation_data, + json=simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv.status_code == 200 - # Verify uploaded_by metadata - rv_get = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + # Verify the simulation was created + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_get.status_code == 200 + metadata = rv_get.json["metadata"] uploaded_by_meta = [m for m in metadata if m["element"] == "uploaded_by"] assert len(uploaded_by_meta) == 1 - assert uploaded_by_meta[0]["value"] == "specific-user@example.com" + assert uploaded_by_meta[0]["value"] == "test-user" @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_post_simulations_trace_with_replaces(client): """Test the trace endpoint with a simulation that replaces another.""" # Create original simulation - old_sim_uuid = uuid.uuid4() - old_simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": old_sim_uuid.hex}, - "alias": "trace-original", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "version", "value": "1.0"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + # Create initial simulation + old_simulation_data = generate_simulation_data(alias="trace-original") rv_old = client.post( "/v1.2/simulations", - json=old_simulation_data, + json=old_simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv_old.status_code == 200 - # Create new simulation that replaces it - new_sim_uuid = uuid.uuid4() - new_simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": new_sim_uuid.hex}, - "alias": "trace-updated", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "version", "value": "2.0"}, - {"element": "replaces", "value": old_sim_uuid.hex}, - {"element": "replaces_reason", "value": "New features"}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + # Create new simulation that replaces the old one + new_simulation_data = generate_simulation_data( + alias="trace-updated", + metadata=[ + MetadataData( + element="replaces", value=old_simulation_data.simulation.uuid.hex + ), + MetadataData(element="replaces_reason", value="New features"), + ], + ) rv_new = client.post( "/v1.2/simulations", - json=new_simulation_data, + json=new_simulation_data.model_dump(mode="json"), headers=HEADERS, content_type="application/json", ) assert rv_new.status_code == 200 # Get trace for the new simulation - rv_trace = client.get(f"/v1.2/trace/{new_sim_uuid.hex}", headers=HEADERS) + rv_trace = client.get( + f"/v1.2/trace/{new_simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv_trace.status_code == 200 trace_data = rv_trace.json @@ -563,7 +409,7 @@ def test_post_simulations_trace_with_replaces(client): assert "replaces" in trace_data replaces_uuid = trace_data["replaces"]["uuid"] - assert replaces_uuid == old_sim_uuid + assert replaces_uuid == old_simulation_data.simulation.uuid assert "replaces_reason" in trace_data assert trace_data["replaces_reason"] == "New features" From eb4a4743f87dd7a9fa777d0e761248a858576e85 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Fri, 13 Feb 2026 14:40:04 +0100 Subject: [PATCH 14/18] Cleanup simulation get tests --- src/simdb/remote/models.py | 11 +- tests/remote/test_api.py | 827 +++++++++---------------------------- 2 files changed, 210 insertions(+), 628 deletions(-) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 9669fe5..9a4c3f9 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,6 +1,7 @@ +from urllib.parse import urlencode from datetime import datetime as dt from datetime import timezone -from typing import Annotated, Any, Generic, List, Optional, TypeVar, Union +from typing import Annotated, Any, Dict, Generic, List, Optional, TypeVar, Union from uuid import UUID, uuid1 from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer @@ -41,6 +42,12 @@ class MetadataData(BaseModel): element: str value: Union[CustomUUID, Any] + def as_dict(self): + return {self.element: self.value} + + def as_querystring(self): + return urlencode(self.as_dict()) + class SimulationReference(BaseModel): uuid: CustomUUID @@ -92,4 +99,4 @@ class PaginatedResponse(BaseModel, Generic[T]): count: int page: int limit: int - results: T + results: List[T] diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index 5c31043..057329b 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -16,7 +16,10 @@ from simdb.remote.models import ( FileData, MetadataData, + PaginatedResponse, SimulationData, + SimulationDataResponse, + SimulationListItem, SimulationPostData, ) @@ -81,6 +84,16 @@ def generate_simulation_file() -> FileData: ) +def post_simulation(client, simulation_data, headers=HEADERS): + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data.model_dump(mode="json"), + headers=headers, + content_type="application/json", + ) + return rv_post + + @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_root(client): rv = client.get("/") @@ -108,12 +121,7 @@ def test_post_simulations(client): ) # POST the simulation - rv = client.post( - "/v1.2/simulations", - json=simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv = post_simulation(client, simulation_data) # Verify the response assert rv.status_code == 200 @@ -136,12 +144,7 @@ def test_post_simulations_with_alias_auto_increment(client, suffix): alias=f"{random_name}{suffix}", ) - rv = client.post( - "/v1.2/simulations", - json=simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv = post_simulation(client, simulation_data) assert rv.status_code == 200 assert rv.json["ingested"] == simulation_data.simulation.uuid.hex @@ -167,24 +170,14 @@ def test_post_simulations_alias_increment_sequence(client): alias="sequence-", ) - rv1 = client.post( - "/v1.2/simulations", - json=simulation_data_1.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv1 = post_simulation(client, simulation_data_1) assert rv1.status_code == 200 simulation_data_2 = generate_simulation_data( alias="sequence-", ) - rv2 = client.post( - "/v1.2/simulations", - json=simulation_data_2.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv2 = post_simulation(client, simulation_data_2) assert rv2.status_code == 200 # Verify aliases were incremented @@ -204,12 +197,7 @@ def test_post_simulations_no_alias(client): """Test POST endpoint with no alias provided (should use uuid.hex).""" simulation_data = generate_simulation_data() - rv = client.post( - "/v1.2/simulations", - json=simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv = post_simulation(client, simulation_data) assert rv.status_code == 200 rv_get = client.get( @@ -225,12 +213,7 @@ def test_post_simulations_with_replaces(client): # Create initial simulation old_simulation_data = generate_simulation_data(alias="old_simulation") - rv_old = client.post( - "/v1.2/simulations", - json=old_simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv_old = post_simulation(client, old_simulation_data) assert rv_old.status_code == 200 # Create new simulation that replaces the old one @@ -244,12 +227,7 @@ def test_post_simulations_with_replaces(client): ], ) - rv_new = client.post( - "/v1.2/simulations", - json=new_simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv_new = post_simulation(client, new_simulation_data) assert rv_new.status_code == 200 # Verify the old simulation is marked as DEPRECATED @@ -293,12 +271,7 @@ def test_post_simulations_replaces_nonexistent(client): ) # Should still succeed (old simulation just doesn't exist to deprecate) - rv = client.post( - "/v1.2/simulations", - json=simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv = post_simulation(client, simulation_data) assert rv.status_code == 200 # Verify the new simulation was created @@ -316,12 +289,7 @@ def test_post_simulations_with_watcher(client): add_watcher=True, uploaded_by="watcher-user" ) - rv = client.post( - "/v1.2/simulations", - json=simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv = post_simulation(client, simulation_data) assert rv.status_code == 200 # Verify the simulation was created @@ -344,12 +312,7 @@ def test_post_simulations_uploaded_by(client): """Test POST endpoint with add_watcher set to true.""" simulation_data = generate_simulation_data(uploaded_by="test-user") - rv = client.post( - "/v1.2/simulations", - json=simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv = post_simulation(client, simulation_data) assert rv.status_code == 200 # Verify the simulation was created @@ -371,12 +334,7 @@ def test_post_simulations_trace_with_replaces(client): # Create initial simulation old_simulation_data = generate_simulation_data(alias="trace-original") - rv_old = client.post( - "/v1.2/simulations", - json=old_simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv_old = post_simulation(client, old_simulation_data) assert rv_old.status_code == 200 # Create new simulation that replaces the old one @@ -390,12 +348,7 @@ def test_post_simulations_trace_with_replaces(client): ], ) - rv_new = client.post( - "/v1.2/simulations", - json=new_simulation_data.model_dump(mode="json"), - headers=HEADERS, - content_type="application/json", - ) + rv_new = post_simulation(client, new_simulation_data) assert rv_new.status_code == 200 # Get trace for the new simulation @@ -425,17 +378,11 @@ def test_get_simulations_basic(client): assert rv.status_code == 200 assert rv.is_json - data = rv.json - assert "count" in data - assert "page" in data - assert "limit" in data - assert "results" in data + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - # Should return paginated results - assert data["page"] == 1 - assert data["limit"] == 100 - assert isinstance(data["results"], list) - assert data["count"] >= 100 # At least the 100 fixture simulations + assert data.page == 1 + assert data.limit == 100 + assert data.count >= 100 @pytest.mark.skipif(not has_flask, reason="requires flask library") @@ -447,10 +394,10 @@ def test_get_simulations_pagination_limit(client): rv = client.get("/v1.2/simulations", headers=headers_with_limit) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - assert data["limit"] == custom_limit - assert len(data["results"]) <= custom_limit + assert data.limit == custom_limit + assert len(data.results) <= custom_limit @pytest.mark.skipif(not has_flask, reason="requires flask library") @@ -461,10 +408,10 @@ def test_get_simulations_pagination_page(client): rv = client.get("/v1.2/simulations", headers=headers_page_2) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - assert data["page"] == 2 - assert data["limit"] == 10 + assert data.page == 2 + assert data.limit == 10 @pytest.mark.skipif(not has_flask, reason="requires flask library") @@ -476,66 +423,46 @@ def test_get_simulations_pagination_multiple_pages(client): headers_page_1 = {**HEADERS, "simdb-result-limit": str(limit), "simdb-page": "1"} rv1 = client.get("/v1.2/simulations", headers=headers_page_1) assert rv1.status_code == 200 - page1_data = rv1.json + page1_data = PaginatedResponse[SimulationListItem].model_validate(rv1.json) # Get second page headers_page_2 = {**HEADERS, "simdb-result-limit": str(limit), "simdb-page": "2"} rv2 = client.get("/v1.2/simulations", headers=headers_page_2) assert rv2.status_code == 200 - page2_data = rv2.json + page2_data = PaginatedResponse[SimulationListItem].model_validate(rv2.json) # Both should have same count and limit - assert page1_data["count"] == page2_data["count"] - assert page1_data["limit"] == page2_data["limit"] == limit + assert page1_data.count == page2_data.count + assert page1_data.limit == page2_data.limit == limit # Pages should be different - assert page1_data["page"] == 1 - assert page2_data["page"] == 2 + assert page1_data.page == 1 + assert page2_data.page == 2 - # Results should be different (assuming we have enough data) - if page1_data["count"] > limit: - page1_uuids = {item["uuid"] for item in page1_data["results"]} - page2_uuids = {item["uuid"] for item in page2_data["results"]} - assert page1_uuids != page2_uuids + page1_uuids = {item.uuid for item in page1_data.results} + page2_uuids = {item.uuid for item in page2_data.results} + assert page1_uuids != page2_uuids @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_simulations_filter_by_alias(client): """Test filtering simulations by alias.""" # First create a simulation with a known alias - sim_uuid = uuid.uuid4() test_alias = "filter-test-alias" + simulation_data = generate_simulation_data(alias=test_alias) - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": test_alias, - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "test_key", "value": "test_value"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", - ) + rv_post = post_simulation(client, simulation_data) assert rv_post.status_code == 200 # Now filter by alias rv = client.get(f"/v1.2/simulations?alias={test_alias}", headers=HEADERS) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - assert data["count"] >= 1 + assert data.count == 1 # Check that the filtered result contains our simulation - aliases = [item.get("alias") for item in data["results"]] + aliases = [item.alias for item in data.results] assert test_alias in aliases @@ -543,218 +470,158 @@ def test_get_simulations_filter_by_alias(client): def test_get_simulations_filter_by_uuid(client): """Test filtering simulations by UUID.""" # Create a simulation with a known UUID - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "uuid-filter-test", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + simulation_data = generate_simulation_data() - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", - ) + rv_post = post_simulation(client, simulation_data) assert rv_post.status_code == 200 # Filter by UUID - rv = client.get(f"/v1.2/simulations?uuid={sim_uuid.hex}", headers=HEADERS) + rv = client.get( + f"/v1.2/simulations?uuid={simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - assert data["count"] >= 1 + assert data.count == 1 # Check that the filtered result contains our simulation - uuids = [item.get("uuid") for item in data["results"]] - assert sim_uuid in uuids + uuids = [item.uuid for item in data.results] + assert simulation_data.simulation.uuid in uuids @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_simulations_filter_by_metadata(client): """Test filtering simulations by metadata.""" # Create simulations with specific metadata - sim_uuid_1 = uuid.uuid4() - sim_uuid_2 = uuid.uuid4() - test_machine = "test-machine-xyz" - - simulation_data_1 = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_1.hex}, - "alias": "metadata-filter-1", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "machine", "value": test_machine}, - {"element": "code", "value": "test-code"}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - simulation_data_2 = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid_2.hex}, - "alias": "metadata-filter-2", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "machine", "value": test_machine}, - {"element": "code", "value": "different-code"}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + test_metadata = MetadataData(element="machine", value="test_machine") - rv_post_1 = client.post( - "/v1.2/simulations", - json=simulation_data_1, - headers=HEADERS, - content_type="application/json", - ) + simulation_data_1 = generate_simulation_data(metadata=[test_metadata]) + simulation_data_2 = generate_simulation_data(metadata=[test_metadata]) + rv_post_1 = post_simulation(client, simulation_data_1) assert rv_post_1.status_code == 200 - rv_post_2 = client.post( - "/v1.2/simulations", - json=simulation_data_2, - headers=HEADERS, - content_type="application/json", - ) + rv_post_2 = post_simulation(client, simulation_data_2) assert rv_post_2.status_code == 200 # Filter by machine metadata - rv = client.get(f"/v1.2/simulations?machine={test_machine}", headers=HEADERS) + rv = client.get( + f"/v1.2/simulations?{test_metadata.as_querystring()}", + headers=HEADERS, + ) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - assert data["count"] >= 2 + assert data.count == 2 # Check that both simulations are in the results - results_uuids = [item.get("uuid") for item in data["results"]] - assert sim_uuid_1 in results_uuids - assert sim_uuid_2 in results_uuids + results_uuids = [item.uuid for item in data.results] + assert simulation_data_1.simulation.uuid in results_uuids + assert simulation_data_2.simulation.uuid in results_uuids @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_simulations_filter_multiple_metadata(client): """Test filtering simulations by multiple metadata fields.""" # Create a simulation with multiple metadata fields - sim_uuid = uuid.uuid4() - test_machine = "multi-filter-machine" - test_code = "multi-filter-code" - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "multi-metadata-filter", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "machine", "value": test_machine}, - {"element": "code", "value": test_code}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + test_metadata = [ + MetadataData(element="machine", value="multi-filter-machine"), + MetadataData(element="code", value="multi-filter-code"), + ] - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", - ) + simulation_data = generate_simulation_data(metadata=test_metadata) + + rv_post = post_simulation(client, simulation_data) assert rv_post.status_code == 200 # Filter by both machine and code rv = client.get( - f"/v1.2/simulations?machine={test_machine}&code={test_code}", headers=HEADERS + f"/v1.2/simulations?{'&'.join([m.as_querystring() for m in test_metadata])}", + headers=HEADERS, ) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - assert data["count"] >= 1 - results_uuids = [item.get("uuid") for item in data["results"]] - assert sim_uuid in results_uuids + assert data.count == 1 + results_uuids = [item.uuid for item in data.results] + assert simulation_data.simulation.uuid in results_uuids @pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_get_simulations_sorting_asc(client): - """Test sorting simulations in ascending order.""" +@pytest.mark.xfail(reason="Only sorting by metadata keys works for now") +def test_get_simulations_alias_sorting_asc(client): + """Test sorting simulations in ascending order by alias.""" # Create simulations with sortable aliases for i in range(3): - sim_uuid = uuid.uuid4() - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": f"sort-test-{i:03d}", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", - ) + simulation_data = generate_simulation_data(alias=f"alias-sort-test-{i:03d}") + + rv_post = post_simulation(client, simulation_data) assert rv_post.status_code == 200 # Get simulations sorted by alias ascending headers_sorted = {**HEADERS, "simdb-sort-by": "alias", "simdb-sort-asc": "true"} - rv = client.get("/v1.2/simulations?alias=sort-test-%", headers=headers_sorted) + rv = client.get( + "/v1.2/simulations?alias=IN%3Aalias-sort-test-", headers=headers_sorted + ) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - # Filter to only our test simulations - test_sims = [ - item - for item in data["results"] - if item.get("alias", "").startswith("sort-test-") - ] - - if len(test_sims) >= 2: - # Check that results are sorted in ascending order - aliases = [item.get("alias") for item in test_sims] - assert aliases == sorted(aliases) + # Check that results are sorted in ascending order + aliases = [item.alias for item in data.results if item.alias is not None] + assert aliases == sorted(aliases) + assert len(aliases) == 3 @pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_get_simulations_sorting_desc(client): - """Test sorting simulations in descending order.""" - # Get simulations sorted by alias descending - headers_sorted = {**HEADERS, "simdb-sort-by": "alias", "simdb-sort-asc": "false"} +@pytest.mark.parametrize("ascending", [True, False]) +def test_get_simulations_metadata_sorting(client, ascending): + """Test sorting simulations in ascending order.""" + # Create simulations with sortable aliases + # post them in the order: 2 1 0 5 4 3 + # ascending should result in 0 1 2 3 4 5 + # descending should result in 5 4 3 2 1 0 + for i in reversed(range(3)): + simulation_data = generate_simulation_data( + alias=f"sort-test-{ascending}-{i:03d}", + metadata=[MetadataData(element="sort-test", value=i)], + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 - rv = client.get("/v1.2/simulations", headers=headers_sorted) + for i in reversed(range(3, 6)): + simulation_data = generate_simulation_data( + alias=f"sort-test-{ascending}-{i:03d}", + metadata=[MetadataData(element="sort-test", value=i)], + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Get simulations sorted by alias ascending + headers_sorted = { + **HEADERS, + "simdb-sort-by": "sort-test", + "simdb-sort-asc": "true" if ascending else "false", + } + + rv = client.get( + f"/v1.2/simulations?alias=IN%3Asort-test-{ascending}&sort-test", + headers=headers_sorted, + ) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - # Just verify the request succeeded and returned data - assert "results" in data - assert isinstance(data["results"], list) + # Check that results are sorted in the correct order + metadata = [ + item.metadata[0].value for item in data.results if item.metadata is not None + ] + assert metadata == sorted(metadata, reverse=not ascending) + assert len(metadata) == 6 @pytest.mark.skipif(not has_flask, reason="requires flask library") @@ -766,41 +633,26 @@ def test_get_simulations_empty_result(client): ) assert rv.status_code == 200 - data = rv.json + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) - assert data["count"] == 0 - assert data["results"] == [] + assert data.count == 0 + assert len(data.results) == 0 @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_simulations_with_metadata_keys(client): """Test requesting specific metadata keys in results.""" # Create a simulation with known metadata - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "meta-keys-test", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "machine", "value": "machine-x"}, - {"element": "code", "value": "code-y"}, - {"element": "description", "value": "test description"}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", + simulation_data = generate_simulation_data( + alias="meta-keys-test", + metadata=[ + MetadataData(element="machine", value="machine-x"), + MetadataData(element="code", value="code-y"), + ], ) + + rv_post = post_simulation(client, simulation_data) assert rv_post.status_code == 200 # Request simulations with specific metadata keys @@ -809,208 +661,90 @@ def test_get_simulations_with_metadata_keys(client): ) assert rv.status_code == 200 - data = rv.json - - assert data["count"] >= 1 - - -@pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_get_simulations_combined_pagination_sorting_filtering(client): - """Test GET request with pagination, sorting, and filtering combined.""" - # Create multiple simulations for testing - test_prefix = "combined-test" - for i in range(5): - sim_uuid = uuid.uuid4() - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": f"{test_prefix}-{i:02d}", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "test_group", "value": "combined"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", - ) - assert rv_post.status_code == 200 - - # Request with all features combined - headers_combined = { - **HEADERS, - "simdb-result-limit": "3", - "simdb-page": "1", - "simdb-sort-by": "alias", - "simdb-sort-asc": "true", - } - - rv = client.get( - f"/v1.2/simulations?alias={test_prefix}-%", headers=headers_combined - ) - - assert rv.status_code == 200 - data = rv.json + data: PaginatedResponse[SimulationListItem] = PaginatedResponse[ + SimulationListItem + ].model_validate(rv.json) - assert data["page"] == 1 - assert data["limit"] == 3 - assert len(data["results"]) <= 3 + assert data.count == 1 + assert len(data.results) == 1 @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_simulation_by_uuid(client): """Test GET /v1.2/simulation/{simulation_id} endpoint - retrieve by UUID.""" # Create a simulation with known properties - sim_uuid = uuid.uuid4() - input_uuid = uuid.uuid4() - output_uuid = uuid.uuid4() - test_alias = "get-test-simulation" - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": test_alias, - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [ - { - "uuid": {"_type": "uuid.UUID", "hex": input_uuid.hex}, - "type": "FILE", - "uri": "file:///test/input.dat", - "checksum": "input123", - "datetime": datetime.now(timezone.utc).isoformat(), - "usage": "input_data", - "purpose": "test input", - "sensitivity": "public", - "access": "open", - "embargo": None, - } - ], - "outputs": [ - { - "uuid": {"_type": "uuid.UUID", "hex": output_uuid.hex}, - "type": "FILE", - "uri": "file:///test/output.dat", - "checksum": "output456", - "datetime": datetime.now(timezone.utc).isoformat(), - "usage": "output_data", - "purpose": "test output", - "sensitivity": "public", - "access": "open", - "embargo": None, - } - ], - "metadata": [ - {"element": "machine", "value": "test-machine"}, - {"element": "code", "value": "test-code"}, - {"element": "description", "value": "Test simulation for GET"}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } + simulation_data = generate_simulation_data(uploaded_by="test-uploader") - # Create the simulation - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", - ) + rv_post = post_simulation(client, simulation_data) assert rv_post.status_code == 200 # Test GET by UUID - rv = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) + rv = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) assert rv.status_code == 200 assert rv.is_json - data = rv.json + # Validate full data model + SimulationDataResponse.model_validate(rv.json) + + simulation_data_received = SimulationData.model_validate(rv.json, extra="ignore") + simulation_data_check = simulation_data.simulation.model_copy() - # Verify basic fields - assert "uuid" in data - assert data["uuid"] == sim_uuid - assert data["alias"] == test_alias - - # Verify datetime field exists - assert "datetime" in data - - # Verify inputs and outputs - assert "inputs" in data - assert len(data["inputs"]) == 1 - assert data["inputs"][0]["uuid"] == input_uuid - assert data["inputs"][0]["uri"] == "file:/test/input.dat" - assert data["inputs"][0]["checksum"] == "input123" - - assert "outputs" in data - assert len(data["outputs"]) == 1 - assert data["outputs"][0]["uuid"] == output_uuid - assert data["outputs"][0]["uri"] == "file:/test/output.dat" - assert data["outputs"][0]["checksum"] == "output456" - - # Verify metadata - assert "metadata" in data - assert len(data["metadata"]) >= 3 # At least our 3 metadata items - metadata_dict = {m["element"]: m["value"] for m in data["metadata"]} - assert metadata_dict["machine"] == "test-machine" - assert metadata_dict["code"] == "test-code" - assert metadata_dict["description"] == "Test simulation for GET" - - # Verify children and parents fields exist - assert "children" in data - assert "parents" in data - assert isinstance(data["children"], list) - assert isinstance(data["parents"], list) + # fill fields that are filled by the server + simulation_data_check.alias = simulation_data_check.uuid.hex + simulation_data_check.metadata = [ + MetadataData(element="uploaded_by", value=simulation_data.uploaded_by) + ] + + # datetime gets updated by the server + simulation_data_check.datetime = simulation_data_received.datetime + + assert simulation_data_received == simulation_data_check @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_simulation_by_alias(client): """Test GET /v1.2/simulation/{simulation_id} endpoint - retrieve by alias.""" # Create a simulation with a unique alias - sim_uuid = uuid.uuid4() - test_alias = "get-by-alias-test" - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": test_alias, - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "test", "value": "alias retrieval"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", + simulation_data = generate_simulation_data( + alias="test-get-alias", uploaded_by="test-uploader" ) + + rv_post = post_simulation(client, simulation_data) assert rv_post.status_code == 200 # Test GET by alias - rv = client.get(f"/v1.2/simulation/{test_alias}", headers=HEADERS) + rv = client.get( + f"/v1.2/simulation/{simulation_data.simulation.alias}", headers=HEADERS + ) assert rv.status_code == 200 - data = rv.json + assert rv.is_json - assert data["uuid"] == sim_uuid - assert data["alias"] == test_alias + # Validate full data model + SimulationDataResponse.model_validate(rv.json) + + simulation_data_received = SimulationData.model_validate(rv.json, extra="ignore") + simulation_data_check = simulation_data.simulation.model_copy() + + # fill fields that are filled by the server + simulation_data_check.metadata = [ + MetadataData(element="uploaded_by", value=simulation_data.uploaded_by) + ] + + # datetime gets updated by the server + simulation_data_check.datetime = simulation_data_received.datetime + + assert simulation_data_received == simulation_data_check @pytest.mark.skipif(not has_flask, reason="requires flask library") def test_get_simulation_not_found(client): """Test GET /v1.2/simulation/{simulation_id} endpoint - non-existent simulation.""" # Try to get a non-existent simulation - fake_uuid = uuid.uuid4() + fake_uuid = uuid.uuid1() rv = client.get(f"/v1.2/simulation/{fake_uuid.hex}", headers=HEADERS) @@ -1019,162 +753,3 @@ def test_get_simulation_not_found(client): # Should contain an error message assert "error" in data or data.get("message") == "Simulation not found" - - -@pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_get_simulation_with_parents_and_children(client): - """Test GET /v1.2/simulation/{simulation_id} endpoint - verify parents/children.""" - # Create parent simulation - parent_uuid = uuid.uuid4() - parent_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": parent_uuid.hex}, - "alias": "parent-simulation", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [{"element": "role", "value": "parent"}], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - rv_parent = client.post( - "/v1.2/simulations", - json=parent_data, - headers=HEADERS, - content_type="application/json", - ) - assert rv_parent.status_code == 200 - - # Create child simulation that references parent - child_uuid = uuid.uuid4() - child_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": child_uuid.hex}, - "alias": "child-simulation", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [], - "outputs": [], - "metadata": [ - {"element": "role", "value": "child"}, - {"element": "parent", "value": parent_uuid.hex}, - ], - }, - "add_watcher": False, - "uploaded_by": "test-user", - } - - rv_child = client.post( - "/v1.2/simulations", - json=child_data, - headers=HEADERS, - content_type="application/json", - ) - assert rv_child.status_code == 200 - - # Get child simulation and verify parents field - rv = client.get(f"/v1.2/simulation/{child_uuid.hex}", headers=HEADERS) - - assert rv.status_code == 200 - data = rv.json - - # Verify the parents/children structure - assert "parents" in data - assert "children" in data - - -@pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_get_simulation_full_response_structure(client): - """Test GET /v1.2/simulation/{simulation_id} endpoint - verify complete response - structure.""" - # Create a comprehensive simulation - sim_uuid = uuid.uuid4() - - simulation_data = { - "simulation": { - "uuid": {"_type": "uuid.UUID", "hex": sim_uuid.hex}, - "alias": "complete-structure-test", - "datetime": datetime.now(timezone.utc).isoformat(), - "inputs": [ - { - "uuid": {"_type": "uuid.UUID", "hex": uuid.uuid4().hex}, - "type": "FILE", - "uri": "file:///complete/input.dat", - "checksum": "complete123", - "datetime": datetime.now(timezone.utc).isoformat(), - "usage": "input", - "purpose": "complete input", - "sensitivity": "public", - "access": "open", - "embargo": None, - } - ], - "outputs": [ - { - "uuid": {"_type": "uuid.UUID", "hex": uuid.uuid4().hex}, - "type": "FILE", - "uri": "file:///complete/output.dat", - "checksum": "complete456", - "datetime": datetime.now(timezone.utc).isoformat(), - "usage": "output", - "purpose": "complete output", - "sensitivity": "public", - "access": "open", - "embargo": None, - } - ], - "metadata": [ - {"element": "machine", "value": "complete-machine"}, - {"element": "version", "value": "1.0"}, - ], - }, - "add_watcher": False, - "uploaded_by": "complete-user", - } - - rv_post = client.post( - "/v1.2/simulations", - json=simulation_data, - headers=HEADERS, - content_type="application/json", - ) - assert rv_post.status_code == 200 - - # Get the simulation - rv = client.get(f"/v1.2/simulation/{sim_uuid.hex}", headers=HEADERS) - - assert rv.status_code == 200 - data = rv.json - - # Verify all required top-level fields are present - required_fields = [ - "uuid", - "alias", - "datetime", - "inputs", - "outputs", - "metadata", - "children", - "parents", - ] - for field in required_fields: - assert field in data, f"Required field '{field}' missing from response" - - # Verify inputs structure - assert len(data["inputs"]) == 1 - input_required_fields = ["uuid", "type", "uri", "checksum", "datetime"] - for field in input_required_fields: - assert field in data["inputs"][0], f"Required input field '{field}' missing" - - # Verify outputs structure - assert len(data["outputs"]) == 1 - output_required_fields = ["uuid", "type", "uri", "checksum", "datetime"] - for field in output_required_fields: - assert field in data["outputs"][0], f"Required output field '{field}' missing" - - # Verify metadata structure - assert len(data["metadata"]) >= 2 - for meta in data["metadata"]: - assert "element" in meta - assert "value" in meta From e17494003da89c26b70daaa6f6d70cb3831548bd Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Fri, 13 Feb 2026 15:36:38 +0100 Subject: [PATCH 15/18] Custom list models --- src/simdb/database/models/simulation.py | 24 ++++++------ src/simdb/remote/models.py | 49 +++++++++++++++++++++---- tests/remote/test_api.py | 25 +++++-------- 3 files changed, 65 insertions(+), 33 deletions(-) diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 69ab660..74baa5f 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union -from simdb.remote.models import SimulationData +from simdb.remote.models import FileDataList, MetadataDataList, SimulationData if sys.version_info < (3, 11): from backports.datetime_fromisoformat import MonkeyPatch @@ -344,9 +344,9 @@ def from_data_model(cls, data: SimulationData) -> "Simulation": simulation.uuid = data.uuid simulation.alias = data.alias simulation.datetime = data.datetime - simulation.inputs = [File.from_data_model(el) for el in data.inputs] - simulation.outputs = [File.from_data_model(el) for el in data.outputs] - simulation.meta = [MetaData.from_data_model(el) for el in data.metadata] + simulation.inputs = [File.from_data_model(el) for el in data.inputs.root] + simulation.outputs = [File.from_data_model(el) for el in data.outputs.root] + simulation.meta = [MetaData.from_data_model(el) for el in data.metadata.root] return simulation def data( @@ -370,15 +370,17 @@ def data( def to_model( self, recurse: bool = False, meta_keys: Optional[List[str]] = None ) -> SimulationData: - inputs = [] - outputs = [] - metadata = [] + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() if recurse: - inputs = [f.to_model() for f in self.inputs] - outputs = [f.to_model() for f in self.outputs] - metadata = [m.to_model() for m in self.meta] + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) elif meta_keys: - metadata = [m.to_model() for m in self.meta if m.element in meta_keys] + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) return SimulationData( uuid=self.uuid, alias=self.alias, diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 9a4c3f9..aa56c46 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,10 +1,17 @@ -from urllib.parse import urlencode from datetime import datetime as dt from datetime import timezone -from typing import Annotated, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import Annotated, Any, Generic, List, Optional, TypeVar, Union +from urllib.parse import urlencode from uuid import UUID, uuid1 -from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer +from pydantic import ( + BaseModel, + BeforeValidator, + Field, + PlainSerializer, + RootModel, + model_validator, +) HexUUID = Annotated[UUID, PlainSerializer(lambda x: x.hex, return_type=str)] @@ -38,6 +45,14 @@ class FileData(BaseModel): embargo: Optional[str] = None +class FileDataList(RootModel): + root: List[FileData] = [] + + # Allows indexing: users[0] + def __getitem__(self, item) -> FileData: + return self.root[item] + + class MetadataData(BaseModel): element: str value: Union[CustomUUID, Any] @@ -49,6 +64,26 @@ def as_querystring(self): return urlencode(self.as_dict()) +class MetadataDataList(RootModel): + root: List[MetadataData] = [] + + def __getitem__(self, item) -> MetadataData: + return self.root[item] + + def as_dict(self): + return {m.element: m.value for m in self.root} + + @model_validator(mode="before") + @classmethod + def parse_dictionary(cls, data: Any): + if isinstance(data, dict): + return [{"element": k, "value": v} for (k, v) in data.items()] + return data + + def as_querystring(self): + return urlencode(self.as_dict()) + + class SimulationReference(BaseModel): uuid: CustomUUID alias: Optional[str] = None @@ -58,9 +93,9 @@ class SimulationData(BaseModel): uuid: CustomUUID = Field(default_factory=lambda: uuid1()) alias: Optional[str] = None datetime: dt = Field(default_factory=lambda: dt.now(timezone.utc)) - inputs: List[FileData] = [] - outputs: List[FileData] = [] - metadata: List[MetadataData] = [] + inputs: FileDataList = FileDataList() + outputs: FileDataList = FileDataList() + metadata: MetadataDataList = MetadataDataList() class SimulationDataResponse(SimulationData): @@ -89,7 +124,7 @@ class SimulationListItem(BaseModel): uuid: CustomUUID alias: Optional[str] = None datetime: str - metadata: Optional[List[MetadataData]] = None + metadata: Optional[MetadataDataList] = None T = TypeVar("T") diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index 057329b..b834341 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -16,6 +16,7 @@ from simdb.remote.models import ( FileData, MetadataData, + MetadataDataList, PaginatedResponse, SimulationData, SimulationDataResponse, @@ -524,10 +525,7 @@ def test_get_simulations_filter_by_metadata(client): def test_get_simulations_filter_multiple_metadata(client): """Test filtering simulations by multiple metadata fields.""" # Create a simulation with multiple metadata fields - test_metadata = [ - MetadataData(element="machine", value="multi-filter-machine"), - MetadataData(element="code", value="multi-filter-code"), - ] + test_metadata = {"machine": "multi-filter-machine", "code": "multi-filter-code"} simulation_data = generate_simulation_data(metadata=test_metadata) @@ -536,7 +534,7 @@ def test_get_simulations_filter_multiple_metadata(client): # Filter by both machine and code rv = client.get( - f"/v1.2/simulations?{'&'.join([m.as_querystring() for m in test_metadata])}", + f"/v1.2/simulations?{simulation_data.simulation.metadata.as_querystring()}", headers=HEADERS, ) @@ -646,10 +644,7 @@ def test_get_simulations_with_metadata_keys(client): simulation_data = generate_simulation_data( alias="meta-keys-test", - metadata=[ - MetadataData(element="machine", value="machine-x"), - MetadataData(element="code", value="code-y"), - ], + metadata={"machine": "machine-x", "code": "code-y"}, ) rv_post = post_simulation(client, simulation_data) @@ -694,9 +689,9 @@ def test_get_simulation_by_uuid(client): # fill fields that are filled by the server simulation_data_check.alias = simulation_data_check.uuid.hex - simulation_data_check.metadata = [ - MetadataData(element="uploaded_by", value=simulation_data.uploaded_by) - ] + simulation_data_check.metadata = MetadataDataList.model_validate( + {"uploaded_by": simulation_data.uploaded_by} + ) # datetime gets updated by the server simulation_data_check.datetime = simulation_data_received.datetime @@ -730,9 +725,9 @@ def test_get_simulation_by_alias(client): simulation_data_check = simulation_data.simulation.model_copy() # fill fields that are filled by the server - simulation_data_check.metadata = [ - MetadataData(element="uploaded_by", value=simulation_data.uploaded_by) - ] + simulation_data_check.metadata = MetadataDataList.model_validate( + {"uploaded_by": simulation_data.uploaded_by} + ) # datetime gets updated by the server simulation_data_check.datetime = simulation_data_received.datetime From 4dea628dd5a89d0241152e6f1364903cae1c95d0 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 17 Feb 2026 13:39:49 +0100 Subject: [PATCH 16/18] Change metadata to pydantic model --- src/simdb/remote/apis/v1_2/simulations.py | 7 ++++++- tests/remote/test_api.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index f0960f7..2cc279d 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -26,6 +26,7 @@ from simdb.remote.core.path import find_common_root, secure_path from simdb.remote.core.typing import current_app from simdb.remote.models import ( + MetadataDataList, PaginatedResponse, SimulationDataResponse, SimulationListItem, @@ -491,7 +492,11 @@ def get(self, sim_id: str, user: User): try: simulation = current_app.db.get_simulation(sim_id) if simulation: - return jsonify([meta.data() for meta in simulation.meta]) + return jsonify( + MetadataDataList.model_validate( + [meta.data() for meta in simulation.meta] + ).model_dump(mode="json") + ) return error("Simulation not found") except DatabaseError as err: return error(str(err)) diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index b834341..a2752e8 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -748,3 +748,25 @@ def test_get_simulation_not_found(client): # Should contain an error message assert "error" in data or data.get("message") == "Simulation not found" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_metadata(client): + """Test GET /v1.2/simulation/metadata/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data( + metadata={"metadata-a": "abc", "metadata-b": "123"}, uploaded_by="test-user" + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = MetadataDataList.model_validate(rv.json) + check_data = simulation_data.simulation.metadata.model_copy() + check_data.root.append(MetadataData(element="uploaded_by", value="test-user")) + assert data == simulation_data.simulation.metadata From 8f0cc20a6480a370da1e11581f927bd8c7fba482 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 17 Feb 2026 15:58:23 +0100 Subject: [PATCH 17/18] Finish v1.2/simulations tests --- src/simdb/cli/simdb.py | 4 +- src/simdb/database/database.py | 4 +- src/simdb/remote/models.py | 54 ++++++++++++ tests/remote/test_api.py | 157 +++++++++++++++++++++++++++++++++ 4 files changed, 215 insertions(+), 4 deletions(-) diff --git a/src/simdb/cli/simdb.py b/src/simdb/cli/simdb.py index c475b0b..2b3214b 100644 --- a/src/simdb/cli/simdb.py +++ b/src/simdb/cli/simdb.py @@ -1,6 +1,6 @@ import copy import sys -from typing import IO +from typing import TextIO import click @@ -53,7 +53,7 @@ def list_commands(self, ctx): @click.option("-v", "--verbose", is_flag=True, help="Run with verbose output.") @click.option("-c", "--config-file", type=click.File("r"), help="Config file to load.") @click.pass_context -def cli(ctx, debug: bool, verbose: bool, config_file: IO): +def cli(ctx, debug: bool, verbose: bool, config_file: TextIO): if not ctx.obj: ctx.obj = Config() ctx.obj.load(config_file) diff --git a/src/simdb/database/database.py b/src/simdb/database/database.py index 20a9b0c..ff3ea9c 100644 --- a/src/simdb/database/database.py +++ b/src/simdb/database/database.py @@ -227,8 +227,8 @@ def _find_simulation(self, sim_ref: str) -> "Simulation": ) except SQLAlchemyError: simulation = None - if not simulation: - raise DatabaseError(f"Simulation {sim_ref} not found.") from None + if not simulation: + raise DatabaseError(f"Simulation {sim_ref} not found.") from None return simulation def remove(self): diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index aa56c46..e2b5693 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -32,6 +32,19 @@ def _deserialize_custom_uuid(v: Any) -> UUID: ] +class StatusPatchData(BaseModel): + status: str + + +class DeletedSimulation(BaseModel): + uuid: UUID + files: List[str] + + +class SimulationDeleteResponse(BaseModel): + deleted: DeletedSimulation + + class FileData(BaseModel): type: str uri: str @@ -64,6 +77,15 @@ def as_querystring(self): return urlencode(self.as_dict()) +class MetadataPatchData(BaseModel): + key: str + value: str + + +class MetadataDeleteData(BaseModel): + key: str + + class MetadataDataList(RootModel): root: List[MetadataData] = [] @@ -135,3 +157,35 @@ class PaginatedResponse(BaseModel, Generic[T]): page: int limit: int results: List[T] + + +class PaginationData(BaseModel): + limit: int + page: int + sort_by: str + sort_asc: bool + + @model_validator(mode="before") + @classmethod + def parse_headers(cls, data: Any): + if not isinstance(data, dict): + return data + new_data = { + "limit": data.get("simdb-result-limit", 100), + "page": data.get("simdb-page", 1), + "sort_by": data.get("simdb-sort-by", ""), + "sort_asc": data.get("simdb-sort-asc", False), + } + return new_data + + +class SimulationTraceData(SimulationData): + status: Optional[str] = None + passed_on: Optional[Any] = None + failed_on: Optional[Any] = None + deprecated_on: Optional[Any] = None + accepted_on: Optional[Any] = None + not_validated_on: Optional[Any] = None + deleted_on: Optional[Any] = None + replaces: Optional["SimulationTraceData"] = None + replaces_reason: Optional[Any] = None diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index a2752e8..668a8bc 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -17,11 +17,15 @@ FileData, MetadataData, MetadataDataList, + MetadataDeleteData, + MetadataPatchData, PaginatedResponse, SimulationData, SimulationDataResponse, SimulationListItem, SimulationPostData, + SimulationTraceData, + StatusPatchData, ) has_flask = importlib.util.find_spec("flask") is not None @@ -48,6 +52,7 @@ def client(): config.set_option("server.upload_folder", upload_dir) config.set_option("authentication.type", "None") config.set_option("server.copy_files", False) + config.set_option("role.admin.users", "admin,admin2") app = create_app(config=config, testing=True, debug=True) app.testing = True @@ -770,3 +775,155 @@ def test_get_simulation_metadata(client): check_data = simulation_data.simulation.metadata.model_copy() check_data.root.append(MetadataData(element="uploaded_by", value="test-user")) assert data == simulation_data.simulation.metadata + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_patch_simulation(client): + """Test PATCH /v1.2/simulation/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data() + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.patch( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", + json=StatusPatchData(status="failed").model_dump(mode="json"), + headers=HEADERS, + ) + + assert rv.status_code == 200 + + # Status is never returned, so we can't check if it is set + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_delete_simulation(client): + """Test DELETE /v1.2/simulation/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data() + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.delete( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + + assert rv.status_code == 400 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_patch_simulation_metadata(client): + """Test PATCH /v1.2/simulation/metadata/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data( + metadata={"metadata-a": "abc"}, uploaded_by="test-user" + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.patch( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + json=MetadataPatchData(key="metadata-a", value="def").model_dump(mode="json"), + headers=HEADERS, + ) + + assert rv.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = MetadataDataList.model_validate(rv.json) + check_data = simulation_data.simulation.metadata.model_copy() + check_data[0].value = "def" + check_data.root.append(MetadataData(element="uploaded_by", value="test-user")) + assert data == simulation_data.simulation.metadata + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_delete_simulation_metadata(client): + """Test DELETE /v1.2/simulation/metadata/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data( + metadata={"metadata-a": "abc"}, uploaded_by="test-user" + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.delete( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + json=MetadataDeleteData(key="metadata-a").model_dump(mode="json"), + headers=HEADERS, + ) + + assert rv.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = MetadataDataList.model_validate(rv.json) + check_data = simulation_data.simulation.metadata.model_copy() + check_data.root.pop() + check_data.root.append(MetadataData(element="uploaded_by", value="test-user")) + assert data == simulation_data.simulation.metadata + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_trace_endpoint(client): + """Test trace endpoint returns valid SimulationTraceData and handles replacement + chains.""" + # Create v1 -> v2 -> v3 replacement chain + sim_v1 = generate_simulation_data(alias="trace-v1") + rv1 = post_simulation(client, sim_v1) + assert rv1.status_code == 200 + + sim_v2 = generate_simulation_data( + alias="trace-v2", + metadata=[ + MetadataData(element="replaces", value=sim_v1.simulation.uuid.hex), + MetadataData(element="replaces_reason", value="Bug fixes"), + ], + ) + rv2 = post_simulation(client, sim_v2) + assert rv2.status_code == 200 + + sim_v3 = generate_simulation_data( + alias="trace-v3", + metadata=[ + MetadataData(element="replaces", value=sim_v2.simulation.uuid.hex), + MetadataData(element="replaces_reason", value="Performance"), + ], + ) + rv3 = post_simulation(client, sim_v3) + assert rv3.status_code == 200 + + # Test trace for v3 (full chain) + rv_trace = client.get(f"/v1.2/trace/{sim_v3.simulation.uuid.hex}", headers=HEADERS) + assert rv_trace.status_code == 200 + + trace = SimulationTraceData.model_validate(rv_trace.json) + + # Verify v3 + assert trace.uuid == sim_v3.simulation.uuid + assert trace.alias == "trace-v3" + assert trace.replaces_reason == "Performance" + + # Verify v2 (nested) + assert trace.replaces.uuid == sim_v2.simulation.uuid + assert trace.replaces.replaces_reason == "Bug fixes" + + # Verify v1 (double nested) + assert trace.replaces.replaces.uuid == sim_v1.simulation.uuid + assert trace.replaces.replaces.replaces is None From d2a9894fe0b4dd62852046b2f8f8ef194efbd860 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 17 Feb 2026 15:59:43 +0100 Subject: [PATCH 18/18] Update endpoints to use validation --- src/simdb/remote/apis/v1_2/simulations.py | 66 +++++++++++------------ 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 2cc279d..d7bccea 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -27,11 +27,17 @@ from simdb.remote.core.typing import current_app from simdb.remote.models import ( MetadataDataList, + MetadataDeleteData, + MetadataPatchData, PaginatedResponse, + PaginationData, SimulationDataResponse, + SimulationDeleteResponse, SimulationListItem, SimulationPostData, SimulationPostResponse, + SimulationTraceData, + StatusPatchData, ValidationResult, ) from simdb.uri import URI @@ -124,11 +130,8 @@ def _set_alias(alias: str): return alias, next_id -def _build_trace(sim_id: str) -> Dict[str, Any]: - try: - simulation = current_app.db.get_simulation(sim_id) - except DatabaseError as err: - return {"error": str(err)} +def _build_trace(sim_id: str) -> SimulationTraceData: + simulation = current_app.db.get_simulation(sim_id) data: Dict[str, Any] = cast(Dict[str, Any], simulation.data(recurse=False)) status = simulation.find_meta("status") @@ -155,7 +158,7 @@ def _build_trace(sim_id: str) -> Dict[str, Any]: if replaces_reason: data["replaces_reason"] = replaces_reason[0].value - return data + return SimulationTraceData.model_validate(data) def _get_json_aware( @@ -243,13 +246,13 @@ class SimulationList(Resource): @requires_auth() # @cache.cached(key_prefix=cache_key) def get(self, user: User): - limit = int(request.headers.get(SimulationList.LIMIT_HEADER) or 100) - page = int(request.headers.get(SimulationList.PAGE_HEADER) or 1) - sort_by = request.headers.get(SimulationList.SORT_BY_HEADER, "") - sort_asc = ( - request.headers.get(SimulationList.SORT_ASC_HEADER, "false").lower() - == "true" + pd = PaginationData.model_validate( + {k.lower(): v for (k, v) in request.headers.items()} ) + limit = pd.limit + page = pd.page + sort_by = pd.sort_by + sort_asc = pd.sort_asc names = [] constraints = [] if request.args: @@ -447,13 +450,11 @@ def get(self, sim_id: str, user: User): @requires_auth("admin") def patch(self, sim_id: str, user: Optional[User] = None): try: - data = request.get_json() or {} - if "status" not in data: - return error("Status not provided") + data = StatusPatchData.model_validate(request.json) simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") - status = models_sim.Simulation.Status(data["status"]) + status = models_sim.Simulation.Status(data.status) _update_simulation_status(simulation, status, user) current_app.db.insert_simulation(simulation) clear_cache() @@ -479,7 +480,11 @@ def delete(self, sim_id: str, user: User): directory = first_file.uri.path.parent if directory != Path() and directory != Path("/"): directory.rmdir() - return jsonify({"deleted": {"simulation": simulation.uuid, "files": files}}) + return jsonify( + SimulationDeleteResponse.model_validate( + {"deleted": {"uuid": simulation.uuid, "files": files}} + ).model_dump(mode="json") + ) except DatabaseError as err: return error(str(err)) @@ -511,16 +516,10 @@ def get(self, sim_id: str, user: User): @requires_auth("admin") def patch(self, sim_id: str, user: Optional[User] = None): try: - data = request.get_json() or {} - - if "key" not in data: - return error("Metadata key not provided") - - if "value" not in data: - return error("New metadata value not provided") + data = MetadataPatchData.model_validate(request.json) - key = data["key"] - value = data["value"].lower() + key = data.key + value = data.value.lower() simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") @@ -546,18 +545,13 @@ def patch(self, sim_id: str, user: Optional[User] = None): @requires_auth("admin") def delete(self, sim_id: str, user: Optional[User] = None): try: - data = request.get_json() or {} - - if "key" not in data: - return error("Metadata key not provided") - - key = data["key"] + data = MetadataDeleteData.model_validate(request.json) simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") - simulation.remove_meta(key) + simulation.remove_meta(data.key) current_app.db.insert_simulation(simulation) clear_cache() return {} @@ -574,7 +568,7 @@ def post(self, sim_id, user: User): result = _validate(simulation, user) current_app.db.insert_simulation(simulation) clear_cache() - return jsonify(result) + return jsonify(result.model_dump(mode="json")) except DatabaseError as err: return error(str(err)) @@ -585,8 +579,8 @@ class SimulationTrace(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] def get(self, sim_id: str, user: User): try: - data = _build_trace(sim_id) - return jsonify(data) + trace_data = _build_trace(sim_id) + return jsonify(trace_data.model_dump(mode="json")) except DatabaseError as err: return error(str(err))