Source code for PartSegCore.segmentation.noise_filtering

import typing
import warnings
from abc import ABC
from enum import Enum

import numpy as np
from local_migrator import register_class, rename_key, update_argument
from pydantic import Field

from PartSegCore.algorithm_describe_base import AlgorithmDescribeBase, AlgorithmSelection
from PartSegCore.image_operations import bilateral, gaussian, median
from PartSegCore.segmentation.algorithm_base import calculate_operation_radius as _calculate_operation_radius
from PartSegCore.utils import BaseModel


@register_class(old_paths=["PartSeg.utils.segmentation.noise_filtering.GaussType"])
class DimensionType(Enum):
    Layer = 1
    Stack = 2

    def __str__(self):
        return self.name.replace("_", " ")


[docs] class NoiseFilteringBase(AlgorithmDescribeBase, ABC): """Base class for noise filtering operations"""
[docs] @classmethod def noise_filter(cls, channel: np.ndarray, spacing: typing.Iterable[float], arguments: dict) -> np.ndarray: """ This function need be overloaded in implementation :param channel: single channel ad 2d or 3d array :param spacing: image spacing :param arguments: additional arguments defined by :py:meth:`get_fields` :return: channel array with removed noise """ raise NotImplementedError
class NoneNoiseFiltering(NoiseFilteringBase): __argument_class__ = BaseModel @classmethod def get_name(cls): return "None" @classmethod def noise_filter(cls, channel: np.ndarray, spacing: typing.Iterable[float], arguments: dict): return channel @register_class(version="0.0.1", migrations=[("0.0.1", rename_key("gauss_type", "dimension_type", optional=True))]) class GaussNoiseFilteringParams(BaseModel): dimension_type: DimensionType = Field(DimensionType.Layer, title="Gauss type") radius: float = Field(1.0, title="Gauss radius", ge=0, le=100) class GaussNoiseFiltering(NoiseFilteringBase): __argument_class__ = GaussNoiseFilteringParams @classmethod def get_name(cls): return "Gauss" @classmethod @update_argument("arguments") def noise_filter(cls, channel: np.ndarray, spacing: typing.Iterable[float], arguments: GaussNoiseFilteringParams): gauss_radius = calculate_operation_radius(arguments.radius, spacing, arguments.dimension_type) layer = arguments.dimension_type == DimensionType.Layer return gaussian(channel, gauss_radius, layer=layer) class BilateralNoiseFilteringParams(BaseModel): dimension_type: DimensionType = Field(DimensionType.Layer, title="Bilateral type") radius: float = Field(1.0, title="Bilateral radius", ge=0, le=100) class BilateralNoiseFiltering(NoiseFilteringBase): __argument_class__ = BilateralNoiseFilteringParams @classmethod def get_name(cls): return "Bilateral" @classmethod @update_argument("arguments") def noise_filter( cls, channel: np.ndarray, spacing: typing.Iterable[float], arguments: BilateralNoiseFilteringParams ): gauss_radius = calculate_operation_radius(arguments.radius, spacing, arguments.dimension_type) layer = arguments.dimension_type == DimensionType.Layer return bilateral(channel, max(gauss_radius), layer=layer) def calculate_operation_radius(radius, spacing, gauss_type): res = _calculate_operation_radius(radius, spacing, gauss_type) return [radius for _ in spacing] if res == radius else res class MedianNoiseFilteringParams(BaseModel): dimension_type: DimensionType = Field(DimensionType.Layer, title="Median type") radius: int = Field(1, title="Median radius", ge=0, le=100) class MedianNoiseFiltering(NoiseFilteringBase): __argument_class__ = MedianNoiseFilteringParams @classmethod def get_name(cls): return "Median" @classmethod @update_argument("arguments") def noise_filter(cls, channel: np.ndarray, spacing: typing.Iterable[float], arguments: MedianNoiseFilteringParams): gauss_radius = calculate_operation_radius(arguments.radius, spacing, arguments.dimension_type) layer = arguments.dimension_type == DimensionType.Layer gauss_radius = [int(x) for x in gauss_radius] return median(channel, gauss_radius, layer=layer) class NoiseFilterSelection(AlgorithmSelection, class_methods=["noise_filter"], suggested_base_class=NoiseFilteringBase): pass NoiseFilterSelection.register(NoneNoiseFiltering) NoiseFilterSelection.register(GaussNoiseFiltering) NoiseFilterSelection.register(MedianNoiseFiltering) NoiseFilterSelection.register(BilateralNoiseFiltering) def __getattr__(name): # pragma: no cover if name == "noise_filtering_dict": warnings.warn( "noise_filtering_dict is deprecated. Please use NoiseFilterSelection instead", category=FutureWarning, stacklevel=2, ) return NoiseFilterSelection.__register__ raise AttributeError(f"module {__name__} has no attribute {name}")