Skip to content
Snippets Groups Projects
functional.py 20.6 KiB
Newer Older
Filip Lux's avatar
Filip Lux committed
# ============================================================================================= #
#  Author:       Pavel Iakubovskii, ZFTurbo, ashawkey, Dominik Müller,                          #
Filip Lux's avatar
Filip Lux committed
#                Samuel Šuľan, Lucia Hradecká, Filip Lux                                        #
Filip Lux's avatar
Filip Lux committed
#  Copyright:    albumentations:    : https://github.com/albumentations-team                    #
#                Pavel Iakubovskii  : https://github.com/qubvel                                 #
#                ZFTurbo            : https://github.com/ZFTurbo                                #
#                ashawkey           : https://github.com/ashawkey                               #
#                Dominik Müller     : https://github.com/muellerdo                              #
#                Lucia Hradecká     : lucia.d.hradecka@gmail.com                                #
#                Filip Lux          : lux.filip@gmail.com                                       #
#                Samuel Šuľan                                                                   #
Filip Lux's avatar
Filip Lux committed
#                                                                                               #
#  Volumentations History:                                                                      #
#       - Original:                 https://github.com/albumentations-team/albumentations       #
#       - 3D Conversion:            https://github.com/ashawkey/volumentations                  #
#       - Continued Development:    https://github.com/ZFTurbo/volumentations                   #
#       - Enhancements:             https://github.com/qubvel/volumentations                    #
#       - Further Enhancements:     https://github.com/muellerdo/volumentations                 #
#       - Biomedical Enhancements:  https://gitlab.fi.muni.cz/cbia/bio-volumentations           #
#                                                                                               #
#  MIT License.                                                                                 #
#                                                                                               #
#  Permission is hereby granted, free of charge, to any person obtaining a copy                 #
#  of this software and associated documentation files (the "Software"), to deal                #
#  in the Software without restriction, including without limitation the rights                 #
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell                    #
#  copies of the Software, and to permit persons to whom the Software is                        #
#  furnished to do so, subject to the following conditions:                                     #
#                                                                                               #
#  The above copyright notice and this permission notice shall be included in all               #
#  copies or substantial portions of the Software.                                              #
#                                                                                               #
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR                   #
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,                     #
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE                  #
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER                       #
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,                #
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE                #
#  SOFTWARE.                                                                                    #
# ============================================================================================= #

import numpy as np
import skimage.transform as skt
from skimage.exposure import equalize_hist
from scipy.ndimage import gaussian_filter
Filip Lux's avatar
Filip Lux committed
from warnings import warn

from .sitk_utils import get_affine_transform, apply_sitk_transform
from .utils import is_included, get_nonchannel_axes, atleast_kd
from src.biovol_typing import TypeTripletFloat, TypeSpatioTemporalCoordinate, TypeSextetInt, TypeSpatialShape
from src.random_utils import normal, poisson
Filip Lux's avatar
Filip Lux committed

"""
vol: [C, D, H, W (, T)]

you should give (D, H, W) form shape.

skimage interpolation notations:

order = 0: Nearest-Neighbor
order = 1: Bi-Linear (default)
order = 2: Bi-Quadratic
order = 3: Bi-Cubic
order = 4: Bi-Quartic
order = 5: Bi-Quintic

Interpolation behaves strangely when input of type int.
** Be sure to change volume and mask data type to float !!! **  (already done by Float() in compose - TODO not for int-mask)
Filip Lux's avatar
Filip Lux committed

But for parameters use primarily ints.
"""


# TODO parameter
# Anti-aliasing - gaussian filter to smooth. using automatically when downsampling, except when integer
# and interpolation is 0. (so mask)
# float mask - how, for now no gaussian filter.
def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, mask=False,
           anti_aliasing_downsample=True):
    # TODO: random fix, check if it is correct
    new_shape = list(input_new_shape)[:-1]
Filip Lux's avatar
Filip Lux committed

    # Zero or negative check
    if np.any(np.asarray(new_shape) <= 0):
        warn(f'Resize(): shape: {new_shape} contains zero or negative number, continuing without Resize.',
             UserWarning)
        return img
