Source code for nifti_overlay.core

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Module containing code for generating images of NIFTI images overlaid on each other.
"""

from itertools import cycle
import os
import pathlib
from typing import Tuple

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

from nifti_overlay.image import Anatomy, Edges, Mask
from nifti_overlay.multiimage import CheckerBoard, MultiImage

[docs] class NiftiOverlay: """The main tool of the package, which creates tiled plots of one or more NIFTI images. """ def __init__( self, planes: str = 'xyz', nslices: int = 7, transpose: bool = False, min_all: None | float = None, max_all: None | float = None, minx: float = 0.15, maxx: float = 0.85, miny: float = 0.15, maxy: float = 0.85, minz: float = 0.15, maxz: float = 0.85, background: str = 'black', figsize: Tuple | str = 'automatic', dpi: int = 200, verbose: bool = False ): """Initialize a NiftiOverlay object. Parameters ---------- planes : str, optional String specifying the number of rows being plotted, and the dimension of the image being plotted on each row. The length of the string indicates the number of rows being plotted, and the specific characters indicate the image dimensions being plotted ("x", "y", or "z") and their order. Characters can be repeated or omitted. nslices : int, optional Integer specifying the number of slices to plot per plane, by default 7. transpose : bool, optional Make the overlay have shape [slices, planes] instead of [planes, slices], by default False. min_all : None | float, optional Proportion specifying the minimum extent of the image dimension over which to sample slices, by default None, in which case the minx, miny, and minz parameters are used. max_all : None | float, optional Proportion specifying the maximum extent of the image dimension over which to sample slices, by default None, in which case the minx, miny, and minz parameters are used. minx : float, optional Minimum slice sampling limit (proportion) in the x dimension, by default 0.15 maxx : float, optional Maximum slice sampling limit (proportion) in the x dimension, by default 0.85 miny : float, optional Minimum slice sampling limit (proportion) in the y dimension, by default 0.15 maxy : float, optional Maximum slice sampling limit (proportion) in the y dimension, by default 0.85 minz : float, optional Minimum slice sampling limit (proportion) in the z dimension, by default 0.15 maxz : float, optional Maximum slice sampling limit (proportion) in the z dimension, by default 0.85 background : str, optional Color passed to matplotlib which sets the background color of panels, by default 'black' figsize : Tuple | str, optional Figure size, by default 'automatic', which provides a standard setting. Otherwise, a 2-tuple can be passed specifying the figure dimensions in inches. dpi : int, optional Figure DPI, by default 200 verbose : bool, optional Report progress while generating overlays, by default False """ # user-supplied attributes self.planes = planes self.nslices = nslices self.transpose = transpose self.minx = minx self.miny = miny self.minz = minz self.maxx = maxx self.maxy = maxy self.maxz = maxz self.min_all = min_all self.max_all = max_all self.background = background self.figsize = figsize self.dpi = dpi self.verbose = verbose # matplotlib stuff self.fig = None self.axes = None # holder for images to be plotted self.images = [] # other variables self.planes_to_idx = {'x': 0, 'y': 1, 'z': 2} self.color_cycle = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color']) self.print = print if self.verbose else lambda *args, **kwargs: None self.automatic_figsize_scale = 1.0 @property def nrows(self): """Number of rows of panels that will be plotted.""" return len(self.planes) if not self.transpose else self.nslices @property def ncols(self): """Number of columns of panels that will be plotted.""" return self.nslices if not self.transpose else len(self.planes) @property def paddings(self): """Get the slice sampling extent in each dimension. Returns ------- dict Dictionary with keys "x", "y", and "z". For each, the value are the slice sampling limits provided as a 2-tuple (minimum, maximum). """ minx = self.minx miny = self.miny minz = self.minz if self.min_all is not None: minx = self.min_all miny = self.min_all minz = self.min_all maxx = self.maxx maxy = self.maxy maxz = self.maxz if self.max_all is not None: maxx = self.max_all maxy = self.max_all maxz = self.max_all paddings = {'x':(minx, maxx), 'y':(miny, maxy), 'z': (minz, maxz)} return paddings
[docs] def add_anat( self, src: str | pathlib.Path | nib.Nifti1Image, color='gist_gray', alpha=1, scale_panel=False, drop_zero=False, vmin=None, vmax=None ) -> Anatomy: """Add an anatomical image to be plotted. This is typically a continuously value image which will be represented with a continuous colormap. Parameters ---------- src : str | pathlib.Path | nib.Nifti1Image Source image. color : str, optional Colormap key from matplotlib. The default is 'gist_gray'. alpha : float, optional Color transparency. The default is 1.0. scale_panel : bool, optional When plotting, scale the colormap to be the extent of intensities observed in each individual panel, rather than across the whole 3D volume. Can be used to maximize the dynamic range within each panel, but makes comparisons across panels more difficult. The default is False. drop_zero : bool, optional Don't plot voxels which have a value of zero. The default is False. vmin : float | None, optional Bottom limit of color range. The default is None. vmax : float | None, optional Top limit of color range. The default is None. Returns ------- nifti_overlay.image.Anatomy Object for managing the image being plotted. """ img = Anatomy(src=src, color=color, alpha=alpha, scale_panel=scale_panel, drop_zero=drop_zero, vmin=vmin, vmax=vmax) self.images.append(img) return img
[docs] def add_checkerboard( self, src: str | pathlib.Path | nib.Nifti1Image, boxes: int = 10, color: str = 'gist_gray', alpha: float = 1.0, normalize: bool = True, histogram_matching: bool = True ) -> CheckerBoard: """Add a checkerboard to be plotted. Parameters ---------- src : Sequence[str, pathlib.Path, nib.Nifti1Image] Collection of images. boxes : int, optional The number of boxes across *in the shortest dimension of the slice being plotted*, by default 10 color : str, optional Matplotlib colormap to use, by default 'gist_gray' alpha : float, optional Color opacity, by default 1.0 normalize : bool, optional Normalize slices being plotted to be between 0 and 1, by default True histogram_matching : bool, optional Use `skimage.exposure.histogram_matching` to normalize the intensities of images being plotted, by default True. As one colormap is applied to all images, this can be good to make sure all images are visible. May not be necessary if the images are all coming in with the same dynamic range. Returns ------- CheckerBoard Object for managing the image being plotted. """ img = CheckerBoard(src=src, boxes=boxes, color=color, alpha=alpha, normalize=normalize, histogram_matching=histogram_matching) self.images.append(img) return img
[docs] def add_edges( self, src: str | pathlib.Path | nib.Nifti1Image, color: str='yellow', alpha: float=1.0, sigma: float=1.0, interpolation: str='none', ) -> Edges: """Add an edges image to be plotted. The contours/edges of each slice are automatically detected. Parameters ---------- src : str path, pathlib.Path, or nibabel.Nifti1Image Source image. color : str, RGB, RGBA, optional Color understood by matplotlib. The default is 'yellow'. alpha : float, optional Color transparency. The default is 1.0. sigma : float, optional Standard deviation of the Gaussian filter of the Canny edge detector. The default is 1.0. See `skimage.feature.canny`. interpolation : str, optional Type of interpolation for plotting images. The default is 'none'. See matplotlib for more details (`plt.imshow`). Returns ------- Edges Object for managing the image being plotted. """ img = Edges(src=src, color=color, alpha=alpha, sigma=sigma, interpolation=interpolation) self.images.append(img) return img
[docs] def add_mask( self, src: str | pathlib.Path | nib.Nifti1Image, color: str | None = None, alpha: float=1.0, mask_value: float=1.0, ) -> Mask: """Add a mask to be plotted; a binary image plotted with a single color. Parameters ---------- src : str path, pathlib.Path, or nibabel.Nifti1Image Source image. color : str, RGB, RGBA, None, optional Color understood by matplotlib. The default is None, in which case the default color cycle is used. alpha : float, optional Color transparency. The default is 1.0. mask_value : int or float, optional Value to be plotted. For segmentations with multiple labels, this can be used to specify the single label to be plotted. The default is 1. (Note: to plot all values for a multi-label segmentation, you can plot it as an Anatomy instead of a Mask.) Returns ------- Mask Object for managing the image being plotted. """ img = Mask(src=src, color=color, alpha=alpha, mask_value=mask_value) self.images.append(img) return img
def _check_images(self): """Plot initialization helper to check if any image have been added.""" if not self.images: raise RuntimeError('No images have been added for plotting. Use ' 'NiftiOverlay().add_anat() or NiftiOverlay.add_mask() ' 'to add images to be plotted') def _check_mismatched_dimensions(self): """Plot initialization helper to check if the images have mismatched dimensions.""" # no images added if not self.images: return None # all same shape shapes = [i.shape for i in self.images] shapeset = set(shapes) if len(shapeset) == 1: return shapes[0] # different shape else: raise RuntimeError(f"DIMENSION ERROR. Found different image dimensions for different images: {shapeset}") def _get_figure_dimensions(self): """Return the figure dimensions in inches (x-dimensions, y-dimensions). When ``figsize="automatic"``, will return dimensions which are proportional to the number of rows and columns of panels being plotted. By default, the ratio is 1 inch for every row and 1 inch for every column. This ratio can be altered with the ``automatic_figsize_scale`` attribute.""" if self.figsize == 'automatic': figx, figy = self.ncols, self.nrows figx *= self.automatic_figsize_scale figy *= self.automatic_figsize_scale else: figx, figy = self.figsize return figx, figy def _init_figure(self): """Initialize a figure for plotting. This will get the figure dimensions, and create a matplotlib Figure and Axes with approrpriate formatting. Axes are reshaped to have the same dimensions as the overlay panels.""" figx, figy = self._get_figure_dimensions() self.print() self.print("Initializing figure:") self.print(f" Shape: {self.nrows}, {self.ncols}") self.print(f" Size: {figx} in., {figy} in.") self.print(f" DPI: {self.dpi}") self.fig, self.axes = None, None self.fig, self.axes = plt.subplots(self.nrows, self.ncols, figsize=(figx, figy), dpi=self.dpi) self.fig.subplots_adjust(0,0,1,1,0,0) self.fig.patch.set_facecolor(self.background) # reshape axes if type(self.axes) != np.ndarray: self.axes = np.array(self.axes).reshape((1,1)) elif self.axes.ndim == 2: pass elif self.axes.ndim == 1: self.axes = self.axes.reshape([self.nrows, self.ncols]) def _main_plot_loop(self): """Start the plot.""" total = len(self.images) for index, image in enumerate(self.images): self._plot_image(image, index, total) def _plot_image(self, image, index, total): """Plot a single image. ``index`` and ``total`` are used just for printing purposes to show the plotting progress. All panels for one image are plotted with one call to this function.""" n = index if isinstance(image, MultiImage): path = f'< {len(image.images)} image(s) >' else: path = image.path self.print() self.print( "--------------------------------------------------") self.print(f"IMAGE {n+1} / {total}") self.print( "--------------------------------------------------") self.print() self.print(f"Image path: {path}") self.print(f"Shape: {image.shape}") self.print(f"Image type: {image.__class__.__name__}") total_panels = self.nrows * self.ncols mask_color = None if isinstance(image, Mask) and image.color is None: mask_color = next(self.color_cycle) for i, p in enumerate(self.planes): dimension = self.planes_to_idx[p] dimension_size = image.shape[dimension] min_window, max_window = self.paddings[p] min_slice = int(min_window*dimension_size) max_slice = int(max_window*dimension_size) num = self.nslices if num == 1: indices = [int((max_slice + min_slice) / 2)] else: indices = np.linspace(min_slice, max_slice, num) self.print() self.print(f"Plotting row [{i}]") self.print(f"Axis = '{p}'") self.print(f"Minimum & Maximum extent: {min_window}, {max_window}") self.print(f"Slices to plot along dimension (pre-rounding): {list(indices)}") for j, idx in enumerate(indices): percentage = round(((i * len(indices) + j) / total_panels) * 100, 2) self.print() self.print(f'Plotting panel [{i}, {j}] ({percentage}%)') ax_x = i if not self.transpose else j ax_y = j if not self.transpose else i ax = self.axes[ax_x, ax_y] position = int(idx) panel_args = {'dimension': dimension, 'position': position, 'ax': ax} self.print("Call:") for k, v in panel_args.items(): self.print(f" {k}: {v}") if mask_color is not None: panel_args['_override_color'] = mask_color image.plot_slice(**panel_args) ax.axis('off') ax.set_facecolor(self.background) self.print() self.print("Finished.") self.print() self.print("--------------------------------------------------") self.print() self.print("--------------------------------------------------")
[docs] def plot(self): """Run the plot with the provided parameters (set during initialization) of NiftiOverlay and the images added with the ``add_...()`` methods. """ self._check_images() self._check_mismatched_dimensions() self._init_figure() self._main_plot_loop()
[docs] def generate(self, savepath: str, separate: bool = False, rerun: bool = True): """Create the plot and save it to a file (or files). Parameters ---------- savepath : str Path to save output image file, with extension included. See matplotlib documentation for acceptable output format.s separate : bool, optional Save each panel as a separate image, by default False. In this case, a directory path should be provided instead of a single file path. Images will be saved as PNG. If the provided directory does not exist, it will attempt to be created. rerun : bool, optional Regenerate the figure before saving, by default True. In some cases, you may have already called the ``plot()`` method and made no changes to the visualization, in which case rerunning is not strictly necessary. But changes are not automatically detected - so rerunning is default. """ if self.fig is None or rerun: self.plot() if not separate: self.print(f"Saving output to {savepath}...") self.fig.savefig(savepath, facecolor=self.background) return if not os.path.isdir(savepath): os.mkdir(savepath) for r in range(self.nrows): for c in range(self.ncols): ax = self.axes[r, c] extent = ax.get_window_extent().transformed(self.fig.dpi_scale_trans.inverted()) impath = os.path.join(savepath, f"panel_{r}x{c}.png") self.print(f"Saving panel at {impath}...") self.fig.savefig(impath, bbox_inches=extent, facecolor=self.background)