Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions astrodata/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,13 +859,12 @@ def _my_attribute(attr):

if attribute == DEFAULT_EXTENSION:
raise AttributeError(
f"{attribute} extensions should be "
"appended with .append"
f"{attribute} extensions should be appended with .append"
)

if attribute in {"DQ", "VAR"}:
raise AttributeError(
f"{attribute} should be set on the " "nddata object"
f"{attribute} should be set on the nddata object"
)

add_to = self.nddata if self.is_single else None
Expand Down Expand Up @@ -898,7 +897,7 @@ def __delattr__(self, attribute):
del self._tables[attribute]
else:
raise AttributeError(
f"'{attribute}' is not a global table " "for this instance"
f"'{attribute}' is not a global table for this instance"
)

def __contains__(self, attribute):
Expand Down Expand Up @@ -1381,8 +1380,7 @@ def _append_nddata(self, new_nddata, name, add_to):
"""
if add_to is not None:
raise TypeError(
"You can only append NDData derived instances "
"at the top level"
"You can only append NDData derived instances at the top level"
)

hd = new_nddata.meta["header"]
Expand Down Expand Up @@ -1488,8 +1486,7 @@ def _append_astrodata(self, ad, name, header, add_to):

if not ad.is_single:
raise ValueError(
"Cannot append AstroData instances that are "
"not single slices"
"Cannot append AstroData instances that are not single slices"
)

if add_to is not None:
Expand Down Expand Up @@ -1564,8 +1561,7 @@ def append(self, ext, name=None, header=None):
"""
if self.is_sliced:
raise TypeError(
"Can't append objects to slices, use "
"'ext.NAME = obj' instead"
"Can't append objects to slices, use 'ext.NAME = obj' instead"
)

# NOTE: Most probably, if we want to copy the input argument, we
Expand Down
30 changes: 16 additions & 14 deletions astrodata/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ def wcs_to_asdftablehdu(wcs, extver=None):
except jsonschema.exceptions.ValidationError as err:
# (The original traceback also gets printed here)
raise TypeError(
f"Cannot serialize model(s) for 'WCS' extension " f"{extver or ''}"
f"Cannot serialize model(s) for 'WCS' extension {extver or ''}"
) from err

# ASDF can only dump YAML to a binary file object, so do that and read
Expand Down Expand Up @@ -1270,19 +1270,21 @@ def asdftablehdu_to_wcs(hdu):

return None

with af:
try:
wcs = af.tree["wcs"]

except KeyError as err:
LOGGER.warning(
"Ignoring 'WCS' extension %s: missing "
"'wcs' dict entry. Error was %s",
ver,
err,
)

return None
else:
with af:
try:
wcs = af.tree["wcs"]

except KeyError as err:
LOGGER.warning(
"Ignoring 'WCS' extension %s: missing "
"'wcs' dict entry. (got exception: %s: %s)",
ver,
type(err),
err,
)

return None

else:
LOGGER.warning("Ignoring non-FITS-table 'WCS' extension %s", ver)
Expand Down
18 changes: 14 additions & 4 deletions astrodata/nddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def _slice_wcs(self, slices):
"Only one ellipsis can be specified in a slice"
)

ell_index = slices.index(Ellipsis) + 1
ell_index = slices.index(Ellipsis)
slice_fill = [slice(None)] * (ndim - len(slices) + 1)
slices[ell_index:ell_index] = slice_fill
slices[ell_index : ell_index + 1] = slice_fill

slices.extend([slice(None)] * (ndim - len(slices)))

Expand All @@ -154,15 +154,25 @@ def _slice_wcs(self, slices):
if slice_.start:
start = (
length + slice_.start
if slice_.start < 1
if slice_.start < 0
else slice_.start
)
if start > 0:
model.append(models.Shift(start))
mapped_axes.append(max(mapped_axes) + 1 if mapped_axes else 0)

elif isinstance(slice_, INTEGER_TYPES):
model.append(models.Const1D(slice_))
model.append(
models.Const1D((length + slice_) if slice_ < 0 else slice_)
)
mapped_axes.append(-1)

# Equivalent to slice(None, None, None)
elif slice_ is None:
mapped_axes.append(
max(mapped_axes) + 1 if mapped_axes else None
)

