From 8d6117665df5b43b87b05cc1c6e07ef519dd55dc Mon Sep 17 00:00:00 2001 From: Sam Bianco Date: Mon, 16 Feb 2026 16:53:11 -0500 Subject: [PATCH 1/3] Enable cloud by default, tests --- astroquery/exceptions.py | 8 + astroquery/mast/__init__.py | 4 + astroquery/mast/cloud.py | 17 +- astroquery/mast/observations.py | 77 +++- astroquery/mast/tests/test_mast.py | 457 +++++++++++----------- astroquery/mast/tests/test_mast_remote.py | 80 ++-- 6 files changed, 334 insertions(+), 309 deletions(-) diff --git a/astroquery/exceptions.py b/astroquery/exceptions.py index 3b466a5c39..5a3f9b2041 100644 --- a/astroquery/exceptions.py +++ b/astroquery/exceptions.py @@ -119,3 +119,11 @@ class BlankResponseWarning(AstropyWarning): Astroquery warning to be raised if one or more rows in a table are bad, but not all rows are. """ + pass + + +class CloudAccessWarning(AstropyWarning): + """ + Astroquery warning to be raised if cloud access cannot be enabled. + """ + pass diff --git a/astroquery/mast/__init__.py b/astroquery/mast/__init__.py index 96028a89e4..68745f7fcc 100644 --- a/astroquery/mast/__init__.py +++ b/astroquery/mast/__init__.py @@ -29,6 +29,10 @@ class Conf(_config.ConfigNamespace): pagesize = _config.ConfigItem( 50000, 'Number of results to request at once from the STScI server.') + enable_cloud_dataset = _config.ConfigItem( + True, + 'Enable access to cloud-hosted datasets (e.g. on AWS S3) by default. ' + 'Requires the `boto3` and `botocore` packages to be installed.') conf = Conf() diff --git a/astroquery/mast/cloud.py b/astroquery/mast/cloud.py index 2f77dc59bf..e1e7b8ebc1 100644 --- a/astroquery/mast/cloud.py +++ b/astroquery/mast/cloud.py @@ -13,12 +13,22 @@ from astroquery import log from astropy.utils.console import ProgressBarOrSpinner from astropy.utils.exceptions import AstropyDeprecationWarning -from botocore.exceptions import ClientError, BotoCoreError from ..exceptions import RemoteServiceError, NoResultsWarning from . import utils +try: + import boto3 + HAS_BOTO3 = True +except ImportError: + HAS_BOTO3 = False +try: + import botocore + from botocore.exceptions import ClientError, BotoCoreError + HAS_BOTOCORE = True +except ImportError: + HAS_BOTOCORE = False __all__ = [] @@ -44,15 +54,14 @@ def __init__(self, provider="AWS", profile=None, verbose=False): verbose : bool Default False. Display extra info and warnings if true. """ + if not HAS_BOTO3 or not HAS_BOTOCORE: + raise ImportError("Please install the `boto3` and `botocore` packages to enable cloud dataset access.") # Dealing with deprecated argument if profile is not None: warnings.warn(("MAST Open Data on AWS is now free to access and does " "not require an AWS account"), AstropyDeprecationWarning) - import boto3 - import botocore - self.boto3 = boto3 self.botocore = botocore self.config = botocore.client.Config(signature_version=botocore.UNSIGNED) diff --git a/astroquery/mast/observations.py b/astroquery/mast/observations.py index 2ae5fc9c6b..3a4a2a45a4 100644 --- a/astroquery/mast/observations.py +++ b/astroquery/mast/observations.py @@ -18,7 +18,6 @@ import astropy.units as u import astropy.coordinates as coord -from botocore.exceptions import ClientError, BotoCoreError from astropy.table import Table, Row, vstack from astroquery import log @@ -27,13 +26,17 @@ from ..utils import async_to_sync from ..utils.class_or_instance import class_or_instance -from ..exceptions import (InvalidQueryError, RemoteServiceError, NoResultsWarning, InputWarning) +from ..exceptions import (InvalidQueryError, RemoteServiceError, NoResultsWarning, InputWarning, CloudAccessWarning) -from . import utils +from . import utils, conf from .core import MastQueryWithLogin -__all__ = ['Observations', 'ObservationsClass', - 'MastClass', 'Mast'] +try: + from botocore.exceptions import ClientError, BotoCoreError +except ImportError: + ClientError = BotoCoreError = () + +__all__ = ['Observations', 'ObservationsClass', 'MastClass', 'Mast'] @async_to_sync @@ -51,6 +54,24 @@ class ObservationsClass(MastQueryWithLogin): _caom_filtered = 'Mast.Caom.Filtered' _caom_products = 'Mast.Caom.Products' + def __init__(self, mast_token=None): + super().__init__(mast_token) + self._cloud_enabled_explicitly = None # Track whether cloud access was explicitly enabled by the user + + def _ensure_cloud_access(self): + """Ensure cloud access is initialized if appropriate.""" + # User explicitly disabled + if self._cloud_enabled_explicitly is False: + return + + # Already initialized + if self._cloud_connection is not None: + return + + # Default behavior is to enable cloud access if the config option is set, so we check that here + if self._cloud_enabled_explicitly is None and conf.enable_cloud_dataset: + self.enable_cloud_dataset(_internal=True) + def _parse_result(self, responses, *, verbose=False): # Used by the async_to_sync decorator functionality """ Parse the results of a list of `~requests.Response` objects and returns an `~astropy.table.Table` of results. @@ -180,7 +201,7 @@ def _parse_caom_criteria(self, *, resolver=None, **criteria): return position, mashup_filters - def enable_cloud_dataset(self, provider="AWS", profile=None, verbose=True): + def enable_cloud_dataset(self, provider="AWS", profile=None, verbose=True, *, _internal=False): """ Enable downloading public files from S3 instead of MAST. Requires the boto3 library to function. @@ -196,13 +217,21 @@ def enable_cloud_dataset(self, provider="AWS", profile=None, verbose=True): Default True. Logger to display extra info and warning. """ - self._cloud_connection = CloudAccess(provider, profile, verbose) + try: + self._cloud_connection = CloudAccess(provider, profile, verbose) + if not _internal: + self._cloud_enabled_explicitly = True + except ImportError as e: + # boto3 or botocore is not installed + self._cloud_connection = None + warnings.warn(e.msg, CloudAccessWarning) def disable_cloud_dataset(self): """ Disables downloading public files from S3 instead of MAST. """ self._cloud_connection = None + self._cloud_enabled_explicitly = False @class_or_instance def query_region_async(self, coordinates, *, radius=0.2*u.deg, pagesize=None, page=None): @@ -656,6 +685,9 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou url : str The full url download path """ + # Ensure cloud access is enabled + self._ensure_cloud_access() + if not uri or not isinstance(uri, str): raise InvalidQueryError("A valid data product URI must be provided.") @@ -693,8 +725,9 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou NoResultsWarning) return 'SKIPPED', None, None - warnings.warn(f'The product {uri} was not found in the cloud. ' - 'Falling back to MAST download.', InputWarning) + if self._cloud_enabled_explicitly: + warnings.warn(f'The product {uri} was not found in the cloud. ' + 'Falling back to MAST download.', InputWarning) self._download_file(escaped_url, local_path, cache=cache, head_safe=True, verbose=verbose) except (ClientError, BotoCoreError) as ex: # Should be in cloud, but download failed @@ -703,8 +736,9 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou NoResultsWarning) return 'SKIPPED', None, None - warnings.warn(f'Could not download {uri} from cloud: {ex}. Falling back to MAST download.', - InputWarning) + if self._cloud_enabled_explicitly: + warnings.warn(f'Could not download {uri} from cloud: {ex}. Falling back to MAST download.', + InputWarning) self._download_file(escaped_url, local_path, cache=cache, head_safe=True, verbose=verbose) else: if cloud_only: @@ -771,7 +805,6 @@ def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_o status, msg, url = 'ERROR', None, None cloud_uri = cloud_uri_map.get(mast_uri) if cloud_uri_map else None - if cloud_uri: try: self._cloud_connection.download_file_from_cloud(cloud_uri, local_path, cache, verbose) @@ -784,8 +817,9 @@ def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_o status = 'SKIPPED' msg = str(ex) else: - warnings.warn(f'Could not download {cloud_uri} from cloud: {ex}. ' - 'Falling back to MAST download.', InputWarning) + if self._cloud_enabled_explicitly: + warnings.warn(f'Could not download {cloud_uri} from cloud: {ex}. ' + 'Falling back to MAST download.', InputWarning) status, msg, url = self.download_file(mast_uri, local_path=local_path, cache=cache, force_on_prem=True, verbose=verbose) else: @@ -797,8 +831,9 @@ def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_o status = 'SKIPPED' msg = 'Product not found in cloud' else: - warnings.warn(f'The product {mast_uri} was not found in the cloud. ' - 'Falling back to MAST download.', InputWarning) + if self._cloud_enabled_explicitly: + warnings.warn(f'The product {mast_uri} was not found in the cloud. ' + 'Falling back to MAST download.', InputWarning) status, msg, url = self.download_file(mast_uri, local_path=local_path, cache=cache, force_on_prem=True, verbose=verbose) else: @@ -899,6 +934,9 @@ def download_products(self, products, *, download_dir=None, flat=False, response : `~astropy.table.Table` The manifest of files downloaded, or status of files on disk if curl option chosen. """ + # Ensure cloud access is enabled + self._ensure_cloud_access() + # If the products list is a row we need to cast it as a table if isinstance(products, Row): products = Table(products, masked=True) @@ -961,6 +999,9 @@ def list_cloud_datasets(self): response : list List of dataset prefixes that support cloud data access. """ + # Ensure cloud access is enabled + self._ensure_cloud_access() + if self._cloud_connection is None: raise RemoteServiceError( 'Please enable anonymous cloud access by calling `enable_cloud_dataset` method. ' @@ -1027,6 +1068,8 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa List of URIs generated from the data products. May contain entries that are None if data_products includes products not found in the cloud. """ + # Ensure cloud access is enabled + self._ensure_cloud_access() if self._cloud_connection is None: raise RemoteServiceError( @@ -1110,6 +1153,8 @@ def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False): Cloud URI generated from the data product. If the product cannot be found in the cloud, None is returned. """ + # Ensure cloud access is enabled + self._ensure_cloud_access() if self._cloud_connection is None: raise RemoteServiceError( diff --git a/astroquery/mast/tests/test_mast.py b/astroquery/mast/tests/test_mast.py index 286a92cf90..04fecd2c63 100644 --- a/astroquery/mast/tests/test_mast.py +++ b/astroquery/mast/tests/test_mast.py @@ -5,7 +5,7 @@ import re import warnings from shutil import copyfile -from unittest.mock import patch +from unittest.mock import MagicMock, patch from pathlib import Path import astropy.units as u @@ -14,16 +14,20 @@ from astropy.table import Table, unique from astropy.coordinates import SkyCoord from astropy.io import fits -from botocore.exceptions import ClientError from astropy.utils.exceptions import AstropyDeprecationWarning from requests import HTTPError, Response from astroquery.mast import (Catalogs, MastMissions, Observations, Tesscut, Zcut, Mast, utils, services, - discovery_portal, auth, core) + discovery_portal, auth, core, cloud) from astroquery.mast.cloud import CloudAccess from astroquery.utils.mocks import MockResponse from astroquery.exceptions import (BlankResponseWarning, InvalidQueryError, InputWarning, MaxResultsWarning, - NoResultsWarning, RemoteServiceError, ResolverError) + NoResultsWarning, RemoteServiceError, ResolverError, CloudAccessWarning) + +try: + from botocore.exceptions import ClientError +except ImportError: + ClientError = BotoCoreError = () DATA_FILES = {'Mast.Caom.Cone': 'caom.json', 'Mast.Name.Lookup': 'resolver.json', @@ -68,7 +72,7 @@ def data_path(filename): return os.path.join(data_dir, filename) -@pytest.fixture +@pytest.fixture(autouse=True) def patch_post(request): mp = request.getfixturevalue("monkeypatch") @@ -84,6 +88,32 @@ def patch_post(request): return mp +@pytest.fixture() +def patch_boto3(monkeypatch): + """Fixture to patch boto3 client and resource for cloud access tests.""" + pytest.importorskip('boto3') + mock_client = MagicMock() + mock_client.head_object.return_value = {'ContentLength': 12345} + + mock_resource = MagicMock() + mock_resource.Bucket.return_value.download_file.return_value = None + + monkeypatch.setattr('boto3.client', lambda *args, **kwargs: mock_client) + monkeypatch.setattr('boto3.resource', lambda *args, **kwargs: mock_resource) + + return mock_client, mock_resource + + +@pytest.fixture() +def reset_cloud_state(): + """Reset the cloud dataset access state in Observations before and after each test.""" + Observations.disable_cloud_dataset() + Observations._cloud_enabled_explicitly = None + yield + Observations.disable_cloud_dataset() + Observations._cloud_enabled_explicitly = None + + def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwargs): if "columnsconfig" in url: if "Mast.Catalogs.Tess.Cone" in data: @@ -210,23 +240,23 @@ def zcut_download_mockreturn(url, file_path): ########################### -def test_missions_query_region_async(patch_post): +def test_missions_query_region_async(): responses = MastMissions.query_region_async(regionCoords, radius=0.002, sci_pi_last_name='GORDON') assert isinstance(responses, MockResponse) -def test_missions_query_object_async(patch_post): +def test_missions_query_object_async(): responses = MastMissions.query_object_async("M101", radius="0.002 deg") assert isinstance(responses, MockResponse) -def test_missions_query_object(patch_post): +def test_missions_query_object(): result = MastMissions.query_object("M101", radius=".002 deg") assert isinstance(result, Table) assert len(result) > 0 -def test_missions_query_region(patch_post): +def test_missions_query_region(): result = MastMissions.query_region(regionCoords, sci_instrume=['ACS', 'WFPC'], radius=0.002 * u.deg, @@ -236,7 +266,7 @@ def test_missions_query_region(patch_post): assert len(result) > 0 -def test_missions_query_criteria_async(patch_post): +def test_missions_query_criteria_async(): responses = MastMissions.query_criteria_async( coordinates=regionCoords, radius=3, @@ -248,7 +278,7 @@ def test_missions_query_criteria_async(patch_post): assert isinstance(responses, MockResponse) -def test_missions_query_criteria_async_with_missing_results(patch_post): +def test_missions_query_criteria_async_with_missing_results(): with pytest.raises(KeyError): responses = MastMissions.query_criteria_async( coordinates=regionCoords, @@ -262,7 +292,7 @@ def test_missions_query_criteria_async_with_missing_results(patch_post): services._json_to_table(json.loads(responses), 'results') -def test_missions_query_criteria(patch_post): +def test_missions_query_criteria(): result = MastMissions.query_criteria( coordinates=regionCoords, radius=3, @@ -300,7 +330,7 @@ def test_missions_query_criteria(patch_post): ) -def test_missions_parse_select_cols(patch_post): +def test_missions_parse_select_cols(): # Default columns cols = MastMissions._parse_select_cols(None) # Default columns for HST assert cols is None @@ -350,7 +380,7 @@ def test_missions_parse_select_cols(patch_post): assert col in ullyses_cols -def test_missions_get_product_list_async(patch_post): +def test_missions_get_product_list_async(): # String input result = MastMissions.get_product_list_async('Z14Z0104T') assert isinstance(result, list) @@ -385,7 +415,7 @@ def test_missions_get_product_list_async(patch_post): missions.get_product_list_async(Table({'a': [1, 2, 3]})) -def test_missions_get_product_list(patch_post): +def test_missions_get_product_list(): # String input result = MastMissions.get_product_list('Z14Z0104T') assert isinstance(result, Table) @@ -410,7 +440,7 @@ def test_missions_get_product_list(patch_post): assert isinstance(result, Table) -def test_missions_get_unique_product_list(patch_post, caplog): +def test_missions_get_unique_product_list(caplog): unique_products = MastMissions.get_unique_product_list('Z14Z0104T') assert isinstance(unique_products, Table) assert (len(unique_products) == len(unique(unique_products, keys='filename'))) @@ -419,7 +449,7 @@ def test_missions_get_unique_product_list(patch_post, caplog): assert caplog.text == '' -def test_missions_filter_products(patch_post): +def test_missions_filter_products(): # Filter products list by column products = MastMissions.get_product_list('Z14Z0104T') filtered = MastMissions.filter_products(products, category='CALIBRATED') @@ -503,7 +533,7 @@ def test_missions_filter_products(patch_post): MastMissions.filter_products(products, non_existing='value') -def test_missions_download_products(patch_post, tmp_path): +def test_missions_download_products(tmp_path): # Check string input test_dataset_id = 'Z14Z0104T' result = MastMissions.download_products(test_dataset_id, download_dir=tmp_path) @@ -556,7 +586,7 @@ def test_missions_download_products(patch_post, tmp_path): @patch.object(Path, 'is_file', return_value=True) -def test_missions_download_file(mock_is_file, patch_post, tmp_path): +def test_missions_download_file(mock_is_file, tmp_path): # JWST download missions = MastMissions() missions.mission = 'JWST' @@ -573,7 +603,7 @@ def test_missions_download_file(mock_is_file, patch_post, tmp_path): missions.download_file('classy_test_file.fits', local_path=tmp_path) -def test_missions_download_no_auth(patch_post, caplog): +def test_missions_download_no_auth(caplog): # Exclusive access products should not be downloaded if user is not authenticated # User is not authenticated uri = 'unauthorized.fits' @@ -596,7 +626,7 @@ def test_missions_download_no_auth(patch_post, caplog): assert 'You do not have access to download this data' in caplog.text -def test_missions_get_dataset_kwd(patch_post, caplog): +def test_missions_get_dataset_kwd(caplog): m = MastMissions() # Default is HST @@ -626,7 +656,7 @@ def test_missions_get_dataset_kwd(patch_post, caplog): [['query_region', dict()], ['query_criteria', dict(ang_sep=0.6)]] ) -def test_missions_radius_too_large(method, kwargs, patch_post): +def test_missions_radius_too_large(method, kwargs): m = MastMissions(mission='jwst') coordinates = SkyCoord(0, 0, unit=u.deg) radius = m._max_query_radius + 0.1 * u.deg @@ -641,14 +671,14 @@ def test_missions_radius_too_large(method, kwargs, patch_post): ################### -def test_list_missions(patch_post): +def test_list_missions(): missions = Observations.list_missions() assert isinstance(missions, list) for m in ['HST', 'HLA', 'GALEX', 'Kepler']: assert m in missions -def test_mast_service_request_async(patch_post): +def test_mast_service_request_async(): service = 'Mast.Name.Lookup' params = {'input': "M103", 'format': 'json'} @@ -660,7 +690,7 @@ def test_mast_service_request_async(patch_post): assert output -def test_mast_service_request(patch_post): +def test_mast_service_request(): service = 'Mast.Caom.Cone' params = {'ra': 23.34086, 'dec': 60.658, @@ -670,7 +700,7 @@ def test_mast_service_request(patch_post): assert isinstance(result, Table) -def test_mast_query(patch_post): +def test_mast_query(): # cone search result = Mast.mast_query('Mast.Caom.Cone', ra=23.34086, dec=60.658, radius=0.2) assert isinstance(result, Table) @@ -703,7 +733,7 @@ def test_mast_query(patch_post): Mast.mast_query('Mast.Caom.Filtered', s_ra={'min': 10.0}) -def test_resolve_object_single(patch_post): +def test_resolve_object_single(): obj = "TIC 307210830" tic_coord = SkyCoord(124.531756290083, -68.3129998725044, unit="deg") simbad_coord = SkyCoord(124.5317560026638, -68.3130014904408, unit="deg") @@ -752,7 +782,7 @@ def test_resolve_object_single(patch_post): assert isinstance(loc, dict) -def test_resolve_object_multi(patch_post): +def test_resolve_object_multi(): objects = ["TIC 307210830", "M1", "Barnard's Star"] # No resolver specified @@ -792,7 +822,7 @@ def test_resolve_object_multi(patch_post): Mast.resolve_object(["nonexisting1", "nonexisting2"]) -def test_login_logout(patch_post): +def test_login_logout(): test_token = "56a9cf3df4c04052atest43feb87f282" Mast.login(token=test_token) @@ -804,7 +834,7 @@ def test_login_logout(patch_post): assert not Mast._session.cookies.get("mast_token") -def test_session_info(patch_post): +def test_session_info(): info = Mast.session_info(verbose=False) assert isinstance(info, dict) assert info['ezid'] == 'alice' @@ -819,27 +849,27 @@ def test_session_info(patch_post): # query functions -def test_observations_query_region_async(patch_post): +def test_observations_query_region_async(): responses = Observations.query_region_async(regionCoords, radius=0.2) assert isinstance(responses, list) -def test_observations_query_region(patch_post): +def test_observations_query_region(): result = Observations.query_region(regionCoords, radius=0.2 * u.deg) assert isinstance(result, Table) -def test_observations_query_object_async(patch_post): +def test_observations_query_object_async(): responses = Observations.query_object_async("M103", radius="0.2 deg") assert isinstance(responses, list) -def test_observations_query_object(patch_post): +def test_observations_query_object(): result = Observations.query_object("M103", radius=".02 deg") assert isinstance(result, Table) -def test_query_observations_criteria_async(patch_post): +def test_query_observations_criteria_async(): # without position responses = Observations.query_criteria_async(dataproduct_type=["image"], proposal_pi="Ost*", @@ -852,7 +882,7 @@ def test_query_observations_criteria_async(patch_post): assert isinstance(responses, list) -def test_observations_query_criteria(patch_post): +def test_observations_query_criteria(): # without position result = Observations.query_criteria(dataproduct_type=["image"], proposal_pi="Ost*", @@ -874,17 +904,17 @@ def test_observations_query_criteria(patch_post): # count functions -def test_observations_query_region_count(patch_post): +def test_observations_query_region_count(): result = Observations.query_region_count(regionCoords, radius="0.2 deg") assert result == 599 -def test_observations_query_object_count(patch_post): +def test_observations_query_object_count(): result = Observations.query_object_count("M8", radius=0.2*u.deg) assert result == 599 -def test_observations_query_criteria_count(patch_post): +def test_observations_query_criteria_count(): result = Observations.query_criteria_count(dataproduct_type=["image"], proposal_pi="Ost*", s_dec=[43.5, 45.5]) @@ -901,7 +931,7 @@ def test_observations_query_criteria_count(patch_post): # product functions -def test_observations_get_product_list_async(patch_post): +def test_observations_get_product_list_async(): responses = Observations.get_product_list_async('2003738726') assert isinstance(responses, list) @@ -916,7 +946,7 @@ def test_observations_get_product_list_async(patch_post): assert isinstance(responses, list) -def test_observations_get_product_list(patch_post): +def test_observations_get_product_list(): result = Observations.get_product_list('2003738726') assert isinstance(result, Table) @@ -939,7 +969,7 @@ def test_observations_get_product_list(patch_post): Observations.get_product_list([' ']) -def test_observations_filter_products(patch_post): +def test_observations_filter_products(): products = Observations.get_product_list('2003738726') filtered = Observations.filter_products(products, productType=["sCiEnCE"], @@ -970,85 +1000,69 @@ def test_observations_filter_products(patch_post): Observations.filter_products(products, invalid=True) -def test_observations_download_products(patch_post, tmpdir): - # actually download the products - result = Observations.download_products('2003738726', - download_dir=str(tmpdir), - productType=["SCIENCE"], - mrp_only=False) +@patch.object(Path, "is_file", return_value=True) +def test_observations_download_products(mock_is_file, patch_boto3, monkeypatch, reset_cloud_state): + mock_resource = patch_boto3[1] + obsid = '2003738726' + data_uri = 'mast:HST/product/u9o40504m_c3m.fits' + + # Actually download the products + result = Observations.download_products(obsid, + dataURI=data_uri) assert isinstance(result, Table) - # just get the curl script - result = Observations.download_products('2003738726', - download_dir=str(tmpdir), + # Just get the curl script + result = Observations.download_products(obsid, curl_flag=True, productType=["SCIENCE"], mrp_only=False) assert isinstance(result, Table) - # without console output - result = Observations.download_products('2003738726', - download_dir=str(tmpdir), - productType=["SCIENCE"], + # Without console output, flat + result = Observations.download_products(obsid, + dataURI=data_uri, + flat=True, verbose=False) assert isinstance(result, Table) - # passing row product - products = Observations.get_product_list('2003738726') - result1 = Observations.download_products(products[0], download_dir=str(tmpdir)) + # Passing row product + products = Observations.get_product_list(obsid) + result1 = Observations.download_products(products[0]) assert isinstance(result1, Table) # Warn if no products to download with pytest.warns(NoResultsWarning, match='No products to download'): - result = Observations.download_products('2003738726', - download_dir=str(tmpdir), - productType=["INVALID_TYPE"]) + result = Observations.download_products(obsid, productType=["INVALID_TYPE"]) assert result is None # Warn if curl_flag and flags are both set with pytest.warns(InputWarning, match='flat=True has no effect on curl downloads.'): - result = Observations.download_products('2003738726', + result = Observations.download_products(obsid, curl_flag=True, flat=True) assert isinstance(result, Table) - -@patch('boto3.resource') -@patch('boto3.client') -@patch.object(Path, "is_file", return_value=True) -def test_observations_download_products_cloud(mock_is_file, mock_client, mock_resource, patch_post, - monkeypatch): - pytest.importorskip("boto3") - mock_client.return_value.head_object.return_value = {'ContentLength': 12345} - mock_resource.return_value.Bucket.return_value.download_file.return_value = None - obsid = '2003738726' - data_uri = 'mast:HST/product/u9o40504m_c3m.fits' - - # Enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() - result = Observations.download_products(obsid, dataURI=data_uri) assert isinstance(result, Table) assert result[0]['Status'] == 'COMPLETE' - # Mock cloud download failure, fallback to on-prem + # Mock cloud download failure + Observations.enable_cloud_dataset() # enable cloud dataset to emit warning client_err = ClientError({'Error': {'Code': '500', 'Message': 'Internal Server Error'}}, 'HeadObject') - mock_resource.return_value.Bucket.return_value.download_file.side_effect = client_err - # Check that info message is logged + mock_resource.Bucket.return_value.download_file.side_effect = client_err + # Warn and fall back to on-prem download with pytest.warns(InputWarning, match='Falling back to MAST download'): - result = Observations.download_products(obsid, - dataURI=data_uri) + result = Observations.download_products(obsid, dataURI=data_uri) assert result[0]['Status'] == 'COMPLETE' - - # Cloud download failure, do not fallback to on-prem + # Do not fall back to on-prem download, skip instead with pytest.warns(NoResultsWarning, match='Skipping download.'): result = Observations.download_products(obsid, dataURI=data_uri, cloud_only=True) assert result[0]['Status'] == 'SKIPPED' - # Products not found in cloud, skip download + # Products not found in cloud monkeypatch.setattr(Observations, 'get_cloud_uris', lambda *a, **k: {}) with pytest.warns(NoResultsWarning, match='was not found in the cloud. Skipping download.'): result = Observations.download_products(obsid, @@ -1056,15 +1070,13 @@ def test_observations_download_products_cloud(mock_is_file, mock_client, mock_re cloud_only=True) assert result[0]['Status'] == 'SKIPPED' assert result[0]['Message'] == 'Product not found in cloud' - - # Products not found in cloud, fall back + # Warn and fall back to on-prem download if products not found in cloud and cloud_only is False with pytest.warns(InputWarning, match='was not found in the cloud. Falling back to MAST download'): result = Observations.download_products(obsid, dataURI=data_uri) assert result[0]['Status'] == 'COMPLETE' - Observations.disable_cloud_dataset() - # Cloud access not enabled, warn if cloud_only is True + Observations.disable_cloud_dataset() with pytest.warns(InputWarning, match='cloud data access is not enabled'): result = Observations.download_products('2003738726', dataURI='mast:HST/product/u9o40504m_c3m.fits', @@ -1073,38 +1085,15 @@ def test_observations_download_products_cloud(mock_is_file, mock_client, mock_re @patch.object(Path, "is_file", return_value=True) -def test_observations_download_file(mock_is_file, patch_post, tmpdir): +def test_observations_download_file(mock_is_file, patch_boto3, reset_cloud_state, tmpdir): + mock_client, mock_resource = patch_boto3 + mock_client.head_object.return_value = {'ContentLength': 12345} + mock_resource.Bucket.return_value.download_file.return_value = None mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' result = Observations.download_file(mast_uri, local_path=tmpdir) assert result == ('COMPLETE', None, None) - unauth_uri = 'mast:HST/product/unauthorized.fits' - result = Observations.download_file(unauth_uri) - assert result[0] == 'ERROR' - assert 'HTTPError' in result[1] - - -def test_observations_download_file_not_found(patch_post, tmpdir): - mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' - - result = Observations.download_file(mast_uri, local_path=tmpdir) - assert result[0] == 'ERROR' - assert result[1] == 'File was not downloaded' - - -@patch('boto3.resource') -@patch('boto3.client') -@patch.object(Path, "is_file", return_value=True) -def test_observations_download_file_cloud(mock_is_file, mock_client, mock_resource, patch_post): - pytest.importorskip("boto3") - mock_client.return_value.head_object.return_value = {'ContentLength': 12345} - mock_resource.return_value.Bucket.return_value.download_file.return_value = None - - # Enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() - mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' - # Warn if both cloud_only and force_on_prem are True with pytest.raises(InvalidQueryError, match='Invalid argument combination'): result = Observations.download_file(mast_uri, cloud_only=True, force_on_prem=True) @@ -1116,6 +1105,7 @@ def test_observations_download_file_cloud(mock_is_file, mock_client, mock_resour assert result == ('SKIPPED', None, None) # Use on-prem download if cloud_only is False and file is not in cloud + Observations.enable_cloud_dataset() # enable cloud dataset to emit warning with pytest.warns(InputWarning, match=f'The product {nonexistent_uri} was not found in the cloud'): result = Observations.download_file(nonexistent_uri, cloud_only=False) assert result == ('COMPLETE', None, None) @@ -1124,67 +1114,69 @@ def test_observations_download_file_cloud(mock_is_file, mock_client, mock_resour with pytest.raises(InvalidQueryError, match='A valid data product URI'): Observations.download_file(12345, cloud_only=True) - Observations.disable_cloud_dataset() + # Mock cloud download failure, fallback to on-prem + client_err = ClientError({'Error': {'Code': '500', 'Message': 'Internal Server Error'}}, 'HeadObject') + mock_resource.Bucket.return_value.download_file.side_effect = client_err + with pytest.warns(InputWarning, match='Falling back to MAST download'): + result = Observations.download_file(mast_uri) + assert result == ('COMPLETE', None, None) + + # Skip if cloud download fails and cloud_only is True + with pytest.warns(NoResultsWarning, match='Skipping download.'): + result = Observations.download_file(mast_uri, cloud_only=True) + assert result == ('SKIPPED', None, None) # Warning if cloud dataset is not enabled + Observations.disable_cloud_dataset() with pytest.warns(InputWarning, match='cloud data access is not enabled'): result = Observations.download_file(mast_uri, cloud_only=True) assert result == ('COMPLETE', None, None) -@patch('boto3.client') -def test_observations_list_cloud_missions(mock_client, patch_post): - pytest.importorskip('boto3') - mock_client.return_value.list_objects_v2.return_value = { +def test_observations_download_file_not_found(patch_boto3, reset_cloud_state): + mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' + result = Observations.download_file(mast_uri) + assert result[0] == 'ERROR' + assert result[1] == 'File was not downloaded' + + mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' + result = Observations.download_file(mast_uri) + assert result[0] == 'ERROR' + assert result[1] == 'File was not downloaded' + + +def test_observations_list_cloud_missions(patch_boto3, reset_cloud_state): + mock_client = patch_boto3[0] + mock_client.list_objects_v2.return_value = { 'CommonPrefixes': [{'Prefix': 'hst/'}, {'Prefix': 'jwst/'}, {'Prefix': 'mast/'}] } - with pytest.raises(RemoteServiceError): - Observations.list_cloud_datasets() - - Observations.enable_cloud_dataset() supported = Observations.list_cloud_datasets() assert isinstance(supported, list) assert 'hst' in supported assert 'jwst' in supported assert 'mast' in supported - Observations.disable_cloud_dataset() - - -@patch('boto3.client') -def test_observations_list_cloud_missions_error(mock_client, patch_post, caplog): - pytest.importorskip('boto3') - - # Error without cloud connection - with pytest.raises(RemoteServiceError): - Observations.list_cloud_datasets() +def test_observations_list_cloud_missions_error(patch_boto3, reset_cloud_state): # Mock an error when listing objects + mock_client = patch_boto3[0] client_error = ClientError({'Error': {'Code': 'AWS error'}}, 'ListObjectsV2') - mock_client.return_value.list_objects_v2.side_effect = client_error + mock_client.list_objects_v2.side_effect = client_error - Observations.enable_cloud_dataset() supported = Observations.list_cloud_datasets() assert supported == [] + # Cloud access not enabled Observations.disable_cloud_dataset() + with pytest.raises(RemoteServiceError, match='Please enable anonymous cloud access'): + Observations.list_cloud_datasets() -@patch('boto3.client') -def test_observations_get_cloud_uri(mock_client, patch_post): - pytest.importorskip("boto3") - +def test_observations_get_cloud_uri(patch_boto3, reset_cloud_state): mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' expected = 's3://stpubdata/hst/public/u9o4/u9o40504m/u9o40504m_c3m.fits' - # Error without cloud connection - with pytest.raises(RemoteServiceError): - Observations.get_cloud_uri('mast:HST/product/u9o40504m_c3m.fits') - - # Enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() - # Row input product = Table() product['dataURI'] = [mast_uri] @@ -1200,23 +1192,16 @@ def test_observations_get_cloud_uri(mock_client, patch_post): with pytest.warns(NoResultsWarning, match='Failed to retrieve cloud path'): Observations.get_cloud_uri('mast:HST/product/does_not_exist.fits') + # Cloud access not enabled Observations.disable_cloud_dataset() + with pytest.raises(RemoteServiceError, match='Please enable anonymous cloud access'): + Observations.get_cloud_uri(mast_uri) -@patch('boto3.client') -def test_observations_get_cloud_uris(mock_client, patch_post): - pytest.importorskip("boto3") - +def test_observations_get_cloud_uris(patch_boto3, reset_cloud_state): mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' expected = 's3://stpubdata/hst/public/u9o4/u9o40504m/u9o40504m_c3m.fits' - # Error without cloud connection - with pytest.raises(RemoteServiceError): - Observations.get_cloud_uris(['mast:HST/product/u9o40504m_c3m.fits']) - - # Enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() - # Get the cloud URIs # Table input product = Table() @@ -1252,40 +1237,33 @@ def test_observations_get_cloud_uris(mock_client, patch_post): with pytest.warns(NoResultsWarning, match='Failed to retrieve cloud path'): Observations.get_cloud_uris(['mast:HST/product/does_not_exist.fits']) + # Cloud access not enabled Observations.disable_cloud_dataset() + with pytest.raises(RemoteServiceError, match='Please enable anonymous cloud access'): + Observations.get_cloud_uris([mast_uri]) -@patch('boto3.client') -def test_observations_get_cloud_uris_error(mock_client, patch_post): - pytest.importorskip("boto3") +def test_observations_get_cloud_uris_error(patch_boto3, reset_cloud_state): + mock_client = patch_boto3[0] # Mock head_object to raise an exception # Raise the error if not a 404 exc = ClientError({'Error': {'Code': '500', 'Message': 'Internal Server Error'}}, 'HeadObject') - mock_client.return_value.head_object.side_effect = exc + mock_client.head_object.side_effect = exc - Observations.enable_cloud_dataset() with pytest.raises(ClientError): Observations.get_cloud_uris(['mast:HST/product/u9o40504m_c3m.fits']) # Only warn if the error is a 404 exc = ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject') - mock_client.return_value.head_object.side_effect = exc + mock_client.head_object.side_effect = exc with pytest.warns(NoResultsWarning, match='Failed to retrieve cloud path'): uris = Observations.get_cloud_uris(['mast:HST/product/u9o40504m_c3m.fits']) assert uris == [] - Observations.disable_cloud_dataset() - - -@patch('boto3.client') -def test_observations_get_cloud_uris_query(mock_client, patch_post): - pytest.importorskip("boto3") - - # enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() +def test_observations_get_cloud_uris_query(patch_boto3, reset_cloud_state): # get uris with streamlined function uris = Observations.get_cloud_uris(target_name=234295610, filter_products={'productSubGroupDescription': 'C3M'}) @@ -1300,7 +1278,28 @@ def test_observations_get_cloud_uris_query(mock_client, patch_post): Observations.get_cloud_uris(target_name=234295610, filter_products={'productSubGroupDescription': 'LC'}) + +def test_observations_enable_cloud_dataset(patch_boto3, reset_cloud_state): + # Enable cloud dataset + Observations.enable_cloud_dataset() + assert Observations._cloud_connection is not None + assert Observations._cloud_enabled_explicitly is True + + # Force an import error when connecting to cloud dataset + cloud.HAS_BOTO3 = False + Observations.disable_cloud_dataset() # reset state + with pytest.warns(CloudAccessWarning): + Observations.enable_cloud_dataset() + + # Reset cloud dataset state for other tests + cloud.HAS_BOTO3 = True + + +def test_observations_disable_cloud_dataset(patch_boto3, reset_cloud_state): + # Explicitly disable cloud dataset Observations.disable_cloud_dataset() + assert Observations._cloud_connection is None + assert Observations._cloud_enabled_explicitly is False ###################### @@ -1308,17 +1307,17 @@ def test_observations_get_cloud_uris_query(mock_client, patch_post): ###################### -def test_catalogs_query_region_async(patch_post): +def test_catalogs_query_region_async(): responses = Catalogs.query_region_async(regionCoords, radius=0.002) assert isinstance(responses, list) -def test_catalogs_fabric_query_region_async(patch_post): +def test_catalogs_fabric_query_region_async(): responses = Catalogs.query_region_async(regionCoords, radius=0.002, catalog="panstarrs", table="mean") assert isinstance(responses, MockResponse) -def test_catalogs_query_region(patch_post): +def test_catalogs_query_region(): result = Catalogs.query_region(regionCoords, radius=0.002 * u.deg) assert isinstance(result, Table) @@ -1346,32 +1345,32 @@ def test_catalogs_query_region(patch_post): assert isinstance(result, Table) -def test_catalogs_fabric_query_region(patch_post): +def test_catalogs_fabric_query_region(): result = Catalogs.query_region(regionCoords, radius=0.002 * u.deg, catalog="panstarrs", table="mean") assert isinstance(result, Table) -def test_catalogs_query_object_async(patch_post): +def test_catalogs_query_object_async(): responses = Catalogs.query_object_async("M101", radius="0.002 deg") assert isinstance(responses, list) -def test_catalogs_fabric_query_object_async(patch_post): +def test_catalogs_fabric_query_object_async(): responses = Catalogs.query_object_async("M101", radius="0.002 deg", catalog="panstarrs", table="mean") assert isinstance(responses, MockResponse) -def test_catalogs_query_object(patch_post): +def test_catalogs_query_object(): result = Catalogs.query_object("M101", radius=".002 deg") assert isinstance(result, Table) -def test_catalogs_fabric_query_object(patch_post): +def test_catalogs_fabric_query_object(): result = Catalogs.query_object("M101", radius=".002 deg", catalog="panstarrs", table="mean") assert isinstance(result, Table) -def test_catalogs_query_criteria_async(patch_post): +def test_catalogs_query_criteria_async(): responses = Catalogs.query_criteria_async(catalog="Tic", Bmag=[30, 50], objType="STAR") assert isinstance(responses, list) @@ -1407,7 +1406,7 @@ def test_catalogs_query_criteria_async(patch_post): assert "one of objectname and coordinates" in str(invalid_query.value) -def test_catalogs_query_criteria(patch_post): +def test_catalogs_query_criteria(): # without position result = Catalogs.query_criteria(catalog="Tic", Bmag=[30, 50], objType="STAR") @@ -1430,7 +1429,7 @@ def test_catalogs_query_criteria(patch_post): assert "non-positional" in str(invalid_query.value) -def test_catalogs_query_hsc_matchid_async(patch_post): +def test_catalogs_query_hsc_matchid_async(): responses = Catalogs.query_hsc_matchid_async(82371983) assert isinstance(responses, list) @@ -1442,22 +1441,22 @@ def test_catalogs_query_hsc_matchid_async(patch_post): assert "Invalid HSC version number" in str(i_w[0].message) -def test_catalogs_query_hsc_matchid(patch_post): +def test_catalogs_query_hsc_matchid(): result = Catalogs.query_hsc_matchid(82371983) assert isinstance(result, Table) -def test_catalogs_get_hsc_spectra_async(patch_post): +def test_catalogs_get_hsc_spectra_async(): responses = Catalogs.get_hsc_spectra_async() assert isinstance(responses, list) -def test_catalogs_get_hsc_spectra(patch_post): +def test_catalogs_get_hsc_spectra(): result = Catalogs.get_hsc_spectra() assert isinstance(result, Table) -def test_catalogs_download_hsc_spectra(patch_post, tmpdir): +def test_catalogs_download_hsc_spectra(tmpdir): allSpectra = Catalogs.get_hsc_spectra() # actually download the products @@ -1474,7 +1473,7 @@ def test_catalogs_download_hsc_spectra(patch_post, tmpdir): # TesscutClass tests # ###################### -def test_tesscut_get_sector(patch_post): +def test_tesscut_get_sector(): coord = SkyCoord(324.24368, -27.01029, unit="deg") sector_table = Tesscut.get_sectors(coordinates=coord) assert isinstance(sector_table, Table) @@ -1536,7 +1535,7 @@ def test_tesscut_get_sector(patch_post): assert "Input product must be SPOC." in str(invalid_query.value) -def test_tesscut_download_cutouts(patch_post, tmpdir): +def test_tesscut_download_cutouts(tmpdir): coord = SkyCoord(107.27, -70.0, unit="deg") # Testing with inflate @@ -1592,7 +1591,7 @@ def test_tesscut_download_cutouts(patch_post, tmpdir): assert "Input product must be SPOC." in str(invalid_query.value) -def test_tesscut_get_cutouts(patch_post, tmpdir): +def test_tesscut_get_cutouts(tmpdir): coord = SkyCoord(107.27, -70.0, unit="deg") cutout_hdus_list = Tesscut.get_cutouts(coordinates=coord, size=5) assert isinstance(cutout_hdus_list, list) @@ -1633,7 +1632,7 @@ def test_tesscut_get_cutouts(patch_post, tmpdir): assert "Input product must be SPOC." in str(invalid_query.value) -def test_tesscut_get_cutouts_mt_no_sector(patch_post): +def test_tesscut_get_cutouts_mt_no_sector(): """Test get_cutouts with moving target but no sector specified. When sector is not specified for moving targets, the method should @@ -1647,7 +1646,7 @@ def test_tesscut_get_cutouts_mt_no_sector(patch_post): assert isinstance(cutout_hdus_list[0], fits.HDUList) -def test_tesscut_download_cutouts_mt_no_sector(patch_post, tmpdir): +def test_tesscut_download_cutouts_mt_no_sector(tmpdir): """Test download_cutouts with moving target but no sector specified. When sector is not specified for moving targets, the method should @@ -1664,7 +1663,7 @@ def test_tesscut_download_cutouts_mt_no_sector(patch_post, tmpdir): assert os.path.isfile(manifest[0]["Local Path"]) -def test_tesscut_get_cutouts_mt_no_sector_empty_results(patch_post, monkeypatch): +def test_tesscut_get_cutouts_mt_no_sector_empty_results(monkeypatch): """Test get_cutouts with moving target when no sectors are available. When get_sectors returns an empty table, the method should warn and return an empty list. @@ -1679,7 +1678,7 @@ def test_tesscut_get_cutouts_mt_no_sector_empty_results(patch_post, monkeypatch) assert len(cutout_hdus_list) == 0 -def test_tesscut_download_cutouts_mt_no_sector_empty_results(patch_post, tmpdir, monkeypatch): +def test_tesscut_download_cutouts_mt_no_sector_empty_results(tmpdir, monkeypatch): """Test download_cutouts with moving target when no sectors are available. When get_sectors returns an empty table, the method should warn and return an empty Table. @@ -1701,7 +1700,7 @@ def test_tesscut_download_cutouts_mt_no_sector_empty_results(patch_post, tmpdir, ###################### -def test_zcut_get_survey(patch_post): +def test_zcut_get_survey(): coord = SkyCoord(189.49206, 62.20615, unit="deg") survey_list = Zcut.get_surveys(coordinates=coord) @@ -1719,7 +1718,7 @@ def test_zcut_get_survey(patch_post): assert survey_list[2] == 'goods_north' -def test_zcut_download_cutouts(patch_post, tmpdir): +def test_zcut_download_cutouts(tmpdir): coord = SkyCoord(189.49206, 62.20615, unit="deg") @@ -1747,7 +1746,7 @@ def test_zcut_download_cutouts(patch_post, tmpdir): assert os.path.isfile(cutout_table[0]['Local Path']) -def test_zcut_get_cutouts(patch_post, tmpdir): +def test_zcut_get_cutouts(tmpdir): coord = SkyCoord(189.49206, 62.20615, unit="deg") cutout_list = Zcut.get_cutouts(coordinates=coord, size=5) @@ -1761,7 +1760,7 @@ def test_zcut_get_cutouts(patch_post, tmpdir): ################ -def test_parse_input_location(patch_post): +def test_parse_input_location(): # Test with coordinates coord = SkyCoord(23.34086, 60.658, unit="deg") loc = utils.parse_input_location(coordinates=coord) @@ -1790,7 +1789,7 @@ def test_parse_input_location(patch_post): assert isinstance(loc, SkyCoord) -def test_json_to_table_fallback_type_coercion(patch_post): +def test_json_to_table_fallback_type_coercion(): json_obj = {'info': [{'name': 'test_int', 'type': 'int'}], 'data': [['1'], ['2'], ['not_an_int'], ['3'], [-999]]} @@ -1816,49 +1815,45 @@ def test_json_to_table_fallback_type_coercion(patch_post): # Cloud tests # ################ -@patch("boto3.resource") -@patch("boto3.client") -def test_download_file_from_cloud(mock_client, mock_resource, patch_post): - pytest.importorskip("boto3") +def test_cloud_access_init(): + cloud.HAS_BOTO3 = False + with pytest.raises(ImportError, match='Please install the `boto3` and `botocore` packages'): + CloudAccess() + + # Restore the original state for other tests + cloud.HAS_BOTO3 = True + +def test_download_file_from_cloud(patch_boto3): + mock_client, mock_resource = patch_boto3 cloud = CloudAccess() - mock_client.return_value.head_object.return_value = {'ContentLength': 123} - mock_resource.return_value.Bucket.return_value.download_file.return_value = None + mock_client.head_object.return_value = {'ContentLength': 123} + mock_resource.Bucket.return_value.download_file.return_value = None cloud.download_file_from_cloud( "s3://stpubdata/hst/public/u9o4/u9o40504m/u9o40504m_c3m.fits", "local.fits", verbose=False ) - mock_resource.return_value.Bucket.return_value.download_file.assert_called_once() + mock_resource.Bucket.return_value.download_file.assert_called_once() -@patch("boto3.resource") -@patch("boto3.client") -def test_download_file_from_cloud_not_found(mock_client, mock_resource, patch_post): - pytest.importorskip("boto3") - +def test_download_file_from_cloud_not_found(patch_boto3): cloud = CloudAccess() # Force get_cloud_uri_list to return [None] cloud.get_cloud_uri_list = lambda *a, **k: [None] with pytest.raises(RemoteServiceError, match='was not found in the cloud'): - cloud.download_file_from_cloud( - "mast:HST/product/missing.fits", - "local.fits", - ) + cloud.download_file_from_cloud("mast:HST/product/missing.fits", "local.fits") @patch('os.path.exists', return_value=True) @patch('os.path.getsize', return_value=123) -@patch('boto3.resource') -@patch('boto3.client') -def test_download_file_from_cloud_existing(mock_client, mock_resource, mock_getsize, mock_exists, patch_post): - pytest.importorskip("boto3") - - mock_client.return_value.head_object.return_value = {'ContentLength': 123} +def test_download_file_from_cloud_existing(mock_getsize, mock_exists, patch_boto3): + mock_client, mock_resource = patch_boto3 + mock_client.head_object.return_value = {'ContentLength': 123} cloud = CloudAccess() # File exists locally with same size @@ -1868,7 +1863,7 @@ def test_download_file_from_cloud_existing(mock_client, mock_resource, mock_gets verbose=False ) # No download should be attempted - mock_resource.return_value.Bucket.return_value.download_file.assert_not_called() + mock_resource.Bucket.return_value.download_file.assert_not_called() # File exists locally with different size mock_getsize.return_value = 456 @@ -1878,17 +1873,15 @@ def test_download_file_from_cloud_existing(mock_client, mock_resource, mock_gets verbose=False ) # Download should be attempted - mock_resource.return_value.Bucket.return_value.download_file.assert_called_once() + mock_resource.Bucket.return_value.download_file.assert_called_once() -@patch("boto3.resource") -@patch("boto3.client") -def test_download_file_from_cloud_verbose(mock_client, mock_resource, patch_post): - pytest.importorskip("boto3") +def test_download_file_from_cloud_verbose(patch_boto3): + mock_client, mock_resource = patch_boto3 cloud = CloudAccess() - mock_client.return_value.head_object.return_value = {'ContentLength': 123} - mock_resource.return_value.Bucket.return_value.download_file.return_value = None + mock_client.head_object.return_value = {'ContentLength': 123} + mock_resource.Bucket.return_value.download_file.return_value = None cloud.download_file_from_cloud( "s3://stpubdata/hst/public/u9o4/u9o40504m/u9o40504m_c3m.fits", @@ -1896,5 +1889,5 @@ def test_download_file_from_cloud_verbose(mock_client, mock_resource, patch_post verbose=True ) # Ensure callback was supplied - _, kwargs = mock_resource.return_value.Bucket.return_value.download_file.call_args + _, kwargs = mock_resource.Bucket.return_value.download_file.call_args assert "Callback" in kwargs diff --git a/astroquery/mast/tests/test_mast_remote.py b/astroquery/mast/tests/test_mast_remote.py index 0627cccdd2..1992a24728 100644 --- a/astroquery/mast/tests/test_mast_remote.py +++ b/astroquery/mast/tests/test_mast_remote.py @@ -39,6 +39,17 @@ def msa_product_table(): return products +@pytest.fixture() +def reset_cloud_state(): + pytest.importorskip('boto3') + # Reset cloud dataset state before and after each test + Observations._cloud_enabled_explicitly = None + Observations._cloud_connection = None + yield + Observations._cloud_enabled_explicitly = None + Observations._cloud_connection = None + + @pytest.mark.remote_data class TestMast: @@ -783,9 +794,8 @@ def test_observations_download_products_no_duplicates(self, tmp_path, caplog, ms with caplog.at_level("INFO", logger="astroquery"): assert "products were duplicates" in caplog.text - def test_observations_download_products_cloud(self, tmp_path, msa_product_table): - pytest.importorskip('boto3') - + def test_observations_download_products_cloud(self, tmp_path, msa_product_table, reset_cloud_state): + # Explicity enable cloud dataset Observations.enable_cloud_dataset() # Adding a product that's not in the cloud to test mixed downloads @@ -820,8 +830,6 @@ def test_observations_download_products_cloud(self, tmp_path, msa_product_table) assert Path(result['Local Path'][0]).exists() assert Path(result['Local Path'][1]).exists() - Observations.disable_cloud_dataset() - def test_observations_download_file(self, tmp_path): def check_result(result, path): @@ -863,22 +871,16 @@ def check_result(result, path): 'MISDR1_18916_0459-fd-flagstar.fits.gz', 'mast:HST/product/u24r0102t_c3m.fits' ]) - def test_observations_download_file_cloud(self, tmp_path, in_uri): - pytest.importorskip("boto3") - - Observations.enable_cloud_dataset() - + def test_observations_download_file_cloud(self, tmp_path, in_uri, reset_cloud_state): filename = Path(in_uri).name result = Observations.download_file(uri=in_uri, cloud_only=True, local_path=tmp_path) assert result == ('COMPLETE', None, None) assert Path(tmp_path, filename).exists() - Observations.disable_cloud_dataset() - - def test_observations_download_file_cloud_not_found(self, tmp_path): - pytest.importorskip("boto3") + def test_observations_download_file_cloud_not_found(self, tmp_path, reset_cloud_state): in_uri = 'mast:IUE/url/pub/vospectra/iue2/swp18830mxlo_vo.fits' + # Explicity enable cloud dataset Observations.enable_cloud_dataset() # Warn and fallback @@ -894,8 +896,6 @@ def test_observations_download_file_cloud_not_found(self, tmp_path): assert result == ('SKIPPED', None, None) assert not Path(tmp_path, Path(in_uri).name).exists() - Observations.disable_cloud_dataset() - def test_observations_download_file_escaped(self, tmp_path): # test that `download_file` correctly escapes a URI in_uri = 'mast:HLA/url/cgi-bin/fitscut.cgi?' \ @@ -926,15 +926,13 @@ def test_observations_download_file_no_length(self, tmp_path, caplog): assert result == ("COMPLETE", None, None) assert Path(tmp_path, filename).exists() - def test_observations_list_cloud_missions(self): - pytest.importorskip('boto3') - Observations.enable_cloud_dataset() + def test_observations_list_cloud_missions(self, reset_cloud_state): + # Test that the function to list missions with cloud datasets returns expected missions missions = Observations.list_cloud_datasets() assert isinstance(missions, list) assert len(missions) > 0 for m in ['hst', 'jwst', 'panstarrs', 'galex', 'tess']: assert m in missions - Observations.disable_cloud_dataset() @pytest.mark.parametrize("test_data_uri, expected_cloud_uri", [ ("mast:HST/product/u24r0102t_c1f.fits", @@ -943,13 +941,10 @@ def test_observations_list_cloud_missions(self): "s3://stpubdata/panstarrs/ps1/public/rings.v3.skycell/1334/061/" "rings.v3.skycell.1334.061.stk.r.unconv.exp.fits") ]) - def test_observations_get_cloud_uri(self, test_data_uri, expected_cloud_uri): - pytest.importorskip("boto3") + def test_observations_get_cloud_uri(self, test_data_uri, expected_cloud_uri, reset_cloud_state): # get a product list product = Table() product['dataURI'] = [test_data_uri] - # enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() # get uri uri = Observations.get_cloud_uri(product[0]) @@ -961,12 +956,8 @@ def test_observations_get_cloud_uri(self, test_data_uri, expected_cloud_uri): uri = Observations.get_cloud_uri(test_data_uri) assert uri == expected_cloud_uri, f'Cloud URI does not match expected. ({uri} != {expected_cloud_uri})' - Observations.disable_cloud_dataset() - @pytest.mark.parametrize("test_obs_id", ["25568122", "31411", "107604081"]) - def test_observations_get_cloud_uris(self, test_obs_id): - pytest.importorskip("boto3") - + def test_observations_get_cloud_uris(self, test_obs_id, reset_cloud_state): # get a product list index = 24 if test_obs_id == '25568122' else 0 products = Observations.get_product_list(test_obs_id)[index:index + 2] @@ -974,9 +965,6 @@ def test_observations_get_cloud_uris(self, test_obs_id): assert len(products) > 0, (f'No products found for OBSID {test_obs_id}. ' 'Unable to move forward with getting URIs from the cloud.') - # enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() - # get uris uris = Observations.get_cloud_uris(products) @@ -987,19 +975,13 @@ def test_observations_get_cloud_uris(self, test_obs_id): Observations.get_cloud_uris(products, extension='png') - Observations.disable_cloud_dataset() - - def test_observations_get_cloud_uris_list_input(self): - pytest.importorskip("boto3") + def test_observations_get_cloud_uris_list_input(self, reset_cloud_state): uri_list = ['mast:HST/product/u24r0102t_c1f.fits', 'mast:PS1/product/rings.v3.skycell.1334.061.stk.r.unconv.exp.fits'] expected = ['s3://stpubdata/hst/public/u24r/u24r0102t/u24r0102t_c1f.fits', 's3://stpubdata/panstarrs/ps1/public/rings.v3.skycell/1334/061/rings.v3.skycell.1334.' '061.stk.r.unconv.exp.fits'] - # enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() - # list of URI strings as input uris = Observations.get_cloud_uris(uri_list) assert len(uris) > 0, f'Products for URI list {uri_list} were not found in the cloud.' @@ -1021,14 +1003,7 @@ def test_observations_get_cloud_uris_list_input(self): with pytest.warns(NoResultsWarning, match='Failed to retrieve cloud path'): Observations.get_cloud_uris(['mast:HST/product/does_not_exist.fits']) - Observations.disable_cloud_dataset() - - def test_observations_get_cloud_uris_query(self): - pytest.importorskip("boto3") - - # enable access to public AWS S3 bucket - Observations.enable_cloud_dataset() - + def test_observations_get_cloud_uris_query(self, reset_cloud_state): # get uris with other functions obs = Observations.query_criteria(target_name=234295610) prod = Observations.get_product_list(obs) @@ -1048,25 +1023,16 @@ def test_observations_get_cloud_uris_query(self): with pytest.warns(NoResultsWarning): Observations.get_cloud_uris(target_name=234295611) - Observations.disable_cloud_dataset() - - def test_observations_get_cloud_uris_no_duplicates(self, msa_product_table): - pytest.importorskip("boto3") - + def test_observations_get_cloud_uris_no_duplicates(self, msa_product_table, reset_cloud_state): # Get a product list with 6 duplicate JWST MSA config files products = msa_product_table assert len(products) == 6 - # enable access to public AWS S3 bucket - Observations.enable_cloud_dataset(provider='AWS') - # Check that only one URI is returned uris = Observations.get_cloud_uris(products) assert len(uris) == 1 - Observations.disable_cloud_dataset() - ###################### # CatalogClass tests # ###################### From 80e07e79d4b7dd054a2a4a5cd1b11029a44bc10d Mon Sep 17 00:00:00 2001 From: Sam Bianco Date: Mon, 16 Feb 2026 17:11:19 -0500 Subject: [PATCH 2/3] Changelog, fix docs build --- CHANGES.rst | 3 +++ astroquery/mast/__init__.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 373df434d1..a8a284b20c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -93,6 +93,9 @@ mast - Added a new ``Observations.list_cloud_datasets()`` method for querying cloud-supported MAST datasets, alongside improvements to cloud download handling. [#3488] +- The cloud dataset in ``Observations`` is now enabled by default if the ``boto3`` and ``botocore`` packages are installed. This + default can be overridden by setting the ``enable_cloud_dataset`` configuration option to False. [#3534] + jplspec ^^^^^^^ diff --git a/astroquery/mast/__init__.py b/astroquery/mast/__init__.py index 68745f7fcc..fa0b9e8ff8 100644 --- a/astroquery/mast/__init__.py +++ b/astroquery/mast/__init__.py @@ -32,7 +32,7 @@ class Conf(_config.ConfigNamespace): enable_cloud_dataset = _config.ConfigItem( True, 'Enable access to cloud-hosted datasets (e.g. on AWS S3) by default. ' - 'Requires the `boto3` and `botocore` packages to be installed.') + 'Requires the ``boto3`` and ``botocore`` packages to be installed.') conf = Conf() From 62832b0c5c06a9710f4b937e137b8bf8209ab59d Mon Sep 17 00:00:00 2001 From: Sam Bianco Date: Tue, 17 Feb 2026 13:04:01 -0500 Subject: [PATCH 3/3] Fixture placement --- astroquery/mast/tests/test_mast.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/astroquery/mast/tests/test_mast.py b/astroquery/mast/tests/test_mast.py index 04fecd2c63..1bc4a18dd3 100644 --- a/astroquery/mast/tests/test_mast.py +++ b/astroquery/mast/tests/test_mast.py @@ -89,7 +89,7 @@ def patch_post(request): @pytest.fixture() -def patch_boto3(monkeypatch): +def patch_boto3(monkeypatch, reset_cloud_state): """Fixture to patch boto3 client and resource for cloud access tests.""" pytest.importorskip('boto3') mock_client = MagicMock() @@ -1001,7 +1001,7 @@ def test_observations_filter_products(): @patch.object(Path, "is_file", return_value=True) -def test_observations_download_products(mock_is_file, patch_boto3, monkeypatch, reset_cloud_state): +def test_observations_download_products(mock_is_file, patch_boto3, monkeypatch): mock_resource = patch_boto3[1] obsid = '2003738726' data_uri = 'mast:HST/product/u9o40504m_c3m.fits' @@ -1085,7 +1085,7 @@ def test_observations_download_products(mock_is_file, patch_boto3, monkeypatch, @patch.object(Path, "is_file", return_value=True) -def test_observations_download_file(mock_is_file, patch_boto3, reset_cloud_state, tmpdir): +def test_observations_download_file(mock_is_file, patch_boto3, tmpdir): mock_client, mock_resource = patch_boto3 mock_client.head_object.return_value = {'ContentLength': 12345} mock_resource.Bucket.return_value.download_file.return_value = None @@ -1133,7 +1133,7 @@ def test_observations_download_file(mock_is_file, patch_boto3, reset_cloud_state assert result == ('COMPLETE', None, None) -def test_observations_download_file_not_found(patch_boto3, reset_cloud_state): +def test_observations_download_file_not_found(patch_boto3): mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' result = Observations.download_file(mast_uri) assert result[0] == 'ERROR' @@ -1145,7 +1145,7 @@ def test_observations_download_file_not_found(patch_boto3, reset_cloud_state): assert result[1] == 'File was not downloaded' -def test_observations_list_cloud_missions(patch_boto3, reset_cloud_state): +def test_observations_list_cloud_missions(patch_boto3): mock_client = patch_boto3[0] mock_client.list_objects_v2.return_value = { 'CommonPrefixes': [{'Prefix': 'hst/'}, {'Prefix': 'jwst/'}, {'Prefix': 'mast/'}] @@ -1158,7 +1158,7 @@ def test_observations_list_cloud_missions(patch_boto3, reset_cloud_state): assert 'mast' in supported -def test_observations_list_cloud_missions_error(patch_boto3, reset_cloud_state): +def test_observations_list_cloud_missions_error(patch_boto3): # Mock an error when listing objects mock_client = patch_boto3[0] client_error = ClientError({'Error': {'Code': 'AWS error'}}, 'ListObjectsV2') @@ -1173,7 +1173,7 @@ def test_observations_list_cloud_missions_error(patch_boto3, reset_cloud_state): Observations.list_cloud_datasets() -def test_observations_get_cloud_uri(patch_boto3, reset_cloud_state): +def test_observations_get_cloud_uri(patch_boto3): mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' expected = 's3://stpubdata/hst/public/u9o4/u9o40504m/u9o40504m_c3m.fits' @@ -1198,7 +1198,7 @@ def test_observations_get_cloud_uri(patch_boto3, reset_cloud_state): Observations.get_cloud_uri(mast_uri) -def test_observations_get_cloud_uris(patch_boto3, reset_cloud_state): +def test_observations_get_cloud_uris(patch_boto3): mast_uri = 'mast:HST/product/u9o40504m_c3m.fits' expected = 's3://stpubdata/hst/public/u9o4/u9o40504m/u9o40504m_c3m.fits' @@ -1243,7 +1243,7 @@ def test_observations_get_cloud_uris(patch_boto3, reset_cloud_state): Observations.get_cloud_uris([mast_uri]) -def test_observations_get_cloud_uris_error(patch_boto3, reset_cloud_state): +def test_observations_get_cloud_uris_error(patch_boto3): mock_client = patch_boto3[0] # Mock head_object to raise an exception @@ -1263,7 +1263,7 @@ def test_observations_get_cloud_uris_error(patch_boto3, reset_cloud_state): assert uris == [] -def test_observations_get_cloud_uris_query(patch_boto3, reset_cloud_state): +def test_observations_get_cloud_uris_query(patch_boto3): # get uris with streamlined function uris = Observations.get_cloud_uris(target_name=234295610, filter_products={'productSubGroupDescription': 'C3M'}) @@ -1279,7 +1279,7 @@ def test_observations_get_cloud_uris_query(patch_boto3, reset_cloud_state): filter_products={'productSubGroupDescription': 'LC'}) -def test_observations_enable_cloud_dataset(patch_boto3, reset_cloud_state): +def test_observations_enable_cloud_dataset(patch_boto3): # Enable cloud dataset Observations.enable_cloud_dataset() assert Observations._cloud_connection is not None @@ -1295,7 +1295,7 @@ def test_observations_enable_cloud_dataset(patch_boto3, reset_cloud_state): cloud.HAS_BOTO3 = True -def test_observations_disable_cloud_dataset(patch_boto3, reset_cloud_state): +def test_observations_disable_cloud_dataset(patch_boto3): # Explicitly disable cloud dataset Observations.disable_cloud_dataset() assert Observations._cloud_connection is None