Source code for PartSegCore.segmentation.threshold

import typing
import warnings
from abc import ABC

import numpy as np
import SimpleITK as sitk
from local_migrator import register_class, rename_key, update_argument
from pydantic import Field

from PartSegCore.algorithm_describe_base import AlgorithmDescribeBase, AlgorithmSelection
from PartSegCore.segmentation.algorithm_base import SegmentationLimitException
from PartSegCore.segmentation.utils import close_small_holes
from PartSegCore.utils import BaseModel


[docs] class SingleThresholdParams(BaseModel): threshold: float = Field(8000.0, ge=-100000, le=100000, title="Threshold", description="Threshold values")
[docs] @register_class(version="0.0.1", migrations=[("0.0.1", rename_key("masked", "apply_mask"))]) class SimpleITKThresholdParams128(BaseModel): apply_mask: bool = Field(True, description="If apply mask before calculate threshold") bins: int = Field(128, title="Histogram bins", ge=8, le=2**16)
[docs] @register_class(version="0.0.1", migrations=[("0.0.1", rename_key("masked", "apply_mask"))]) class SimpleITKThresholdParams256(BaseModel): apply_mask: bool = Field(True, description="If apply mask before calculate threshold") bins: int = Field(128, title="Histogram bins", ge=8, le=2**16)
[docs] class MultipleOtsuThresholdParams(BaseModel): components: int = Field(2, title="Number of Components", ge=2, lt=100) border_component: int = Field(1, title="Border Component", ge=1, lt=100) valley: bool = Field(True, title="Valley emphasis") bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16)
[docs] class BaseThreshold(AlgorithmDescribeBase, ABC): @classmethod def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: BaseModel, operator: typing.Callable[[object, object], bool], ): raise NotImplementedError
[docs] class ManualThreshold(BaseThreshold): __argument_class__ = SingleThresholdParams
[docs] @classmethod def get_name(cls): return "Manual"
@classmethod @update_argument("arguments") def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: SingleThresholdParams, operator ): result = np.array(operator(data, arguments.threshold)).astype(np.uint8) if mask is not None: result[mask == 0] = 0 return result, arguments.threshold
[docs] class SitkThreshold(BaseThreshold, ABC): __argument_class__ = SimpleITKThresholdParams128 @classmethod @update_argument("arguments") def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: SimpleITKThresholdParams128, operator ): if mask is not None and mask.dtype != np.uint8 and arguments.apply_mask: mask = (mask > 0).astype(np.uint8) ob, bg, th_op = (0, 1, np.min) if operator(1, 0) else (1, 0, np.max) image_sitk = sitk.GetImageFromArray(data) if arguments.apply_mask and mask is not None: mask_sitk = sitk.GetImageFromArray(mask) calculated = cls.calculate_threshold(image_sitk, mask_sitk, ob, bg, arguments.bins, True, 1) else: calculated = cls.calculate_threshold(image_sitk, ob, bg, arguments.bins) result = sitk.GetArrayFromImage(calculated) if mask is not None: result[mask == 0] = 0 threshold = th_op(data[result > 0]) if np.any(result) else th_op(-data) return result, threshold @staticmethod def calculate_threshold(*args, **kwargs): raise NotImplementedError
[docs] class OtsuThreshold(SitkThreshold):
[docs] @classmethod def get_name(cls): return "Otsu"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.OtsuThreshold(*args)
[docs] class LiThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Li"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.LiThreshold(*args)
[docs] class MaximumEntropyThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Maximum Entropy"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.MaximumEntropyThreshold(*args)
[docs] class RenyiEntropyThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Renyi Entropy"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.RenyiEntropyThreshold(*args)
[docs] class ShanbhagThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Shanbhag"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.ShanbhagThreshold(*args)
[docs] class TriangleThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Triangle"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.TriangleThreshold(*args)
[docs] class YenThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Yen"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.YenThreshold(*args)
[docs] class HuangThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Huang"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.HuangThreshold(*args)
[docs] class IntermodesThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Intermodes"
@staticmethod def calculate_threshold(*args, **kwargs): try: return sitk.IntermodesThreshold(*args) except RuntimeError as e: if "Exceeded maximum iterations for histogram smoothing" in e.args[0]: raise SegmentationLimitException(*e.args) from e raise
[docs] class IsoDataThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Iso Data"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.IsoDataThreshold(*args)
[docs] class KittlerIllingworthThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Kittler Illingworth"
@staticmethod def calculate_threshold(*args, **kwargs): try: return sitk.KittlerIllingworthThreshold(*args) except RuntimeError as e: if "sigma2 <= 0" in e.args[0]: raise SegmentationLimitException(*e.args) from e raise
[docs] class MomentsThreshold(SitkThreshold): __argument_class__ = SimpleITKThresholdParams256
[docs] @classmethod def get_name(cls): return "Moments"
@staticmethod def calculate_threshold(*args, **kwargs): return sitk.MomentsThreshold(*args)
[docs] class MultipleOtsuThreshold(BaseThreshold): __argument_class__ = MultipleOtsuThresholdParams @classmethod def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: MultipleOtsuThresholdParams, operator: typing.Callable[[object, object], bool], ): cleaned_image_sitk = sitk.GetImageFromArray(data) res = sitk.OtsuMultipleThresholds(cleaned_image_sitk, arguments.components, 0, arguments.bins, arguments.valley) res = sitk.GetArrayFromImage(res) if operator(1, 0): res = (res >= arguments.border_component).astype(np.uint8) threshold = np.min(data[res > 0]) if np.any(res) else np.max(data) else: res = (res < arguments.border_component).astype(np.uint8) threshold = np.max(data[res > 0]) if np.any(res) else np.min(data) if mask is not None: res[mask == 0] = 0 return res, threshold
[docs] @classmethod def get_name(cls) -> str: return "Multiple Otsu"
[docs] class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold): pass
ThresholdSelection.register(ManualThreshold) ThresholdSelection.register(OtsuThreshold) ThresholdSelection.register(LiThreshold) ThresholdSelection.register(RenyiEntropyThreshold) ThresholdSelection.register(ShanbhagThreshold) ThresholdSelection.register(TriangleThreshold) ThresholdSelection.register(YenThreshold) ThresholdSelection.register(HuangThreshold) ThresholdSelection.register(IntermodesThreshold) ThresholdSelection.register(IsoDataThreshold) ThresholdSelection.register(KittlerIllingworthThreshold) ThresholdSelection.register(MomentsThreshold) ThresholdSelection.register(MaximumEntropyThreshold) ThresholdSelection.register(MultipleOtsuThreshold)
[docs] class DoubleThresholdParams(BaseModel): core_threshold: ThresholdSelection = ThresholdSelection.get_default() base_threshold: ThresholdSelection = ThresholdSelection.get_default()
[docs] class DoubleThreshold(BaseThreshold): __argument_class__ = DoubleThresholdParams
[docs] @classmethod def get_name(cls): return "Base/Core"
@classmethod @update_argument("arguments") def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: DoubleThresholdParams, operator ): thr: BaseThreshold = ThresholdSelection[arguments.core_threshold.name] mask1, thr_val1 = thr.calculate_mask(data, mask, arguments.core_threshold.values, operator) thr: BaseThreshold = ThresholdSelection[arguments.base_threshold.name] mask2, thr_val2 = thr.calculate_mask(data, mask, arguments.base_threshold.values, operator) mask2[mask2 > 0] = 1 mask2[mask1 > 0] = 2 return mask2, (thr_val1, thr_val2)
[docs] class RangeThresholdParams(DoubleThresholdParams): core_threshold: ThresholdSelection = Field(default_factory=ThresholdSelection.get_default, title="Upper threshold") base_threshold: ThresholdSelection = Field(default_factory=ThresholdSelection.get_default, title="Lower threshold")
[docs] class RangeThreshold(DoubleThreshold): __argument_class__ = RangeThresholdParams
[docs] @classmethod def get_name(cls): return "Range"
[docs] @register_class(version="0.0.1", migrations=[("0.0.1", rename_key("hist_num", "bins"))]) class DoubleOtsuParams(BaseModel): valley: bool = Field(True, title="Valley emphasis") bins: int = Field(128, title="Histogram bins", ge=8, le=2**16)
[docs] class DoubleOtsu(BaseThreshold): __argument_class__ = DoubleOtsuParams
[docs] @classmethod def get_name(cls): return "Double Otsu"
@classmethod @update_argument("arguments") def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: DoubleOtsuParams, operator: typing.Callable[[object, object], bool], ): cleaned_image_sitk = sitk.GetImageFromArray(data) res = sitk.OtsuMultipleThresholds(cleaned_image_sitk, 2, 0, arguments.bins, arguments.valley) res = sitk.GetArrayFromImage(res) if not np.any(res.flat): return res, (0, 0) thr1 = data[res == 2].min() thr2 = data[res == 1].min() return res, (thr1, thr2)
[docs] class MultipleOtsuDoubleThresholdParams(BaseModel): components: int = Field(2, title="Number of Components", ge=2, lt=100) lower_component: int = Field(1, title="Lower Component", ge=1, lt=100) upper_component: int = Field(1, title="Upper Component", ge=1, lt=100) valley: bool = Field(True, title="Valley emphasis") bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16)
[docs] class MultipleOtsu(BaseThreshold): __argument_class__ = MultipleOtsuDoubleThresholdParams
[docs] @classmethod def get_name(cls): return "Multiple Otsu"
@classmethod def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: MultipleOtsuDoubleThresholdParams, operator: typing.Callable[[object, object], bool], ): cleaned_image_sitk = sitk.GetImageFromArray(data) res = sitk.OtsuMultipleThresholds(cleaned_image_sitk, arguments.components, 0, arguments.bins, arguments.valley) res = sitk.GetArrayFromImage(res) if not np.any(res.flat): return res, (0, 0) map_component = np.zeros(arguments.components + 1, dtype=np.uint8) map_component[: arguments.lower_component] = 0 map_component[arguments.lower_component : arguments.upper_component] = 1 map_component[arguments.upper_component :] = 2 res2 = map_component[res] thr1 = data[res2 == 2].min() if np.any(res2 == 2) else data[res2 == 1].max() thr2 = data[res2 == 1].min() if np.any(res2 == 1) else data.max() return res2, (thr1, thr2)
[docs] @register_class(version="0.0.1", migrations=[("0.0.1", rename_key("minimum_radius", "minimum_border_distance"))]) class MaximumDistanceWatershedParams(BaseModel): threshold: ThresholdSelection = ThresholdSelection.get_default() dilate_radius: int = Field(5, title="Dilate Radius", ge=1, le=100, description="To merge small objects") minimum_size: int = Field(100, title="Minimum Size", ge=1, le=1000000, description="To remove small objects") minimum_border_distance: int = Field( 10, title="Border Radius", ge=0.0, le=100.0, description="Minimum distance of local maxima from the border. To avoid artifacts", )
[docs] class MaximumDistanceCore(BaseThreshold): """ This Is algorithm intended to bue used in "* threshold with watershed" algorithms. It generates array with three values: * 0 - background, * 1 - area to watershed, * 2 - core objects to start watershed from This algorithm is developed to make possible split of almost convex objects that are touching each other. Core objects are identified as local maxima of distance from the border. To perform this task the following steps are performed: 1. Thresholding - to detect whole area of objects. This is controlled by ``threshold`` parameter. 2. Remove small objects - to remove small objects. This is controlled by ``minimum_size` parameter. 3. Small objects close - to merge small objects. As distance transform is used, it is required small holes. This steep closes holes smaller tan 10px. 4. Distance transform - to find distance from the border 5. Identify local maxima - to find core objects 6. Remove local maxima that are too close to the border - to avoid artifacts. This distance is controlled by ``minimum_border_distance`` parameter. 7. Dilate core objects - to make them bigger. For elongated objects it is possible to have multiple local maxima along longest axis of object. This step is to merge them. This distance is controlled by ``dilate_radius`` parameter. This is algorithm that detect core objects """ __argument_class__ = MaximumDistanceWatershedParams
[docs] @classmethod def get_name(cls): return "Maximum Distance Core"
@classmethod @update_argument("arguments") def calculate_mask( cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: MaximumDistanceWatershedParams, operator: typing.Callable[[object, object], bool], ): thr: BaseThreshold = ThresholdSelection[arguments.threshold.name] mask1, thr_val = thr.calculate_mask(data, mask, arguments.threshold.values, operator) mask1 = sitk.GetArrayFromImage( sitk.RelabelComponent(sitk.ConnectedComponent(sitk.GetImageFromArray(mask1)), arguments.minimum_size) ) if not np.any(mask1): return mask1, (thr_val, thr_val) mask2 = close_small_holes((mask1 > 0), 10) data = sitk.GetArrayFromImage(sitk.DanielssonDistanceMap(sitk.GetImageFromArray((mask2 == 0).astype(np.uint8)))) maxima = sitk.GetArrayFromImage(sitk.RegionalMaxima(sitk.GetImageFromArray(data))) maxima[data < arguments.minimum_border_distance] = 0 dilated_maxima = sitk.GetArrayFromImage( sitk.BinaryDilate(sitk.GetImageFromArray(maxima), [arguments.dilate_radius] * 3) ) mask1[mask1 > 0] = 1 mask1[dilated_maxima > 0] = 2 if operator(0, 1): meth = np.amax else: meth = np.amin return mask1, (thr_val, float(meth(data[mask1 == 2]) if np.any(mask1 == 2) else meth(data[mask1 == 1])))
[docs] class DoubleThresholdSelection( AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold ): pass
DoubleThresholdSelection.register(DoubleThreshold) DoubleThresholdSelection.register(DoubleOtsu) DoubleThresholdSelection.register(MultipleOtsu) DoubleThresholdSelection.register(MaximumDistanceCore, old_names=["Maximum Distance Watershed"])
[docs] class RangeThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold): pass
RangeThresholdSelection.register(RangeThreshold) RangeThresholdSelection.register(DoubleOtsu) def __getattr__(name): # pragma: no cover if name == "threshold_dict": warnings.warn( "threshold_dict is deprecated. Please use ThresholdSelection instead", category=FutureWarning, stacklevel=2 ) return ThresholdSelection.__register__ if name == "double_threshold_dict": warnings.warn( "double_threshold_dict is deprecated. Please use DoubleThresholdSelection instead", category=FutureWarning, stacklevel=2, ) return DoubleThresholdSelection.__register__ raise AttributeError(f"module {__name__} has no attribute {name}")