else:
raise IndexError("Slice not an integer or range")
if model:
Expand Down
24 changes: 13 additions & 11 deletions astrodata/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ def add_history(ad, timestamp_start, timestamp_stop, primitive, args):
if hasattr(ad, "HISTORY"):
colsize = max(
colsize,
(max(len(ph[args_col_idx]) for ph in ad.HISTORY) + 1)
if args_col_idx is not None
else 16,
(
(max(len(ph[args_col_idx]) for ph in ad.HISTORY) + 1)
if args_col_idx is not None
else 16
),
)

timestamp_start_arr = [
Expand Down Expand Up @@ -335,16 +337,16 @@ def provenance_summary(ad, provenance=True, history=True):

# Titles
retval += (
f'{"Primitive":<{primitive_col_size}} '
f'{"Args":<{args_col_size}} '
f'{"Start":<{timestamp_start_col_size}} {"Stop"}\n'
f"{'Primitive':<{primitive_col_size}} "
f"{'Args':<{args_col_size}} "
f"{'Start':<{timestamp_start_col_size}} {'Stop'}\n"
)
# now the lines
retval += (
f'{"":{"-"}<{primitive_col_size}} '
f'{"":{"-"}<{args_col_size}} '
f'{"":{"-"}<{timestamp_start_col_size}} '
f'{"":{"-"}<{timestamp_stop_col_size}}\n'
f"{'':{'-'}<{primitive_col_size}} "
f"{'':{'-'}<{args_col_size}} "
f"{'':{'-'}<{timestamp_start_col_size}} "
f"{'':{'-'}<{timestamp_stop_col_size}}\n"
)

# Rows, looping over args lines
Expand All @@ -368,7 +370,7 @@ def provenance_summary(ad, provenance=True, history=True):
)

else:
retval += f'{"":<{primitive_col_size}} {argrow}\n'
retval += f"{'':<{primitive_col_size}} {argrow}\n"
# prep for additional arg rows without duplicating the
# other values
first = False
Expand Down
76 changes: 60 additions & 16 deletions astrodata/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def assert_most_equal(actual, desired, max_miss, verbose=True):
verbose : bool, optional
If True, the conflicting values are appended to the error message.

Raises
Raiseqs
------
AssertionError
If actual and desired are not equal.
Expand Down Expand Up @@ -534,7 +534,7 @@ def download_multiple_files(
"testing cache."
)

# This is cleaned up once the program finishes.
# This is does not persist the Python session finishes.
os.environ[env_var] = str(path)

if not os.path.isdir(path) and os.path.exists(path):
Expand Down Expand Up @@ -621,7 +621,7 @@ def download_from_archive(
warnings.warn(
"sub_path is None, so the file will be saved to the root of the "
"cache directory. To suppress this warning, set sub_path to a "
"valid path (e.g., empty string instead of None)."
"valid path (e.g., empty string/'.' instead of None)."
)

# Check that the environment variable is a valid name.
Expand Down Expand Up @@ -743,6 +743,9 @@ class ADCompare:
)
fits_keys.update([f"CD{i}_{j}" for i in range(1, 6) for j in range(1, 6)])

# Add PROCSVER and PROCSOFT (DRAGONS processing version/software kws)
fits_keys.update(["PROCSVER", "PROCSOFT", "PROCMODE"])

def __init__(self, ad1, ad2):
self.ad1 = ad1
self.ad2 = ad2
Expand Down Expand Up @@ -980,19 +983,64 @@ def attributes(self):
return errorlist

def _attributes(self, ext1, ext2):
"""Check the attributes of two extensions."""
"""Check and compare attributes."""
errorlist = []
for attr in ["data", "mask", "variance", "OBJMASK", "OBJCAT"]:
attr1 = getattr(ext1, attr, None)
attr2 = getattr(ext2, attr, None)

if all(attr is None for attr in [attr1, attr2]):
continue
if hasattr(attr1, "shape") and hasattr(attr2, "shape"):
if attr1.shape != attr2.shape:
errorlist.append(
f"Attribute error for {attr}: "
f"Mismatching {attr}.shape values: {attr1} v {attr2}"
)

if not np.array_equal(attr1, attr2):
errorlist.append(f"{attr} mismatch: {attr1} v {attr2}")
continue
continue

if (attr1 is None) ^ (attr2 is None):
errorlist.append(
f"Attribute error for {attr}: "
f"{attr1 is not None} v {attr2 is not None}"
)

