Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions examples/Tutorial_nerdss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/complexenumeration-env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from IPython.display import Image\n",
"\n",
Expand Down Expand Up @@ -118,15 +127,45 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading PDB Data -------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/complexenumeration-env/lib/python3.11/site-packages/MDAnalysis/topology/PDBParser.py:350: UserWarning: Element information is missing, elements attribute will not be populated. If needed these can be guessed using universe.guess_TopologyAttrs(context='default', to_guess=['elements']).\n",
" warnings.warn(\"Element information is missing, elements attribute \"\n",
"/opt/anaconda3/envs/complexenumeration-env/lib/python3.11/site-packages/MDAnalysis/topology/PDBParser.py:350: UserWarning: Element information is missing, elements attribute will not be populated. If needed these can be guessed using universe.guess_TopologyAttrs(context='default', to_guess=['elements']).\n",
" warnings.warn(\"Element information is missing, elements attribute \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Filtering: translation -------------\n",
"Formatting trajectory data...\n",
"Converting Trajectory Data to JSON -------------\n",
"Sanitizing JSON for NaNs and NumPy types...\n",
"Writing sanitized JSON...\n",
"Saved clean JSON to example_virus.simularium\n"
]
}
],
"source": [
"converter = NerdssConverter(example_data)\n",
"# this _filter is just roughly centering the data in the box\n",
"_filter = TranslateFilter(default_translation=np.array([-80.0, -80.0, -80.0]))\n",
"filtered_data = converter.filter_data([_filter])\n",
"JsonWriter.save(filtered_data, \"example_virus\", False)"
"JsonWriter.save_replacing_nan(filtered_data, \"example_virus\", False)"
]
},
{
Expand Down Expand Up @@ -167,7 +206,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.13"
},
"vscode": {
"interpreter": {
Expand Down
74 changes: 74 additions & 0 deletions simulariumio/tests/writers/test_json_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
test_json_writer.py

Unit tests for the JsonWriter class in simulariumio.writers.

Current only tests `save_replacing_nan()` to ensure that:
* All NaN values are replaced with null (None in Python).
* Output JSON is valid and does not contain any NaN.
* Output file is written correctly in `.simularium` format.

Status:
- Uses pytest with fixtures to generate a minimal TrajectoryData mock.
- Tests are self-contained and do not require writing to real disk locations.
- Assumes simulariumio is installed or importable via editable mode (`pip install -e .`).

To run:
> pytest simulariumio/tests/writers/test_json_writer.py
"""
import json
import os
import tempfile
import numpy as np
import pytest
from simulariumio import TrajectoryData, AgentData, UnitData, DisplayData
from simulariumio.constants import DISPLAY_TYPE, VALUES_PER_3D_POINT, VIZ_TYPE
from simulariumio.writers import JsonWriter

@pytest.fixture
def trajectory_with_nan():
# Create mock AgentData with NaNs
agent_data = AgentData.from_dimensions(dimensions=(1, 1, VALUES_PER_3D_POINT))
agent_data.positions[0][0] = [np.nan, 1.0, 2.0]
agent_data.radii[0][0] = np.nan
agent_data.subpoints[0][0] = [0.0, np.nan, 0.0, 1.0, 1.0, 1.0]
agent_data.types[0].append("A")
agent_data.unique_ids[0][0] = 1
agent_data.n_agents[0] = 1
agent_data.n_subpoints[0][0] = 6
agent_data.viz_types[0][0] = VIZ_TYPE.FIBER

# Wrap into a TrajectoryData object
traj_data = TrajectoryData(
meta_data=None,
agent_data=agent_data,
time_units=UnitData(name="s"),
spatial_units=UnitData(name="nm"),
plots=None,
)
return traj_data

def test_save_replacing_nan(trajectory_with_nan):
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_output")

# Act: Call method under test
JsonWriter.save_replacing_nan(trajectory_with_nan, output_path, validate_ids=False)

# Verify file exists
output_file = output_path + ".simularium"
assert os.path.exists(output_file)

# Verify contents
with open(output_file, "r") as f:
json_data = json.load(f)

# Assert that NaN was replaced with null
positions = json_data["trajectoryInfo"]["agentData"]["positions"][0][0]
radii = json_data["trajectoryInfo"]["agentData"]["radii"][0][0]
subpoints = json_data["trajectoryInfo"]["agentData"]["subpoints"][0][0]

assert positions[0] is None
assert isinstance(positions[1], float)
assert radii is None
assert subpoints[1] is None
53 changes: 53 additions & 0 deletions simulariumio/writers/json_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
from typing import Any, Dict, List
import math

import numpy as np

Expand Down Expand Up @@ -185,6 +186,58 @@ def save(
with open(f"{output_path}.simularium", "w+") as outfile:
json.dump(json_data, outfile)
print(f"saved to {output_path}.simularium")

@staticmethod
def save_replacing_nan(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could do this with just an additional function argument to the main save function rather than having an entirely different function?

trajectory_data: "TrajectoryData", output_path: str, validate_ids: bool
) -> None:
"""
Save simularium data in JSON format, replacing all NaN with null and
converting NumPy arrays to native types.

Parameters
----------
trajectory_data: TrajectoryData
The data to save.
output_path: str
Output file path (without extension).
validate_ids: bool
Whether to perform agent ID validation.
"""

def _sanitize_for_json(obj):
"""
Recursively convert to JSON-safe structure:
- NaN → None
- numpy arrays → lists
- numpy scalars → native Python scalars
"""
if isinstance(obj, float):
return None if math.isnan(obj) else obj
elif isinstance(obj, (np.floating, np.integer)):
return obj.item()
elif isinstance(obj, np.ndarray):
return _sanitize_for_json(obj.tolist())
elif isinstance(obj, list):
return [_sanitize_for_json(x) for x in obj]
elif isinstance(obj, dict):
return {k: _sanitize_for_json(v) for k, v in obj.items()}
else:
return obj

if validate_ids:
Writer._validate_ids(trajectory_data)

print("Formatting trajectory data...")
json_data = JsonWriter.format_trajectory_data(trajectory_data)

print("Sanitizing JSON for NaNs and NumPy types...")
sanitized_data = _sanitize_for_json(json_data)

print("Writing sanitized JSON...")
with open(f"{output_path}.simularium", "w") as outfile:
json.dump(sanitized_data, outfile, indent=2, allow_nan=False)
print(f"Saved clean JSON to {output_path}.simularium")

@staticmethod
def save_plot_data(plot_data: List[Dict[str, Any]], output_path: str):
Expand Down