Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions astrodata/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def gwcs_to_fits(ndd, hdr=None):

wcs_dict.update(
{
f"CD{i+1}_{j+1}": 0.0
f"CD{i + 1}_{j + 1}": 0.0
for j in range(nworld_axes)
for i in range(nworld_axes)
}
Expand Down Expand Up @@ -300,8 +300,8 @@ def gwcs_to_fits(ndd, hdr=None):
lat_axis = world_axes.index("lat")
world_axes[lon_axis] = f"RA---{projcode}"
world_axes[lat_axis] = f"DEC--{projcode}"
wcs_dict[f"CRVAL{lon_axis+1}"] = nat2cel.lon.value
wcs_dict[f"CRVAL{lat_axis+1}"] = nat2cel.lat.value
wcs_dict[f"CRVAL{lon_axis + 1}"] = nat2cel.lon.value
wcs_dict[f"CRVAL{lat_axis + 1}"] = nat2cel.lat.value

# Remove projection parts so we can calculate the CD matrix
if projcode:
Expand Down Expand Up @@ -422,7 +422,7 @@ def gwcs_to_fits(ndd, hdr=None):
# Require an inverse to write out
wcs_dict.update(
{
f"CD{i+1}_{j+1}": affine_matrix[i, j]
f"CD{i + 1}_{j + 1}": affine_matrix[i, j]
for j, _ in enumerate(shape)
for i, _ in enumerate(world_axes)
}
Expand All @@ -437,7 +437,7 @@ def gwcs_to_fits(ndd, hdr=None):
}
)

crval = [wcs_dict[f"CRVAL{i+1}"] for i, _ in enumerate(world_axes)]
crval = [wcs_dict[f"CRVAL{i + 1}"] for i, _ in enumerate(world_axes)]

try:
crval[lon_axis] = 0
Expand Down Expand Up @@ -520,7 +520,7 @@ def gwcs_to_fits(ndd, hdr=None):
# To ensure an invertable CD matrix, we need to get nonexistent pixel axes
# "involved".
for j in range(len(shape), nworld_axes):
wcs_dict[f"CD{nworld_axes}_{j+1}"] = 1
wcs_dict[f"CD{nworld_axes}_{j + 1}"] = 1

return wcs_dict

Expand Down Expand Up @@ -697,7 +697,7 @@ def read_wcs_from_header(header):
this_ctype = header[f"CTYPE{i}"]

except KeyError:
this_ctype = f"LINEAR{untyped_axes+1 if untyped_axes else ''}"
this_ctype = f"LINEAR{untyped_axes + 1 if untyped_axes else ''}"
untyped_axes += 1

ctype.append(this_ctype)
Expand Down Expand Up @@ -904,7 +904,13 @@ def make_fitswcs_transform(trans_input):
other_models = fitswcs_other(wcs_info, other=other)
all_models = other_models
if sky_model:
i = -1

for i, m in enumerate(all_models):
m.meta["output_axes"] = [i]

all_models.append(sky_model)
sky_model.meta["output_axes"] = [i + 1, i + 2]

# Now arrange the models so the inputs and outputs are in the right places
all_models.sort(key=lambda m: m.meta["output_axes"][0])
Expand Down Expand Up @@ -1151,11 +1157,13 @@ def remove_axis_from_frame(frame, axis):
new_frames.append(remove_axis_from_frame(f, axis))

else:
new_frames.append(deepcopy(f))
f._axes_order = tuple(
new_frame = deepcopy(f)
new_frame._axes_order = tuple(
x if x < axis else x - 1 for x in f.axes_order
)

new_frames.append(new_frame)

if len(new_frames) == 1:
ret_frame = deepcopy(new_frames[0])
ret_frame.name = frame.name
Expand Down Expand Up @@ -1358,8 +1366,9 @@ def remove_unused_world_axis(ext):
"""
ndim = len(ext.shape)
affine = calculate_affine_matrices(ext.wcs.forward_transform, ext.shape)

# Check whether there's a single output that isn't affected by the input
removable_axes = np.all(affine.matrix[:, ndim - 1 :] == 0, axis=1)
removable_axes = np.all(affine.matrix == 0, axis=1)
removable_axes = removable_axes[::-1] # xyz order

if removable_axes.sum() == 1:
Expand All @@ -1372,13 +1381,14 @@ def remove_unused_world_axis(ext):
new_pipeline = []
for step in reversed(ext.wcs.pipeline):
frame, transform = step.frame, step.transform
if axis < frame.naxes:
frame = remove_axis_from_frame(frame, axis)

if transform is not None:
if axis < transform.n_outputs:
transform, axis = remove_axis_from_model(transform, axis)

if axis is not None and axis < frame.naxes:
frame = remove_axis_from_frame(frame, axis)

new_pipeline = [(frame, transform)] + new_pipeline

if axis not in (ndim, None):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_remove_unused_world_axis(F2_IMAGE):
assert_allclose(new_result, result)
adwcs.remove_unused_world_axis(ad[0])
new_result = ad[0].wcs(900, 800)
assert_allclose(new_result, result[:2])
assert_allclose(new_result, result[-2:])
for frame in ad[0].wcs.available_frames:
assert getattr(ad[0].wcs, frame).naxes == 2

Expand Down
Loading