Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 64 additions & 14 deletions astrodata/adfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import logging
import os
import traceback
from contextlib import contextmanager
from copy import deepcopy

Expand Down Expand Up @@ -74,6 +75,7 @@ def _open_file(source):
)

# try vs all handlers
exception_list = []
for func in AstroDataFactory._file_openers:
try:
fp = func(source)
Expand All @@ -83,27 +85,52 @@ def _open_file(source):
except KeyboardInterrupt:
raise

except FileNotFoundError:
raise

except Exception as err: # noqa
LOGGER.error(
LOGGER.debug(
"Failed to open %s with %s, got error: %s",
source,
func,
err,
)

# Handle nonexistent files.
if isinstance(err, FileNotFoundError):
raise err
exception_list.append(
(
func.__name__,
type(err),
err,
"".join(
traceback.format_exception(
None, err, err.__traceback__
)
).splitlines(),
)
)

else:
if hasattr(fp, "close"):
fp.close()

return

raise AstroDataError(
message_lines = [
f"No access, or not supported format for: {source}"
)
]

if exception_list:
n_err = len(exception_list)
message_lines.append(f"Got {n_err} exceptions while opening:")

for adclass, errname, err, trace_lines in exception_list:
message_lines.append(f"+ {adclass}: {errname}: {str(err)}")
message_lines.extend(
f" {trace_line}" for trace_line in trace_lines
)

message = "\n".join(message_lines)

raise AstroDataError(message)

yield source

Expand Down Expand Up @@ -172,6 +199,7 @@ def get_astro_data(self, source):
An AstroData instance.
"""
candidates = []
exception_list = []
with self._open_file(source) as opened:
for adclass in self._registry:
try:
Expand All @@ -182,13 +210,26 @@ def get_astro_data(self, source):
raise

except Exception as err:
LOGGER.error(
LOGGER.debug(
"Failed to open %s with %s, got error: %s",
source,
adclass,
err,
)

exception_list.append(
(
adclass.__name__,
type(err),
err,
"".join(
traceback.format_exception(
None, err, err.__traceback__
)
).splitlines(),
)
)

# For every candidate in the list, remove the ones that are base
# classes for other candidates. That way we keep only the more
# specific ones.
Expand All @@ -207,19 +248,28 @@ def get_astro_data(self, source):
)

if not final_candidates:
raise AstroDataError("No class matches this dataset")
message_lines = ["No class matches this dataset"]
if exception_list:
n_err = len(exception_list)
message_lines.append(f"Got {n_err} exceptions while matching:")

for adclass, errname, err, trace_lines in exception_list:
message_lines.append(f"+ {adclass}: {errname}: {str(err)}")
message_lines.extend(
f" {trace_line}" for trace_line in trace_lines
)

message = "\n".join(message_lines)

raise AstroDataError(message)

return final_candidates[0].read(source)

@deprecated(
"Renamed to create_from_scratch, please use that method instead: "
"astrodata.factory.AstroDataFactory.create_from_scratch"
)
def createFromScratch(
self,
phu,
extensions=None,
): # noqa
def createFromScratch(self, phu, extensions=None): # noqa
"""Create an AstroData object from a collection of objects.

Deprecated, see |create_from_scratch|.
Expand Down
101 changes: 95 additions & 6 deletions tests/unit/test_adfactory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Disable pylint
# pylint: skip-file

import pytest

from copy import deepcopy
import os

import astrodata
from astrodata import adfactory
from astrodata import AstroData

from astropy.io import fits

import pytest


factory = adfactory.AstroDataFactory

Expand Down Expand Up @@ -58,7 +57,6 @@ def example_dir(tmp_path) -> str:
os.remove(dirname)

os.mkdir(dirname)

return dirname


Expand All @@ -77,3 +75,94 @@ def test__open_file_file_not_found(nonexistent_file, example_dir):
with pytest.raises(FileNotFoundError):
with factory._open_file(example_dir) as _:
pass

def test_report_all_exceptions_on_failure_get_astro_data(example_fits_file, monkeypatch,):
"""Tests that all exceptions are reported if file fails to open.

This test tries to capture errors that were previously discarded. It does
this by checking what is sent to stderr/stdout.

In the future, when support for python 3.10 is dropped, exception groups
would vastly simplify this.
"""
# Use local adfactory to avoid spoiling the main one.
factory = astrodata.adfactory.AstroDataFactory()
monkeypatch.setattr(astrodata, "factory", factory)

class AD1(AstroData):
_message = "This_is_exception_1"
@staticmethod
def _matches_data(source):
raise ValueError(AD1._message)

class AD2(AstroData):
_message = "This_is_exception_2"
@staticmethod
def _matches_data(source):
raise IndexError(AD2._message)

class AD3(AstroData):
_message = "This_is_exception_3"
@staticmethod
def _matches_data(source):
raise Exception(AD3._message)

classes = (AD1, AD2, AD3)

for _cls in classes:
astrodata.factory.add_class(_cls)

with pytest.raises(astrodata.AstroDataError) as exception_info:
astrodata.from_file(example_fits_file)


caught_err = exception_info.value
assert str(caught_err)
assert "No class matches this dataset" in str(caught_err)

for message in (_cls._message for _cls in classes):
assert message in str(caught_err), str(caught_err)

def test_report_all_exceptions_on_failure__open_file(example_fits_file, monkeypatch, ):
"""Tests that all exceptions are reported if file fails to open.

This test tries to capture errors that were previously discarded. It does
this by checking what is sent to stderr/stdout.

In the future, when support for python 3.10 is dropped, exception groups
would vastly simplify this.
"""
# Use local adfactory to avoid spoiling the main one.
factory = astrodata.adfactory.AstroDataFactory()

def _open1(source):
raise ValueError("Exception_1")

def _open2(source):
raise IndexError("Exception_2")

def _open3(source):
raise Exception("Exception_3")

factory._file_openers = (_open1, _open2, _open3)

class AD(AstroData):
@staticmethod
def _matches_data(source):
return True

monkeypatch.setattr(astrodata, "factory", factory)
monkeypatch.setattr(astrodata.adfactory.AstroDataFactory, "_file_openers", factory._file_openers)


with pytest.raises(astrodata.AstroDataError) as exception_info:
astrodata.from_file(example_fits_file)


caught_err = exception_info.value
assert str(caught_err)
assert "No access, or not supported format for: " in str(caught_err)

n_openers = len(factory._file_openers)
for message in (f"Exception_{i}" for i in range(1, n_openers+1)):
assert message in str(caught_err)
Loading