Skip to content
Snippets Groups Projects
functional.py 19.2 KiB
Newer Older
Filip Lux's avatar
Filip Lux committed
# ============================================================================================= #
#  Author:       Pavel Iakubovskii, ZFTurbo, ashawkey, Dominik Müller,                          #
Filip Lux's avatar
Filip Lux committed
#                Samuel Šuľan, Lucia Hradecká, Filip Lux                                        #
Filip Lux's avatar
Filip Lux committed
#  Copyright:    albumentations:    : https://github.com/albumentations-team                    #
#                Pavel Iakubovskii  : https://github.com/qubvel                                 #
#                ZFTurbo            : https://github.com/ZFTurbo                                #
#                ashawkey           : https://github.com/ashawkey                               #
#                Dominik Müller     : https://github.com/muellerdo                              #
#                Lucia Hradecká     : lucia.d.hradecka@gmail.com                                #
#                Filip Lux          : lux.filip@gmail.com                                       #
#                Samuel Šuľan                                                                   #
Filip Lux's avatar
Filip Lux committed
#                                                                                               #
#  Volumentations History:                                                                      #
#       - Original:                 https://github.com/albumentations-team/albumentations       #
#       - 3D Conversion:            https://github.com/ashawkey/volumentations                  #
#       - Continued Development:    https://github.com/ZFTurbo/volumentations                   #
#       - Enhancements:             https://github.com/qubvel/volumentations                    #
#       - Further Enhancements:     https://github.com/muellerdo/volumentations                 #
#       - Biomedical Enhancements:  https://gitlab.fi.muni.cz/cbia/bio-volumentations           #
#                                                                                               #
#  MIT License.                                                                                 #
#                                                                                               #
#  Permission is hereby granted, free of charge, to any person obtaining a copy                 #
#  of this software and associated documentation files (the "Software"), to deal                #
#  in the Software without restriction, including without limitation the rights                 #
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell                    #
#  copies of the Software, and to permit persons to whom the Software is                        #
#  furnished to do so, subject to the following conditions:                                     #
#                                                                                               #
#  The above copyright notice and this permission notice shall be included in all               #
#  copies or substantial portions of the Software.                                              #
#                                                                                               #
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR                   #
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,                     #
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE                  #
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER                       #
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,                #
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE                #
#  SOFTWARE.                                                                                    #
# ============================================================================================= #

import numpy as np
import skimage.transform as skt
from skimage.exposure import equalize_hist
from scipy.ndimage import gaussian_filter
Filip Lux's avatar
Filip Lux committed
from warnings import warn

Filip Lux's avatar
Filip Lux committed
from ..biovol_typing import TypeTripletFloat, TypeSpatioTemporalCoordinate, TypeSextetInt, TypeSpatialShape
from .sitk_utils import get_affine_transform, apply_sitk_transform
from .utils import is_included
Filip Lux's avatar
Filip Lux committed


"""
vol: [C, D, H, W (, T)]

you should give (D, H, W) form shape.

skimage interpolation notations:

order = 0: Nearest-Neighbor
order = 1: Bi-Linear (default)
order = 2: Bi-Quadratic
order = 3: Bi-Cubic
order = 4: Bi-Quartic
order = 5: Bi-Quintic

Interpolation behaves strangely when input of type int.
** Be sure to change volume and mask data type to float !!! **  (already done by Float() in compose - TODO not for int-mask)
Filip Lux's avatar
Filip Lux committed

But for parameters use primarily ints.
"""


# TODO parameter
# Anti-aliasing - gaussian filter to smooth. using automatically when downsampling, except when integer
# and interpolation is 0. (so mask)
# float mask - how, for now no gaussian filter.
def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, mask=False,
           anti_aliasing_downsample=True):
    # TODO: random fix, check if it is correct
    new_shape = list(input_new_shape)[:-1]
Filip Lux's avatar
Filip Lux committed

    # Zero or negative check
    for dimension in new_shape:
        if dimension <= 0:
            warn(f"Resize(): shape: {new_shape} contains zero or negative number, continuing without Resize.",
                 UserWarning)
            return img

    # shape check
    if mask:
        # too many or few dimensions of new_shape
        if len(new_shape) < len(img.shape) - 1 or len(new_shape) > len(img.shape):
            warn(f"Resize(): wrong parameter shape:  {new_shape}," +
                 f"expecting something with dimensions of {img.shape} or {img.shape[0:-1]}, " +
Filip Lux's avatar
Filip Lux committed
                 "continuing without resizing ", UserWarning)
            return img
        # Adding time dimension
        elif len(new_shape) == len(img.shape) - 1:
            new_shape = np.append(new_shape, img.shape[-1])
    else:
        if len(new_shape) < len(img.shape[1:]) - 1 or len(new_shape) > len(img.shape[1:]):
            warn(f"Resize(): wrong dimensions of shape:  {new_shape}," +
                 f"expecting something with dimensions of {img.shape[1:]} or {img.shape[1:-1]}, continuing " +
Filip Lux's avatar
Filip Lux committed
                 "without resizing ", UserWarning)
            return img
        # adding time dimension
        elif len(new_shape) == len(img.shape[1:]) - 1:
            new_shape = np.append(new_shape, img.shape[-1])

    anti_aliasing = False
    if mask:
        new_img = skt.resize(
            img,
            new_shape,
            order=interpolation,
            mode=border_mode,
            cval=cval,
            clip=True,
            anti_aliasing=anti_aliasing
        )
        return new_img
