diff --git a/astrodata/wcs.py b/astrodata/wcs.py index 8d02bed..359db52 100644 --- a/astrodata/wcs.py +++ b/astrodata/wcs.py @@ -253,7 +253,7 @@ def gwcs_to_fits(ndd, hdr=None): wcs_dict.update( { - f"CD{i+1}_{j+1}": 0.0 + f"CD{i + 1}_{j + 1}": 0.0 for j in range(nworld_axes) for i in range(nworld_axes) } @@ -300,8 +300,8 @@ def gwcs_to_fits(ndd, hdr=None): lat_axis = world_axes.index("lat") world_axes[lon_axis] = f"RA---{projcode}" world_axes[lat_axis] = f"DEC--{projcode}" - wcs_dict[f"CRVAL{lon_axis+1}"] = nat2cel.lon.value - wcs_dict[f"CRVAL{lat_axis+1}"] = nat2cel.lat.value + wcs_dict[f"CRVAL{lon_axis + 1}"] = nat2cel.lon.value + wcs_dict[f"CRVAL{lat_axis + 1}"] = nat2cel.lat.value # Remove projection parts so we can calculate the CD matrix if projcode: @@ -422,7 +422,7 @@ def gwcs_to_fits(ndd, hdr=None): # Require an inverse to write out wcs_dict.update( { - f"CD{i+1}_{j+1}": affine_matrix[i, j] + f"CD{i + 1}_{j + 1}": affine_matrix[i, j] for j, _ in enumerate(shape) for i, _ in enumerate(world_axes) } @@ -437,7 +437,7 @@ def gwcs_to_fits(ndd, hdr=None): } ) - crval = [wcs_dict[f"CRVAL{i+1}"] for i, _ in enumerate(world_axes)] + crval = [wcs_dict[f"CRVAL{i + 1}"] for i, _ in enumerate(world_axes)] try: crval[lon_axis] = 0 @@ -520,7 +520,7 @@ def gwcs_to_fits(ndd, hdr=None): # To ensure an invertable CD matrix, we need to get nonexistent pixel axes # "involved". for j in range(len(shape), nworld_axes): - wcs_dict[f"CD{nworld_axes}_{j+1}"] = 1 + wcs_dict[f"CD{nworld_axes}_{j + 1}"] = 1 return wcs_dict @@ -697,7 +697,7 @@ def read_wcs_from_header(header): this_ctype = header[f"CTYPE{i}"] except KeyError: - this_ctype = f"LINEAR{untyped_axes+1 if untyped_axes else ''}" + this_ctype = f"LINEAR{untyped_axes + 1 if untyped_axes else ''}" untyped_axes += 1 ctype.append(this_ctype) @@ -904,7 +904,13 @@ def make_fitswcs_transform(trans_input): other_models = fitswcs_other(wcs_info, other=other) all_models = other_models if sky_model: + i = -1 + + for i, m in enumerate(all_models): + m.meta["output_axes"] = [i] + all_models.append(sky_model) + sky_model.meta["output_axes"] = [i + 1, i + 2] # Now arrange the models so the inputs and outputs are in the right places all_models.sort(key=lambda m: m.meta["output_axes"][0]) @@ -1151,11 +1157,13 @@ def remove_axis_from_frame(frame, axis): new_frames.append(remove_axis_from_frame(f, axis)) else: - new_frames.append(deepcopy(f)) - f._axes_order = tuple( + new_frame = deepcopy(f) + new_frame._axes_order = tuple( x if x < axis else x - 1 for x in f.axes_order ) + new_frames.append(new_frame) + if len(new_frames) == 1: ret_frame = deepcopy(new_frames[0]) ret_frame.name = frame.name @@ -1358,8 +1366,9 @@ def remove_unused_world_axis(ext): """ ndim = len(ext.shape) affine = calculate_affine_matrices(ext.wcs.forward_transform, ext.shape) + # Check whether there's a single output that isn't affected by the input - removable_axes = np.all(affine.matrix[:, ndim - 1 :] == 0, axis=1) + removable_axes = np.all(affine.matrix == 0, axis=1) removable_axes = removable_axes[::-1] # xyz order if removable_axes.sum() == 1: @@ -1372,13 +1381,14 @@ def remove_unused_world_axis(ext): new_pipeline = [] for step in reversed(ext.wcs.pipeline): frame, transform = step.frame, step.transform - if axis < frame.naxes: - frame = remove_axis_from_frame(frame, axis) if transform is not None: if axis < transform.n_outputs: transform, axis = remove_axis_from_model(transform, axis) + if axis is not None and axis < frame.naxes: + frame = remove_axis_from_frame(frame, axis) + new_pipeline = [(frame, transform)] + new_pipeline if axis not in (ndim, None): diff --git a/tests/unit/test_wcs.py b/tests/unit/test_wcs.py index 67f2eff..33095ab 100644 --- a/tests/unit/test_wcs.py +++ b/tests/unit/test_wcs.py @@ -170,7 +170,7 @@ def test_remove_unused_world_axis(F2_IMAGE): assert_allclose(new_result, result) adwcs.remove_unused_world_axis(ad[0]) new_result = ad[0].wcs(900, 800) - assert_allclose(new_result, result[:2]) + assert_allclose(new_result, result[-2:]) for frame in ad[0].wcs.available_frames: assert getattr(ad[0].wcs, frame).naxes == 2