Skip to content
Snippets Groups Projects
transforms.py 78.5 KiB
Newer Older
Lucia Hradecká's avatar
Lucia Hradecká committed
            ignore_index (float | None, optional): If a float, then transformation of `mask` is done with 
                ``border_mode = 'constant'`` and ``mval = ignore_index``. 
                
                If ``None``, this argument is ignored.

                Defaults to ``None``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``0.5``.

        Targets:
            image, mask, float mask, key points, bounding boxes
Lucia Hradecká's avatar
Lucia Hradecká committed
    """
    def __init__(self, angles: TypeTripletFloat = (0, 0, 0),
                 translation: TypeTripletFloat = (0, 0, 0),
                 scale: TypeTripletFloat = (1, 1, 1),
                 spacing: TypeTripletFloat = (1, 1, 1),
                 change_to_isotropic: bool = False,
                 interpolation: str = 'linear',
Lucia Hradecká's avatar
Lucia Hradecká committed
                 border_mode: str = 'constant', ival: float = 0, mval: float = 0,
                 ignore_index: Union[float, None] = None, always_apply: bool = False, p: float = 0.5):
        super().__init__(always_apply, p)
        self.angles: TypeTripletFloat = parse_coefs(angles, identity_element=0)
        self.translation: TypeTripletFloat = parse_coefs(translation, identity_element=0)
        self.scale: TypeTripletFloat = parse_coefs(scale, identity_element=1)
        self.spacing: TypeTripletFloat = parse_coefs(spacing, identity_element=1)
        self.interpolation: str = parse_itk_interpolation(interpolation)
Lucia Hradecká's avatar
Lucia Hradecká committed
        self.border_mode = border_mode                 # not used
        self.mask_mode = border_mode                   # not used
        self.ival = ival
        self.mval = mval
        self.keep_scale = not change_to_isotropic

        if ignore_index is not None:
            self.mask_mode = "constant"
            self.mval = ignore_index

    def apply(self, img, **params):
        return F.affine(img,
                        scales=self.scale,
                        degrees=self.angles,
                        translation=self.translation,
                        interpolation=self.interpolation,
                        border_mode=self.border_mode,
                        value=self.ival,
                        spacing=self.spacing)

    def apply_to_mask(self, mask, **params):
        interpolation = parse_itk_interpolation('nearest')   # refers to 'sitkNearestNeighbor'
Filip Lux's avatar
Filip Lux committed
        return F.affine(np.expand_dims(mask, 0),
Lucia Hradecká's avatar
Lucia Hradecká committed
                        scales=self.scale,
                        degrees=self.angles,
                        translation=self.translation,
                        interpolation=interpolation,
                        border_mode=self.mask_mode,
                        value=self.mval,
Filip Lux's avatar
Filip Lux committed
                        spacing=self.spacing)[0]
Lucia Hradecká's avatar
Lucia Hradecká committed

    def apply_to_float_mask(self, mask, **params):
Filip Lux's avatar
Filip Lux committed
        return F.affine(np.expand_dims(mask, 0),
Lucia Hradecká's avatar
Lucia Hradecká committed
                        scales=self.scale,
                        degrees=self.angles,
                        translation=self.translation,
                        interpolation=self.interpolation,
                        border_mode=self.mask_mode,
                        value=self.mval,
Filip Lux's avatar
Filip Lux committed
                        spacing=self.spacing)[0]
    def apply_to_keypoints(self, keypoints, **params):
        return F.affine_keypoints(keypoints,
                                  scales=self.scale,
                                  degrees=self.angles,
                                  translation=self.translation,
                                  spacing=self.spacing,
                                  domain_limit=params['domain_limit'])


        # set parameters of the transform
        domain_limit = get_spatio_temporal_domain_limit(data, targets)

        return {
            "domain_limit": domain_limit
        }


# IMAGE ONLY TRANSFORMS
# TODO potential upgrade : different sigmas for different channels
class GaussianNoise(ImageOnlyTransform):
    """Adds Gaussian noise to the image. The noise is drawn from normal distribution with given parameters.

        Args:
            var_limit (tuple, optional): Variance of normal distribution is randomly chosen from this interval.

                Defaults to ``(0.001, 0.1)``.
            mean (float, optional): Mean of normal distribution.

                Defaults to ``0``.
            always_apply (bool, optional): Always apply this transformation in composition.

                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition.

                Defaults to ``0.5``.

        Targets:
            image
    """

Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, var_limit: TypePairFloat = (0.001, 0.1), mean: float = 0,
                 always_apply: bool = False, p: float = 0.5):
        super().__init__(always_apply, p)
        self.var_limit = var_limit
        self.mean = mean

    def apply(self, img, **params):
        return F.gaussian_noise(img, sigma=params['sigma'], mean=self.mean)

        var = uniform(self.var_limit[0], self.var_limit[1])
        sigma = var ** 0.5
        return {"sigma": sigma}

    def __repr__(self):
        return f'GaussianNoise({self.var_limit}, {self.mean}, {self.always_apply}, {self.p})'


class PoissonNoise(ImageOnlyTransform):
    """Adds Poisson noise to the image.

        Args:
