Source code for PartSegImage.image

from __future__ import annotations

import re
import sys
import typing
import warnings
from contextlib import suppress
from copy import copy
from dataclasses import dataclass
from functools import wraps
from itertools import cycle, zip_longest

import numpy as np

from PartSegImage.channel_class import Channel

Spacing = typing.Tuple[typing.Union[float, int], ...]
_IMAGE_DATA = typing.Union[typing.List[np.ndarray], np.ndarray]

_DEF = object()
FRAME_THICKNESS = 2

DEFAULT_SCALE_FACTOR = 10**9

ch_par: dict[str, bool]

if sys.version_info[:2] > (3, 9):
    ch_par = {"kw_only": True, "slots": True}
else:
    ch_par = {}


[docs] @dataclass(**ch_par) class ChannelInfo: name: str color_map: str | np.ndarray | tuple | list | None = None contrast_limits: tuple[float, float] | None = None
[docs] @dataclass(**ch_par) class ChannelInfoFull: name: str color_map: str | np.ndarray contrast_limits: tuple[float, float] def __post_init__(self): if not isinstance(self.color_map, (str, np.ndarray)): self.color_map = np.array(self.color_map)
def minimal_dtype(val: int): """ Calculate minimal type to handle value in array :param val: :return: minimal dtype to handle given value :rtype: """ if val < 250: return np.uint8 return np.uint16 if val < 2**16 - 5 else np.uint32 def reduce_array( array: np.ndarray, components: typing.Collection[int] | None = None, max_val: int | None = None, dtype=None, ) -> np.ndarray: """ Relabel components from 1 to components_num with keeping order :param array: array to relabel, deed to be integer type :param components: components to be keep, if None then all will be keep :param max_val: number of maximum component in array, if absent then will be calculated (to reduce whole array processing) :param dtype: type of returned array if no then minimal type is calculated :return: relabeled array in minimum type """ # this function minimal dtype is np.uint8 so there is no need to do calculation. if components is None: components = np.unique(array.flat) if max_val is None: max_val = np.max(components) if max_val is None: max_val = np.max(array) translate = np.zeros(max_val + 1, dtype=dtype or minimal_dtype(len(components) + 1)) for i, val in enumerate(sorted(components), start=0 if 0 in components else 1): translate[val] = i return translate[array] def rename_argument(from_name: str, to_name: str, since_version: str): def decorator(fun): @wraps(fun) def _fun(*args, **kwargs): if from_name in kwargs: warnings.warn( f"Argument {from_name} is deprecated since {since_version}. Use {to_name} instead", DeprecationWarning, stacklevel=2, ) kwargs[to_name] = kwargs.pop(from_name) return fun(*args, **kwargs) return _fun return decorator def positional_to_named(fun): @wraps(fun) def _fun(*args, **kwargs): if len(args) > 2: warnings.warn( "Since PartSeg 0.15.4 all arguments, except first one, should be named", DeprecationWarning, stacklevel=2, ) for name, arg in zip( ( "spacing", "file_path", "mask", "default_coloring", "ranges", "axes_order", "shift", "name", "metadata_dict", ), args[2:], # start from 2 because first two arguments are self and data ): kwargs[name] = arg return fun(*args[:2], **kwargs) return _fun def merge_into_channel_info(fun): @wraps(fun) def _fun(*args, **kwargs): if "channel_info" in kwargs: fun(*args, **kwargs) return None channel_names = kwargs.pop("channel_names", []) default_coloring = kwargs.pop("default_coloring", []) ranges = kwargs.pop("ranges", []) if any([channel_names, default_coloring, ranges]): if isinstance(channel_names, str): channel_names = [channel_names] if channel_names is None: channel_names = [] if default_coloring is None: default_coloring = [] if ranges is None: ranges = [] warnings.warn( "Using channel_names, default_coloring and ranges is deprecated since PartSeg 0.15.4", DeprecationWarning, stacklevel=2, ) channel_info = [ ChannelInfo(name=name, color_map=color, contrast_limits=contrast_limits) for name, color, contrast_limits in zip_longest(channel_names, default_coloring, ranges) ] kwargs["channel_info"] = channel_info return fun(*args, **kwargs) return _fun
[docs] class Image: """ Base class for Images used in PartSeg :param data: 5-dim array with order: time, z, y, x, channel :param spacing: spacing for z, y, x :param file_path: path to image on disc :param mask: mask array in shape z,y,x :param channel_info: list of metadata stored per channel :param axes_order: allow to create Image object form data with different axes order, or missed axes :cvar str ~.axis_order: internal order of axes It is prepared for subclassing with changed internal order. Eg: >>> class ImageJImage(Image): >>> axis_order = "TZCYX" """ _image_spacing: Spacing axis_order = "CTZYX" array_axis_order: str def __new__(cls, *args, **kwargs): if hasattr(cls, "return_order"): # pragma: no cover warnings.warn("Using return_order is deprecated since PartSeg 0.11.0", DeprecationWarning, stacklevel=2) cls.axis_order = cls.return_order cls.array_axis_order = cls.axis_order.replace("C", "") return super().__new__(cls) @positional_to_named @rename_argument("image_spacing", "spacing", "0.15.4") @merge_into_channel_info def __init__( self, data: _IMAGE_DATA, *, spacing: Spacing, file_path=None, mask: None | np.ndarray = None, channel_info: list[ChannelInfo | ChannelInfoFull] | None = None, axes_order: str | None = None, shift: Spacing | None = None, name: str = "", metadata_dict: dict | None = None, ): # TODO add time distance to image spacing if axes_order is None: # pragma: no cover warnings.warn( f"axes_order should be provided, Currently it uses {self.__class__}.axis_order", category=DeprecationWarning, stacklevel=2, ) axes_order = self.axis_order self._check_data_dimensionality(data, axes_order) if not isinstance(spacing, tuple): spacing = tuple(spacing) self._channel_arrays = self._split_data_on_channels(data, axes_order) self._image_spacing = (1.0,) * (3 - len(spacing)) + spacing self._image_spacing = tuple(el if el > 0 else 10**-6 for el in self._image_spacing) self._shift = tuple(shift) if shift is not None else (0,) * len(self._image_spacing) self.name = name self.file_path = file_path self._mask_array = self._fit_mask(mask, data, axes_order) self._channel_info = self._adjust_channel_info(channel_info, self._channel_arrays) self.metadata = dict(metadata_dict) if metadata_dict is not None else {} @staticmethod def _adjust_channel_info( channel_info: list[ChannelInfo | ChannelInfoFull] | None, channel_array: list[np.ndarray], default_colors=("red", "blue", "green", "yellow", "magenta", "cyan"), ) -> list[ChannelInfoFull]: default_colors = cycle(default_colors) if channel_info is None: ranges = [(np.min(x), np.max(x)) for x in channel_array] return [ ChannelInfoFull(name=f"channel {i}", color_map=x[0], contrast_limits=x[1]) for i, x in enumerate(zip(default_colors, ranges), start=1) ] channel_info = channel_info[: len(channel_array)] res = [ ChannelInfoFull( name=ch_inf.name or f"channel {i+1}", color_map=( ch_inf.color_map if ch_inf.color_map is not None else next(default_colors) # skipcq: PTC-W0063 ), contrast_limits=( ch_inf.contrast_limits if ch_inf.contrast_limits is not None else (np.min(channel_array[i]), np.max(channel_array[i])) ), ) for i, ch_inf in enumerate(channel_info) ] res.extend( ChannelInfoFull( name=f"channel {i+1}", color_map=next(default_colors), # skipcq: PTC-W0063 contrast_limits=(np.min(arr), np.max(arr)), ) for i, arr in enumerate(channel_array[len(res) :], start=len(channel_info)) ) return res @staticmethod def _check_data_dimensionality(data, axes_order): if (isinstance(data, list) and any(x.ndim + 1 != len(axes_order) for x in data)) or ( not isinstance(data, list) and data.ndim != len(axes_order) ): if isinstance(data, list): ndim = ", ".join([f"{x.ndim} + 1" for x in data]) else: ndim = str(data.ndim) raise ValueError( "Data should have same number of dimensions " f"like length of axes_order (axis :{len(axes_order)}, ndim: {ndim}" ) def _fit_mask(self, mask, data, axes_order): mask_array = self._prepare_mask(mask, data, axes_order) if mask_array is not None: mask_array = self.fit_mask_to_image(mask_array) return mask_array @classmethod def _prepare_mask(cls, mask, data, axes_order) -> np.ndarray | None: if mask is None: return None if isinstance(data, list): data_shape = list(data[0].shape) else: data_shape = list(data.shape) with suppress(ValueError): data_shape.pop(axes_order.index("C")) mask = cls._fit_array_to_image(data_shape, mask) return cls.reorder_axes(mask, axes_order.replace("C", "")) @classmethod def _split_data_on_channels(cls, data: np.ndarray | list[np.ndarray], axes_order: str) -> list[np.ndarray]: if isinstance(data, list) and not axes_order.startswith("C"): # pragma: no cover raise ValueError("When passing data as list of numpy arrays then Channel must be first axis.") if "C" not in axes_order: if not isinstance(data, np.ndarray): # pragma: no cover raise TypeError("If `axes_order` does not contain `C` then data must be numpy array.") return [cls.reorder_axes(data, axes_order)] if axes_order.startswith("C"): if isinstance(data, list): dtype = np.result_type(*data) return [cls.reorder_axes(x, axes_order[1:]).astype(dtype) for x in data] return [cls.reorder_axes(x, axes_order[1:]) for x in data] if not isinstance(data, np.ndarray): raise TypeError("If `data` is list of arrays then `axes_order` must start with `C`") # pragma: no cover pos: list[slice | int] = [slice(None) for _ in range(data.ndim)] c_pos = axes_order.index("C") res = [] for i in range(data.shape[c_pos]): pos[c_pos] = i res.append(cls.reorder_axes(data[tuple(pos)], axes_order.replace("C", ""))) return res @staticmethod def _merge_channel_names(base_channel_names: list[str], new_channel_names: list[str]) -> list[str]: base_channel_names = base_channel_names[:] reg = re.compile(r"channel \d+") for name in new_channel_names: match = reg.match(name) new_name = name base_name = name if match and base_name in base_channel_names: new_name = f"channel {len(base_channel_names) + 1}" i = 1 while new_name in base_channel_names: new_name = f"{base_name} ({i})" i += 1 if i > 10000: # pragma: no cover raise ValueError("fail when try to fix channel name") base_channel_names.append(new_name) return base_channel_names @property def channel_info(self) -> list[ChannelInfoFull]: return [copy(x) for x in self._channel_info] @property def ranges(self) -> list[tuple[float, float]]: return [x.contrast_limits for x in self._channel_info] @property def default_coloring(self) -> list[str | np.ndarray]: return [x.color_map for x in self._channel_info]
[docs] def merge(self, image: Image, axis: str) -> Image: """ Produce new image merging image data along given axis. All metadata are obtained from self. :param Image image: Image to be merged :param str axis: :return: New image produced from merge :rtype: Image """ if axis == "C": data = self._image_data_normalize( self._channel_arrays + [self.reorder_axes(x, image.array_axis_order) for x in image._channel_arrays] ) channel_names = self._merge_channel_names(self.channel_names, image.channel_names) color_map = self.default_coloring + image.default_coloring else: index = self.array_axis_order.index(axis) data = self._image_data_normalize( [ np.concatenate((y, self.reorder_axes(y, image.array_axis_order)), axis=index) for x, y in zip(self._channel_arrays, image._channel_arrays) ] ) channel_names = self.channel_names color_map = self.default_coloring return self.substitute( data=data, ranges=self.ranges + image.ranges, channel_names=channel_names, default_coloring=color_map )
@property def channel_names(self) -> list[str]: return [x.name for x in self._channel_info] @property def channel_pos(self) -> int: # pragma: no cover """Channel axis. Need to have 'C' in :py:attr:`axis_order`""" warnings.warn( "channel_pos is deprecated and code its using may not work properly", category=FutureWarning, stacklevel=2 ) return self.axis_order.index("C") @property def x_pos(self): return self.array_axis_order.index("X") @property def y_pos(self): return self.array_axis_order.index("Y") @property def time_pos(self): """Time axis. Need to have 'T' in :py:attr:`axis_order`""" return self.array_axis_order.index("T") @property def stack_pos(self) -> int: """Stack axis. Need to have 'Z' in :py:attr:`axis_order`""" return self.array_axis_order.index("Z") @property def dtype(self) -> np.dtype: """dtype of image array""" return self._channel_arrays[0].dtype @staticmethod def _reorder_axes(array: np.ndarray, input_axes: str, return_axes) -> np.ndarray: if array.ndim != len(input_axes): raise ValueError(f"array.ndim ({array.ndim}) need to be equal to length of axes ('{input_axes}')") if input_axes == return_axes: return array mapping_dict = {v: i for i, v in enumerate(return_axes)} if array.ndim < len(return_axes): array = array.reshape(array.shape + (1,) * (len(return_axes) - array.ndim)) new_positions = [mapping_dict[x] for x in input_axes if x in mapping_dict] axes_to_map = [i for i, x in enumerate(input_axes) if x in mapping_dict] return np.moveaxis(array, axes_to_map, new_positions)
[docs] @classmethod def reorder_axes(cls, array: np.ndarray, axes: str) -> np.ndarray: """ reorder axes to internal storage format :param np.ndarray array: array to have changed order of axes :param str axes: axes order :return: array with correct order of axes """ return cls._reorder_axes(array, axes, cls.array_axis_order)
[docs] def get_dimension_number(self) -> int: """return number of nontrivial dimensions""" return np.squeeze(self._channel_arrays[0]).ndim
[docs] def get_dimension_letters(self) -> str: """ :return: letters which indicates non trivial dimensions """ return "".join(key for val, key in zip(self._channel_arrays[0].shape, self.array_axis_order) if val > 1)
[docs] def substitute( self, data=None, image_spacing=None, file_path=None, mask=_DEF, default_coloring=None, ranges=None, channel_names=None, ) -> Image: """Create copy of image with substitution of not None elements""" data = self._channel_arrays if data is None else data image_spacing = self._image_spacing if image_spacing is None else image_spacing file_path = self.file_path if file_path is None else file_path mask = self._mask_array if mask is _DEF else mask default_coloring = self.default_coloring if default_coloring is None else default_coloring ranges = self.ranges if ranges is None else ranges channel_names = self.channel_names if channel_names is None else channel_names channel_info = [ ChannelInfo(name=name, color_map=color, contrast_limits=contrast_limits) for name, color, contrast_limits in zip_longest(channel_names, default_coloring, ranges) ] return self.__class__( data=data, spacing=image_spacing, file_path=file_path, mask=mask, axes_order=self.axis_order, channel_info=channel_info, metadata_dict=self.metadata, )
[docs] def set_mask(self, mask: np.ndarray | None, axes: str | None = None): """ Set mask for image, check if it has proper shape. :param mask: mask in same shape like image. May not contains 1 dim axes. :param axes: order of axes in mask, use if different than :py:attr:`return_order` :raise ValueError: on wrong shape """ if mask is None: self._mask_array = None elif axes is not None: self._mask_array = self.fit_mask_to_image(self.reorder_axes(mask, axes)) else: self._mask_array = self.fit_mask_to_image(mask)
def get_data(self) -> np.ndarray: if "C" in self.axis_order: return np.stack(self._channel_arrays, axis=self.axis_order.index("C")) return self._channel_arrays[0] @property def mask(self) -> np.ndarray | None: return self._mask_array[:] if self._mask_array is not None else None @staticmethod def _fit_array_to_image(base_shape, array: np.ndarray) -> np.ndarray: """change shape of array with inserting single dimensional entries""" shape = list(array.shape) for i, el in enumerate(base_shape): if el == 1 and el != shape[i]: shape.insert(i, 1) elif el != shape[i]: raise ValueError(f"Wrong array shape {shape} for {base_shape}") if len(shape) != len(base_shape): raise ValueError(f"Wrong array shape {shape} for {base_shape}") return np.reshape(array, shape)
[docs] def fit_array_to_image(self, array: np.ndarray) -> np.ndarray: """ Change shape of array with inserting single dimensional entries :param np.ndarray array: array to be fitted :return: reshaped array with added missing 1 in shape :raises ValueError: if cannot fit array """ return self._fit_array_to_image(self._channel_arrays[0].shape, array)
# noinspection DuplicatedCode
[docs] def fit_mask_to_image(self, array: np.ndarray) -> np.ndarray: """ call :py:meth:`fit_array_to_image` and then relabel and change type to minimal which fit all information """ array = self.fit_array_to_image(array) if np.max(array) == 1: return array.astype(np.uint8) unique = np.unique(array) if unique.size == 1: if unique[0] != 0: return np.ones(array.shape, dtype=np.uint8) return array.astype(np.uint8) max_val = np.max(unique) return reduce_array(array, unique, max_val)
[docs] def get_image_for_save(self) -> np.ndarray: """ :return: numpy array in imagej tiff order axes """ if "C" in self.axis_order: return self._reorder_axes( np.stack(self._channel_arrays, axis=self.axis_order.index("C")), self.axis_order, "TZCYX" ) return self._reorder_axes(self._channel_arrays[0], self.axis_order, "TZCYX")
[docs] def get_mask_for_save(self) -> np.ndarray | None: """ :return: if image has mask then return mask with axes in proper order """ if self._mask_array is None: return None return self._reorder_axes(self._mask_array, "".join(self.array_axis_order), "TZCYX")
@property def has_mask(self) -> bool: """check if image is masked""" return self._mask_array is not None @property def is_time(self) -> bool: """check if image contains time data""" return self.times > 1 @property def is_stack(self) -> bool: """check if image contain 3d data""" return self.layers > 1 @property def channels(self) -> int: """number of image channels""" return len(self._channel_arrays) @property def layers(self) -> int: """z-dim of image""" return self._channel_arrays[0].shape[self.stack_pos] @property def times(self) -> int: """number of time frames""" return self._channel_arrays[0].shape[self.time_pos] @property def plane_shape(self) -> tuple[int, int]: """y,x size of image""" return self._channel_arrays[0].shape[self.y_pos], self._channel_arrays[0].shape[self.x_pos] @property def shape(self): """Whole image shape. order of axes my change. Current order is in :py:attr:`return_order`""" return self._channel_arrays[0].shape
[docs] def swap_time_and_stack(self): """ Swap time and stack axes. For example my be used to convert time image in 3d image. """ image_array_list = [np.swapaxes(x, self.time_pos, self.stack_pos) for x in self._channel_arrays] return self.substitute(data=self._image_data_normalize(image_array_list))
[docs] @classmethod def get_axis_positions(cls) -> dict[str, int]: """ :return: dict with mapping axis to its position :rtype: dict """ return {letter: i for i, letter in enumerate(cls.axis_order)}
[docs] @classmethod def get_array_axis_positions(cls) -> dict[str, int]: """ :return: dict with mapping axis to its position for array fitted to image :rtype: dict """ return {letter: i for i, letter in enumerate(cls.array_axis_order)}
[docs] def get_data_by_axis(self, **kwargs) -> np.ndarray: """ Get part of data extracted by sub axis. Axis is selected by single letter from :py:attr:`axis_order` :param kwargs: axis list with :return: :rtype: """ slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))] axis_pos = self.get_array_axis_positions() if "c" in kwargs: kwargs["C"] = kwargs.pop("c") if "C" in kwargs: if isinstance(kwargs["C"], Channel): kwargs["C"] = kwargs["C"].value if isinstance(kwargs["C"], str): kwargs["C"] = self.channel_names.index(kwargs["C"]) channel = kwargs.pop("C", slice(None) if "C" in self.axis_order else 0) if isinstance(channel, Channel): channel = channel.value axis_order = self.axis_order for name, value in kwargs.items(): if name.upper() in axis_pos: slices[axis_pos[name.upper()]] = value if isinstance(value, int): axis_order = axis_order.replace(name.upper(), "") slices_t = tuple(slices) if isinstance(channel, int): return self._channel_arrays[channel][slices_t] return np.stack([x[slices_t] for x in self._channel_arrays[channel]], axis=axis_order.index("C"))
[docs] def clip_array(self, array: np.ndarray, **kwargs: int | slice) -> np.ndarray: """ Clip array by axis. Axis is selected by single letter from :py:attr:`axis_order` :param array: array to clip :param kwargs: mapping from axis to position or slice on this axis :return: clipped array """ array = self.fit_array_to_image(array) slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))] axis_pos = self.get_array_axis_positions() for name in kwargs: if (n := name.upper()) in axis_pos: slices[axis_pos[n]] = kwargs[name] return array[tuple(slices)]
[docs] def get_channel(self, num: int | str | Channel) -> np.ndarray: """ Alias for :py:func:`get_sub_data` with argument ``c=num`` :param int | str | Channel num: channel num or name to be extracted :return: given channel array :rtype: numpy.ndarray """ return self.get_data_by_axis(c=num)
def has_channel(self, num: int | str | Channel) -> bool: if isinstance(num, Channel): num = num.value if isinstance(num, str): return num in self.channel_names return 0 <= num < self.channels
[docs] def get_layer(self, time: int, stack: int) -> np.ndarray: """ return single layer contains data for all channel :param time: time coordinate. For images with not time use 0. :param stack: "z coordinate. For time data use 0. :return: """ warnings.warn( "Image.get_layer is deprecated. Use get_data_by_axis instead", category=DeprecationWarning, stacklevel=2 ) return self.get_data_by_axis(T=time, Z=stack)
@property def is_2d(self) -> bool: """ Check if image z and time dimension are equal to 1. Equivalent to: `image.layers == 1 and image.times == 1` """ return self.layers == 1 and self.times == 1 @property def spacing(self) -> Spacing: """image spacing""" return tuple(self._image_spacing[1:]) if self.is_2d else self._image_spacing def normalized_scaling(self, factor=DEFAULT_SCALE_FACTOR) -> Spacing: if self.is_2d: return (1, 1, *tuple(np.multiply(self.spacing, factor))) return (1, *tuple(np.multiply(self.spacing, factor))) @property def shift(self): return self._shift[1:] if self.is_2d else self._shift @property def voxel_size(self) -> Spacing: """alias for spacing""" return self.spacing
[docs] def set_spacing(self, value: Spacing): """set image spacing""" if 0 in value: return if self.is_2d and len(value) + 1 == len(self._image_spacing): value = (1.0, *tuple(value)) if len(value) != len(self._image_spacing): # pragma: no cover raise ValueError("Correction of spacing fail.") self._image_spacing = tuple(value)
@staticmethod def _frame_array(array: np.ndarray | None, index_to_add: list[int], frame=FRAME_THICKNESS): if array is None: # pragma: no cover return array result_shape = list(array.shape) image_pos = [slice(None) for _ in range(array.ndim)] for index in index_to_add: result_shape[index] += frame * 2 image_pos[index] = slice(frame, result_shape[index] - frame) data = np.zeros(shape=result_shape, dtype=array.dtype) data[tuple(image_pos)] = array return data
[docs] @staticmethod def calc_index_to_frame(array_axis: str, important_axis: str) -> list[int]: """ calculate in which axis frame should be added :param str array_axis: list of image axis :param str important_axis: list of framed axis :return: list of indices to add frame. """ return [array_axis.index(letter) for letter in important_axis]
def _frame_cut_area(self, cut_area: typing.Iterable[slice], frame: int): cut_area = list(cut_area) important_axis = "XY" if self.is_2d else "XYZ" for ind in self.calc_index_to_frame(self.array_axis_order, important_axis): sl = cut_area[ind] cut_area[ind] = slice( max(sl.start - frame, 0) if sl.start is not None else None, sl.stop + frame if sl.stop is not None else None, sl.step, ) return cut_area def _cut_image_slices( self, cut_area: typing.Iterable[slice], frame: int ) -> tuple[list[np.ndarray], np.ndarray | None]: new_mask = None cut_area = self._frame_cut_area(cut_area, frame) new_image = [x[tuple(cut_area)] for x in self._channel_arrays] if self._mask_array is not None: new_mask = self._mask_array[tuple(cut_area)] return new_image, new_mask def _roi_to_slices(self, roi: np.ndarray) -> list[slice]: cut_area = self.fit_array_to_image(roi) points = np.nonzero(cut_area) lower_bound = np.min(points, axis=1) upper_bound = np.max(points, axis=1) return [slice(x, y + 1) for x, y in zip(lower_bound, upper_bound)] def _cut_with_roi(self, cut_area: np.ndarray, replace_mask: bool, frame: int): new_mask = None cut_area = self.fit_array_to_image(cut_area) new_cut = tuple(self._roi_to_slices(cut_area)) catted_cut_area = cut_area[new_cut] new_image = [np.copy(x[new_cut]) for x in self._channel_arrays] for el in new_image: el[catted_cut_area == 0] = 0 if replace_mask: new_mask = catted_cut_area elif self._mask_array is not None: new_mask = self._mask_array[new_cut] new_mask[catted_cut_area == 0] = 0 important_axis = "XY" if self.is_2d else "XYZ" new_image = [ self._frame_array(x, self.calc_index_to_frame(self.array_axis_order, important_axis), frame) for x in new_image ] new_mask = self._frame_array(new_mask, self.calc_index_to_frame(self.array_axis_order, important_axis), frame) return new_image, new_mask
[docs] def cut_image( self, cut_area: np.ndarray | typing.Iterable[slice], replace_mask=False, frame: int = FRAME_THICKNESS, zero_out_cut_area: bool = True, ) -> Image: """ Create new image base on mask or list of slices :param bool replace_mask: if cut area is represented by mask array, then in result image the mask is set base on cut_area if cur_area is np.ndarray :param typing.Union[np.ndarray, typing.Iterable[slice]] cut_area: area to cut. Defined with slices or mask :param int frame: additional frame around cut_area :param bool zero_out_cut_area: :return: Image """ if isinstance(cut_area, np.ndarray): if zero_out_cut_area: new_image, new_mask = self._cut_with_roi(cut_area, replace_mask, frame) else: new_cut = self._roi_to_slices(cut_area) new_image, new_mask = self._cut_image_slices(new_cut, frame) if replace_mask: new_mask = cut_area[tuple(self._frame_cut_area(new_cut, frame))] else: new_image, new_mask = self._cut_image_slices(cut_area, frame) return self.__class__( data=self._image_data_normalize(new_image), spacing=self._image_spacing, file_path=None, mask=new_mask, channel_info=self._channel_info, axes_order=self.axis_order, )
def get_imagej_colors(self): res = [] for color in self.default_coloring: if isinstance(color, str): if color.startswith("#"): color_array = _hex_to_rgb(color) else: color_array = _name_to_rgb(color) res.append(np.array([np.linspace(0, x, num=256) for x in color_array]).astype(np.uint8)) elif color.ndim == 1: res.append(np.array([np.linspace(0, x, num=256) for x in color]).astype(np.uint8)) else: if color.shape[1] != 256: res.append( np.array( [ np.interp(np.linspace(0, 255, num=256), np.linspace(0, color.shape[1], num=256), x) for x in color ] ) ) res.append(color) return res def get_colors(self) -> list[str | list[int]]: res: list[str | list[int]] = [] for color in self.default_coloring: if isinstance(color, str): res.append(color) elif color.ndim == 2: res.append(list(color[:, -1])) else: res.append(list(color)) return res
[docs] def get_um_spacing(self) -> Spacing: """image spacing in micrometers""" return tuple(float(x * 10**6) for x in self.spacing)
[docs] def get_um_shift(self) -> Spacing: """image spacing in micrometers""" return tuple(float(x * 10**6) for x in self.shift)
[docs] def get_ranges(self) -> list[tuple[float, float]]: """image brightness ranges for each channel""" return self.ranges[:]
def __str__(self): return ( f"{self.__class__} Shape {self._channel_arrays[0].shape}, dtype: {self._channel_arrays[0].dtype}, " f"labels: {self.channel_names}, coloring: {self.get_colors()} mask: {self.has_mask}" ) def __repr__(self): mask_info = f"mask=True, mask_dtype={self._mask_array.dtype}" if self.mask is not None else "mask=False" return ( f"Image(shape={self._channel_arrays[0].shape} dtype={self._channel_arrays[0].dtype}, spacing={self.spacing}" f", labels={self.channel_names}, channels={self.channels}, axes={self.axis_order!r}, {mask_info})" ) @classmethod def _image_data_normalize(cls, data: _IMAGE_DATA) -> _IMAGE_DATA: if isinstance(data, np.ndarray): return data if cls.axis_order.startswith("C"): shape = data[0].shape if any(x.shape != shape for x in data): raise ValueError(f"Shape of arrays are different {[x.shape for x in data]}") return data if "C" not in cls.axis_order: return data[0] return np.stack(data, axis=cls.axis_order.index("C"))
def _hex_to_rgb(hex_code: str) -> tuple[int, int, int]: """ Convert a hex color code to an RGB tuple. :param str hex_code: The hex color code, either short form (#RGB) or long form (#RRGGBB) :return: A tuple containing the RGB values (R, G, B) """ hex_code = hex_code.lstrip("#") if len(hex_code) in {3, 4}: hex_code = "".join([c * 2 for c in hex_code]) elif len(hex_code) not in {6, 8}: raise ValueError(f"Invalid hex code format: {hex_code}") return int(hex_code[:2], 16), int(hex_code[2:4], 16), int(hex_code[4:6], 16) def _name_to_rgb(name: str) -> tuple[int, int, int]: """ Convert a color name to an RGB tuple. :param str name: The color name :return: A tuple containing the RGB values (R, G, B) """ name = name.lower() if name not in _NAMED_COLORS: raise ValueError(f"Unknown color name: {name}") return _hex_to_rgb(_NAMED_COLORS[name]) try: from vispy.color import get_color_dict except ImportError: # pragma: no cover _NAMED_COLORS = { "red": "#FF0000", "green": "#008000", "blue": "#0000FF", "yellow": "#FFFF00", "cyan": "#00FFFF", "magenta": "#FF00FF", "white": "#FFFFFF", "black": "#000000", "orange": "#FFA500", } else: _NAMED_COLORS = get_color_dict()