Source code for nifti_overlay.image

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Classes for overlays generating plot data from single image files.
"""

from abc import ABC, abstractmethod
import pathlib

import nibabel as nib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from skimage import feature

[docs] class Image(ABC): """Generic, abstract class for representing a single image to plot. **This is an abstract class, and should not be used directly.**""" def __init__(self, src: str | pathlib.Path | nib.Nifti1Image): """ Parameters ---------- src : str path, pathlib.Path, or nibabel.Nifti1Image Source image. Raises ------ TypeError `src` is not an accepted type. ValueError Input image is not 3D. Returns ------- None. """ # read different formats if isinstance(src, str): self.nifti = nib.load(src) self.path = src elif isinstance(src, pathlib.Path): path = str(src) self.nifti = nib.load(path) self.path = path elif isinstance(src, nib.Nifti1Image): self.nifti = src self.path = None else: raise TypeError('Acceptable types for initialization are ' 'str, pathlib.Path, or nibabel.Nifti1Image, ' f'not {type(src)}') # set other attributes self.nifti = nib.as_closest_canonical(self.nifti) self.data = self.nifti.get_fdata() # check shape if len(self.shape) != 3: raise ValueError(f'Image must be 3D; image dimensions are: {self.shape}') @property def shape(self): """Tuple specifying the length of each dimension.""" return self.nifti.header.get_data_shape()
[docs] def aspect(self, dimension: int): """For a given dimension (int[0-2]), get the aspect ratio as indicated by the voxel size.""" voxel_dims = list(self.nifti.header["pixdim"][1:4]) del voxel_dims[dimension] x, y = voxel_dims return y/x
[docs] def dimension_shape(self, dimension: int): """For a given dimension (int[0-2]), get the size of the remaining 2 dimensions. When plotting over a given dimension/axis, the image shape will be what is returned here.""" tmp = list(self.shape) del tmp[dimension] rot90 = tmp[1], tmp[0] return rot90
[docs] @abstractmethod def plot_slice( self, dimension: int, position: int, ax: matplotlib.axes.Axes | None, **kwargs): """Plot a 2D slice from one axis (`dimension`) at a given depth (`position`).""" ...
[docs] @abstractmethod def get_slice(self, dimension: int, position: int): """Return the data for a 2D slice from one axis (`dimension`) at a given depth (`position`).""" ...
[docs] class Anatomy(Image): """Class for plotting anatomical images, where the voxel values are typically continuous and should be plotted with a colormap.""" def __init__( self, src: str | pathlib.Path | nib.Nifti1Image, color: str = 'gist_gray', alpha: float = 1.0, scale_panel: bool = False, drop_zero: bool = False, vmin: float | None = None, vmax: float | None = None ): """ Instantiate an Anatomy image. Parameters ---------- src : str | pathlib.Path | nib.Nifti1Image Source image. color : str, optional Colormap key from matplotlib. The default is 'gist_gray'. alpha : float, optional Color transparency. The default is 1.0. scale_panel : bool, optional When plotting, scale the colormap to be the extent of intensities observed in each individual panel, rather than across the whole 3D volume. Can be used to maximize the dynamic range within each panel, but makes comparisons across panels more difficult. The default is False. drop_zero : bool, optional Don't plot voxels which have a value of zero. The default is False. vmin : float | None, optional Bottom limit of color range. The default is None. vmax : float | None, optional Top limit of color range. The default is None. Returns ------- None. """ super().__init__(src) self.color = color self.alpha = alpha self.scale_panel = scale_panel self.drop_zero = drop_zero self.vmin = vmin self.vmax = vmax
[docs] def get_slice(self, dimension: int, position: int): """Return the data to be plotted for a given slice. The data returned are the raw intensities of a slice through the 3D volume. If ``drop_zero`` is specified, the zero values with be replaced with NaN. Parameters ---------- dimension : int Image dimension to plot over (0 for x, 1 for y, 2 for z). position : int Index of the slice to plot. Returns ------- numpy array Imaging slice data. """ data = self.data if self.drop_zero: data = np.where(data == 0, np.nan, data) data = np.ma.array(data, mask=np.isnan(data)).data xsect = np.rot90(np.take(data, indices=position, axis=dimension)) return xsect
[docs] def plot_slice(self, dimension: int, position: int, ax: matplotlib.axes.Axes | None = None, **kwargs): """Plot the data for a given slice. Parameters ---------- dimension : int Image dimension to plot over (0 for x, 1 for y, 2 for z). position : int Index of the slice to plot. ax : None | matplotlib.axes.Axes, optional Matplotlib axes to plot on, by default None, in which case ``plt.gca()`` is used. Returns ------- matplotlib axes Axes after plotting. """ data = self.data.copy() if ax is None: ax = plt.gca() xsect = self.get_slice(dimension, position) # set vmax/vmin if self.vmin: vmin = self.vmin else: vmin = xsect.min() if self.scale_panel else data.min() if self.vmax: vmax = self.vmax else: vmax = xsect.max() if self.scale_panel else data.max() # plot aspect = self.aspect(dimension) return ax.imshow(xsect, cmap=self.color, aspect=aspect, vmin=vmin, vmax=vmax, alpha=self.alpha, **kwargs)
[docs] class Edges(Image): """Class for plotting the contours/edges (automatically detected) of a slice.""" def __init__( self, src: str | pathlib.Path | nib.Nifti1Image, color: str='yellow', alpha: float=1.0, sigma: float=1.0, interpolation: str='none' ): """ Instantiate an Edges image. Parameters ---------- src : str path, pathlib.Path, or nibabel.Nifti1Image Source image. color : str, RGB, RGBA, optional Color understood by matplotlib. The default is 'yellow'. alpha : float, optional Color transparency. The default is 1.0. sigma : float, optional Standard deviation of the Gaussian filter of the Canny edge detector. The default is 1.0. See `skimage.feature.canny`. interpolation : str, optional Type of interpolation for plotting images. The default is 'none'. See matplotlib for more details (`plt.imshow`). Returns ------- None. """ super().__init__(src) self.color = color self.alpha = alpha self.sigma = sigma self.interpolation = interpolation
[docs] def get_slice(self, dimension: int, position: int): """Return the data to be plotted for a given slice. The data for a slice is first indexed, and then the Canny edge detector is run to generate the binary edge image for the slice. Parameters ---------- dimension : int Image dimension to plot over (0 for x, 1 for y, 2 for z). position : int Index of the slice to plot. Returns ------- numpy array Imaging slice data. """ xsect = np.rot90(np.take(self.data, indices=position, axis=dimension)) edges = feature.canny(xsect, sigma=self.sigma) return edges
[docs] def plot_slice(self, dimension: int, position: int, ax: matplotlib.axes.Axes | None = None, **kwargs): """Plot the data for a given slice. Parameters ---------- dimension : int Image dimension to plot over (0 for x, 1 for y, 2 for z). position : int Index of the slice to plot. ax : None | matplotlib.axes.Axes, optional Matplotlib axes to plot on, by default None, in which case ``plt.gca()`` is used. Returns ------- matplotlib axes Axes after plotting. """ if ax is None: ax = plt.gca() edges = self.get_slice(dimension, position) X = np.zeros(edges.shape + (4,)) rgba = matplotlib.colors.to_rgba(self.color, alpha=self.alpha) X[edges] = rgba aspect = self.aspect(dimension) return ax.imshow(X, aspect=aspect, alpha=self.alpha, interpolation=self.interpolation, **kwargs)
[docs] class Mask(Image): """Class for plotting a binary mask, represented as a single color.""" def __init__( self, src: str | pathlib.Path | nib.Nifti1Image, color: str = 'red', alpha: float = 1.0, mask_value: float = 1.0 ): """ Instantiate a Mask image. Parameters ---------- src : str path, pathlib.Path, or nibabel.Nifti1Image Source image. color : str, RGB, RGBA, optional Color understood by matplotlib. The default is 'red'. alpha : float, optional Color transparency. The default is 1.0. mask_value : int or float, optional Value to be plotted. For segmentations with multiple labels, this can be used to specify the single label to be plotted. The default is 1. (Note: to plot all values for a multi-label segmentation, you can plot it as an Anatomy instead of a Mask.) Returns ------- None. """ super().__init__(src) self.color = color self.alpha = alpha self.mask_value = mask_value
[docs] def get_slice(self, dimension: int, position: int): """Return the data to be plotted for a given slice. The intesities for a slice are first extracted. Then, all values matching the `mask_value` are replaced with `1`, and all other values are replaced with NaN. Parameters ---------- dimension : int Image dimension to plot over (0 for x, 1 for y, 2 for z). position : int Index of the slice to plot. Returns ------- numpy array Imaging slice data. """ data = np.where(self.data == self.mask_value, 1., np.nan) # data = np.ma.array(data, mask=np.isnan(data)).data xsect = np.rot90(np.take(data, indices=position, axis=dimension)) return xsect
[docs] def plot_slice( self, dimension: int, position: int, ax: matplotlib.axes.Axes | None = None, _override_color: str | None = None, **kwargs): """Plot the data for a given slice. Parameters ---------- dimension : int Image dimension to plot over (0 for x, 1 for y, 2 for z). position : int Index of the slice to plot. ax : None | matplotlib.axes.Axes, optional Matplotlib axes to plot on, by default None, in which case ``plt.gca()`` is used. _override_color : str | None Set a color which is only used at time of plotting. Mostly for internal purposes. Returns ------- matplotlib axes Axes after plotting. """ if ax is None: ax = plt.gca() if self.color is None and _override_color is None: raise ValueError('Either color attribute must be set, ' 'or `_override_color` must be provided.') color = _override_color if _override_color else self.color cmap = matplotlib.colors.ListedColormap(['black', color]) bounds=[0,.5,2] norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N) xsect = self.get_slice(dimension, position) aspect = self.aspect(dimension) return ax.imshow(xsect, cmap=cmap, norm=norm, aspect=aspect, alpha=self.alpha, **kwargs)