Filip Lux's avatar
Filip Lux committed
    if anti_aliasing_downsample and np.any(np.array(img.shape[1:]) < np.array(new_shape)):
        anti_aliasing = True
Filip Lux's avatar
Filip Lux committed
    data = []
    for i in range(img.shape[0]):
        subimg = img[i].copy()
        d0 = skt.resize(
            subimg,
            new_shape,
            order=interpolation,
            mode=border_mode,
            cval=cval,
            clip=True,
            anti_aliasing=anti_aliasing
        )
        data.append(d0.copy())
    new_img = np.stack(data, axis=0)
Filip Lux's avatar
Filip Lux committed
    return new_img


def resize_keypoints(keypoints,
                     domain_limit: TypeSpatioTemporalCoordinate,
                     new_shape: TypeSpatioTemporalCoordinate):
    assert len(domain_limit) == len(new_shape) == 4

    # for each dim compute ratio
    ratio = np.array(new_shape[:3]) / np.array(domain_limit[:3])

    # it supposes that length of keypoint is 3
    return [keypoint * ratio for keypoint in keypoints]


Filip Lux's avatar
Filip Lux committed
def affine(img: np.array,
           degrees: TypeTripletFloat = (0, 0, 0),
           scales: TypeTripletFloat = (1, 1, 1),
           translation: TypeTripletFloat = (0, 0, 0),
           interpolation: str = 'linear',
Filip Lux's avatar
Filip Lux committed
           border_mode: str = 'constant',
           value: float = 0,
           spacing: TypeTripletFloat = (1, 1, 1)):
    """
    img (np.array) : format (channel, ax1, ax2, ax3, [time])
    """
    shape = img.shape[1:]
    transform = get_affine_transform(shape,
Filip Lux's avatar
Filip Lux committed
                                     scales=scales,
                                     degrees=degrees,
                                     translation=translation,
                                     spacing=spacing)

    return apply_sitk_transform(img,
                                sitk_transform=transform,
                                interpolation=interpolation,
Filip Lux's avatar
Filip Lux committed
                                default_value=value,
                                spacing=spacing)


def affine_keypoints(keypoints: list,
                     domain_limit: TypeSpatioTemporalCoordinate,
                     degrees: TypeTripletFloat = (0, 0, 0),
                     scales: TypeTripletFloat = (1, 1, 1),
                     translation: TypeTripletFloat = (0, 0, 0),
                     border_mode: str = 'constant',
                     keep_all: bool = False,
                     spacing: TypeTripletFloat = (1, 1, 1)):
    """

    Args:
        keypoints: list of input keypoints
        domain_limit: limit of the domain, there keyp-points can appear, it is used to define center of transforms
                and to filter out output key-point from the outside of the domain
        degrees:
        scales:
        translation:
        border_mode: not used
        keep_all: True to keep also key_point frou poutside the domain
        spacing: relative voxel size

    Returns: list of transformed key-points

    """
    transform = get_affine_transform(domain_limit,
                                     scales=scales,
                                     degrees=degrees,
                                     translation=translation,
                                     spacing=spacing)

    transform = transform.GetInverse()

    res = []
    for point in keypoints:
        transformed_point = transform.TransformPoint(point)
        if keep_all or is_included(domain_limit, transformed_point):
            res.append(transformed_point)
    return res


# Used in rot90_keypoints
def flip_keypoints(keypoints, axes, img_shape):
Filip Lux's avatar
Filip Lux committed

    # all values in axes are in [1, 2, 3]
    assert np.all(np.array([ax in [1, 2, 3] for ax in axes])), f'{axes} does not contain values from [1, 2, 3]'
Filip Lux's avatar
Filip Lux committed

    mult, add = np.ones(3, int), np.zeros(3, int)
    for ax in axes:
        mult[ax-1] = -1
        add[ax-1] = img_shape[ax-1] - 1

    res = []
    for k in keypoints:
        flipped = list(np.array(k[:3]) * mult + add)
        if len(k) == 4:
            flipped.append(k[-1])
        res.append(tuple(flipped))
    return res