Filip Lux's avatar
Filip Lux committed
            peak_limit (tuple): Range to sample the expected intensity of Poisson noise.
Filip Lux's avatar
Filip Lux committed
                Defaults to ``(0.1, 0.5)``.
            always_apply (bool, optional): Always apply this transformation in composition.

                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition.

                Defaults to ``0.5``.

        Targets:
            image
    """

Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, peak_limit: TypePairFloat = (0.1, 0.5),
                 always_apply: bool = False, p: float = 0.5):
        super().__init__(always_apply, p)
        self.peak_limit = peak_limit

    def apply(self, img, **params):
        return F.poisson_noise(img, peak=params['peak'])

        peak = uniform(self.peak_limit[0], self.peak_limit[1])
        return {"peak": peak}

    def __repr__(self):
        return f'PoissonNoise({self.always_apply}, {self.p})'

Lucia Hradecká's avatar
Lucia Hradecká committed

# TODO create checks (mean, std, got good shape, and etc.), what if given list but only one channel, and reverse.
class NormalizeMeanStd(ImageOnlyTransform):
    """Normalize image values to have mean 0 and standard deviation 1, given channel-wise means and standard deviations.

        For a single-channel image, the normalization is applied by the formula: :math:`img = (img - mean) / std`.
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
        If the image contains more channels, then the formula is used for each channel separately.
Lucia Hradecká's avatar
Lucia Hradecká committed

        It is recommended to input dataset-wide means and standard deviations.

        Args:
            mean (float | List[float]): Channel-wise image mean.

Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                Must be either of: ``M`` (for single-channel images),
                ``(M_1, M_2, ..., M_C)`` (for multi-channel images).
Lucia Hradecká's avatar
Lucia Hradecká committed
            std (float | List[float]): Channel-wise image standard deviation.

Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                Must be either of: ``S`` (for single-channel images),
                ``(S_1, S_2, ..., S_C)`` (for multi-channel images).
