"""
This module contains functions that will segment images for granule detection.
"""
import logging
import matplotlib as mpl
import numpy as np
import scipy
import skimage
import sunpy
import sunpy.map
__all__ = ["segment", "segments_overlap_fraction"]
[docs]
def segment(smap, *, skimage_method="li", mark_dim_centers=False, bp_min_flux=None):
"""
Segment an optical image of the solar photosphere into tri-value maps with:
* 0 as intergranule
* 1 as granule
* 2 as brightpoint
If mark_dim_centers is set to True, an additional label, 3, will be assigned to
dim granule centers.
Parameters
----------
smap : `~sunpy.map.GenericMap`
`~sunpy.map.GenericMap` containing data to segment. Must have square pixels.
skimage_method : {"li", "otsu", "isodata", "mean", "minimum", "yen", "triangle"}, optional
scikit-image thresholding method, defaults to "li".
Depending on input data, one or more of these methods may be
significantly better or worse than the others. Typically, 'li', 'otsu',
'mean', and 'isodata' are good choices, 'yen' and 'triangle' over-
identify intergranule material, and 'minimum' over identifies granules.
mark_dim_centers : `bool`, optional
Whether to mark dim granule centers as a separate category for future exploration.
bp_min_flux : `float`, optional
Minimum flux per pixel for a region to be considered a brightpoint.
Default is `None` which will use data mean + 0.5 * sigma.
Returns
-------
segmented_map : `~sunpy.map.GenericMap`
`~sunpy.map.GenericMap` containing a segmented image (with the original header).
"""
if not isinstance(smap, sunpy.map.mapbase.GenericMap):
msg = "Input must be an instance of a sunpy.map.GenericMap"
raise TypeError(msg)
if smap.scale[0].value == smap.scale[1].value:
resolution = smap.scale[0].value
else:
msg = "Currently only maps with square pixels are supported."
raise ValueError(msg)
# Obtain local histogram equalization of map.
# min-max normalization to [0, 1]
map_norm = (smap.data - np.nanmin(smap.data)) / (np.nanmax(smap.data) - np.nanmin(smap.data))
map_he = skimage.filters.rank.equalize(
skimage.util.img_as_ubyte(map_norm),
footprint=skimage.morphology.disk(radius=100),
)
# Apply initial skimage threshold.
median_filtered = scipy.ndimage.median_filter(map_he, size=3)
threshold = _get_threshold(median_filtered, skimage_method)
segmented_image = np.uint8(median_filtered > threshold)
# Fix the extra intergranule material bits in the middle of granules.
seg_im_fixed = _trim_intergranules(segmented_image, mark=mark_dim_centers)
# Mark brightpoint and get final granule and brightpoint count.
seg_im_markbp, brightpoint_count, granule_count = _mark_brightpoint(
seg_im_fixed,
smap.data,
map_he,
resolution,
bp_min_flux,
)
logging.info(f"Segmentation has identified {granule_count} granules and {brightpoint_count} brightpoint") # NOQA: G004
# Create output map using input wcs and adding colormap such that 0 (intergranules) = black, 1 (granule) = white, 2 (brightpoints) = yellow, 3 (dim_centers) = blue.
segmented_map = sunpy.map.Map(seg_im_markbp, smap.wcs)
cmap = mpl.colors.ListedColormap(["black", "white", "#ffc406", "blue"])
norm = mpl.colors.BoundaryNorm(boundaries=[-0.5, 0.5, 1.5, 2.5, 3.5], ncolors=cmap.N)
segmented_map.plot_settings["cmap"] = cmap
segmented_map.plot_settings["norm"] = norm
return segmented_map
def _get_threshold(data, method):
"""
Get the threshold value using given skimage segmentation type.
Parameters
----------
data : `numpy.ndarray`
Data to threshold.
method : {"li", "otsu", "isodata", "mean", "minimum", "yen", "triangle"}
scikit-image thresholding method.
Returns
-------
threshold : `float`
Threshold value.
"""
if not isinstance(data, np.ndarray):
msg = "Input data must be an instance of a np.ndarray"
raise TypeError(msg)
if len(data.flatten()) > 500**2:
logging.info(
"Input image is large (> 500**2), so threshold computation will be based on a random 500x500 sample of pixels",
)
rng = np.random.default_rng()
# Computing threshold based on random sample works well and saves significant computational time
data = rng.choice(
data.flatten(),
(500, 500),
)
method = method.lower()
method_funcs = {
"li": skimage.filters.threshold_li,
"otsu": skimage.filters.threshold_otsu,
"yen": skimage.filters.threshold_yen,
"mean": skimage.filters.threshold_mean,
"minimum": skimage.filters.threshold_minimum,
"triangle": skimage.filters.threshold_triangle,
"isodata": skimage.filters.threshold_isodata,
}
if method not in method_funcs:
raise ValueError("Method must be one of: " + ", ".join(list(method_funcs.keys())))
return method_funcs[method](data)
def _trim_intergranules(segmented_image, *, mark=False):
"""
Remove the erroneous identification of intergranule material in the middle
of granules that the pure threshold segmentation produces.
Parameters
----------
segmented_image : `numpy.ndarray`
The segmented image containing incorrect extra intergranules.
mark : `bool`
If `False` (the default), remove erroneous intergranules.
If `True`, mark them as 3 instead (for later examination).
Returns
-------
segmented_image_fixed : `numpy.ndarray`
The segmented image without incorrect extra intergranules.
"""
if len(np.unique(segmented_image)) > 2:
msg = "segmented_image must only have values of 1 and 0."
raise ValueError(msg)
# Float conversion for correct region labeling.
segmented_image_fixed = np.copy(segmented_image).astype(float)
# Add padding of intergranule around edges.
# Avoids the case where all edge pixels are granule,
# which will result in all dim centers as intergranules.
pad = int(np.shape(segmented_image)[0] / 200)
segmented_image_fixed[:, 0:pad] = 0
segmented_image_fixed[0:pad, :] = 0
segmented_image_fixed[:, -pad:] = 0
segmented_image_fixed[-pad:, :] = 0
labeled_seg = skimage.measure.label(segmented_image_fixed + 1, connectivity=2)
values = np.unique(labeled_seg)
# Find value of the large continuous 0-valued region.
size = 0
for value in values:
if len(labeled_seg[labeled_seg == value]) > size and sum(segmented_image[labeled_seg == value] == 0):
real_IG_value = value
size = len(labeled_seg[labeled_seg == value])
# Set all other 0 regions to mark value (3).
for value in values:
if np.sum(segmented_image[labeled_seg == value]) == 0 and value != real_IG_value:
segmented_image_fixed[labeled_seg == value] = 3 if mark else 1
return segmented_image_fixed
def _mark_brightpoint(segmented_image, data, he_data, resolution, bp_min_flux=None):
"""
Mark brightpoints separately from granules - give them a value of 2.
Parameters
----------
segmented_image : `numpy.ndarray`
The segmented image containing incorrect middles.
data : `numpy array`
The original image.
he_data : `numpy array`
Original image with local histogram equalization applied.
resolution : `float`
Spatial resolution (arcsec/pixel) of the data.
bp_min_flux : `float`, optional
Minimum flux per pixel for a region to be considered a brightpoint.
Default is `None` which will use data mean + 0.5 * sigma.
Returns
-------
segmented_image_fixed : `numpy.ndrray`
The segmented image with brightpoints marked as 2.
brightpoint_count: `int`
The number of brightpoints identified in the image.
granule_count: `int`
The number of granules identified, after re-classifcation of brightpoint.
"""
# General size limits
bp_size_limit = (
0.1 # Approximate max size of a photosphere bright point in square arcsec (see doi 10.3847/1538-4357/aab150)
)
bp_pix_upper_limit = (bp_size_limit / resolution) ** 2 # Max area in pixels
bp_pix_lower_limit = 4 # Very small bright regions are likely artifacts
# General flux limit determined by visual inspection (set using equalized map)
if bp_min_flux is None:
stand_devs = 1.25 # General flux limit determined by visual inspection (set using equalized map)
bp_brightness_limit = np.nanmean(he_data) + stand_devs * np.nanstd(he_data)
else:
bp_brightness_limit = bp_min_flux
if len(np.unique(segmented_image)) > 3:
msg = "segmented_image must have only values of 1, 0 and 3 (if dim centers marked)"
raise ValueError(msg)
# Obtain gradient map and set threshold for gradient on BP edges
grad = np.abs(np.gradient(data)[0] + np.gradient(data)[1])
bp_min_grad = np.quantile(grad, 0.95)
# Label all regions of flux greater than brightness limit (candidate regions)
bright_dim_seg = np.zeros_like(data)
bright_dim_seg[he_data > bp_brightness_limit] = 1
labeled_bright_dim_seg = skimage.measure.label(bright_dim_seg + 1, connectivity=2)
values = np.unique(labeled_bright_dim_seg)
# From candidate regions, select those within pixel limit and gradient limit
segmented_image_fixed = np.copy(segmented_image.astype(float)) # Make type float to enable adding float values
bp_count = 0
for value in values:
if (bright_dim_seg[labeled_bright_dim_seg == value])[0] == 1: # Check region is not the non-bp region
# check that region is within pixel limits.
region_size = len(labeled_bright_dim_seg[labeled_bright_dim_seg == value])
if region_size < bp_pix_upper_limit and region_size > bp_pix_lower_limit:
# check that region has high average gradient (maybe try max gradient?)
region_mean_grad = np.mean(grad[labeled_bright_dim_seg == value])
if region_mean_grad > bp_min_grad:
segmented_image_fixed[labeled_bright_dim_seg == value] = 2
bp_count += 1
gran_count = len(values) - 1 - bp_count # Subtract 1 for IG region.
return segmented_image_fixed, bp_count, gran_count
[docs]
def segments_overlap_fraction(segment1, segment2):
"""
Compute the fraction of overlap between two segmented
`~sunpy.map.GenericMap`.
Designed for comparing output Map from `segment` with other segmentation methods.
Parameters
----------
segment1: `~sunpy.map.GenericMap`
Main `~sunpy.map.GenericMap` to compare against. Must have 0 = intergranule, 1 = granule.
segment2 :`~sunpy.map.GenericMap`
Comparison `~sunpy.map.GenericMap`. Must have 0 = intergranule, 1 = granule.
As an example, this could come from a simple segment using sklearn.cluster.KMeans
Returns
-------
confidence : `float`
The numeric confidence metric: 0 = no agreement and 1 = complete agreement.
"""
segment1 = np.array(segment1.data)
segment2 = np.array(segment2.data)
total_granules = np.count_nonzero(segment1 == 1)
total_intergranules = np.count_nonzero(segment1 == 0)
if total_granules == 0:
msg = "No granules in `segment1`. It is possible the clustering failed."
raise ValueError(msg)
if total_intergranules == 0:
msg = "No intergranules in `segment1`. It is possible the clustering failed."
raise ValueError(msg)
granule_agreement_count = 0
intergranule_agreement_count = 0
granule_agreement_count = ((segment1 == 1) * (segment2 == 1)).sum()
intergranule_agreement_count = ((segment1 == 0) * (segment2 == 0)).sum()
percentage_agreement_granules = granule_agreement_count / total_granules
percentage_agreement_intergranules = intergranule_agreement_count / total_intergranules
return np.mean([percentage_agreement_granules, percentage_agreement_intergranules])