From 395740c34adc2dfac43f39e70958a288b8ab2aea Mon Sep 17 00:00:00 2001 From: arjunsavel Date: Tue, 4 Mar 2025 20:17:05 -0500 Subject: [PATCH] add DB refresh --- pyproject.toml | 1 + src/scope/input_output.py | 23 +++++++++++++++++++++++ src/scope/tests/test_io.py | 9 +++++++++ 3 files changed, 33 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5dc41a9..b45347e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "exoplanet-core", "pymc>=4", "schwimmbad", + "astroquery" ] dynamic = ["version"] diff --git a/src/scope/input_output.py b/src/scope/input_output.py index f2cf7c9..e2b811f 100644 --- a/src/scope/input_output.py +++ b/src/scope/input_output.py @@ -9,12 +9,15 @@ from datetime import datetime import pandas as pd +from astroquery.ipac.nexsci.nasa_exoplanet_archive import NasaExoplanetArchive from scope.calc_quantities import * from scope.logger import * logger = get_logger() +data_dir = os.path.join(os.path.dirname(__file__), "./data") + class ScopeConfigError(Exception): def __init__(self, message="scope input file error:"): @@ -423,3 +426,23 @@ def parse_arguments(): ) return parser.parse_args() + + +def refresh_db(): + """ + Refresh the database with the latest exoplanet data. + """ + # Download the latest exoplanet data + # Update the database file + + table = NasaExoplanetArchive.query_criteria( + table="pscomppars", select="*", where="pl_name is not null" + ) + + # Convert to Pandas DataFrame for easier handling + df = table.to_pandas() + + # Save to CSV + filepath = os.path.join(data_dir, "default_params_exoplanet_archive.csv") + df.to_csv(filepath, index=False) + return df diff --git a/src/scope/tests/test_io.py b/src/scope/tests/test_io.py index 81c68e6..a2e51f5 100644 --- a/src/scope/tests/test_io.py +++ b/src/scope/tests/test_io.py @@ -10,6 +10,7 @@ write_input_file, ScopeConfigError, parameter_mapping, + refresh_db, ) test_data_path = os.path.join(os.path.dirname(__file__), "../data") @@ -210,3 +211,11 @@ def test_database_columns(): for value in parameter_mapping.values(): assert value in db.columns + + +def test_refresh_db(): + """ + just test that a reasonable db comes back + """ + df = refresh_db() + assert len(df) > 0 and "pl_name" in df.columns