Lucia Hradecká's avatar
Lucia Hradecká committed
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``True``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``1``.

        Targets:
            image
    """
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, mean: Union[tuple, float], std: Union[tuple, float],
Lucia Hradecká's avatar
Lucia Hradecká committed
                 always_apply: bool = True, p: float = 1.0):
        super().__init__(always_apply, p)
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
        self.mean: np.ndarray = np.array(mean, dtype=np.float32)
        self.std: np.ndarray = np.array(std, dtype=np.float32)
        assert self.mean.shape == self.std.shape
Lucia Hradecká's avatar
Lucia Hradecká committed
        self.denominator = np.reciprocal(self.std, dtype=np.float32)

    def apply(self, image, **params):
        return F.normalize_mean_std(image, self.mean, self.denominator)

    def __repr__(self):
        return f'NormalizeMeanStd({self.mean}, {self.std}, ' \
               f' {self.always_apply}, {self.p})'


class GaussianBlur(ImageOnlyTransform):
    """Performs Gaussian blurring of the image. In case of a multi-channel image, individual channels are blured separately.

        Internally, the ``scipy.ndimage.gaussian_filter`` function is used. The ``border_mode`` and ``cval``
        arguments are forwarded to it. More details at:
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.gaussian_filter.html.

        Args:
            sigma (float, Tuple(float), List[Tuple(float) | float] , optional): Gaussian sigma.

                Must be either of: ``S``, ``(S_Z, S_Y, S_X)``, ``(S_Z, S_Y, S_X, S_T)``, ``[S_1, S_2, ..., S_C]``,
                ``[(S_Z1, S_Y1, S_X1), (S_Z2, S_Y2, S_X2), ..., (S_ZC, S_YC, S_XC)]``, or
                ``[(S_Z1, S_Y1, S_X1, S_T1), (S_Z2, S_Y2, S_X2, S_T2), ..., (S_ZC, S_YC, S_XC, S_TC)]``.

                If a float, the spatial dimensions are blurred with the same strength (equivalent to ``(S, S, S)``).

                If a tuple, the sigmas for spatial dimensions and possibly the time dimension must be specified.

                If a list, sigmas for each channel must be specified either as a single number or as a tuple.

                Defaults to ``0.8``.
            border_mode (str, optional): Values outside image domain are filled according to this mode.

                Defaults to ``'reflect'``.
            cval (float, optional): Value to fill past edges of image. Only applied when ``border_mode = 'constant'``.

                Defaults to ``0``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``0.5``.

        Targets:
            image
    """
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, sigma: Union[float , tuple, List[Union[tuple, float]]] = 0.8,
Lucia Hradecká's avatar
Lucia Hradecká committed
                 border_mode: str = "reflect", cval: float = 0,
                 always_apply: bool = False, p: float = 0.5):
        
        super().__init__(always_apply, p)
        self.sigma = sigma
        self.border_mode = border_mode
        self.cval = cval

    def apply(self, img, **params):
        return F.gaussian_blur(img, self.sigma, self.border_mode, self.cval)


class RandomGaussianBlur(ImageOnlyTransform):
    """Performs Gaussian blur on the image with a random strength blurring.
        In case of a multi-channel image, individual channels are blured separately.

        Behaves similarly to GaussianBlur. The Gaussian sigma is randomly drawn from
        the interval [min_sigma, s] for the respective s from ``max_sigma`` for each channel and dimension.

        Internally, the ``scipy.ndimage.gaussian_filter`` function is used. The ``border_mode`` and ``cval``
        arguments are forwarded to it. More details at:
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.gaussian_filter.html.

        Args:
            max_sigma (float, Tuple(float), List[Tuple(float) | float] , optional): Maximum Gaussian sigma.

                Must be either of: ``S``, ``(S_Z, S_Y, S_X)``, ``(S_Z, S_Y, S_X, S_T)``, ``[S_1, S_2, ..., S_C]``,
                ``[(S_Z1, S_Y1, S_X1), (S_Z2, S_Y2, S_X2), ..., (S_ZC, S_YC, S_XC)]``, or
                ``[(S_Z1, S_Y1, S_X1, S_T1), (S_Z2, S_Y2, S_X2, S_T2), ..., (S_ZC, S_YC, S_XC, S_TC)]``.

                If a float, the spatial dimensions are blurred equivalently (equivalent to ``(S, S, S)``).

                If a tuple, the sigmas for spatial dimensions and possibly the time dimension must be specified.

                If a list, sigmas for each channel must be specified either as a single number or as a tuple.

                Defaults to ``0.8``.
            min_sigma (float, optional): Minimum Gaussian sigma for all channels and dimensions.

                Defaults to ``0``.
            border_mode (str, optional): Values outside image domain are filled according to this mode.

                Defaults to ``'reflect'``.
            cval (float, optional): Value to fill past edges of image. Only applied when ``border_mode = 'constant'``.

                Defaults to ``0``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``0.5``.

        Targets:
            image
    """
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, max_sigma: Union[float, tuple, List[Union[float, tuple]]] = 0.8,
Lucia Hradecká's avatar
Lucia Hradecká committed
                 min_sigma: float = 0, border_mode: str = "reflect", cval: float = 0,
                 always_apply: bool = False, p: float = 0.5):
        super().__init__(always_apply, p)
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
        self.max_sigma = max_sigma  # parse_coefs(max_sigma, d4=True)
Lucia Hradecká's avatar
Lucia Hradecká committed
        self.min_sigma = min_sigma
        self.border_mode = border_mode
        self.cval = cval

    def apply(self, img, **params):
        return F.gaussian_blur(img, params["sigma"], self.border_mode, self.cval)

Lucia Hradecká's avatar
Lucia Hradecká committed
        if isinstance(self.max_sigma, (float, int)):
            sigma = random.uniform(self.min_sigma, self.max_sigma)
        elif isinstance(self.max_sigma, tuple):
            sigma = tuple([random.uniform(self.min_sigma, self.max_sigma[i]) for i in range(len(self.max_sigma))])
        else:
            sigma = []
            for channel in self.max_sigma:
                if isinstance(channel, (float, int)):
                    sigma.append(random.uniform(self.min_sigma, channel))
                else:
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                    sigma.append(tuple([random.uniform(self.min_sigma, channel[i]) for i in range(len(channel))]))
Lucia Hradecká's avatar
Lucia Hradecká committed
        return {"sigma": sigma}


class RandomGamma(ImageOnlyTransform):
    """Performs the gamma transformation with a randomly chosen gamma. If image values (in any channel) are outside
        the [0,1] interval, this transformation is not performed.

        Args:
            gamma_limit (Tuple(float), optional): Interval from which gamma is selected.

                Defaults to ``(0.8, 1.2)``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``0.5``.

        Targets:
            image
    """
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, gamma_limit: TypePairFloat = (0.8, 1.2),
Lucia Hradecká's avatar
Lucia Hradecká committed
                 always_apply: bool = False, p: float = 0.5):
        super().__init__(always_apply, p)
        self.gamma_limit = gamma_limit

    def apply(self, img, gamma=1, **params):
        return F.gamma_transform(img, gamma=gamma)

Lucia Hradecká's avatar
Lucia Hradecká committed
        return {"gamma": random.uniform(self.gamma_limit[0], self.gamma_limit[1])}

    def __repr__(self):
        return f'RandomGamma({self.gamma_limit}, {self.always_apply}, {self.p})'


class RandomBrightnessContrast(ImageOnlyTransform):
    """Randomly change brightness and contrast of the input image.

        Unlike ``RandomBrightnessContrast`` from `Albumentations`, this transform is using the
        formula :math:`f(a) = (c+1) * a + b`, where :math:`c` is contrast and :math:`b` is brightness.

        Args:
            brightness_limit ((float, float) | float, optional): Interval from which the change in brightness is
                randomly drawn. If the change in brightness is 0, the brightness will not change.

                Must be either of: ``B``, ``(B1, B2)``.

                If a float, the interval will be ``(-B, B)``.

                Defaults to ``0.2``.
            contrast_limit ((float, float) | float, optional): Interval from which the change in contrast is
                randomly drawn. If the change in contrast is 1, the contrast will not change.

                Must be either of: ``C``, ``(C1, C2)``.

                If a float, the interval will be ``(-C, C)``.

                Defaults to ``0.2``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``0.5``.

        Targets:
            image
    """
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, brightness_limit: Union[float, TypePairFloat] = 0.2,
                 contrast_limit: Union[float, TypePairFloat] = 0.2,
                 always_apply: bool = False, p: float = 0.5):
Lucia Hradecká's avatar
Lucia Hradecká committed
        super().__init__(always_apply, p)
        self.brightness_limit = to_tuple(brightness_limit)
        self.contrast_limit = to_tuple(contrast_limit)

    def apply(self, img, **params):
        return F.brightness_contrast_adjust(img, params['alpha'], params['beta'])

Lucia Hradecká's avatar
Lucia Hradecká committed
        return {
            "alpha": 1.0 + random.uniform(self.contrast_limit[0], self.contrast_limit[1]),
            "beta": 0.0 + random.uniform(self.brightness_limit[0], self.brightness_limit[1]),
        }

    def __repr__(self):
        return f'RandomBrightnessContrast({self.brightness_limit}, {self.contrast_limit},  ' \
               f'{self.always_apply}, {self.p})'


class HistogramEqualization(ImageOnlyTransform):
    """Performs equalization of histogram. The equalization is done channel-wise, meaning that each channel is equalized
        separately.

        **Warning! Images are normalized over both spatial and temporal domains together. The output is in the range [0, 1].**

        Args:
            bins (int, optional): Number of bins for image histogram.

                Defaults to ``256``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``False``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``1``.

        Targets:
            image
    """
    def __init__(self, bins: int = 256, always_apply: bool = False, p: float = 1):
        super().__init__(always_apply, p)
        self.bins = bins

    def apply(self, img, **params):
        return F.histogram_equalization(img, self.bins)


class Pad(DualTransform):
    """Pads the input.

        Internally, the ``numpy.pad`` function is used. The ``border_mode``, ``ival`` and ``mval``
        arguments are forwarded to it. More details at:
        https://numpy.org/doc/stable/reference/generated/numpy.pad.html.