Filip Lux's avatar
Filip Lux committed

    # shape check
    if mask:
        # too many or few dimensions of new_shape
        if len(new_shape) < len(img.shape) - 1 or len(new_shape) > len(img.shape):
            warn(f'Resize(): wrong parameter shape:  {new_shape},' +
                 f'expecting something with dimensions of {img.shape} or {img.shape[0:-1]}, ' +
                 'continuing without resizing ', UserWarning)
Filip Lux's avatar
Filip Lux committed
            return img
        # Adding time dimension
        elif len(new_shape) == len(img.shape) - 1:
            new_shape = np.append(new_shape, img.shape[-1])
    else:
        if len(new_shape) < len(img.shape[1:]) - 1 or len(new_shape) > len(img.shape[1:]):
            warn(f'Resize(): wrong dimensions of shape:  {new_shape},' +
                 f'expecting something with dimensions of {img.shape[1:]} or {img.shape[1:-1]}, continuing ' +
                 'without resizing ', UserWarning)
Filip Lux's avatar
Filip Lux committed
            return img
        # adding time dimension
        elif len(new_shape) == len(img.shape[1:]) - 1:
            new_shape = np.append(new_shape, img.shape[-1])

    anti_aliasing = False
    if mask:
        new_img = skt.resize(
            img,
            new_shape,
            order=interpolation,
            mode=border_mode,
            cval=cval,
            clip=True,
            anti_aliasing=anti_aliasing
        )
        return new_img
Filip Lux's avatar
Filip Lux committed
    if anti_aliasing_downsample and np.any(np.array(img.shape[1:]) < np.array(new_shape)):
        anti_aliasing = True
Filip Lux's avatar
Filip Lux committed
    data = []
    for i in range(img.shape[0]):  # for each channel
Filip Lux's avatar
Filip Lux committed
        subimg = img[i].copy()
        d0 = skt.resize(
            subimg,
            new_shape,
            order=interpolation,
            mode=border_mode,
            cval=cval,
            clip=True,
            anti_aliasing=anti_aliasing
        )
        data.append(d0.copy())
    new_img = np.stack(data, axis=0)
Filip Lux's avatar
Filip Lux committed
    return new_img


def resize_keypoints(keypoints,
                     domain_limit: TypeSpatioTemporalCoordinate,
                     new_shape: TypeSpatioTemporalCoordinate):
    assert len(domain_limit) == len(new_shape) == 4

    # for each dim compute ratio
    ratio = np.array(new_shape[:3]) / np.array(domain_limit[:3])

    # (we suppose here that length of keypoint is 3)
    return list(map(tuple, np.asarray(keypoints) * ratio))
Filip Lux's avatar
Filip Lux committed
def affine(img: np.array,
           degrees: TypeTripletFloat = (0, 0, 0),
           scales: TypeTripletFloat = (1, 1, 1),
           translation: TypeTripletFloat = (0, 0, 0),
           interpolation: str = 'linear',
Filip Lux's avatar
Filip Lux committed
           border_mode: str = 'constant',
           value: float = 0,
           spacing: TypeTripletFloat = (1, 1, 1)):
    """
    img (np.array) : format (channel, ax1, ax2, ax3, [time])
    """
    shape = img.shape[1:]  # ignore the channel dimension
    transform = get_affine_transform(shape,
Filip Lux's avatar
Filip Lux committed
                                     scales=scales,
                                     degrees=degrees,
                                     translation=translation,
                                     spacing=spacing)

    return apply_sitk_transform(img,
                                sitk_transform=transform,
                                interpolation=interpolation,
Filip Lux's avatar
Filip Lux committed
                                default_value=value,
                                spacing=spacing)