Filip Lux's avatar
Filip Lux committed


# Used in rot90_keypoints
def transpose_keypoints(keypoints, ax1, ax2):

    # all values in axes are in [1, 2, 3]
    assert (ax1 in [1, 2, 3]) and (ax2 in [1, 2, 3]), f'[{ax1} {ax2}] does not contain values from [1, 2, 3]'

    res = []
    for k in keypoints:
        k = list(k)
        k[ax1-1], k[ax2-1] = k[ax2-1], k[ax1-1]
        res.append(tuple(k))
    return res


def rot90_keypoints(keypoints, factor, axes, img_shape):

    if factor == 1:
        keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)
        keypoints = transpose_keypoints(keypoints, axes[0], axes[1])

    elif factor == 2:
        keypoints = flip_keypoints(keypoints, axes, img_shape)
    elif factor == 3:

        keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
        keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)

    return keypoints


def pad_keypoints(keypoints, pad_size):
    a, b, c, d, e, f = pad_size

    res = []
    for coo in keypoints:
        padding = np.array((a, c, e)) if len(coo) == 3 else np.array((a, c, e, 0))
        res.append(coo + padding)
    return res


def pad_pixels(img, input_pad_width: TypeSextetInt, border_mode, cval, mask=False):
    a, b, c, d, e, f = input_pad_width
    pad_width = [(a, b), (c, d), (e, f)]

    # zeroes for channel dimension
    if not mask:
        pad_width = [(0, 0)] + pad_width

    # zeroes for temporal dimension
    if len(img.shape) == 5:  # if len(img.shape) > len(pad_width):
        pad_width = pad_width + [(0, 0)]

    assert len(img.shape) == len(pad_width)

    if border_mode == "constant":
        return np.pad(img, pad_width, border_mode, constant_values=cval)
    if border_mode == "linear_ramp":
        return np.pad(img, pad_width, border_mode, end_values=cval)
    return np.pad(img, pad_width, border_mode)


# Used in crop()
def get_spatial_shape(array: np.array, mask: bool) -> TypeSpatialShape:
    return np.array(array.shape)[:3] if mask else np.array(array.shape)[1:4]