Lucia Hradecká's avatar
Lucia Hradecká committed
        Args:
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
            pad_size (int | Tuple[int]): Number of pixels padded to the edges of each axis.
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                Must be either of: ``P``, ``(P1, P2)``, or ``(P_Z1, P_Z2, P_Y1, P_Y2, P_X1, P_X2)``.
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                If an integer, it is equivalent to ``(P, P, P, P, P, P)``.
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                If a tuple of two numbers, it is equivalent to ``(P1, P2, P1, P2, P1, P2)``.
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                Otherwise, it must specify padding for all spatial dimensions.
                The unspecified dimensions (C and T) are not affected.
Lucia Hradecká's avatar
Lucia Hradecká committed
            border_mode (str, optional): Values outside image domain are filled according to this mode.

                Defaults to ``'constant'``.
            ival (float | Sequence, optional): Values of `image` voxels outside of the `image` domain.
                Only applied when ``border_mode = 'constant'`` or ``border_mode = 'linear_ramp'``.

                Defaults to ``0``.
            mval (float | Sequence, optional): Values of `mask` voxels outside of the `mask` domain.
                Only applied when ``border_mode = 'constant'`` or ``border_mode = 'linear_ramp'``.

                Defaults to ``0``.
            ignore_index (float | None, optional): If a float, then transformation of `mask` is done with 
                ``border_mode = 'constant'`` and ``mval = ignore_index``. 
                
                If ``None``, this argument is ignored.

                Defaults to ``None``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``True``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``1``.

        Targets:
            image, mask, float mask, key points, bounding boxes
