Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 59 additions & 13 deletions omtool/core/datamodel/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
"""
Struct that holds together particle set and timestamp that it describes.
"""

import contextlib
import pandas as pd
from amuse.datamodel.particles import Particles
from amuse.lab import units
from amuse.units.quantities import ScalarQuantity
from astropy.io import fits

fields = {
"x": units.kpc,
"y": units.kpc,
"z": units.kpc,
"vx": units.kms,
"vy": units.kms,
"vz": units.kms,
"mass": units.MSun,
"x": 1 | units.kpc,
"y": 1 | units.kpc,
"z": 1 | units.kpc,
"vx": 1 | units.kms,
"vy": 1 | units.kms,
"vz": 1 | units.kms,
"mass": 1 | units.MSun,
}

optional_fields = {
"is_barion": None,
}

Expand All @@ -29,9 +34,50 @@ def __init__(
particles: Particles = Particles(),
timestamp: ScalarQuantity = 0 | units.Myr,
):
self.particles = particles
self._particles_df = pd.DataFrame(columns=fields.keys())

for field, unit in fields.items():
if unit is None:
self._particles_df[field] = getattr(particles, field)
else:
self._particles_df[field] = getattr(particles, field) / unit

for field, unit in optional_fields.items():
with contextlib.suppress(AttributeError):
if unit is None:
self._particles_df[field] = getattr(particles, field)
else:
self._particles_df[field] = getattr(particles, field) / unit

self._particles = particles
self.timestamp = timestamp

def get_amuse_particles(self) -> Particles:
particles = Particles(len(self._particles_df))

for column in self._particles_df.columns:
unit = fields[column]
if unit is None:
setattr(particles, column, self._particles_df[column])
else:
setattr(particles, column, self._particles_df[column] * unit)

return particles

@property
def particles(self) -> Particles:
"""
Returns AMUSE Particles object.
"""
return self._particles

@particles.setter
def particles(self, particles: Particles):
"""
Deprecated. Sets particles from AMUSE Particles object. Exists only for backwards compatibility.
"""
self._particles = particles

def __getitem__(self, value) -> "Snapshot":
return Snapshot(self.particles[value], self.timestamp)

Expand Down Expand Up @@ -62,7 +108,7 @@ def to_fits(self, filename: str, append: bool = False):
"""
cols = []

for (key, val) in fields.items():
for (key, val) in fields.items() + optional_fields.items():
if not hasattr(self.particles, key):
continue

Expand Down Expand Up @@ -91,15 +137,15 @@ def to_fits(self, filename: str, append: bool = False):
def to_csv(self, filename: str):
df = pd.DataFrame(columns=fields.keys())

for key, val in fields.items():
for key, val in fields.items() + optional_fields.items():
if not hasattr(self.particles, key):
continue

array = getattr(self.particles, key)

if val is not None:
array = array.value_in(val)
if val is None:
val = 1

df[key] = array
df[key] = array / val

df.to_csv(filename)
20 changes: 18 additions & 2 deletions omtool/core/utils/base_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,27 @@ def setUp(self):
def assertNdarraysEqual(self, first: np.ndarray, second: np.ndarray):
np.testing.assert_array_equal(first, second)

def assertAmuseParticlesEqual(self, first: Particles, second: Particles):
self.assertEqual(len(first), len(second))

if len(first) == 0:
return

self.assertNdarraysEqual(first.x, second.x)
self.assertNdarraysEqual(first.y, second.y)
self.assertNdarraysEqual(first.z, second.z)

self.assertNdarraysEqual(first.vx, second.vx)
self.assertNdarraysEqual(first.vy, second.vy)
self.assertNdarraysEqual(first.vz, second.vz)

self.assertNdarraysEqual(first.mass, second.mass)

def assertSnapshotsEqual(self, first: Snapshot, second: Snapshot, test_kinematics: bool = True):
self.assertEqual(len(first.particles), len(second.particles))

self.assertEqual(first.timestamp, second.timestamp)

if len(first.particles) == 0:
return

Expand All @@ -31,8 +49,6 @@ def assertSnapshotsEqual(self, first: Snapshot, second: Snapshot, test_kinematic

self.assertNdarraysEqual(first.particles.mass, second.particles.mass)

self.assertEqual(first.timestamp, second.timestamp)

def _generate_snapshot(self, N: int = 100) -> Snapshot:
snapshot = Snapshot(Particles(N))
snapshot.particles.mass = [10 * x + 1 for x in range(N)] | units.MSun
Expand Down
Empty file added tests/core/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions tests/core/snapshot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from omtool.core.utils import BaseTestCase
from omtool.core.datamodel import Snapshot
from amuse.lab import Particles, units, VectorQuantity
import numpy as np


class TestSnapshot(BaseTestCase):
def test_get_amuse_particles_empty(self):
particles = Particles()
particles.position = VectorQuantity([], units.kpc)
particles.velocity = VectorQuantity([], units.kms)
particles.mass = VectorQuantity([], units.MSun)
snapshot = Snapshot(particles)

actual = snapshot.get_amuse_particles()
self.assertEqual(len(actual), 0)

def test_get_amuse_particles_two_particles(self):
particles = Particles(2)
particles.position = VectorQuantity(np.array([[1, 1, 1], [2, 2, 2]]), units.kpc)
particles.velocity = VectorQuantity(np.array([[3, 3, 3], [4, 4, 4]]), units.kms)
particles.mass = VectorQuantity(np.array([10, 20]), units.MSun)
snapshot = Snapshot(particles)

actual = snapshot.get_amuse_particles()
self.assertAmuseParticlesEqual(actual, particles)
3 changes: 1 addition & 2 deletions tests/models/fits_model_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Iterator
from unittest.mock import patch

from amuse.lab import Particles, ScalarQuantity, units
from amuse.lab import Particles, units

from omtool.core.datamodel import Snapshot
from omtool.core.utils import BaseTestCase
Expand Down