diff --git a/examples/Tutorial_nerdss.ipynb b/examples/Tutorial_nerdss.ipynb index 8912257c..49246916 100644 --- a/examples/Tutorial_nerdss.ipynb +++ b/examples/Tutorial_nerdss.ipynb @@ -11,7 +11,16 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/anaconda3/envs/complexenumeration-env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "from IPython.display import Image\n", "\n", @@ -118,15 +127,45 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading PDB Data -------------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/anaconda3/envs/complexenumeration-env/lib/python3.11/site-packages/MDAnalysis/topology/PDBParser.py:350: UserWarning: Element information is missing, elements attribute will not be populated. If needed these can be guessed using universe.guess_TopologyAttrs(context='default', to_guess=['elements']).\n", + " warnings.warn(\"Element information is missing, elements attribute \"\n", + "/opt/anaconda3/envs/complexenumeration-env/lib/python3.11/site-packages/MDAnalysis/topology/PDBParser.py:350: UserWarning: Element information is missing, elements attribute will not be populated. If needed these can be guessed using universe.guess_TopologyAttrs(context='default', to_guess=['elements']).\n", + " warnings.warn(\"Element information is missing, elements attribute \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Filtering: translation -------------\n", + "Formatting trajectory data...\n", + "Converting Trajectory Data to JSON -------------\n", + "Sanitizing JSON for NaNs and NumPy types...\n", + "Writing sanitized JSON...\n", + "Saved clean JSON to example_virus.simularium\n" + ] + } + ], "source": [ "converter = NerdssConverter(example_data)\n", "# this _filter is just roughly centering the data in the box\n", "_filter = TranslateFilter(default_translation=np.array([-80.0, -80.0, -80.0]))\n", "filtered_data = converter.filter_data([_filter])\n", - "JsonWriter.save(filtered_data, \"example_virus\", False)" + "JsonWriter.save_replacing_nan(filtered_data, \"example_virus\", False)" ] }, { @@ -167,7 +206,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.13" }, "vscode": { "interpreter": { diff --git a/simulariumio/tests/writers/test_json_writer.py b/simulariumio/tests/writers/test_json_writer.py new file mode 100644 index 00000000..3c599749 --- /dev/null +++ b/simulariumio/tests/writers/test_json_writer.py @@ -0,0 +1,74 @@ +""" +test_json_writer.py + +Unit tests for the JsonWriter class in simulariumio.writers. + +Current only tests `save_replacing_nan()` to ensure that: + * All NaN values are replaced with null (None in Python). + * Output JSON is valid and does not contain any NaN. + * Output file is written correctly in `.simularium` format. + +Status: + - Uses pytest with fixtures to generate a minimal TrajectoryData mock. + - Tests are self-contained and do not require writing to real disk locations. + - Assumes simulariumio is installed or importable via editable mode (`pip install -e .`). + +To run: + > pytest simulariumio/tests/writers/test_json_writer.py +""" +import json +import os +import tempfile +import numpy as np +import pytest +from simulariumio import TrajectoryData, AgentData, UnitData, DisplayData +from simulariumio.constants import DISPLAY_TYPE, VALUES_PER_3D_POINT, VIZ_TYPE +from simulariumio.writers import JsonWriter + +@pytest.fixture +def trajectory_with_nan(): + # Create mock AgentData with NaNs + agent_data = AgentData.from_dimensions(dimensions=(1, 1, VALUES_PER_3D_POINT)) + agent_data.positions[0][0] = [np.nan, 1.0, 2.0] + agent_data.radii[0][0] = np.nan + agent_data.subpoints[0][0] = [0.0, np.nan, 0.0, 1.0, 1.0, 1.0] + agent_data.types[0].append("A") + agent_data.unique_ids[0][0] = 1 + agent_data.n_agents[0] = 1 + agent_data.n_subpoints[0][0] = 6 + agent_data.viz_types[0][0] = VIZ_TYPE.FIBER + + # Wrap into a TrajectoryData object + traj_data = TrajectoryData( + meta_data=None, + agent_data=agent_data, + time_units=UnitData(name="s"), + spatial_units=UnitData(name="nm"), + plots=None, + ) + return traj_data + +def test_save_replacing_nan(trajectory_with_nan): + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "test_output") + + # Act: Call method under test + JsonWriter.save_replacing_nan(trajectory_with_nan, output_path, validate_ids=False) + + # Verify file exists + output_file = output_path + ".simularium" + assert os.path.exists(output_file) + + # Verify contents + with open(output_file, "r") as f: + json_data = json.load(f) + + # Assert that NaN was replaced with null + positions = json_data["trajectoryInfo"]["agentData"]["positions"][0][0] + radii = json_data["trajectoryInfo"]["agentData"]["radii"][0][0] + subpoints = json_data["trajectoryInfo"]["agentData"]["subpoints"][0][0] + + assert positions[0] is None + assert isinstance(positions[1], float) + assert radii is None + assert subpoints[1] is None diff --git a/simulariumio/writers/json_writer.py b/simulariumio/writers/json_writer.py index 0bab8b56..f61b4778 100644 --- a/simulariumio/writers/json_writer.py +++ b/simulariumio/writers/json_writer.py @@ -4,6 +4,7 @@ import json import logging from typing import Any, Dict, List +import math import numpy as np @@ -185,6 +186,58 @@ def save( with open(f"{output_path}.simularium", "w+") as outfile: json.dump(json_data, outfile) print(f"saved to {output_path}.simularium") + + @staticmethod + def save_replacing_nan( + trajectory_data: "TrajectoryData", output_path: str, validate_ids: bool + ) -> None: + """ + Save simularium data in JSON format, replacing all NaN with null and + converting NumPy arrays to native types. + + Parameters + ---------- + trajectory_data: TrajectoryData + The data to save. + output_path: str + Output file path (without extension). + validate_ids: bool + Whether to perform agent ID validation. + """ + + def _sanitize_for_json(obj): + """ + Recursively convert to JSON-safe structure: + - NaN → None + - numpy arrays → lists + - numpy scalars → native Python scalars + """ + if isinstance(obj, float): + return None if math.isnan(obj) else obj + elif isinstance(obj, (np.floating, np.integer)): + return obj.item() + elif isinstance(obj, np.ndarray): + return _sanitize_for_json(obj.tolist()) + elif isinstance(obj, list): + return [_sanitize_for_json(x) for x in obj] + elif isinstance(obj, dict): + return {k: _sanitize_for_json(v) for k, v in obj.items()} + else: + return obj + + if validate_ids: + Writer._validate_ids(trajectory_data) + + print("Formatting trajectory data...") + json_data = JsonWriter.format_trajectory_data(trajectory_data) + + print("Sanitizing JSON for NaNs and NumPy types...") + sanitized_data = _sanitize_for_json(json_data) + + print("Writing sanitized JSON...") + with open(f"{output_path}.simularium", "w") as outfile: + json.dump(sanitized_data, outfile, indent=2, allow_nan=False) + print(f"Saved clean JSON to {output_path}.simularium") @staticmethod def save_plot_data(plot_data: List[Dict[str, Any]], output_path: str):