Lucia Hradecká's avatar
Lucia Hradecká committed
    """
Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
    def __init__(self, pad_size: Union[int, TypePairInt, TypeSextetInt],
                 border_mode: str = 'constant', ival: Union[float, Sequence] = 0, mval: Union[float, Sequence] = 0,
Lucia Hradecká's avatar
Lucia Hradecká committed
                 ignore_index: Union[float, None] = None, always_apply: bool = True, p : float = 1):
        super().__init__(always_apply, p)
        self.pad_size: TypeSextetInt = parse_pads(pad_size)
Lucia Hradecká's avatar
Lucia Hradecká committed
        self.border_mode = border_mode
        self.mask_mode = border_mode 
        self.ival = ival
        self.mval = mval

        if not (ignore_index is None):
            self.mask_mode = "constant"
            self.mval = ignore_index

    def apply(self, img, **params):
        return F.pad_pixels(img, self.pad_size, self.border_mode, self.ival)

    def apply_to_mask(self, mask, **params):
        return F.pad_pixels(mask, self.pad_size, self.mask_mode, self.mval, True)

    def apply_to_keypoints(self, keypoints, **params):
        return F.pad_keypoints(keypoints, self.pad_size)

Lucia Hradecká's avatar
Lucia Hradecká committed
    def __repr__(self):
        return f'Pad({self.pad_size}, {self.border_mode}, {self.ival}, {self.mval}, {self.always_apply}, ' \
               f'{self.p})'


class Normalize(ImageOnlyTransform):
    """Change image mean and standard deviation to the given values (channel-wise).

        Args:
            mean (float | List[float], optional): The desired channel-wise means.

Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                Must be either of: ``M`` (for single-channel images),
                ``[M_1, M_2, ..., M_C]`` (for multi-channel images).