# Used in crop()
def get_pad_dims(spatial_shape: TypeSpatialShape, crop_shape: TypeSpatialShape) -> TypeSextetInt:
    pad_dims = [0] * 6
    for i in range(3):
        i_dim, c_dim = spatial_shape[i], crop_shape[i]
        current_pad_dims = (0, 0)
        if i_dim < c_dim:
            pad_size = c_dim - i_dim
            if pad_size % 2 != 0:
                current_pad_dims = (int(pad_size // 2 + 1), int(pad_size // 2))
                current_pad_dims = (int(pad_size // 2), int(pad_size // 2))

        pad_dims[i * 2:(i + 1) * 2] = current_pad_dims

    return tuple(pad_dims)


def crop(input_array: np.array,
         crop_shape: TypeSpatialShape,
         crop_position: TypeSpatialShape,
         pad_dims,
         border_mode, cval, mask):

    input_spatial_shape = get_spatial_shape(input_array, mask)

    if np.any(input_spatial_shape < crop_shape):
        warn(f'F.crop(): Input size {input_spatial_shape} smaller than crop size {crop_shape}, pad by {border_mode}.',
             UserWarning)

        # pad
        input_array = pad_pixels(input_array, pad_dims, border_mode, cval, mask)

        # test
        input_spatial_shape = get_spatial_shape(input_array, mask)
        assert np.all(input_spatial_shape >= crop_shape)

    x1, y1, z1 = crop_position
    x2, y2, z2 = np.array(crop_position) + np.array(crop_shape)

    if mask:
        result = input_array[x1:x2, y1:y2, z1:z2]
        assert np.all(result.shape[:3] == crop_shape), f'{result.shape} {crop_shape} {mask} {crop_position}'
    else:
        result = input_array[:, x1:x2, y1:y2, z1:z2]
        assert np.all(result.shape[1:4] == crop_shape)

    return result


def crop_keypoints(keypoints,
                   crop_shape: TypeSpatialShape,
                   crop_position: TypeSpatialShape,
                   pad_dims,
                   keep_all: bool):

    (px, _), (py, _), (pz, _) = pad_dims
    pad = np.array((px, py, pz))

    res = []
    for keypoint in keypoints:
        k = keypoint[:3] - crop_position + pad
        if keep_all or (np.all(k >= 0) and np.all((k + .5) < crop_shape)):
            res.append(k)

    return res
Filip Lux's avatar
Filip Lux committed


def gaussian_blur(img, input_sigma, border_mode, cval):
    sigma = input_sigma
    if isinstance(sigma, list):
        if img.shape[0] != len(sigma):
            warn(
                f'GaussianBlur(): wrong list size {len(sigma)}, expecting same as number of dimensions {img.shape[0]}. Ignoring',
                UserWarning)
Filip Lux's avatar
Filip Lux committed
            return img
        return gaussian_blur_stack(img, sigma, border_mode, cval)

    if isinstance(sigma, (int, float)):
        sigma = np.repeat(sigma, len(img.shape))
        sigma[0] = 0
        # Checking for time dimension
        if len(img.shape) > 4:
            sigma[-1] = 0
    else:
        # TODO what to expect in the input.
        if len(sigma) == len(img.shape) - 2:
            sigma = np.append(sigma, 0)
        if len(sigma) == len(img.shape) - 1:
Filip Lux's avatar
Filip Lux committed
            sigma = np.insert(sigma, 0, 0)
    # TODO better warning
    if len(sigma) != len(img.shape):
        warn(f'GaussianBlur(): wrong sigma tuple, ignoring', UserWarning)
        return img
    return gaussian_filter(img, sigma=sigma, mode=border_mode, cval=cval)
Filip Lux's avatar
Filip Lux committed

def gaussian_blur_stack(img, input_sigma, border_mode, cval):
    sigma = list(np.asarray(input_sigma).copy())
    # simple sigma check
    for channel in sigma:
        if not isinstance(channel, (float, int, tuple)):
            warn(f'GaussianBlur(): wrong sigma format, Inside list can be only tuple,float or int. Ignoring',
                 UserWarning)
            return img
Filip Lux's avatar
Filip Lux committed
    # TODO try different techniques for better optimalization.
    for i in range(len(sigma)):
        if isinstance(sigma[i], (float, int)):
            sigma[i] = np.repeat(sigma[i], len(img.shape) - 1)
            if len(sigma[i]) >= 4:
                sigma[i][-1] = 0
        else:
            if len(sigma[i]) == len(img.shape) - 2:
                sigma[i] = np.append(sigma[i], 0)
        img[i] = gaussian_filter(img[i], sigma=sigma[i], mode=border_mode, cval=cval)
    return img


# TODO clipped tag may be important for types other that float32, but tags are from fork and not tested
# @clipped
def brightness_contrast_adjust(img, alpha=1, beta=0):
    if alpha != 1:
        img *= alpha
    if beta != 0:
        img += beta
    return img


Filip Lux's avatar
Filip Lux committed
def gamma_transform(img, gamma):
    if np.all(img < 0) or np.all(img > 1) :
        warn(f"Gamma transform: image is not in range [0,1]. continuing without transform", UserWarning)
        return img
    else:
        return np.power(img, gamma)


def histogram_equalization(img, bins):
    for i in range(img.shape[0]):
        img[i] = equalize_hist(img[i], bins)
    return img


def gaussian_noise(img, mean, sigma):
    img = img.astype("float32")
    noise = np.random.normal(mean, sigma, img.shape).astype(np.float32)
    return img + noise


def poisson_noise(img, peak):
    img = img.astype("float32")
    return img + np.random.poisson(img).astype(np.float32)


def value_to_list(value, length):
    if isinstance(value, (float, int)):
        return [value for _ in range(length)]
    else:
        return value


def correct_length_list(list_to_check, length, value_to_fill=1, list_name="###Default###"):
    if len(list_to_check) < length:
        warn(f"{list_name} have elements {len(list_to_check)}, should be {length} appending {value_to_fill} " +
             "till length matches", UserWarning)
        for i in range(length - len(list_to_check)):
            list_to_check = list_to_check + [value_to_fill]
    if len(list_to_check) > length:
        warn(f"{list_name} have elements {len(list_to_check)}, should be {length} removing elements from behind " +
             " till length matches", UserWarning)
        list_to_check = [list_to_check[i] for i in range(length)]
    return list_to_check


# formula taken from
# https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation
def normalize_channel(img, mean, std):
    return (img - img.mean()) * (std / img.std()) + mean


def normalize(img, input_mean, input_std):
    mean = value_to_list(input_mean, img.shape[0])
    std = value_to_list(input_std, img.shape[0])

    mean = correct_length_list(mean, img.shape[0], value_to_fill=0, list_name="mean")
    std = correct_length_list(std, img.shape[0], value_to_fill=1, list_name="std")

    for i in range(img.shape[0]):
        img[i] = normalize_channel(img[i], mean[i], std[i])
    return img


def normalize_mean_std(img, mean, denominator):
    if len(mean.shape) == 0:
        mean = mean[..., None]
    if len(denominator.shape) == 0:
        denominator = denominator[..., None]
    new_axis = [i + 1 for i in range(len(img.shape) - 1)]
    img -= np.expand_dims(mean, axis=new_axis)
    img *= np.expand_dims(denominator, axis=new_axis)
    return img