def affine_keypoints(keypoints: list,
                     domain_limit: TypeSpatioTemporalCoordinate,
                     degrees: TypeTripletFloat = (0, 0, 0),
                     scales: TypeTripletFloat = (1, 1, 1),
                     translation: TypeTripletFloat = (0, 0, 0),
                     border_mode: str = 'constant',
                     keep_all: bool = False,
                     spacing: TypeTripletFloat = (1, 1, 1)):
    """

    Args:
        keypoints: list of input keypoints
        domain_limit: limit of the domain, there keyp-points can appear, it is used to define center of transforms
                and to filter out output key-point from the outside of the domain
        degrees:
        scales:
        translation:
        border_mode: not used
        keep_all: True to keep also key_point frou poutside the domain
        spacing: relative voxel size

    Returns: list of transformed key-points

    """
    transform = get_affine_transform(domain_limit,  # domain_limit is image shape without the channel axis
                                     scales=scales,
                                     degrees=degrees,
                                     translation=translation,
                                     spacing=spacing)

    transform = transform.GetInverse()

    res = []
    for point in keypoints:
        transformed_point = transform.TransformPoint(point)
        if keep_all or is_included(domain_limit, transformed_point):
            res.append(transformed_point)
    return res


# Used in rot90_keypoints
def flip_keypoints(keypoints, axes, img_shape):
    # all values in axes are in [1, 2, 3]
    assert np.all(np.array([ax in [1, 2, 3] for ax in axes])), f'{axes} does not contain values from [1, 2, 3]'
Filip Lux's avatar
Filip Lux committed

    keys = np.asarray(keypoints)

    ndim = keys.shape[1]
    mult = np.ones(ndim, int)
    add = np.zeros(ndim, int)
    for ax in axes:
        mult[ax - 1] = -1
        add[ax - 1] = img_shape[ax - 1] - 1
    keys = keys * mult + add

    return list(map(tuple, keys))
Filip Lux's avatar
Filip Lux committed


# Used in rot90_keypoints
def transpose_keypoints(keypoints, ax1, ax2):
    # all values in axes are in [1, 2, 3]
    assert (ax1 in [1, 2, 3]) and (ax2 in [1, 2, 3]), f'[{ax1} {ax2}] does not contain values from [1, 2, 3]'

    axis1 = ax1 - 1
    axis2 = ax2 - 1
    keys = np.asarray(keypoints)
    keys[:, axis1], keys[:, axis2] = keys[:, axis2], keys[:, axis1].copy()

    # Return a list of tuples
    return list(map(tuple, keys))


def rot90_keypoints(keypoints, factor, axes, img_shape):
    if factor == 1:
        keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)
        keypoints = transpose_keypoints(keypoints, axes[0], axes[1])

    elif factor == 2:
        keypoints = flip_keypoints(keypoints, axes, img_shape)

        keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
        keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)

    return keypoints


def pad_keypoints(keypoints, pad_size):
    a, b, c, d, e, f = pad_size

    keys = np.asarray(keypoints)
    padding = np.asarray((a, c, e) if keys.shape[1] == 3 else (a, c, e, 0))  # we only need the 'before' pad size

    # Return a list of tuples
    return list(map(tuple, keys + padding))


def pad_pixels(img, input_pad_width: TypeSextetInt, border_mode, cval, mask=False):
    # convert the padding argument to appropriate format
    a, b, c, d, e, f = input_pad_width
    pad_width = [(a, b), (c, d), (e, f)]

    # zeroes for channel dimension
    if not mask:
        pad_width = [(0, 0)] + pad_width

    # zeroes for temporal dimension
    if len(img.shape) > len(pad_width):
        pad_width = pad_width + [(0, 0)]

    assert len(img.shape) == len(pad_width)

    if border_mode == 'constant':
        return np.pad(img, pad_width, border_mode, constant_values=cval)
    if border_mode == 'linear_ramp':
        return np.pad(img, pad_width, border_mode, end_values=cval)
    return np.pad(img, pad_width, border_mode)


# Used in crop()
def get_spatial_shape(array: np.array, mask: bool) -> TypeSpatialShape:
    return np.array(array.shape)[:3] if mask else np.array(array.shape)[1:4]  # mask has no channel dim
