diff --git a/omtool/core/datamodel/snapshot.py b/omtool/core/datamodel/snapshot.py index 50dd883..bfa51ed 100644 --- a/omtool/core/datamodel/snapshot.py +++ b/omtool/core/datamodel/snapshot.py @@ -1,6 +1,8 @@ """ Struct that holds together particle set and timestamp that it describes. """ + +import contextlib import pandas as pd from amuse.datamodel.particles import Particles from amuse.lab import units @@ -8,13 +10,16 @@ from astropy.io import fits fields = { - "x": units.kpc, - "y": units.kpc, - "z": units.kpc, - "vx": units.kms, - "vy": units.kms, - "vz": units.kms, - "mass": units.MSun, + "x": 1 | units.kpc, + "y": 1 | units.kpc, + "z": 1 | units.kpc, + "vx": 1 | units.kms, + "vy": 1 | units.kms, + "vz": 1 | units.kms, + "mass": 1 | units.MSun, +} + +optional_fields = { "is_barion": None, } @@ -29,9 +34,50 @@ def __init__( particles: Particles = Particles(), timestamp: ScalarQuantity = 0 | units.Myr, ): - self.particles = particles + self._particles_df = pd.DataFrame(columns=fields.keys()) + + for field, unit in fields.items(): + if unit is None: + self._particles_df[field] = getattr(particles, field) + else: + self._particles_df[field] = getattr(particles, field) / unit + + for field, unit in optional_fields.items(): + with contextlib.suppress(AttributeError): + if unit is None: + self._particles_df[field] = getattr(particles, field) + else: + self._particles_df[field] = getattr(particles, field) / unit + + self._particles = particles self.timestamp = timestamp + def get_amuse_particles(self) -> Particles: + particles = Particles(len(self._particles_df)) + + for column in self._particles_df.columns: + unit = fields[column] + if unit is None: + setattr(particles, column, self._particles_df[column]) + else: + setattr(particles, column, self._particles_df[column] * unit) + + return particles + + @property + def particles(self) -> Particles: + """ + Returns AMUSE Particles object. + """ + return self._particles + + @particles.setter + def particles(self, particles: Particles): + """ + Deprecated. Sets particles from AMUSE Particles object. Exists only for backwards compatibility. + """ + self._particles = particles + def __getitem__(self, value) -> "Snapshot": return Snapshot(self.particles[value], self.timestamp) @@ -62,7 +108,7 @@ def to_fits(self, filename: str, append: bool = False): """ cols = [] - for (key, val) in fields.items(): + for (key, val) in fields.items() + optional_fields.items(): if not hasattr(self.particles, key): continue @@ -91,15 +137,15 @@ def to_fits(self, filename: str, append: bool = False): def to_csv(self, filename: str): df = pd.DataFrame(columns=fields.keys()) - for key, val in fields.items(): + for key, val in fields.items() + optional_fields.items(): if not hasattr(self.particles, key): continue array = getattr(self.particles, key) - if val is not None: - array = array.value_in(val) + if val is None: + val = 1 - df[key] = array + df[key] = array / val df.to_csv(filename) diff --git a/omtool/core/utils/base_test_case.py b/omtool/core/utils/base_test_case.py index f3b02d4..fafae4b 100644 --- a/omtool/core/utils/base_test_case.py +++ b/omtool/core/utils/base_test_case.py @@ -14,9 +14,27 @@ def setUp(self): def assertNdarraysEqual(self, first: np.ndarray, second: np.ndarray): np.testing.assert_array_equal(first, second) + def assertAmuseParticlesEqual(self, first: Particles, second: Particles): + self.assertEqual(len(first), len(second)) + + if len(first) == 0: + return + + self.assertNdarraysEqual(first.x, second.x) + self.assertNdarraysEqual(first.y, second.y) + self.assertNdarraysEqual(first.z, second.z) + + self.assertNdarraysEqual(first.vx, second.vx) + self.assertNdarraysEqual(first.vy, second.vy) + self.assertNdarraysEqual(first.vz, second.vz) + + self.assertNdarraysEqual(first.mass, second.mass) + def assertSnapshotsEqual(self, first: Snapshot, second: Snapshot, test_kinematics: bool = True): self.assertEqual(len(first.particles), len(second.particles)) + self.assertEqual(first.timestamp, second.timestamp) + if len(first.particles) == 0: return @@ -31,8 +49,6 @@ def assertSnapshotsEqual(self, first: Snapshot, second: Snapshot, test_kinematic self.assertNdarraysEqual(first.particles.mass, second.particles.mass) - self.assertEqual(first.timestamp, second.timestamp) - def _generate_snapshot(self, N: int = 100) -> Snapshot: snapshot = Snapshot(Particles(N)) snapshot.particles.mass = [10 * x + 1 for x in range(N)] | units.MSun diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/snapshot_test.py b/tests/core/snapshot_test.py new file mode 100644 index 0000000..a021cf0 --- /dev/null +++ b/tests/core/snapshot_test.py @@ -0,0 +1,26 @@ +from omtool.core.utils import BaseTestCase +from omtool.core.datamodel import Snapshot +from amuse.lab import Particles, units, VectorQuantity +import numpy as np + + +class TestSnapshot(BaseTestCase): + def test_get_amuse_particles_empty(self): + particles = Particles() + particles.position = VectorQuantity([], units.kpc) + particles.velocity = VectorQuantity([], units.kms) + particles.mass = VectorQuantity([], units.MSun) + snapshot = Snapshot(particles) + + actual = snapshot.get_amuse_particles() + self.assertEqual(len(actual), 0) + + def test_get_amuse_particles_two_particles(self): + particles = Particles(2) + particles.position = VectorQuantity(np.array([[1, 1, 1], [2, 2, 2]]), units.kpc) + particles.velocity = VectorQuantity(np.array([[3, 3, 3], [4, 4, 4]]), units.kms) + particles.mass = VectorQuantity(np.array([10, 20]), units.MSun) + snapshot = Snapshot(particles) + + actual = snapshot.get_amuse_particles() + self.assertAmuseParticlesEqual(actual, particles) diff --git a/tests/models/fits_model_test.py b/tests/models/fits_model_test.py index 36f93be..0c2ebaa 100644 --- a/tests/models/fits_model_test.py +++ b/tests/models/fits_model_test.py @@ -1,7 +1,6 @@ -from typing import Iterator from unittest.mock import patch -from amuse.lab import Particles, ScalarQuantity, units +from amuse.lab import Particles, units from omtool.core.datamodel import Snapshot from omtool.core.utils import BaseTestCase