Source code for mpl_animators.wcs

from functools import partial

import astropy.units as u
import numpy as np
from astropy.visualization import AsymmetricPercentileInterval
from astropy.wcs.wcsapi import BaseLowLevelWCS

from mpl_animators.extern import modest_image

from .base import ArrayAnimator

__all__ = ['ArrayAnimatorWCS']

[docs]class ArrayAnimatorWCS(ArrayAnimator): """ Animate an array with associated `~astropy.wcs.wcsapi.BaseLowLevelWCS` object. The following keyboard shortcuts are defined in the viewer: * 'left': previous step on active slider. * 'right': next step on active slider. * 'top': change the active slider up one. * 'bottom': change the active slider down one. * 'p': play/pause active slider. Parameters ---------- data: `numpy.ndarray` The data to be visualized. wcs: `astropy.wcs.wcsapi.BaseLowLevelWCS` The world coordinate object associated with the array. slices: `tuple` or `list` A list specifying which axes of the array should be plotted on which axes. The list should be the same length as the number of pixel dimensions with ``'x'`` and (optionally) ``'y'`` in the elements corresponding to the axes to be plotted. If only ``'x'`` is present a line plot will be drawn. All other elements should be ``0``. coord_params: `dict`, optional This dict allows you to override `~astropy.visualization.wcsaxes.WCSAxes` parameters for each world coordinate. The keys of this dictionary should be a value which can be looked up in ``WCSAxes.coords`` (i.e. ``em.wl`` or ``hpln``) and the values should be a dict which supports the following keys, and passes their values to the associated `~astropy.visualization.wcsaxes.WCSAxes` methods. * ``format_unit``: `~astropy.visualization.wcsaxes.CoordinateHelper.set_format_unit` * ``major_formatter``: `~astropy.visualization.wcsaxes.CoordinateHelper.set_major_formatter` * ``axislabel``: `~astropy.visualization.wcsaxes.CoordinateHelper.set_axislabel` * ``grid``: `~astropy.visualization.wcsaxes.CoordinateHelper.grid` (The value should be a dict of keyword arguments to ``grid()`` or `True`). * ``ticks``: `dict` or `bool` the keyword arguments to the `~astropy.visualization.wcsaxes.CoordinateHelper.set_ticks` method, or `False` to display no ticks for this coord. ylim: `tuple` or `str`, optional The yaxis limits to use when drawing a line plot, if 'fixed' then use the global data limits, if 'dynamic' then set the y limit for each frame individually (meaning the y limits change as you animate). ylabel: `string`, optional The yaxis label to use when drawing a line plot. Setting the label on the y-axis on an image plot should be done via ``coord_params``. clip_interval : two-element `~astropy.units.Quantity`, optional If provided, the data for each step will be clipped to the percentile interval bounded by the two numbers. """ def __init__(self, data, wcs, slices, coord_params=None, ylim='dynamic', ylabel=None, clip_interval: u.percent = None, **kwargs): if not isinstance(wcs, BaseLowLevelWCS): raise ValueError("A WCS object should be provided that implements the astropy WCS API.") if wcs.pixel_n_dim != data.ndim: raise ValueError("Dimensionality of the data and WCS object do not match.") if len(slices) != wcs.pixel_n_dim: raise ValueError("slices should be the same length as the number of pixel dimensions.") if "x" not in slices: raise ValueError( "slices should contain at least 'x' to indicate the axis to plot on the x axis.") self.plot_dimensionality = 1 image_axes = [slices[::-1].index("x")] if "y" in slices: image_axes.append(slices[::-1].index("y")) self.plot_dimensionality = 2 self.naxis = data.ndim self.num_sliders = self.naxis - self.plot_dimensionality self.slices_wcsaxes = list(slices) self.wcs = wcs self.coord_params = coord_params self.ylim = ylim self.ylabel = ylabel if clip_interval is not None and len(clip_interval) != 2: raise ValueError('A range of 2 values must be specified for clip_interval.') self.clip_interval = clip_interval extra_slider_labels = [] if "slider_functions" in kwargs and "slider_labels" not in kwargs: extra_slider_labels = [a.__name__ for a in kwargs['slider_functions']] slider_labels = self._compute_slider_labels_from_wcs(slices) + extra_slider_labels super().__init__(data, image_axes=image_axes, axis_ranges=None, slider_labels=slider_labels, **kwargs) def _get_wcs_labels(self): """ Read first the axes names property of the wcs and fall back to physical types. """ # Return the name if it is set, or the physical type if it is not. return [l or t for l, t in zip(self.wcs.world_axis_names, self.wcs.world_axis_physical_types)] def _compute_slider_labels_from_wcs(self, slices): """ For each pixel dimension, not used in the plot, calculate the world names which are correlated with that pixel dimension. This can return more than one world name per pixel dimension (i.e. lat & lon) so join them if there are. """ labels = [] wal = np.array(self._get_wcs_labels()) pixel_indicies = np.array([a not in ['x', 'y'] for a in slices]) for sliced_axis in self.wcs.axis_correlation_matrix[:, pixel_indicies].T: labels.append(" / ".join(list(map(str, wal[sliced_axis])))) return labels[::-1] def _partial_pixel_to_world(self, pixel_dimension, pixel_coord): """ Return the world coordinate along one axis, if it is only correlated to that axis. """ wcs_dimension = self.wcs.pixel_n_dim - pixel_dimension - 1 corr = self.wcs.axis_correlation_matrix[:, wcs_dimension] # If more than one world axis is linked to this dimension we can't # display the world coordinate because we have no way of picking, # so we just display pixel index. if len(np.nonzero(corr)[0]) != 1: return pixel_coord * u.pix # We know that the coordinate we care about is independent of the # other axes, so we can set the pixel coordinates to 0. coords = [0] * self.wcs.pixel_n_dim coords[wcs_dimension] = pixel_coord wc = self.wcs.pixel_to_world_values(*coords)[wcs_dimension] return u.Quantity(wc, unit=self.wcs.world_axis_units[wcs_dimension]) def _sanitize_axis_ranges(self, *args): """ This overrides the behaviour of ArrayAnimator to generate axis_ranges based on the WCS. """ axis_ranges = [None] * self.wcs.pixel_n_dim for i in self.slider_axes: axis_ranges[i] = partial(self._partial_pixel_to_world, i) return axis_ranges, None def _apply_coord_params(self, axes): if self.coord_params is None: return for coord_name in self.coord_params: coord = axes.coords[coord_name] params = self.coord_params[coord_name] format_unit = params.get("format_unit", None) if format_unit: coord.set_format_unit(format_unit) major_formatter = params.get("major_formatter", None) if major_formatter: coord.set_major_formatter(major_formatter) axislabel = params.get("axislabel", None) if axislabel: coord.set_axislabel(axislabel) grid = params.get("grid", None) if grid is not None: if not isinstance(grid, dict): grid = {} coord.grid(**grid) ticks = params.get("ticks", None) if ticks is not None: if isinstance(ticks, bool): coord.set_ticks_visible(ticks) coord.set_ticklabel_visible(ticks) elif isinstance(ticks, dict): coord.set_ticks(**ticks) else: raise TypeError( "The 'ticks' value in the coord_params dictionary must be a dict or a boolean." ) def _setup_main_axes(self): self.axes = self.fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=self.wcs, slices=self.slices_wcsaxes) self._apply_coord_params(self.axes)
[docs] def plot_start_image(self, ax): if self.plot_dimensionality == 1: artist = self.plot_start_image_1d(ax) elif self.plot_dimensionality == 2: artist = self.plot_start_image_2d(ax) return artist
[docs] def update_plot(self, val, artist, slider): """ Update the plot when a slider changes. This method both updates the state of the Animator and also re-draws the matplotlib artist. """ ind = int(val) if ind == int(slider.cval): return ax_ind = self.slider_axes[slider.slider_ind] self.frame_slice[ax_ind] = ind self.slices_wcsaxes[self.wcs.pixel_n_dim - ax_ind - 1] = ind if self.plot_dimensionality == 1: self.update_plot_1d(val, artist, slider) elif self.plot_dimensionality == 2: self.update_plot_2d(val, artist, slider) self._apply_coord_params(self.axes) return super().update_plot(val, artist, slider)
[docs] def plot_start_image_1d(self, ax): """ Set up a line plot. When plotting with WCSAxes, we always plot against pixel coordinate. """ if self.ylim != 'dynamic': ylim = self.ylim if ylim == 'fixed': ylim = (, ax.set_ylim(ylim) if self.ylabel: ax.set_ylabel(self.ylabel) ydata =[self.frame_index] line, = ax.plot(ydata, **self.imshow_kwargs) if isinstance(, ax.set_xlim((0, ydata.shape[0])) return line
@property def data_transposed(self): """ Return data for 2D plotting, transposed if needed. """ if self.slices_wcsaxes.index('y') < self.slices_wcsaxes.index("x"): return[self.frame_index].transpose() else: return[self.frame_index]
[docs] def update_plot_1d(self, val, line, slider): """ Update the line plot. """ self.axes.reset_wcs(wcs=self.wcs, slices=self.slices_wcsaxes) line.set_ydata([self.frame_index]) # If we are not setting ylim globally then we set it per frame. if self.ylim == 'dynamic': self.axes.set_ylim([self.frame_index].min(),[self.frame_index].max()) slider.cval = val
[docs] def plot_start_image_2d(self, ax): """ Setup an image plot. """ imshow_args = {'interpolation': 'nearest', 'origin': 'lower'} imshow_args.update(self.imshow_kwargs) if self.clip_interval is not None: imshow_args['vmin'], imshow_args['vmax'] = self._get_2d_plot_limits() im = modest_image.imshow(ax, self.data_transposed, **imshow_args) if 'extent' in imshow_args: ax.set_xlim(imshow_args['extent'][:2]) ax.set_ylim(imshow_args['extent'][2:]) else: ny, nx = self.data_transposed.shape ax.set_xlim(-0.5, nx - 0.5) ax.set_ylim(-0.5, ny - 0.5) ax.dataLim.intervalx = ax.get_xlim() ax.dataLim.intervaly = ax.get_ylim() if self.if_colorbar: self._add_colorbar(im) return im
def _get_2d_plot_limits(self): """ Get vmin, vmax of a data slice when clip_interval is specified. """ percent_limits ='%').value vmin, vmax = AsymmetricPercentileInterval(*percent_limits).get_limits(self.data_transposed) return vmin, vmax
[docs] def update_plot_2d(self, val, im, slider): """ Update the image plot. """ self.axes.reset_wcs(wcs=self.wcs, slices=self.slices_wcsaxes) im.set_array(self.data_transposed) if self.clip_interval is not None: vmin, vmax = self._get_2d_plot_limits() im.set_clim(vmin, vmax) slider.cval = val