diff --git a/pyproject.toml b/pyproject.toml index 31b1501..0956753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "email-validator>=1.1", "imas-python", "numpy>=1.14", + "pydantic>=2.10.6", "python-dateutil>=2.6", "pyuda>=2.9.2", "pyyaml>=3.13", diff --git a/src/simdb/cli/simdb.py b/src/simdb/cli/simdb.py index c475b0b..2b3214b 100644 --- a/src/simdb/cli/simdb.py +++ b/src/simdb/cli/simdb.py @@ -1,6 +1,6 @@ import copy import sys -from typing import IO +from typing import TextIO import click @@ -53,7 +53,7 @@ def list_commands(self, ctx): @click.option("-v", "--verbose", is_flag=True, help="Run with verbose output.") @click.option("-c", "--config-file", type=click.File("r"), help="Config file to load.") @click.pass_context -def cli(ctx, debug: bool, verbose: bool, config_file: IO): +def cli(ctx, debug: bool, verbose: bool, config_file: TextIO): if not ctx.obj: ctx.obj = Config() ctx.obj.load(config_file) diff --git a/src/simdb/database/database.py b/src/simdb/database/database.py index 37ba007..ff3ea9c 100644 --- a/src/simdb/database/database.py +++ b/src/simdb/database/database.py @@ -16,6 +16,7 @@ from simdb.config import Config from simdb.query import QueryType, query_compare +from simdb.remote.models import SimulationReference from .models import Base from .models.file import File @@ -226,8 +227,8 @@ def _find_simulation(self, sim_ref: str) -> "Simulation": ) except SQLAlchemyError: simulation = None - if not simulation: - raise DatabaseError(f"Simulation {sim_ref} not found.") from None + if not simulation: + raise DatabaseError(f"Simulation {sim_ref} not found.") from None return simulation def remove(self): @@ -571,6 +572,24 @@ def get_simulation_parents(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_parents_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.input_for.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.outputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_simulation_children(self, simulation: "Simulation") -> List[dict]: subquery = ( self.session.query(File.checksum) @@ -587,6 +606,24 @@ def get_simulation_children(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_children_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.output_of.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.inputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_file(self, file_uuid_str: str) -> "File": """ Get the specified file from the database. diff --git a/src/simdb/database/models/file.py b/src/simdb/database/models/file.py index a05aaa5..5b6a14d 100644 --- a/src/simdb/database/models/file.py +++ b/src/simdb/database/models/file.py @@ -14,6 +14,7 @@ from simdb.docstrings import inherit_docstrings from simdb.imas.checksum import checksum as imas_checksum from simdb.imas.utils import imas_timestamp +from simdb.remote.models import FileData from simdb.uda.checksum import checksum as uda_checksum from .base import Base @@ -125,6 +126,23 @@ def from_data(cls, data: Dict) -> "File": file.datetime = date_parser.parse(checked_get(data, "datetime", str)) return file + @classmethod + def from_data_model(cls, data: FileData) -> "File": + data_type = data.type + uri = data.uri + file = File( + DataObject.Type[data_type], urilib.URI(uri), perform_integrity_check=False + ) + file.uuid = data.uuid + file.usage = data.usage + file.checksum = data.checksum + file.purpose = data.purpose + file.sensitivity = data.sensitivity + file.access = data.access + file.embargo = data.embargo + file.datetime = data.datetime + return file + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "uuid": self.uuid, @@ -139,3 +157,17 @@ def data(self, recurse: bool = False) -> Dict[str, str]: "datetime": self.datetime.isoformat(), } return data + + def to_model(self) -> FileData: + return FileData( + type=self.type.name, + uri=str(self.uri), + uuid=self.uuid, + checksum=self.checksum, + datetime=self.datetime, + usage=self.usage, + purpose=self.purpose, + sensitivity=self.sensitivity, + access=self.access, + embargo=self.embargo, + ) diff --git a/src/simdb/database/models/metadata.py b/src/simdb/database/models/metadata.py index 628f158..81bcde9 100644 --- a/src/simdb/database/models/metadata.py +++ b/src/simdb/database/models/metadata.py @@ -4,6 +4,7 @@ from sqlalchemy import types as sql_types from simdb.docstrings import inherit_docstrings +from simdb.remote.models import MetadataData from .base import Base @@ -32,6 +33,11 @@ def from_data(cls, data: Dict) -> "MetaData": meta = MetaData(data["element"], data["value"]) return meta + @classmethod + def from_data_model(cls, data: MetadataData) -> "MetaData": + meta = MetaData(data.element, data.value) + return meta + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "element": self.element, @@ -39,5 +45,8 @@ def data(self, recurse: bool = False) -> Dict[str, str]: } return data + def to_model(self) -> MetadataData: + return MetadataData(element=self.element, value=self.value) + Index("metadata_index", MetaData.sim_id, MetaData.element, unique=True) diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 1ee21ab..74baa5f 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union +from simdb.remote.models import FileDataList, MetadataDataList, SimulationData + if sys.version_info < (3, 11): from backports.datetime_fromisoformat import MonkeyPatch @@ -336,6 +338,17 @@ def from_data(cls, data: Dict[str, Union[str, Dict, List]]) -> "Simulation": simulation.meta.append(MetaData.from_data(el)) return simulation + @classmethod + def from_data_model(cls, data: SimulationData) -> "Simulation": + simulation = Simulation(None) + simulation.uuid = data.uuid + simulation.alias = data.alias + simulation.datetime = data.datetime + simulation.inputs = [File.from_data_model(el) for el in data.inputs.root] + simulation.outputs = [File.from_data_model(el) for el in data.outputs.root] + simulation.meta = [MetaData.from_data_model(el) for el in data.metadata.root] + return simulation + def data( self, recurse: bool = False, meta_keys: Optional[List[str]] = None ) -> Dict[str, Union[str, List]]: @@ -354,6 +367,29 @@ def data( ] return data + def to_model( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationData: + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() + if recurse: + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) + elif meta_keys: + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) + return SimulationData( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + ) + def meta_dict(self) -> Dict[str, Union[Dict, Any]]: meta = {m.element: m.value for m in self.meta} return unflatten_dict(meta) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 01351e4..d7bccea 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, cast +import pydantic from flask import json as flask_json # fallback from flask import jsonify, request, send_file from flask_restx import Namespace, Resource @@ -24,6 +25,21 @@ from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path from simdb.remote.core.typing import current_app +from simdb.remote.models import ( + MetadataDataList, + MetadataDeleteData, + MetadataPatchData, + PaginatedResponse, + PaginationData, + SimulationDataResponse, + SimulationDeleteResponse, + SimulationListItem, + SimulationPostData, + SimulationPostResponse, + SimulationTraceData, + StatusPatchData, + ValidationResult, +) from simdb.uri import URI from simdb.validation import ValidationError, Validator from simdb.validation.file import find_file_validator @@ -36,7 +52,8 @@ def _update_simulation_status( ) -> None: old_status = simulation.status simulation.status = status - if status != old_status and len(simulation.watchers) > 0: + watchers_list = list(simulation.watchers) + if status != old_status and len(watchers_list) > 0: server = EmailServer(current_app.simdb_config) msg = f"""\ Simulation status changed from {old_status} to {status}. @@ -45,7 +62,7 @@ def _update_simulation_status( Note: please don't reply to this email, replies to this address are not monitored. """ - to_addresses = [w.email for w in simulation.watchers] + to_addresses = [w.email for w in watchers_list] if to_addresses: if simulation.alias is None or simulation.alias == "": server.send_message( @@ -55,7 +72,7 @@ def _update_simulation_status( server.send_message(f"Simulation {simulation.alias}", msg, to_addresses) -def _validate(simulation, user) -> Dict: +def _validate(simulation, user) -> ValidationResult: schemas = Validator.validation_schemas(current_app.simdb_config, simulation) try: for schema in schemas: @@ -65,10 +82,7 @@ def _validate(simulation, user) -> Dict: ) except ValidationError as err: _update_simulation_status(simulation, models_sim.Simulation.Status.FAILED, user) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) file_validator_type = current_app.simdb_config.get_string_option( "file_validation.type", default=None @@ -88,16 +102,11 @@ def _validate(simulation, user) -> Dict: _update_simulation_status( simulation, models_sim.Simulation.Status.FAILED, user ) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) else: error("Invalid file validator specified in configuration") - return { - "passed": True, - } + return ValidationResult(passed=True, error=None) def _set_alias(alias: str): @@ -121,11 +130,8 @@ def _set_alias(alias: str): return alias, next_id -def _build_trace(sim_id: str) -> Dict[str, Any]: - try: - simulation = current_app.db.get_simulation(sim_id) - except DatabaseError as err: - return {"error": str(err)} +def _build_trace(sim_id: str) -> SimulationTraceData: + simulation = current_app.db.get_simulation(sim_id) data: Dict[str, Any] = cast(Dict[str, Any], simulation.data(recurse=False)) status = simulation.find_meta("status") @@ -152,10 +158,12 @@ def _build_trace(sim_id: str) -> Dict[str, Any]: if replaces_reason: data["replaces_reason"] = replaces_reason[0].value - return data + return SimulationTraceData.model_validate(data) -def _get_json_aware(force: bool = False, silent: bool = False): +def _get_json_aware( + force: bool = False, silent: bool = False +) -> Optional[dict[str, Any]]: """ Parse JSON like Flask's request.get_json, but handle Content-Encoding: gzip. - force/silent mimic request.get_json behavior. @@ -238,13 +246,13 @@ class SimulationList(Resource): @requires_auth() # @cache.cached(key_prefix=cache_key) def get(self, user: User): - limit = int(request.headers.get(SimulationList.LIMIT_HEADER) or 100) - page = int(request.headers.get(SimulationList.PAGE_HEADER) or 1) - sort_by = request.headers.get(SimulationList.SORT_BY_HEADER, "") - sort_asc = ( - request.headers.get(SimulationList.SORT_ASC_HEADER, "false").lower() - == "true" + pd = PaginationData.model_validate( + {k.lower(): v for (k, v) in request.headers.items()} ) + limit = pd.limit + page = pd.page + sort_by = pd.sort_by + sort_asc = pd.sort_asc names = [] constraints = [] if request.args: @@ -276,7 +284,13 @@ def get(self, user: User): sort_asc=sort_asc, ) - return jsonify({"count": count, "page": page, "limit": limit, "results": data}) + serialized_data = [SimulationListItem.model_validate(item) for item in data] + + return jsonify( + PaginatedResponse( + count=count, page=page, limit=limit, results=serialized_data + ).model_dump(mode="json") + ) @requires_auth() def post(self, user: User): @@ -286,46 +300,31 @@ def post(self, user: User): # It returns None if the content type is not application/json. # If silent=True, it returns None instead of raising an error. # If force=True, it ignores the content type check. - data = _get_json_aware() - if not data: - return error("Invalid or missing JSON data") - - if "simulation" not in data: - return error("Simulation data not provided") - - add_watcher = data.get("add_watcher", True) + d = SimulationPostData.model_validate(_get_json_aware()) - simulation = models_sim.Simulation.from_data(data["simulation"]) + simulation = models_sim.Simulation.from_data_model(d.simulation) # Simulation Upload (Push) Date simulation.datetime = datetime.datetime.now() - if data["uploaded_by"] is not None: - simulation.set_meta("uploaded_by", data["uploaded_by"]) - elif user.email is not None: - simulation.set_meta("uploaded_by", user.email) - elif user.name is not None: - simulation.set_meta("uploaded_by", user.name) - else: - simulation.set_meta("uploaded_by", "anonymous") - if add_watcher: + uploaded_by = d.uploaded_by or user.email or user.name or "anonymous" + + simulation.set_meta("uploaded_by", uploaded_by) + + if d.add_watcher and user.email: simulation.watchers.append( models_watcher.Watcher( user.name, user.email, models_watcher.Notification.ALL ) ) - if "alias" in data["simulation"]: - alias = data["simulation"]["alias"] - if alias is not None: - (updated_alias, next_id) = _set_alias(alias) - if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) - simulation.alias = updated_alias - else: - simulation.alias = alias + if d.simulation.alias is not None: + (updated_alias, next_id) = _set_alias(d.simulation.alias) + if updated_alias: + simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.alias = updated_alias else: - simulation.alias = simulation.uuid.hex + simulation.alias = d.simulation.alias else: simulation.alias = simulation.uuid.hex @@ -334,15 +333,19 @@ def post(self, user: User): common_root = find_common_root(sim_file_paths) config = current_app.simdb_config + copy_files = config.get_option("server.copy_files", default=True) + imas_remote_host = config.get_option( + "server.imas_remote_host", default=None + ) - if config.get_option("server.copy_files", default=True): + if copy_files or imas_remote_host: staging_dir = ( Path(config.get_string_option("server.upload_folder")) / simulation.uuid.hex ) for sim_file in files: - if sim_file.uri.scheme == "file": + if copy_files and sim_file.uri.scheme == "file": path = secure_path(sim_file.uri.path, common_root, staging_dir) if not path.exists(): raise ValueError( @@ -350,73 +353,57 @@ def post(self, user: User): ) sim_file.uri = URI(scheme="file", path=path) elif sim_file.uri.scheme == "imas": - path = secure_path( - Path(sim_file.uri.query["path"]), - common_root, - staging_dir, - is_file=common_root is not None, - ) - sim_file.uri = convert_uri(sim_file.uri, path, config) - elif config.get_option("server.imas_remote_host", default=None): - staging_dir = ( - Path(config.get_string_option("server.upload_folder")) - / simulation.uuid.hex - ) - - for sim_file in files: - if sim_file.uri.scheme == "imas": - if config.get_option("server.copy_files", default=True): + if copy_files: path = secure_path( Path(sim_file.uri.query["path"]), common_root, staging_dir, is_file=common_root is not None, ) - sim_file.uri = convert_uri(sim_file.uri, path, config) else: path = Path(sim_file.uri.query["path"]) - sim_file.uri = convert_uri(sim_file.uri, path, config) + sim_file.uri = convert_uri(sim_file.uri, path, config) - result = { - "ingested": simulation.uuid.hex, - } + result = SimulationPostResponse( + ingested=simulation.uuid, error=None, validation=None + ) + + error_on_fail = current_app.simdb_config.get_option( + "validation.error_on_fail", default=False + ) if current_app.simdb_config.get_option( "validation.auto_validate", default=False ): - result["validation"] = _validate(simulation, user) + result.validation = _validate(simulation, user) - if current_app.simdb_config.get_option( - "validation.error_on_fail", default=False - ): - if simulation.status == models_sim.Simulation.Status.NOT_VALIDATED: - raise Exception( - "Validation config option error_on_fail=True without " - "auto_validate=True." + if not result.validation.passed and error_on_fail: + result.error = ( + f"Simulation validation failed and server has " + f"error_on_fail=True.\n{result.validation.error}" ) - elif simulation.status == models_sim.Simulation.Status.FAILED: - result["error"] = f"""Simulation validation failed and server has - error_on_fail=True.\n{result["validation"]["error"]}""" response = jsonify(result) response.status_code = 400 return response + elif error_on_fail: + raise RuntimeError( + "Validation config option error_on_fail=True without " + "auto_validate=True." + ) + disable_replaces = config.get_option( + "development.disable_replaces", default=False + ) replaces = simulation.find_meta("replaces") - if ( - not current_app.simdb_config.get_option( - "development.disable_replaces", default=False - ) - and replaces - and replaces[0].value - ): + + if not disable_replaces and replaces and replaces[0].value: sim_id = replaces[0].value try: replaces_sim = current_app.db.get_simulation(sim_id) except DatabaseError: replaces_sim = None - if replaces_sim is None: - pass - else: + + if replaces_sim is not None: _update_simulation_status( replaces_sim, models_sim.Simulation.Status.DEPRECATED, user ) @@ -429,8 +416,8 @@ def post(self, user: User): with contextlib.suppress(OSError): create_alias_dir(simulation) - return jsonify(result) - except (DatabaseError, ValueError) as err: + return jsonify(result.model_dump(mode="json")) + except (DatabaseError, ValueError, pydantic.ValidationError) as err: return error(str(err)) @@ -442,12 +429,14 @@ def get(self, sim_id: str, user: User): try: simulation = current_app.db.get_simulation(sim_id) if simulation: - sim_data = simulation.data(recurse=True) - sim_data["children"] = current_app.db.get_simulation_children( - simulation + sim_data = simulation.to_model(recurse=True) + children = current_app.db.get_simulation_children_ref(simulation) + parents = current_app.db.get_simulation_parents_ref(simulation) + return jsonify( + SimulationDataResponse( + **sim_data.model_dump(), children=children, parents=parents + ).model_dump(mode="json") ) - sim_data["parents"] = current_app.db.get_simulation_parents(simulation) - return jsonify(sim_data) return error("Simulation not found") except DatabaseError as err: return error(str(err)) @@ -461,13 +450,11 @@ def get(self, sim_id: str, user: User): @requires_auth("admin") def patch(self, sim_id: str, user: Optional[User] = None): try: - data = request.get_json() or {} - if "status" not in data: - return error("Status not provided") + data = StatusPatchData.model_validate(request.json) simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") - status = models_sim.Simulation.Status(data["status"]) + status = models_sim.Simulation.Status(data.status) _update_simulation_status(simulation, status, user) current_app.db.insert_simulation(simulation) clear_cache() @@ -493,7 +480,11 @@ def delete(self, sim_id: str, user: User): directory = first_file.uri.path.parent if directory != Path() and directory != Path("/"): directory.rmdir() - return jsonify({"deleted": {"simulation": simulation.uuid, "files": files}}) + return jsonify( + SimulationDeleteResponse.model_validate( + {"deleted": {"uuid": simulation.uuid, "files": files}} + ).model_dump(mode="json") + ) except DatabaseError as err: return error(str(err)) @@ -506,7 +497,11 @@ def get(self, sim_id: str, user: User): try: simulation = current_app.db.get_simulation(sim_id) if simulation: - return jsonify([meta.data() for meta in simulation.meta]) + return jsonify( + MetadataDataList.model_validate( + [meta.data() for meta in simulation.meta] + ).model_dump(mode="json") + ) return error("Simulation not found") except DatabaseError as err: return error(str(err)) @@ -521,16 +516,10 @@ def get(self, sim_id: str, user: User): @requires_auth("admin") def patch(self, sim_id: str, user: Optional[User] = None): try: - data = request.get_json() or {} + data = MetadataPatchData.model_validate(request.json) - if "key" not in data: - return error("Metadata key not provided") - - if "value" not in data: - return error("New metadata value not provided") - - key = data["key"] - value = data["value"].lower() + key = data.key + value = data.value.lower() simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") @@ -556,18 +545,13 @@ def patch(self, sim_id: str, user: Optional[User] = None): @requires_auth("admin") def delete(self, sim_id: str, user: Optional[User] = None): try: - data = request.get_json() or {} - - if "key" not in data: - return error("Metadata key not provided") - - key = data["key"] + data = MetadataDeleteData.model_validate(request.json) simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") - simulation.remove_meta(key) + simulation.remove_meta(data.key) current_app.db.insert_simulation(simulation) clear_cache() return {} @@ -584,7 +568,7 @@ def post(self, sim_id, user: User): result = _validate(simulation, user) current_app.db.insert_simulation(simulation) clear_cache() - return jsonify(result) + return jsonify(result.model_dump(mode="json")) except DatabaseError as err: return error(str(err)) @@ -595,8 +579,8 @@ class SimulationTrace(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] def get(self, sim_id: str, user: User): try: - data = _build_trace(sim_id) - return jsonify(data) + trace_data = _build_trace(sim_id) + return jsonify(trace_data.model_dump(mode="json")) except DatabaseError as err: return error(str(err)) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py new file mode 100644 index 0000000..e2b5693 --- /dev/null +++ b/src/simdb/remote/models.py @@ -0,0 +1,191 @@ +from datetime import datetime as dt +from datetime import timezone +from typing import Annotated, Any, Generic, List, Optional, TypeVar, Union +from urllib.parse import urlencode +from uuid import UUID, uuid1 + +from pydantic import ( + BaseModel, + BeforeValidator, + Field, + PlainSerializer, + RootModel, + model_validator, +) + +HexUUID = Annotated[UUID, PlainSerializer(lambda x: x.hex, return_type=str)] + + +def _deserialize_custom_uuid(v: Any) -> UUID: + """Deserialize CustomUUID format back to UUID.""" + if isinstance(v, UUID): + return v + if isinstance(v, dict) and "hex" in v: + return UUID(hex=v["hex"]) + raise ValueError(f"Cannot deserialize {v} to UUID") + + +CustomUUID = Annotated[ + UUID, + BeforeValidator(_deserialize_custom_uuid), + PlainSerializer(lambda x: {"_type": "uuid.UUID", "hex": x.hex}), +] + + +class StatusPatchData(BaseModel): + status: str + + +class DeletedSimulation(BaseModel): + uuid: UUID + files: List[str] + + +class SimulationDeleteResponse(BaseModel): + deleted: DeletedSimulation + + +class FileData(BaseModel): + type: str + uri: str + uuid: CustomUUID = Field(default_factory=lambda: uuid1()) + checksum: str + datetime: dt + usage: Optional[str] = None + purpose: Optional[str] = None + sensitivity: Optional[str] = None + access: Optional[str] = None + embargo: Optional[str] = None + + +class FileDataList(RootModel): + root: List[FileData] = [] + + # Allows indexing: users[0] + def __getitem__(self, item) -> FileData: + return self.root[item] + + +class MetadataData(BaseModel): + element: str + value: Union[CustomUUID, Any] + + def as_dict(self): + return {self.element: self.value} + + def as_querystring(self): + return urlencode(self.as_dict()) + + +class MetadataPatchData(BaseModel): + key: str + value: str + + +class MetadataDeleteData(BaseModel): + key: str + + +class MetadataDataList(RootModel): + root: List[MetadataData] = [] + + def __getitem__(self, item) -> MetadataData: + return self.root[item] + + def as_dict(self): + return {m.element: m.value for m in self.root} + + @model_validator(mode="before") + @classmethod + def parse_dictionary(cls, data: Any): + if isinstance(data, dict): + return [{"element": k, "value": v} for (k, v) in data.items()] + return data + + def as_querystring(self): + return urlencode(self.as_dict()) + + +class SimulationReference(BaseModel): + uuid: CustomUUID + alias: Optional[str] = None + + +class SimulationData(BaseModel): + uuid: CustomUUID = Field(default_factory=lambda: uuid1()) + alias: Optional[str] = None + datetime: dt = Field(default_factory=lambda: dt.now(timezone.utc)) + inputs: FileDataList = FileDataList() + outputs: FileDataList = FileDataList() + metadata: MetadataDataList = MetadataDataList() + + +class SimulationDataResponse(SimulationData): + parents: List[SimulationReference] + children: List[SimulationReference] + + +class SimulationPostData(BaseModel): + simulation: SimulationData + add_watcher: bool + uploaded_by: Optional[str] = None + + +class ValidationResult(BaseModel): + passed: bool + error: Optional[str] = None + + +class SimulationPostResponse(BaseModel): + ingested: HexUUID + error: Optional[str] = None + validation: Optional[ValidationResult] = None + + +class SimulationListItem(BaseModel): + uuid: CustomUUID + alias: Optional[str] = None + datetime: str + metadata: Optional[MetadataDataList] = None + + +T = TypeVar("T") + + +class PaginatedResponse(BaseModel, Generic[T]): + count: int + page: int + limit: int + results: List[T] + + +class PaginationData(BaseModel): + limit: int + page: int + sort_by: str + sort_asc: bool + + @model_validator(mode="before") + @classmethod + def parse_headers(cls, data: Any): + if not isinstance(data, dict): + return data + new_data = { + "limit": data.get("simdb-result-limit", 100), + "page": data.get("simdb-page", 1), + "sort_by": data.get("simdb-sort-by", ""), + "sort_asc": data.get("simdb-sort-asc", False), + } + return new_data + + +class SimulationTraceData(SimulationData): + status: Optional[str] = None + passed_on: Optional[Any] = None + failed_on: Optional[Any] = None + deprecated_on: Optional[Any] = None + accepted_on: Optional[Any] = None + not_validated_on: Optional[Any] = None + deleted_on: Optional[Any] = None + replaces: Optional["SimulationTraceData"] = None + replaces_reason: Optional[Any] = None diff --git a/tests/remote/test_api.py b/tests/remote/test_api.py index 8fc9dfe..668a8bc 100644 --- a/tests/remote/test_api.py +++ b/tests/remote/test_api.py @@ -1,7 +1,10 @@ import base64 import importlib import os +import shutil import tempfile +import uuid +from datetime import datetime, timezone from pathlib import Path import pytest @@ -10,6 +13,20 @@ from simdb.config import Config from simdb.database.models import Simulation from simdb.remote.app import create_app +from simdb.remote.models import ( + FileData, + MetadataData, + MetadataDataList, + MetadataDeleteData, + MetadataPatchData, + PaginatedResponse, + SimulationData, + SimulationDataResponse, + SimulationListItem, + SimulationPostData, + SimulationTraceData, + StatusPatchData, +) has_flask = importlib.util.find_spec("flask") is not None @@ -28,10 +45,14 @@ def client(): config = Config() config.load() db_fd, db_file = tempfile.mkstemp() + upload_dir = tempfile.mkdtemp() config.set_option("database.type", "sqlite") config.set_option("database.file", db_file) config.set_option("server.admin_password", TEST_PASSWORD) + config.set_option("server.upload_folder", upload_dir) config.set_option("authentication.type", "None") + config.set_option("server.copy_files", False) + config.set_option("role.admin.users", "admin,admin2") app = create_app(config=config, testing=True, debug=True) app.testing = True @@ -47,6 +68,36 @@ def client(): os.close(db_fd) Path(app.simdb_config.get_option("database.file")).unlink() + shutil.rmtree(upload_dir) + + +def generate_simulation_data( + add_watcher=False, uploaded_by=None, **overrides +) -> SimulationPostData: + simulation_data = SimulationData(**overrides) + data = SimulationPostData( + simulation=simulation_data, add_watcher=add_watcher, uploaded_by=uploaded_by + ) + return data + + +def generate_simulation_file() -> FileData: + return FileData( + type="FILE", + uri="file:///path/to/file", + checksum="fake_checksum", + datetime=datetime.now(timezone.utc), + ) + + +def post_simulation(client, simulation_data, headers=HEADERS): + rv_post = client.post( + "/v1.2/simulations", + json=simulation_data.model_dump(mode="json"), + headers=headers, + content_type="application/json", + ) + return rv_post @pytest.mark.skipif(not has_flask, reason="requires flask library") @@ -67,8 +118,812 @@ def test_get_api_root(client): @pytest.mark.skipif(not has_flask, reason="requires flask library") -def test_get_simulations(client): +def test_post_simulations(client): + """Test POST endpoint for creating a new simulation.""" + simulation_data = generate_simulation_data( + alias="test-simulation", + inputs=[generate_simulation_file()], + outputs=[generate_simulation_file()], + ) + + # POST the simulation + rv = post_simulation(client, simulation_data) + + # Verify the response + assert rv.status_code == 200 + assert rv.json["ingested"] == simulation_data.simulation.uuid.hex + + # Verify the simulation was created by fetching it + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == "test-simulation" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +@pytest.mark.parametrize("suffix", ["-", "#"]) +def test_post_simulations_with_alias_auto_increment(client, suffix): + """Test POST endpoint with alias ending in dash or hashtag (auto-increment).""" + random_name = uuid.uuid4().hex + simulation_data = generate_simulation_data( + alias=f"{random_name}{suffix}", + ) + + rv = post_simulation(client, simulation_data) + + assert rv.status_code == 200 + assert rv.json["ingested"] == simulation_data.simulation.uuid.hex + + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == f"{random_name}{suffix}1" + + # Check seqid metadata was added + metadata = rv_get.json["metadata"] + seqid_meta = [m for m in metadata if m["element"] == "seqid"] + assert len(seqid_meta) == 1 + assert seqid_meta[0]["value"] == 1 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_alias_increment_sequence(client): + """Test multiple simulations with incrementing dash alias.""" + # Create first simulation with dash alias + simulation_data_1 = generate_simulation_data( + alias="sequence-", + ) + + rv1 = post_simulation(client, simulation_data_1) + assert rv1.status_code == 200 + + simulation_data_2 = generate_simulation_data( + alias="sequence-", + ) + + rv2 = post_simulation(client, simulation_data_2) + assert rv2.status_code == 200 + + # Verify aliases were incremented + rv_get1 = client.get( + f"/v1.2/simulation/{simulation_data_1.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get1.json["alias"] == "sequence-1" + + rv_get2 = client.get( + f"/v1.2/simulation/{simulation_data_2.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get2.json["alias"] == "sequence-2" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_no_alias(client): + """Test POST endpoint with no alias provided (should use uuid.hex).""" + simulation_data = generate_simulation_data() + + rv = post_simulation(client, simulation_data) + + assert rv.status_code == 200 + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == simulation_data.simulation.uuid.hex + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_with_replaces(client): + """Test POST endpoint with replaces metadata (deprecates old simulation).""" + # Create initial simulation + old_simulation_data = generate_simulation_data(alias="old_simulation") + + rv_old = post_simulation(client, old_simulation_data) + assert rv_old.status_code == 200 + + # Create new simulation that replaces the old one + new_simulation_data = generate_simulation_data( + alias="updated-simulation", + metadata=[ + MetadataData( + element="replaces", value=old_simulation_data.simulation.uuid.hex + ), + MetadataData(element="replaces_reason", value="Test replacement"), + ], + ) + + rv_new = post_simulation(client, new_simulation_data) + assert rv_new.status_code == 200 + + # Verify the old simulation is marked as DEPRECATED + rv_old_get = client.get( + f"/v1.2/simulation/{old_simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_old_get.status_code == 200 + old_metadata = rv_old_get.json["metadata"] + + status_meta = [m for m in old_metadata if m["element"] == "status"] + assert len(status_meta) == 1 + assert status_meta[0]["value"].lower() == "deprecated" + + # Check replaced_by metadata was added + replaced_by_meta = [m for m in old_metadata if m["element"] == "replaced_by"] + assert len(replaced_by_meta) == 1 + assert replaced_by_meta[0]["value"] == new_simulation_data.simulation.uuid + + # Verify the new simulation has replaces metadata + rv_new_get = client.get( + f"/v1.2/simulation/{new_simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_new_get.status_code == 200 + new_metadata = rv_new_get.json["metadata"] + + replaces_meta = [m for m in new_metadata if m["element"] == "replaces"] + assert len(replaces_meta) == 1 + assert replaces_meta[0]["value"] == old_simulation_data.simulation.uuid.hex + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_replaces_nonexistent(client): + """Test POST endpoint with replaces pointing to non-existent simulation.""" + # Create simulation that tries to replace a non-existent simulation + simulation_data = generate_simulation_data( + alias="replaces-nothing", + metadata=[ + MetadataData(element="replaces", value=uuid.uuid1().hex), + MetadataData(element="replaces_reason", value="Test replacement"), + ], + ) + + # Should still succeed (old simulation just doesn't exist to deprecate) + rv = post_simulation(client, simulation_data) + assert rv.status_code == 200 + + # Verify the new simulation was created + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get.status_code == 200 + assert rv_get.json["alias"] == "replaces-nothing" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_with_watcher(client): + """Test POST endpoint with add_watcher set to true.""" + simulation_data = generate_simulation_data( + add_watcher=True, uploaded_by="watcher-user" + ) + + rv = post_simulation(client, simulation_data) + assert rv.status_code == 200 + + # Verify the simulation was created + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get.status_code == 200 + + # Note: We can't easily verify watchers were added without accessing the db directly + # but we can verify the request was successful and uploaded_by metadata is present + metadata = rv_get.json["metadata"] + uploaded_by_meta = [m for m in metadata if m["element"] == "uploaded_by"] + assert len(uploaded_by_meta) == 1 + assert uploaded_by_meta[0]["value"] == "watcher-user" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_uploaded_by(client): + """Test POST endpoint with uploaded_by field.""" + """Test POST endpoint with add_watcher set to true.""" + simulation_data = generate_simulation_data(uploaded_by="test-user") + + rv = post_simulation(client, simulation_data) + assert rv.status_code == 200 + + # Verify the simulation was created + rv_get = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_get.status_code == 200 + + metadata = rv_get.json["metadata"] + uploaded_by_meta = [m for m in metadata if m["element"] == "uploaded_by"] + assert len(uploaded_by_meta) == 1 + assert uploaded_by_meta[0]["value"] == "test-user" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_post_simulations_trace_with_replaces(client): + """Test the trace endpoint with a simulation that replaces another.""" + # Create original simulation + # Create initial simulation + old_simulation_data = generate_simulation_data(alias="trace-original") + + rv_old = post_simulation(client, old_simulation_data) + assert rv_old.status_code == 200 + + # Create new simulation that replaces the old one + new_simulation_data = generate_simulation_data( + alias="trace-updated", + metadata=[ + MetadataData( + element="replaces", value=old_simulation_data.simulation.uuid.hex + ), + MetadataData(element="replaces_reason", value="New features"), + ], + ) + + rv_new = post_simulation(client, new_simulation_data) + assert rv_new.status_code == 200 + + # Get trace for the new simulation + rv_trace = client.get( + f"/v1.2/trace/{new_simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv_trace.status_code == 200 + trace_data = rv_trace.json + + # Verify trace includes replaces information + assert "replaces" in trace_data + + replaces_uuid = trace_data["replaces"]["uuid"] + assert replaces_uuid == old_simulation_data.simulation.uuid + assert "replaces_reason" in trace_data + assert trace_data["replaces_reason"] == "New features" + + with pytest.xfail("Deprecated on is not set, because replaced_on is never set"): + assert "deprecated_on" in trace_data["replaces"] + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_basic(client): + """Test basic GET request to /v1.2/simulations endpoint.""" rv = client.get("/v1.2/simulations", headers=HEADERS) - assert rv.json["count"] == 100 - assert len(rv.json["results"]) == len(SIMULATIONS) + + assert rv.status_code == 200 + assert rv.is_json + + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.page == 1 + assert data.limit == 100 + assert data.count >= 100 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_pagination_limit(client): + """Test GET request with custom limit.""" + custom_limit = 10 + headers_with_limit = {**HEADERS, "simdb-result-limit": str(custom_limit)} + + rv = client.get("/v1.2/simulations", headers=headers_with_limit) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.limit == custom_limit + assert len(data.results) <= custom_limit + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_pagination_page(client): + """Test GET request with custom page number.""" + headers_page_2 = {**HEADERS, "simdb-result-limit": "10", "simdb-page": "2"} + + rv = client.get("/v1.2/simulations", headers=headers_page_2) + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.page == 2 + assert data.limit == 10 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_pagination_multiple_pages(client): + """Test pagination across multiple pages.""" + limit = 20 + + # Get first page + headers_page_1 = {**HEADERS, "simdb-result-limit": str(limit), "simdb-page": "1"} + rv1 = client.get("/v1.2/simulations", headers=headers_page_1) + assert rv1.status_code == 200 + page1_data = PaginatedResponse[SimulationListItem].model_validate(rv1.json) + + # Get second page + headers_page_2 = {**HEADERS, "simdb-result-limit": str(limit), "simdb-page": "2"} + rv2 = client.get("/v1.2/simulations", headers=headers_page_2) + assert rv2.status_code == 200 + page2_data = PaginatedResponse[SimulationListItem].model_validate(rv2.json) + + # Both should have same count and limit + assert page1_data.count == page2_data.count + assert page1_data.limit == page2_data.limit == limit + + # Pages should be different + assert page1_data.page == 1 + assert page2_data.page == 2 + + page1_uuids = {item.uuid for item in page1_data.results} + page2_uuids = {item.uuid for item in page2_data.results} + assert page1_uuids != page2_uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_by_alias(client): + """Test filtering simulations by alias.""" + # First create a simulation with a known alias + test_alias = "filter-test-alias" + simulation_data = generate_simulation_data(alias=test_alias) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Now filter by alias + rv = client.get(f"/v1.2/simulations?alias={test_alias}", headers=HEADERS) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.count == 1 + # Check that the filtered result contains our simulation + aliases = [item.alias for item in data.results] + assert test_alias in aliases + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_by_uuid(client): + """Test filtering simulations by UUID.""" + # Create a simulation with a known UUID + simulation_data = generate_simulation_data() + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Filter by UUID + rv = client.get( + f"/v1.2/simulations?uuid={simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.count == 1 + # Check that the filtered result contains our simulation + uuids = [item.uuid for item in data.results] + assert simulation_data.simulation.uuid in uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_by_metadata(client): + """Test filtering simulations by metadata.""" + # Create simulations with specific metadata + test_metadata = MetadataData(element="machine", value="test_machine") + + simulation_data_1 = generate_simulation_data(metadata=[test_metadata]) + simulation_data_2 = generate_simulation_data(metadata=[test_metadata]) + rv_post_1 = post_simulation(client, simulation_data_1) + assert rv_post_1.status_code == 200 + + rv_post_2 = post_simulation(client, simulation_data_2) + assert rv_post_2.status_code == 200 + + # Filter by machine metadata + rv = client.get( + f"/v1.2/simulations?{test_metadata.as_querystring()}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.count == 2 + + # Check that both simulations are in the results + results_uuids = [item.uuid for item in data.results] + assert simulation_data_1.simulation.uuid in results_uuids + assert simulation_data_2.simulation.uuid in results_uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_filter_multiple_metadata(client): + """Test filtering simulations by multiple metadata fields.""" + # Create a simulation with multiple metadata fields + test_metadata = {"machine": "multi-filter-machine", "code": "multi-filter-code"} + + simulation_data = generate_simulation_data(metadata=test_metadata) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Filter by both machine and code + rv = client.get( + f"/v1.2/simulations?{simulation_data.simulation.metadata.as_querystring()}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.count == 1 + results_uuids = [item.uuid for item in data.results] + assert simulation_data.simulation.uuid in results_uuids + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +@pytest.mark.xfail(reason="Only sorting by metadata keys works for now") +def test_get_simulations_alias_sorting_asc(client): + """Test sorting simulations in ascending order by alias.""" + # Create simulations with sortable aliases + for i in range(3): + simulation_data = generate_simulation_data(alias=f"alias-sort-test-{i:03d}") + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Get simulations sorted by alias ascending + headers_sorted = {**HEADERS, "simdb-sort-by": "alias", "simdb-sort-asc": "true"} + + rv = client.get( + "/v1.2/simulations?alias=IN%3Aalias-sort-test-", headers=headers_sorted + ) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + # Check that results are sorted in ascending order + aliases = [item.alias for item in data.results if item.alias is not None] + assert aliases == sorted(aliases) + assert len(aliases) == 3 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +@pytest.mark.parametrize("ascending", [True, False]) +def test_get_simulations_metadata_sorting(client, ascending): + """Test sorting simulations in ascending order.""" + # Create simulations with sortable aliases + # post them in the order: 2 1 0 5 4 3 + # ascending should result in 0 1 2 3 4 5 + # descending should result in 5 4 3 2 1 0 + for i in reversed(range(3)): + simulation_data = generate_simulation_data( + alias=f"sort-test-{ascending}-{i:03d}", + metadata=[MetadataData(element="sort-test", value=i)], + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + for i in reversed(range(3, 6)): + simulation_data = generate_simulation_data( + alias=f"sort-test-{ascending}-{i:03d}", + metadata=[MetadataData(element="sort-test", value=i)], + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Get simulations sorted by alias ascending + headers_sorted = { + **HEADERS, + "simdb-sort-by": "sort-test", + "simdb-sort-asc": "true" if ascending else "false", + } + + rv = client.get( + f"/v1.2/simulations?alias=IN%3Asort-test-{ascending}&sort-test", + headers=headers_sorted, + ) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + # Check that results are sorted in the correct order + metadata = [ + item.metadata[0].value for item in data.results if item.metadata is not None + ] + assert metadata == sorted(metadata, reverse=not ascending) + assert len(metadata) == 6 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_empty_result(client): + """Test GET request with filters that return no results.""" + # Use a filter that shouldn't match anything + rv = client.get( + "/v1.2/simulations?alias=non-existent-simulation-12345xyz", headers=HEADERS + ) + + assert rv.status_code == 200 + data = PaginatedResponse[SimulationListItem].model_validate(rv.json) + + assert data.count == 0 + assert len(data.results) == 0 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulations_with_metadata_keys(client): + """Test requesting specific metadata keys in results.""" + # Create a simulation with known metadata + + simulation_data = generate_simulation_data( + alias="meta-keys-test", + metadata={"machine": "machine-x", "code": "code-y"}, + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Request simulations with specific metadata keys + rv = client.get( + "/v1.2/simulations?alias=meta-keys-test&machine&code", headers=HEADERS + ) + + assert rv.status_code == 200 + data: PaginatedResponse[SimulationListItem] = PaginatedResponse[ + SimulationListItem + ].model_validate(rv.json) + + assert data.count == 1 + assert len(data.results) == 1 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_by_uuid(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - retrieve by UUID.""" + # Create a simulation with known properties + simulation_data = generate_simulation_data(uploaded_by="test-uploader") + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Test GET by UUID + rv = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + + assert rv.status_code == 200 + assert rv.is_json + + # Validate full data model + SimulationDataResponse.model_validate(rv.json) + + simulation_data_received = SimulationData.model_validate(rv.json, extra="ignore") + simulation_data_check = simulation_data.simulation.model_copy() + + # fill fields that are filled by the server + simulation_data_check.alias = simulation_data_check.uuid.hex + simulation_data_check.metadata = MetadataDataList.model_validate( + {"uploaded_by": simulation_data.uploaded_by} + ) + + # datetime gets updated by the server + simulation_data_check.datetime = simulation_data_received.datetime + + assert simulation_data_received == simulation_data_check + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_by_alias(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - retrieve by alias.""" + # Create a simulation with a unique alias + simulation_data = generate_simulation_data( + alias="test-get-alias", uploaded_by="test-uploader" + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + # Test GET by alias + rv = client.get( + f"/v1.2/simulation/{simulation_data.simulation.alias}", headers=HEADERS + ) + + assert rv.status_code == 200 + assert rv.is_json + + # Validate full data model + SimulationDataResponse.model_validate(rv.json) + + simulation_data_received = SimulationData.model_validate(rv.json, extra="ignore") + simulation_data_check = simulation_data.simulation.model_copy() + + # fill fields that are filled by the server + simulation_data_check.metadata = MetadataDataList.model_validate( + {"uploaded_by": simulation_data.uploaded_by} + ) + + # datetime gets updated by the server + simulation_data_check.datetime = simulation_data_received.datetime + + assert simulation_data_received == simulation_data_check + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_not_found(client): + """Test GET /v1.2/simulation/{simulation_id} endpoint - non-existent simulation.""" + # Try to get a non-existent simulation + fake_uuid = uuid.uuid1() + + rv = client.get(f"/v1.2/simulation/{fake_uuid.hex}", headers=HEADERS) + + assert rv.status_code == 400 + data = rv.json + + # Should contain an error message + assert "error" in data or data.get("message") == "Simulation not found" + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_get_simulation_metadata(client): + """Test GET /v1.2/simulation/metadata/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data( + metadata={"metadata-a": "abc", "metadata-b": "123"}, uploaded_by="test-user" + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = MetadataDataList.model_validate(rv.json) + check_data = simulation_data.simulation.metadata.model_copy() + check_data.root.append(MetadataData(element="uploaded_by", value="test-user")) + assert data == simulation_data.simulation.metadata + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_patch_simulation(client): + """Test PATCH /v1.2/simulation/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data() + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.patch( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", + json=StatusPatchData(status="failed").model_dump(mode="json"), + headers=HEADERS, + ) + + assert rv.status_code == 200 + + # Status is never returned, so we can't check if it is set + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_delete_simulation(client): + """Test DELETE /v1.2/simulation/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data() + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.delete( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/{simulation_data.simulation.uuid.hex}", headers=HEADERS + ) + + assert rv.status_code == 400 + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_patch_simulation_metadata(client): + """Test PATCH /v1.2/simulation/metadata/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data( + metadata={"metadata-a": "abc"}, uploaded_by="test-user" + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.patch( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + json=MetadataPatchData(key="metadata-a", value="def").model_dump(mode="json"), + headers=HEADERS, + ) + + assert rv.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = MetadataDataList.model_validate(rv.json) + check_data = simulation_data.simulation.metadata.model_copy() + check_data[0].value = "def" + check_data.root.append(MetadataData(element="uploaded_by", value="test-user")) + assert data == simulation_data.simulation.metadata + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_delete_simulation_metadata(client): + """Test DELETE /v1.2/simulation/metadata/{simulation_id} endpoint.""" + simulation_data = generate_simulation_data( + metadata={"metadata-a": "abc"}, uploaded_by="test-user" + ) + + rv_post = post_simulation(client, simulation_data) + assert rv_post.status_code == 200 + + rv = client.delete( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + json=MetadataDeleteData(key="metadata-a").model_dump(mode="json"), + headers=HEADERS, + ) + + assert rv.status_code == 200 + + rv = client.get( + f"/v1.2/simulation/metadata/{simulation_data.simulation.uuid.hex}", + headers=HEADERS, + ) + + assert rv.status_code == 200 + data = MetadataDataList.model_validate(rv.json) + check_data = simulation_data.simulation.metadata.model_copy() + check_data.root.pop() + check_data.root.append(MetadataData(element="uploaded_by", value="test-user")) + assert data == simulation_data.simulation.metadata + + +@pytest.mark.skipif(not has_flask, reason="requires flask library") +def test_trace_endpoint(client): + """Test trace endpoint returns valid SimulationTraceData and handles replacement + chains.""" + # Create v1 -> v2 -> v3 replacement chain + sim_v1 = generate_simulation_data(alias="trace-v1") + rv1 = post_simulation(client, sim_v1) + assert rv1.status_code == 200 + + sim_v2 = generate_simulation_data( + alias="trace-v2", + metadata=[ + MetadataData(element="replaces", value=sim_v1.simulation.uuid.hex), + MetadataData(element="replaces_reason", value="Bug fixes"), + ], + ) + rv2 = post_simulation(client, sim_v2) + assert rv2.status_code == 200 + + sim_v3 = generate_simulation_data( + alias="trace-v3", + metadata=[ + MetadataData(element="replaces", value=sim_v2.simulation.uuid.hex), + MetadataData(element="replaces_reason", value="Performance"), + ], + ) + rv3 = post_simulation(client, sim_v3) + assert rv3.status_code == 200 + + # Test trace for v3 (full chain) + rv_trace = client.get(f"/v1.2/trace/{sim_v3.simulation.uuid.hex}", headers=HEADERS) + assert rv_trace.status_code == 200 + + trace = SimulationTraceData.model_validate(rv_trace.json) + + # Verify v3 + assert trace.uuid == sim_v3.simulation.uuid + assert trace.alias == "trace-v3" + assert trace.replaces_reason == "Performance" + + # Verify v2 (nested) + assert trace.replaces.uuid == sim_v2.simulation.uuid + assert trace.replaces.replaces_reason == "Bug fixes" + + # Verify v1 (double nested) + assert trace.replaces.replaces.uuid == sim_v1.simulation.uuid + assert trace.replaces.replaces.replaces is None