Source code for ndcube.wcs.wrappers.compound_wcs
from functools import reduce
import numpy as np
from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper
__all__ = ['CompoundLowLevelWCS']
def tuplesum(lists):
return reduce(tuple.__add__, map(tuple, lists))
class Mapping:
"""
Allows inputs to be reordered, duplicated or dropped.
This is a very stripped down version of `astropy.modeling.models.Mapping`
to be able to handle input of arbitrary type.
Parameters
----------
mapping : tuple
A tuple of integers representing indices of the inputs to this model
to return and in what order to return them. See
:ref:`compound-model-mappings` for more details.
"""
def __init__(self, mapping):
self.mapping = mapping
self.n_inputs = max(mapping) + 1
self.n_outputs = len(mapping)
def __call__(self, *values):
return tuple(values[idx] for idx in self.mapping)
@property
def inverse(self):
mapping = tuple(self.mapping.index(idx)
for idx in range(self.n_inputs))
return type(self)(mapping)
def __repr__(self):
return f'<Mapping({self.mapping})>'
[docs]
class CompoundLowLevelWCS(BaseWCSWrapper):
"""
A wrapper that takes multiple low level WCS objects and makes a compound
WCS that combines them.
Parameters
----------
*wcs : `~astropy.wcs.wcsapi.BaseLowLevelWCS`
The WCSes to combine
mapping : `tuple`
The pixel dimension mapping between the input pixel dimensions and the
input pixel dimensions to the underlying WCSes. This should have length
equal to the total number of pixel dimensions in all input WCSes and
have a maximum of the number of input pixel dimensions to the resulting
compound WCS -1 (counts from 0). For example ``(0, 1, 2, 1)`` would end
up with the second and fourth pixel dimensions in the input WCSes being
shared, so the compound WCS would have 3 pixel dimensions ``(2 + 1)``.
See :ref:`compound-model-mappings` for more examples of this input
format.
pixel_atol : `float`
A tolerance used to check that the resulting pixel coordinates from
``world_to_pixel`` are the same from all WCSes.
"""
def __init__(self, *wcs, mapping=None, pixel_atol=1e-8):
self._wcs = wcs
if not mapping:
mapping = tuple(range(self._all_pixel_n_dim))
if not len(mapping) == self._all_pixel_n_dim:
raise ValueError(
"The length of the mapping must equal the total number of pixel dimensions in all input WCSes.")
self.mapping = Mapping(mapping)
self.atol = pixel_atol
# Validate the pixel bounds and shape are consistent
self.pixel_bounds
self.pixel_shape
@property
def _all_pixel_n_dim(self):
return sum([w.pixel_n_dim for w in self._wcs])
@property
def pixel_n_dim(self):
return self.mapping.n_inputs
@property
def world_n_dim(self):
return sum([w.world_n_dim for w in self._wcs])
@property
def world_axis_physical_types(self):
return tuplesum([w.world_axis_physical_types for w in self._wcs])
@property
def world_axis_units(self):
return tuplesum([w.world_axis_units for w in self._wcs])
[docs]
def pixel_to_world_values(self, *pixel_arrays):
pixel_arrays = self.mapping(*pixel_arrays)
world_arrays = []
for w in self._wcs:
pixel_arrays_sub = pixel_arrays[:w.pixel_n_dim]
pixel_arrays = pixel_arrays[w.pixel_n_dim:]
world_arrays_sub = w.pixel_to_world_values(*pixel_arrays_sub)
if w.world_n_dim > 1:
world_arrays.extend(world_arrays_sub)
else:
world_arrays.append(world_arrays_sub)
return tuple(world_arrays)
[docs]
def world_to_pixel_values(self, *world_arrays):
pixel_arrays = []
for w in self._wcs:
world_arrays_sub = world_arrays[:w.world_n_dim]
world_arrays = world_arrays[w.world_n_dim:]
pixel_arrays_sub = w.world_to_pixel_values(*world_arrays_sub)
if w.pixel_n_dim > 1:
pixel_arrays.extend(pixel_arrays_sub)
else:
pixel_arrays.append(pixel_arrays_sub)
mapped_axes = set(self.mapping.mapping)
for mapped_axis in mapped_axes:
idx, = np.atleast_1d(self.mapping.mapping == mapped_axis).nonzero()
if len(idx) > 1:
idx_0 = idx[0]
for idx_n in idx[1:]:
if not np.allclose(pixel_arrays[idx_0], pixel_arrays[idx_n],
atol=self.atol, equal_nan=True):
raise ValueError(
"The world inputs for shared pixel axes did not result in a pixel "
f"coordinate to within {self.atol} relative accuracy."
)
return self.mapping.inverse(*pixel_arrays)
@property
def world_axis_object_components(self):
all_components = []
for iw, w in enumerate(self._wcs):
for component in w.world_axis_object_components:
all_components.append((f'{component[0]}_{iw}',) + component[1:])
return all_components
@property
def world_axis_object_classes(self):
# TODO: deal with name conflicts
all_classes = {}
for iw, w in enumerate(self._wcs):
for key, value in w.world_axis_object_classes.items():
all_classes[f'{key}_{iw}'] = value
return all_classes
@property
def pixel_shape(self):
if not any(w.array_shape is None for w in self._wcs):
pixel_shape = tuplesum(w.pixel_shape for w in self._wcs)
out_shape = self.mapping.inverse(*pixel_shape)
for i, ix in enumerate(self.mapping.mapping):
if out_shape[ix] != pixel_shape[i]:
raise ValueError(
"The pixel shapes of the supplied WCSes do not match for the dimensions shared by the supplied mapping.")
return out_shape
@property
def pixel_bounds(self):
if not any(w.pixel_bounds is None for w in self._wcs):
pixel_bounds = tuplesum(w.pixel_bounds for w in self._wcs)
out_bounds = self.mapping.inverse(*pixel_bounds)
for i, ix in enumerate(self.mapping.mapping):
if out_bounds[ix] != pixel_bounds[i]:
raise ValueError(
"The pixel bounds of the supplied WCSes do not match for the dimensions shared by the supplied mapping.")
return out_bounds
@property
def pixel_axis_names(self):
pixel_names = tuplesum(w.pixel_axis_names for w in self._wcs)
out_names = self.mapping.inverse(*pixel_names)
for i, ix in enumerate(self.mapping.mapping):
if out_names[ix] != pixel_names[i]:
out_names[ix] = ' / '.join([out_names[ix], pixel_names[i]])
return out_names
@property
def world_axis_names(self):
return tuplesum(w.world_axis_names for w in self._wcs)
@property
def axis_correlation_matrix(self):
full_matrix = np.zeros((self.world_n_dim, self._all_pixel_n_dim), dtype=bool)
iw = ip = 0
for w in self._wcs:
full_matrix[iw:iw + w.world_n_dim, ip:ip + w.pixel_n_dim] = w.axis_correlation_matrix
iw += w.world_n_dim
ip += w.pixel_n_dim
matrix = np.zeros((self.world_n_dim, self.pixel_n_dim), dtype=bool)
for i, ix in enumerate(self.mapping.mapping):
matrix[:, ix] = np.logical_or(matrix[:, ix], full_matrix[:, i])
return matrix
@property
def serialized_classes(self):
return any([w.serialized_classes for w in self._wcs])