Lucia Hradecká's avatar
Lucia Hradecká committed

                Defaults to ``0``.
            std (float | List[float], optional): The desired channel-wise standard deviations.

Lucia D. Hradecka's avatar
Lucia D. Hradecka committed
                Must be either of: ``S`` (for single-channel images),
                ``[S_1, S_2, ..., S_C]`` (for multi-channel images).
Lucia Hradecká's avatar
Lucia Hradecká committed

                Defaults to ``1``.
            always_apply (bool, optional): Always apply this transformation in composition. 
            
                Defaults to ``True``.
            p (float, optional): Chance of applying this transformation in composition. 
            
                Defaults to ``1``.

        Targets:
            image
    """
    def __init__(self, mean: Union[float, List[float]] = 0, std: Union[float, List[float]] = 1,
                 always_apply: bool = True, p: float = 1.0):
        super().__init__(always_apply, p)
        self.mean = mean
        self.std = std

    def apply(self, img, **params):
        return F.normalize(img, self.mean, self.std)

    def __repr__(self):
        return f'Normalize({self.mean}, {self.std}, {self.always_apply}, {self.p})'


    """ Rescales the input and changes its shape accordingly.

        Internally, the ``skimage.transform.resize`` function is used.
        The ``interpolation``, ``border_mode``, ``ival``, ``mval``,
        and ``anti_aliasing_downsample`` arguments are forwarded to it. More details at:
        https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize.

        Args:
            scales (float|List[float], optional): Value by which the input should be scaled.

                Must be either of: ``S``, ``[S_Z, S_Y, S_X]``.

                If a float, then all spatial dimensions are scaled by it (equivalent to ``[S, S, S]``).

                The unspecified dimensions (C and T) are not affected.

                Defaults to ``1``.
            interpolation (int, optional): Order of spline interpolation.

                Defaults to ``1``.
            border_mode (str, optional): Values outside image domain are filled according to this mode.

                Defaults to ``'reflect'``.
            ival (float, optional): Value of `image` voxels outside of the `image` domain. Only applied when ``border_mode = 'constant'``.

                Defaults to ``0``.
            mval (float, optional): Value of `mask` and `float_mask` voxels outside of the domain. Only applied when ``border_mode = 'constant'``.

                Defaults to ``0``.
            anti_aliasing_downsample (bool, optional): Controls if the Gaussian filter should be applied before
                downsampling. Recommended.

                Defaults to ``True``.
            ignore_index (float | None, optional): If a float, then transformation of `mask` is done with
                ``border_mode = 'constant'`` and ``mval = ignore_index``.

                If ``None``, this argument is ignored.

                Defaults to ``None``.
            always_apply (bool, optional): Always apply this transformation in composition.

                Defaults to ``True``.
            p (float, optional): Chance of applying this transformation in composition.

                Defaults to ``1``.

        Targets:
            image, mask, float mask, key points, bounding boxes
        """

    def __init__(self, scales=1, interpolation: int = 1, border_mode: str = 'reflect', ival: float = 0,
                 mval: float = 0, anti_aliasing_downsample: bool = True, ignore_index=None,
                 always_apply: bool = True, p: float = 1, **kwargs):
        super().__init__(always_apply, p)
        self.scale = parse_coefs(scales, identity_element=1.)
        self.interpolation = interpolation
        self.border_mode = border_mode
        self.mask_mode = border_mode
        self.ival = ival
        self.mval = mval
        self.anti_aliasing_downsample = anti_aliasing_downsample
        if not (ignore_index is None):
            self.mask_mode = "constant"
            self.mval = ignore_index

    def apply(self, img, **params):
        return F.resize(img, input_new_shape=params['new_shape'], interpolation=self.interpolation, cval=self.ival,
                        border_mode=self.border_mode, anti_aliasing_downsample=self.anti_aliasing_downsample)

    def apply_to_mask(self, mask, **params):
        return F.resize(mask, input_new_shape=params['new_shape'], interpolation=0, cval=self.mval,
                        border_mode=self.mask_mode, anti_aliasing_downsample=False, mask=True)

    def apply_to_float_mask(self, mask, **params):
        return F.resize(mask, input_new_shape=params['new_shape'], interpolation=self.interpolation, cval=self.mval,
                        border_mode=self.mask_mode, anti_aliasing_downsample=False, mask=True)

    def apply_to_keypoints(self, keypoints, **params):
        return F.resize_keypoints(keypoints,
                                  domain_limit=params['domain_limit'],
                                  new_shape=params['new_shape'])

    """
    def apply_to_bboxes(self, bboxes, **params):
        for bbox in bboxes:
            new_bbox = F.resize_keypoints(bbox,
                                          input_new_shape=params['new_shape'],
                                          original_shape=params['original_shape'],
                                          keep_all=True)

            if validate_bbox(bbox, new_bbox, min_overlay_ratio):
                res.append(new_bbox)

        return res
    """

        domain_limit: TypeSpatioTemporalCoordinate = get_spatio_temporal_domain_limit(data, targets)

        # compute shape of the resize dimage
        # TODO +(0,) because of the F.resize error/hotfix
        new_shape = tuple(np.asarray(domain_limit[:3]) * np.asarray(self.scale)) + (0,)

        return {
            "domain_limit": domain_limit,
            "new_shape": new_shape,
        }

    def __repr__(self):
        return f'Rescale({self.scale}, {self.interpolation}, {self.border_mode} , {self.ival}, {self.mval},' \
               f'{self.anti_aliasing_downsample}, {self.always_apply}, {self.p})'


class RemoveBackgroundGaussian(ImageOnlyTransform):
    """
    Removes background by subtracting a blurred image from the original image.

    The background image is created using Gaussian blurring. In case of a multi-channel image, individual channels
    are blured separately.

    Internally, the ``scipy.ndimage.gaussian_filter`` function is used. The ``border_mode`` and ``cval``
    arguments are forwarded to it. More details at:
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.gaussian_filter.html.

    Args:
        sigma (float, Tuple(float), List[Tuple(float) | float] , optional): Gaussian sigma.

            Must be either of: ``S``, ``(S_Z, S_Y, S_X)``, ``(S_Z, S_Y, S_X, S_T)``, ``[S_1, S_2, ..., S_C]``,
            ``[(S_Z1, S_Y1, S_X1), (S_Z2, S_Y2, S_X2), ..., (S_ZC, S_YC, S_XC)]``, or
            ``[(S_Z1, S_Y1, S_X1, S_T1), (S_Z2, S_Y2, S_X2, S_T2), ..., (S_ZC, S_YC, S_XC, S_TC)]``.

            If a float, the spatial dimensions are blurred with the same strength (equivalent to ``(S, S, S)``).

            If a tuple, the sigmas for spatial dimensions and possibly the time dimension must be specified.

            If a list, sigmas for each channel must be specified either as a single number or as a tuple.

            Defaults to ``10``.
        mode (str, optional): How to compute the background and remove it. Possible values:
            ``'default'`` (subtract blurred image from the input image),
            ``'bright_objects'`` (subtract the point-wise minimum of (blurred image, input image) from the input image),
            ``'dark_objects'`` (subtract the input image from the point-wise maximum of (blurred image, input image)).

            Defaults to ``'default'``.
        border_mode (str, optional): Values outside image domain are filled according to this mode.

            Defaults to ``'reflect'``.
        cval (float, optional): Value to fill past edges of image. Only applied when ``border_mode = 'constant'``.

            Defaults to ``0``.
        always_apply (bool, optional): Always apply this transformation in composition.

            Defaults to ``True``.
        p (float, optional): Chance of applying this transformation in composition.

            Defaults to ``1.0``.

    Targets:
        image
    """

    def __init__(self, sigma: Union[float, tuple, List[Union[tuple, float]]] = 10, mode: str = 'default',
                 border_mode: str = "reflect", cval: float = 0,
                 always_apply: bool = True, p: float = 1.0):

        super().__init__(always_apply, p)
        self.sigma = sigma
        self.mode = mode
        self.border_mode = border_mode
        self.cval = cval

    def apply(self, img, **params):
        background = F.gaussian_blur(img, self.sigma, self.border_mode, self.cval)

        if self.mode == 'bright_objects':
            return img - np.minimum(background, img)

        if self.mode == 'dark_objects':
            return np.maximum(background, img) - img

        return img - background

    def __repr__(self):
        return f'RemoveBackgroundGaussian({self.sigma}, {self.mode}, {self.border_mode} , {self.cval}, ' \
               f'{self.always_apply}, {self.p})'


Lucia Hradecká's avatar
Lucia Hradecká committed
class Contiguous(DualTransform):
    """Transform the image data to a contiguous array.

        Args:
            always_apply (bool, optional): Always apply this transformation in composition.

                Defaults to ``True``.
            p (float, optional): Chance of applying this transformation in composition.

                Defaults to ``1``.

        Targets:
            image, mask, float mask