def get_pad_dims(spatial_shape: TypeSpatialShape, crop_shape: TypeSpatialShape) -> TypeSextetInt:
    pad_dims = [0] * 6
    for i in range(3):  # for each spatial axis
        i_dim, c_dim = spatial_shape[i], crop_shape[i]
        current_pad_dims = (0, 0)
        if i_dim < c_dim:  # if we want larger crop than is the size of the image (in the given axis) --> we must pad:
            pad_size = c_dim - i_dim
            if pad_size % 2 != 0:
                current_pad_dims = (int(pad_size // 2 + 1), int(pad_size // 2))
                current_pad_dims = (int(pad_size // 2), int(pad_size // 2))

        pad_dims[i * 2:(i + 1) * 2] = current_pad_dims  # store the axis padding tuple (before, after) to pad_dims


def crop(input_array: np.array,
         crop_shape: TypeSpatialShape,
         crop_position: TypeSpatialShape,
         pad_dims,
         border_mode, cval, mask):
    input_spatial_shape = get_spatial_shape(input_array, mask)  # get shape for the spatial dims only
    # if we want larger crop than is the size of the image (in any axis), we must pad the axis
    if np.any(input_spatial_shape < crop_shape):
        warn(f'F.crop(): Input size {input_spatial_shape} smaller than crop size {crop_shape}, pad by {border_mode}.',
             UserWarning)

        # pad
        input_array = pad_pixels(input_array, pad_dims, border_mode, cval, mask)

        # test
        input_spatial_shape = get_spatial_shape(input_array, mask)
        assert np.all(input_spatial_shape >= crop_shape)

    x1, y1, z1 = crop_position
    x2, y2, z2 = np.array(crop_position) + np.array(crop_shape)

    if mask:
        result = input_array[x1:x2, y1:y2, z1:z2]
        assert np.all(result.shape[:3] == crop_shape), f'{result.shape} {crop_shape} {mask} {crop_position}'
    else:
        result = input_array[:, x1:x2, y1:y2, z1:z2]
        assert np.all(result.shape[1:4] == crop_shape)

    return result


def crop_keypoints(keypoints,
                   crop_shape: TypeSpatialShape,
                   crop_position: TypeSpatialShape,
                   pad_dims,
                   keep_all: bool):
    px, _, py, _, pz, _ = pad_dims  # we only need the 'before' padding size
    pad = np.asarray((px, py, pz))
    # Compute new keypoint positions
    keys = np.asarray(keypoints)[:, :3] - np.asarray(crop_position) + pad  # ignore the time dimension of keypoints
    # Filter the keypoints
    if not keep_all:
        mask = (keys >= 0) & (keys + .5 < np.asarray(crop_shape))
        keys = keys[np.sum(mask, axis=1) == 3, :]

    # Return a list of tuples
    return list(map(tuple, keys))
Filip Lux's avatar
Filip Lux committed


def gaussian_blur(img, input_sigma, border_mode, cval):
    sigma = input_sigma

    # if sigma is of type list, we have different sigma for each channel --> delegate to function gaussian_blur_stack()
Filip Lux's avatar
Filip Lux committed
    if isinstance(sigma, list):
        if img.shape[0] != len(sigma):
            warn(f'GaussianBlur(): wrong list size ({len(sigma)}), it should equal the number of channels '
                 f'({img.shape[0]}). Skipping the transformation.', UserWarning)
Filip Lux's avatar
Filip Lux committed
            return img
        return gaussian_blur_stack(img, sigma, border_mode, cval)

    # replicate sigma for each dimension if necessary
Filip Lux's avatar
Filip Lux committed
    if isinstance(sigma, (int, float)):
        sigma = np.repeat(sigma, len(img.shape))
        sigma[0] = 0
        # Checking for time dimension
        if len(img.shape) > 4:
            sigma[-1] = 0
    else:
        # TODO what to expect in the input.
        if len(sigma) == len(img.shape) - 2:
            sigma = np.append(sigma, 0)
        if len(sigma) == len(img.shape) - 1:
Filip Lux's avatar
Filip Lux committed
            sigma = np.insert(sigma, 0, 0)

    # check if we have correct format of sigma
Filip Lux's avatar
Filip Lux committed
    # TODO better warning
    if len(sigma) != len(img.shape):
        warn(f'GaussianBlur(): wrong sigma tuple (length does not equal the number of affected dimensions). '
             f'Skipping the transformation.', UserWarning)
Filip Lux's avatar
Filip Lux committed
        return img
Filip Lux's avatar
Filip Lux committed
    return gaussian_filter(img, sigma=sigma, mode=border_mode, cval=cval)
Filip Lux's avatar
Filip Lux committed

def gaussian_blur_stack(img, input_sigma, border_mode, cval):
    sigma = list(np.asarray(input_sigma).copy())
Filip Lux's avatar
Filip Lux committed
    # simple sigma check
    for channel in sigma:
        if not isinstance(channel, (float, int, tuple)):
            warn(f'GaussianBlur(): wrong sigma format: the list can only contain tuple, float or int. '
                 f'Skipping the transformation.', UserWarning)
Filip Lux's avatar
Filip Lux committed
            return img
    # TODO try different techniques for better optimization
    for i in range(len(sigma)):  # for each channel
        if isinstance(sigma[i], (float, int)):  # replicate sigma for each dimension if necessary
Filip Lux's avatar
Filip Lux committed
            sigma[i] = np.repeat(sigma[i], len(img.shape) - 1)
            if len(sigma[i]) >= 4:
                sigma[i][-1] = 0
        else:
            if len(sigma[i]) == len(img.shape) - 2:
                sigma[i] = np.append(sigma[i], 0)
        img[i] = gaussian_filter(img[i], sigma=sigma[i], mode=border_mode, cval=cval)  # compute
Filip Lux's avatar
Filip Lux committed
    return img


def brightness_contrast_adjust(img, alpha=1, beta=0):
    if alpha != 1:
        img *= alpha
    if beta != 0:
        img += beta
    return img


Filip Lux's avatar
Filip Lux committed
def gamma_transform(img, gamma):
    if np.all(img < 0) or np.all(img > 1):
        warn(f'Gamma transform: image is not in range [0, 1]. Skipping the transformation.', UserWarning)
Filip Lux's avatar
Filip Lux committed
        return img
    else:
        return np.power(img, gamma)


def histogram_equalization(img, bins):
    for i in range(img.shape[0]):  # for each channel
        img[i] = equalize_hist(img[i], bins)
    return img


def gaussian_noise(img, mean, sigma):
    img = img.astype('float32')
    noise = normal(mean, sigma, img.shape).astype(np.float32)
    return img + noise


def poisson_noise(img, peak):
    img = img.astype('float32')
    return img + poisson(img).astype(np.float32)


def value_to_list(value, length):
    if isinstance(value, (float, int)):
        return value  # TODO: maybe return list(value)?
def correct_length_list(list_to_check, length, value_to_fill=1, list_name='###Default###'):
    if len(list_to_check) < length:
        warn(f'{list_name} have elements {len(list_to_check)}, should be {length} appending {value_to_fill} ' +
             'till length matches', UserWarning)
        for i in range(length - len(list_to_check)):
            list_to_check = list_to_check + [value_to_fill]
    if len(list_to_check) > length:
        warn(f'{list_name} have elements {len(list_to_check)}, should be {length} removing elements from behind ' +
             ' till length matches', UserWarning)
        list_to_check = [list_to_check[i] for i in range(length)]
    return list_to_check


def normalize(img, input_mean, input_std):
    Normalize a multi-channel image to have the desired mean and standard deviation values.
    Formula from: https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation

    mean = value_to_list(input_mean, img.shape[0])
    std = value_to_list(input_std, img.shape[0])

    mean = correct_length_list(mean, img.shape[0], value_to_fill=0, list_name='mean')
    std = correct_length_list(std, img.shape[0], value_to_fill=1, list_name='std')
    mean = atleast_kd(mean, img.ndim)
    std = atleast_kd(std, img.ndim)
    img_mean = atleast_kd(img.mean(axis=get_nonchannel_axes(img)), img.ndim)
    img_std = atleast_kd(img.std(axis=get_nonchannel_axes(img)), img.ndim)

    if np.any(np.isclose(img_std, 0)):
        warn(f'Normalize(): standard deviation of at least one input channel is 0. Skipping this transformation.',
             UserWarning)
        return img

    img = (img - img_mean) * (std / img_std) + mean

    return img.astype(img.dtype)


def normalize_mean_std(img, mean, denominator):
    img -= atleast_kd(mean, k=img.ndim)
    img *= atleast_kd(denominator, k=img.ndim)
    return img