Source code for mpl_animators.image

import matplotlib as mpl

from .base import ArrayAnimator

__all__ = ['ImageAnimator']


[docs] class ImageAnimator(ArrayAnimator): """ Create a matplotlib backend independent data explorer for 2D images. The following keyboard shortcuts are defined in the viewer: * 'left': previous step on active slider. * 'right': next step on active slider. * 'top': change the active slider up one. * 'bottom': change the active slider down one. * 'p': play/pause active slider. This viewer can have user defined buttons added by specifying the labels and functions called when those buttons are clicked as keyword arguments. Parameters ---------- data: `numpy.ndarray` The data to be visualized. image_axes: `list`, optional A list of the axes order that make up the image. axis_ranges: `list` of physical coordinates for the `numpy.ndarray`, optional Defaults to `None` and array indices will be used for all axes. The `list` should contain one element for each axis of the `numpy.ndarray`. For the image axes a ``[min, max]`` pair should be specified which will be passed to `matplotlib.pyplot.imshow` as an extent. For the slider axes a ``[min, max]`` pair can be specified or an array the same length as the axis which will provide all values for that slider. Notes ----- Extra keywords are passed to `~sunpy.visualization.animator.ArrayAnimator`. """ def __init__(self, data, image_axes=[-2, -1], axis_ranges=None, **kwargs): # Check that number of axes is 2. if len(image_axes) != 2: raise ValueError("There can only be two spatial axes") # Define number of slider axes. self.naxis = data.ndim self.num_sliders = self.naxis-2 # Define marker to determine if plot axes values are supplied via array of # pixel values or min max pair. This will determine the type of image produced # and hence how to plot and update it. self._non_regular_plot_axis = False # Run init for parent class super().__init__(data, image_axes=image_axes, axis_ranges=axis_ranges, **kwargs)
[docs] def plot_start_image(self, ax): """ Sets up plot of initial image. """ # Create extent arg extent = [] # reverse because numpy is in y-x and extent is x-y if max([len(self.axis_ranges[i]) for i in self.image_axes[::-1]]) > 2: self._non_regular_plot_axis = True for i in self.image_axes[::-1]: if self._non_regular_plot_axis is False and len(self.axis_ranges[i]) > 2: self._non_regular_plot_axis = True extent.append(self.axis_ranges[i][0]) extent.append(self.axis_ranges[i][-1]) imshow_args = {'interpolation': 'nearest', 'origin': 'lower'} imshow_args.update(self.imshow_kwargs) # If value along an axis is set with an array, generate a NonUniformImage if self._non_regular_plot_axis: # If user has inverted the axes, transpose the data so the dimensions match. if self.image_axes[0] < self.image_axes[1]: data = self.data[self.frame_index].transpose() else: data = self.data[self.frame_index] # Initialize a NonUniformImage with the relevant data and axis values and # add the image to the axes. im = mpl.image.NonUniformImage(ax, **imshow_args) im.set_data(self.axis_ranges[self.image_axes[0]], self.axis_ranges[self.image_axes[1]], data) ax.add_image(im) # Define the xlim and ylim from the pixel edges. ax.set_xlim(self.extent[0], self.extent[1]) ax.set_ylim(self.extent[2], self.extent[3]) else: # Else produce a more basic plot with regular axes. imshow_args.update({'extent': extent}) im = ax.imshow(self.data[self.frame_index], **imshow_args) if self.if_colorbar: self._add_colorbar(im) return im
[docs] def update_plot(self, val, im, slider): """ Updates plot based on slider/array dimension being iterated. """ ind = int(val) ax_ind = self.slider_axes[slider.slider_ind] self.frame_slice[ax_ind] = ind if val != slider.cval: if self._non_regular_plot_axis: if self.image_axes[0] < self.image_axes[1]: data = self.data[self.frame_index].transpose() else: data = self.data[self.frame_index] im.set_data(self.axis_ranges[self.image_axes[0]], self.axis_ranges[self.image_axes[1]], data) else: im.set_array(self.data[self.frame_index]) slider.cval = val # Update slider label to reflect real world values in axis_ranges. super().update_plot(val, im, slider)