Lucia Hradecká's avatar
Lucia Hradecká committed
    """
    def __init__(self, always_apply: bool = True, p: float = 1.0):
        super().__init__(always_apply, p)

    def apply(self, image, **params):
        return np.ascontiguousarray(image)

    def apply_to_mask(self, mask, **params):
        return np.ascontiguousarray(mask)

    def __repr__(self):
        return f'Contiguous({self.always_apply}, {self.p})'


class StandardizeDatatype(DualTransform):
    """Change image and float_mask datatype to ``np.float32`` without changing intensities.
    Change mask datatype to ``np.int32``.
Lucia Hradecká's avatar
Lucia Hradecká committed

        Args:
            always_apply (bool, optional): Always apply this transformation in composition.

                Defaults to ``True``.
            p (float, optional): Chance of applying this transformation in composition.

                Defaults to ``1``.

        Targets:
            image, mask, float mask
Lucia Hradecká's avatar
Lucia Hradecká committed
    """
    def __init__(self, always_apply: bool = True, p: float = 1.0):
        super().__init__(always_apply, p)

    def apply(self, image, **params):
        return image.astype(np.float32)

    def apply_to_mask(self, mask, **params):
        return mask.astype(np.int32)

    def apply_to_float_mask(self, mask, **params):
Lucia Hradecká's avatar
Lucia Hradecká committed
        return mask.astype(np.float32)

    def __repr__(self):
        return f'Float({self.always_apply}, {self.p})'