elif attr1 is not None:
if isinstance(attr1, Table):
if len(attr1) != len(attr2):
errorlist.append(
f"attr lengths differ: {len(attr1)} v {len(attr2)}"
)
else: # everything else is pixel-like
if attr1.dtype.name != attr2.dtype.name:
errorlist.append(
f"Datatype mismatch for {attr}: "
f"{attr1.dtype} v {attr2.dtype}"
)
if attr1.shape != attr2.shape:
errorlist.append(
f"Shape mismatch for {attr}: {attr1.shape} "
f"v {attr2.shape}"
)
if "int" in attr1.dtype.name:
try:
assert_most_equal(
attr1, attr2, max_miss=self.max_miss
)
except AssertionError as e:
errorlist.append(
f"Inequality for {attr}: " + str(e)
)
else:
try:
assert_most_close(
attr1,
attr2,
max_miss=self.max_miss,
rtol=self.rtol,
atol=self.atol,
)
except AssertionError as e:
errorlist.append(f"Mismatch for {attr}: " + str(e))
return errorlist

def wcs(self):
Expand Down Expand Up @@ -1037,7 +1085,7 @@ def compare_frames(frame1, frame2):

except AssertionError:
errorlist.compare(
f"Slice {i} {frame} differs: " f"{frame1} v {frame2}"
f"Slice {i} {frame} differs: {frame1} v {frame2}"
)

corners = get_corners(ext1.shape)
Expand Down Expand Up @@ -1083,13 +1131,9 @@ def ad_compare(ad1, ad2, **kwargs):
-------
bool: are the two AD instances basically the same?
"""
try:
ADCompare(ad1, ad2).run_comparison(**kwargs)

except AssertionError:
return False
compare = ADCompare(ad1, ad2).run_comparison(**kwargs)

return True
return compare == {}


_HDUL_LIKE_TYPE = fits.HDUList | list[fits.hdu.FitsHDU]
Expand Down
25 changes: 18 additions & 7 deletions astrodata/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def fitswcs_to_gwcs(input_data, *, raise_errors: bool = False):

except (IndexError, TypeError, ValueError) as err:
if not raise_errors:
logging.warning(
logging.debug(
"Could not create gWCS: %s: %s",
err.__class__.__name__,
err,
Expand Down Expand Up @@ -607,25 +607,29 @@ def calculate_affine_matrices(func, shape, origin=None):

if ndim > 1:
transformed = list(
zip(*list(func(*point[:indim]) for point in points.T))
zip(*list(func(*point[indim - 1 :: -1]) for point in points.T))
)

transformed = np.array(transformed).T

else:
transformed = np.array([func(*points)]).T

# Matrix of wcs derivatives wrt input coordiantes in Python order
matrix = np.array(
[
[
0.5
* (transformed[j + 1, i] - transformed[indim + j + 1, i])
/ halfsize[j]
for j in range(indim)
for j in range(indim - 1, -1, -1)
]
for i in range(ndim)
]
)
offset = transformed[0] - np.dot(matrix, halfsize)

offset = transformed[0] - np.dot(matrix, halfsize[::-1])

return AffineMatrices(matrix[::-1, ::-1], offset[::-1])


Expand Down Expand Up @@ -657,7 +661,13 @@ def read_wcs_from_header(header):
try:
wcsaxes = header["WCSAXES"]

except KeyError:
except KeyError as err:
logging.debug(
"No WCSAXES in header; trying CTYPE/CD (exception: %s: %s)",
type(err),
err,
)

wcsaxes = 0

for kw in header["CTYPE*"]:
Expand Down Expand Up @@ -864,7 +874,7 @@ def make_fitswcs_transform(trans_input):

Arguments
---------
header : `astropy.io.fits.Header` or dict
header : `astropy.io.fits.Header`, dict, or NDData
FITS Header (or dict) with basic WCS information

Raises
Expand All @@ -889,7 +899,8 @@ def make_fitswcs_transform(trans_input):
wcs_info = read_wcs_from_header(trans_input.meta["header"])

except AttributeError as err:
msg = "Expected a FITS Header, dict, or NDData object"
input_type = type(trans_input)
msg = f"Expected a FITS Header, dict, or NDData, not {input_type}"
raise TypeError(msg) from err

other = trans_input.meta["other"]
Expand Down
Loading
Loading