# ============================================================================================= # # Author: Pavel Iakubovskii, ZFTurbo, ashawkey, Dominik Müller, # # Samuel Šuľan, Lucia Hradecká, Filip Lux # # Copyright: albumentations: : https://github.com/albumentations-team # # Pavel Iakubovskii : https://github.com/qubvel # # ZFTurbo : https://github.com/ZFTurbo # # ashawkey : https://github.com/ashawkey # # Dominik Müller : https://github.com/muellerdo # # Lucia Hradecká : lucia.d.hradecka@gmail.com # # Filip Lux : lux.filip@gmail.com # # Samuel Šuľan # # # # Volumentations History: # # - Original: https://github.com/albumentations-team/albumentations # # - 3D Conversion: https://github.com/ashawkey/volumentations # # - Continued Development: https://github.com/ZFTurbo/volumentations # # - Enhancements: https://github.com/qubvel/volumentations # # - Further Enhancements: https://github.com/muellerdo/volumentations # # - Biomedical Enhancements: https://gitlab.fi.muni.cz/cbia/bio-volumentations # # # # MIT License. # # # # Permission is hereby granted, free of charge, to any person obtaining a copy # # of this software and associated documentation files (the "Software"), to deal # # in the Software without restriction, including without limitation the rights # # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # # copies of the Software, and to permit persons to whom the Software is # # furnished to do so, subject to the following conditions: # # # # The above copyright notice and this permission notice shall be included in all # # copies or substantial portions of the Software. # # # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # # SOFTWARE. # # ============================================================================================= # import numpy as np from functools import wraps import skimage.transform as skt from skimage.exposure import equalize_hist from scipy.ndimage import zoom, gaussian_filter from warnings import warn from typing import Union from ..biovol_typing import TypeTripletFloat, TypeSpatioTemporalCoordinate, TypeSextetInt, TypeSpatialShape from .spatial_functional import get_affine_transform, apply_sitk_transform from .utils import is_included MAX_VALUES_BY_DTYPE = { np.dtype("uint8"): 255, np.dtype("uint16"): 65535, np.dtype("uint32"): 4294967295, np.dtype("float32"): 1.0, } # SITK interpolations SITK_interpolation = { 0: 'sitkNearestNeighbor', 1: 'sitkLinear', 2: 'sitkBSpline', 3: 'sitkGaussian' } """ vol: [C, D, H, W (, T)] you should give (D, H, W) form shape. skimage interpolation notations: order = 0: Nearest-Neighbor order = 1: Bi-Linear (default) order = 2: Bi-Quadratic order = 3: Bi-Cubic order = 4: Bi-Quartic order = 5: Bi-Quintic Interpolation behaves strangely when input of type int. ** Be sure to change volume and mask data type to float !!! ** (already done by Float() in compose) But for parameters use primarily ints. """ def preserve_shape(func): """ Preserve shape of the image """ @wraps(func) def wrapped_function(img, *args, **kwargs): shape = img.shape result = func(img, *args, **kwargs) result = result.reshape(shape) return result return wrapped_function def get_center_crop_coords(img_shape, crop_shape): froms = (img_shape - crop_shape) // 2 tos = froms + crop_shape return froms, tos # Too similar to the random_crop. Could be made into one function def crop(input_array: np.array, crop_shape: TypeSpatialShape, crop_position: TypeSpatialShape, pad_dims, border_mode, cval, mask): input_spatial_shape = get_spatial_shape(input_array, mask) if np.any(input_spatial_shape < crop_shape): warn(f'F.crop(): Input size {input_spatial_shape} smaller than crop size {crop_shape}, pad by {border_mode}.', UserWarning) # pad input_array = pad(input_array, pad_dims, border_mode, cval, mask) # test input_spatial_shape = get_spatial_shape(input_array, mask) assert np.all(input_spatial_shape >= crop_shape) x1, y1, z1 = crop_position x2, y2, z2 = np.array(crop_position) + np.array(crop_shape) if mask: result = input_array[x1:x2, y1:y2, z1:z2] assert np.all(result.shape[:3] == crop_shape), f'{result.shape} {crop_shape} {mask} {crop_position}' else: result = input_array[:, x1:x2, y1:y2, z1:z2] assert np.all(result.shape[1:4] == crop_shape) return result def crop_keypoints(keypoints, crop_shape: TypeSpatialShape, crop_position: TypeSpatialShape, pad_dims, keep_all: bool): (px, _), (py, _), (pz, _) = pad_dims pad = np.array((px, py, pz)) res = [] for keypoint in keypoints: k = keypoint[:3] - crop_position + pad if keep_all or (np.all(k >= 0) and np.all((k + .5) < crop_shape)): res.append(k) return res def get_spatial_shape(array: np.array, mask: bool) -> TypeSpatialShape: return np.array(array.shape)[:3] if mask else np.array(array.shape)[1:4] def get_pad_dims(spatial_shape: TypeSpatialShape, crop_shape: TypeSpatialShape): pad_dims = [] for i in range(3): i_dim, c_dim = spatial_shape[i], crop_shape[i] if i_dim < c_dim: pad_size = c_dim - i_dim if pad_size % 2 != 0: pad_dims.append((int(pad_size // 2 + 1), int(pad_size // 2))) else: pad_dims.append((int(pad_size // 2), int(pad_size // 2))) else: pad_dims.append((0, 0)) return pad_dims def pad(img, pad_width, border_mode, cval, mask=True): if not mask: pad_width = [(0, 0)] + pad_width if len(img.shape) > len(pad_width): pad_width = pad_width + [(0, 0)] assert len(img.shape) == len(pad_width) if border_mode == "constant": return np.pad(img, pad_width, border_mode, constant_values=cval) if border_mode == "linear_ramp": return np.pad(img, pad_width, border_mode, end_values=cval) result = np.pad(img, pad_width, border_mode) return result def pad_keypoints(keypoints, pad_size): a, b, c, d, e, f = pad_size res = [] for coo in keypoints: padding = np.array((a, c, e)) if len(coo) == 3 else np.array((a, c, e, 0)) res.append(coo + padding) return res def flip_keypoints(keypoints, axes, img_shape): # all values in axes are in [1, 2, 3] assert np.all(np.array([ax in [1, 2, 3] for ax in axes])), f'{axes} does not contain values from [1, 2, 3]' mult, add = np.ones(3, int), np.zeros(3, int) for ax in axes: mult[ax-1] = -1 add[ax-1] = img_shape[ax-1] - 1 res = [] for k in keypoints: flipped = list(np.array(k[:3]) * mult + add) if len(k) == 4: flipped.append(k[-1]) res.append(tuple(flipped)) return res def rot90_keypoints(keypoints, factor, axes, img_shape): if factor == 1: keypoints = flip_keypoints(keypoints, [axes[1]], img_shape) keypoints = transpose_keypoints(keypoints, axes[0], axes[1]) elif factor == 2: keypoints = flip_keypoints(keypoints, axes, img_shape) elif factor == 3: keypoints = transpose_keypoints(keypoints, axes[0], axes[1]) keypoints = flip_keypoints(keypoints, [axes[1]], img_shape) return keypoints def transpose_keypoints(keypoints, ax1, ax2): # all values in axes are in [1, 2, 3] assert (ax1 in [1, 2, 3]) and (ax2 in [1, 2, 3]), f'[{ax1} {ax2}] does not contain values from [1, 2, 3]' res = [] for k in keypoints: k = list(k) k[ax1-1], k[ax2-1] = k[ax2-1], k[ax1-1] res.append(tuple(k)) return res def pad_pixels(img, input_pad_width: TypeSextetInt, border_mode, cval, mask=False): a, b, c, d, e, f = input_pad_width pad_width = [(a, b), (c, d), (e, f)] # zeroes for channel dimension if not mask: pad_width = [(0, 0)] + pad_width # zeroes for temporal dimension if len(img.shape) == 5: pad_width = pad_width + [(0, 0)] if border_mode == "constant": return np.pad(img, pad_width, border_mode, constant_values=cval) if border_mode == "linear_ramp": return np.pad(img, pad_width, border_mode, end_values=cval) return np.pad(img, pad_width, border_mode) def normalize_mean_std(img, mean, denominator): if len(mean.shape) == 0: mean = mean[..., None] if len(denominator.shape) == 0: denominator = denominator[..., None] new_axis = [i + 1 for i in range(len(img.shape) - 1)] img -= np.expand_dims(mean, axis=new_axis) img *= np.expand_dims(denominator, axis=new_axis) return img # formula taken from # https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation def normalize_channel(img, mean, std): return (img - img.mean()) * (std / img.std()) + mean def value_to_list(value, length): if isinstance(value, (float, int)): return [value for _ in range(length)] else: return value def correct_length_list(list_to_check, length, value_to_fill=1, list_name="###Default###"): if len(list_to_check) < length: warn(f"{list_name} have elements {len(list_to_check)}, should be {length} appending {value_to_fill} " + "till length matches", UserWarning) for i in range(length - len(list_to_check)): list_to_check = list_to_check + [value_to_fill] if len(list_to_check) > length: warn(f"{list_name} have elements {len(list_to_check)}, should be {length} removing elements from behind " + " till length matches", UserWarning) list_to_check = [list_to_check[i] for i in range(length)] return list_to_check def normalize(img, input_mean, input_std): mean = value_to_list(input_mean, img.shape[0]) std = value_to_list(input_std, img.shape[0]) mean = correct_length_list(mean, img.shape[0], value_to_fill=0, list_name="mean") std = correct_length_list(std, img.shape[0], value_to_fill=1, list_name="std") for i in range(img.shape[0]): img[i] = normalize_channel(img[i], mean[i], std[i]) return img def gaussian_noise(img, mean, sigma): img = img.astype("float32") noise = np.random.normal(mean, sigma, img.shape).astype(np.float32) return img + noise def poisson_noise(img, peak): img = img.astype("float32") return img + np.random.poisson(img).astype(np.float32) # TODO parameter # Anti-aliasing - gaussian filter to smooth. using automatically when downsampling, except when integer # and interpolation is 0. (so mask) # float mask - how, for now no gaussian filter. def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, mask=False, anti_aliasing_downsample=True): # TODO: random fix, check if it is correct new_shape = list(input_new_shape)[:-1] # Zero or negative check for dimension in new_shape: if dimension <= 0: warn(f"Resize(): shape: {new_shape} contains zero or negative number, continuing without Resize.", UserWarning) return img # shape check if mask: # too many or few dimensions of new_shape if len(new_shape) < len(img.shape) - 1 or len(new_shape) > len(img.shape): warn(f"Resize(): wrong parameter shape: {new_shape}," + f"expecting something with dimensions of {img.shape } or {img.shape[0:-1] }, " + "continuing without resizing ", UserWarning) return img # Adding time dimension elif len(new_shape) == len(img.shape) - 1: new_shape = np.append(new_shape, img.shape[-1]) else: if len(new_shape) < len(img.shape[1:]) - 1 or len(new_shape) > len(img.shape[1:]): warn(f"Resize(): wrong dimensions of shape: {new_shape}," + f"expecting something with dimensions of {img.shape[1:] } or {img.shape[1:-1] }, continuing " + "without resizing ", UserWarning) return img # adding time dimension elif len(new_shape) == len(img.shape[1:]) - 1: new_shape = np.append(new_shape, img.shape[-1]) anti_aliasing = False if mask: new_img = skt.resize( img, new_shape, order=interpolation, mode=border_mode, cval=cval, clip=True, anti_aliasing=anti_aliasing ) return new_img if anti_aliasing_downsample and np.any(np.array(img.shape[1:]) < np.array(new_shape)): anti_aliasing = True data = [] for i in range(img.shape[0]): subimg = img[i].copy() d0 = skt.resize( subimg, new_shape, order=interpolation, mode=border_mode, cval=cval, clip=True, anti_aliasing=anti_aliasing ) data.append(d0.copy()) new_img = np.stack(data, axis=0) return new_img def resize_keypoints(keypoints, domain_limit: TypeSpatioTemporalCoordinate, new_shape: TypeSpatioTemporalCoordinate): assert len(domain_limit) == len(new_shape) == 4 # for each dim compute ratio ratio = np.array(new_shape[:3]) / np.array(domain_limit[:3]) # it supposes that length of keypoint is 3 return [keypoint * ratio for keypoint in keypoints] # TODO compare with skt.rescale, new version got channel_axis def scale(img, input_scale_factor, interpolation=0, border_mode='reflect', cval=0, mask=True): scale_factor = input_scale_factor # check for zero or negative numbers if isinstance(scale_factor, (int, float)): if scale_factor <= 0: warn(f"RandomScale()/Scale(): scale_factor: {len(scale_factor)} is zero or negative number" + f" continuing without scaling ", UserWarning) return img else: for dimension in scale_factor: if dimension <= 0: warn(f"RandomScale()/Scale(): scale_factor: {len(scale_factor)} contains zero or negative number " + "continuing without scaling ", UserWarning) return img img_shape = img.shape if scale_factor is None: return img if isinstance(scale_factor, (list, tuple)): scale_factor = np.array(scale_factor) if not mask: img_shape = img_shape[1:] # TODO, maybe user wants to add shape for only spatial dimensions if len(img_shape) != len(scale_factor) and len(img_shape) - 1 != len(scale_factor): warn(f"RandomScale()/Scale(): Wrong dimension of scaling factor list: {len(scale_factor)}," + f"expecting {len(img_shape)} or {len(img_shape[:-1]) }, continuing without scaling ", UserWarning) return img elif len(img_shape) - 1 == len(scale_factor): scale_factor = np.append(scale_factor, 1) else: scale_factor = [scale_factor for _ in range(len(img_shape) - 1)] if mask: scale_factor.append(scale_factor[0]) # Not scaling time dimensions if len(scale_factor) == 4: scale_factor[-1] = 1 if mask: return zoom(img, scale_factor, order=interpolation, mode=border_mode, cval=cval) data = [] for i in range(img.shape[0]): subimg = img[i].copy() d0 = zoom(subimg, scale_factor, order=interpolation, mode=border_mode, cval=cval) data.append(d0.copy()) new_img = np.stack(data, axis=0) return new_img ''' #TODO maybe add parameter for order of rotations #LIMIT dimensions def affine_transform(img, input_x_angle, input_y_angle, input_z_angle, translantion, interpolation = 1, border_mode = 'constant', value = 0, input_scaling_coef = None, scale_back = True, mask = False ): if mask: img = img[np.newaxis, :] x_angle, y_angle, z_angle = [np.pi * i / 180 for i in [input_x_angle, input_y_angle, input_z_angle]] if not(input_scaling_coef is None): scaling_coef = np.array(input_scaling_coef) #no scaling on the channels if the scaling_coef is in wrong format if(len(scaling_coef) != 3): warn(f"Rotate transform: Wrong dimension of scaling coeficient list: {len(scaling_coef)}, expecting {3}, continuing without scaling ", UserWarning) inverse_affine_matrix = np.linalg.inv(rotation_matrix_calculation(len(img.shape),x_angle,y_angle,z_angle )) else: scaling_coef = np.insert(scaling_coef, 0, 1 ) if len(scaling_coef) < len(img.shape): scaling_coef = np.append(scaling_coef, 1 ) inverse_scaling_matrix = np.diag([ 1/i for i in scaling_coef]) inverse_rotation_matrix = np.linalg.inv(rotation_matrix_calculation(len(img.shape),x_angle,y_angle,z_angle )) inverse_affine_matrix = inverse_scaling_matrix @ inverse_rotation_matrix if scale_back: inverse_scale_back_matrix = np.diag([ i for i in scaling_coef]) inverse_affine_matrix = inverse_affine_matrix @ inverse_scale_back_matrix else: inverse_affine_matrix = np.linalg.inv(rotation_matrix_calculation(len(img.shape),x_angle,y_angle,z_angle )) c_in=0.5*np.array(img.shape) offset=c_in-inverse_affine_matrix.dot(c_in) if not(translantion is None): if len(translantion) > len(img.shape) - 1: warn(f"Rotate transform(): translation list has wrong length {len(translantion)}, expected {len(img.shape) - 1}", UserWarning) else: for i in range(len(translantion)): offset[i + 1] -= translantion[i] img = sci.affine_transform(img, inverse_affine_matrix, offset, order=interpolation, mode=border_mode, cval= value) if mask: img = img[0] return img ''' def affine(img: np.array, degrees: TypeTripletFloat = (0, 0, 0), scales: TypeTripletFloat = (1, 1, 1), translation: TypeTripletFloat = (0, 0, 0), interpolation: str = 'linear', border_mode: str = 'constant', value: float = 0, spacing: TypeTripletFloat = (1, 1, 1)): """ img (np.array) : format (channel, ax1, ax2, ax3, [time]) """ shape = img.shape[1:] transform = get_affine_transform(shape, scales=scales, degrees=degrees, translation=translation, spacing=spacing) return apply_sitk_transform(img, sitk_transform=transform, interpolation=interpolation, default_value=value, spacing=spacing) def affine_keypoints(keypoints: list, domain_limit: TypeSpatioTemporalCoordinate, degrees: TypeTripletFloat = (0, 0, 0), scales: TypeTripletFloat = (1, 1, 1), translation: TypeTripletFloat = (0, 0, 0), border_mode: str = 'constant', keep_all: bool = False, spacing: TypeTripletFloat = (1, 1, 1)): """ Args: keypoints: list of input keypoints domain_limit: limit of the domain, there keyp-points can appear, it is used to define center of transforms and to filter out output key-point from the outside of the domain degrees: scales: translation: border_mode: not used keep_all: True to keep also key_point frou poutside the domain spacing: relative voxel size Returns: list of transformed key-points """ transform = get_affine_transform(domain_limit, scales=scales, degrees=degrees, translation=translation, spacing=spacing) transform = transform.GetInverse() res = [] for point in keypoints: transformed_point = transform.TransformPoint(point) if keep_all or is_included(domain_limit, transformed_point): res.append(transformed_point) return res # TO REMOVE def rotation_matrix_calculation(dim, x_angle, y_angle, z_angle): rot_matrix = np.identity(dim).astype(np.float32) rot_matrix = rot_matrix @ rot_x(x_angle, dim) rot_matrix = rot_matrix @ rot_y(y_angle, dim) rot_matrix = rot_matrix @ rot_z(z_angle, dim) return rot_matrix def rot_x(angle, dim): if dim == 4: rotation_x = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, np.cos(angle), -np.sin(angle)], [0, 0, np.sin(angle), np.cos(angle)]]) if dim == 5: rotation_x = np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, np.cos(angle), -np.sin(angle), 0], [0, 0, np.sin(angle), np.cos(angle), 0], [0, 0, 0, 0, 1]]) return rotation_x def rot_y(angle, dim): if dim == 4: rotation_y = np.array([[1, 0, 0, 0], [0, np.cos(angle), 0, np.sin(angle)], [0, 0, 1, 0], [0, -np.sin(angle), 0, np.cos(angle)]]) if dim == 5: rotation_y = np.array([[1, 0, 0, 0, 0], [0, np.cos(angle), 0, np.sin(angle), 0], [0, 0, 1, 0, 0], [0, -np.sin(angle), 0, np.cos(angle), 0], [0, 0, 0, 0, 1]]) return rotation_y def rot_z(angle, dim): if dim == 4: rotation_z = np.array([[1, 0, 0, 0], [0, np.cos(angle), -np.sin(angle), 0], [0, np.sin(angle), np.cos(angle), 0], [0, 0, 0, 1]]) if dim == 5: rotation_z = np.array([[1, 0, 0, 0, 0], [0, np.cos(angle), -np.sin(angle), 0, 0], [0, np.sin(angle), np.cos(angle), 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]) return rotation_z # TODO clipped tag may be important for types other that float32, but tags are from fork and not tested # @clipped def brightness_contrast_adjust(img, alpha=1, beta=0): if alpha != 1: img *= alpha if beta != 0: img += beta return img def histogram_equalization(img, bins): for i in range(img.shape[0]): img[i] = equalize_hist(img[i], bins) return img def gaussian_blur(img, input_sigma, border_mode, cval): sigma = input_sigma if isinstance(sigma, list): if img.shape[0] != len(sigma): warn(f'GaussianBlur(): wrong list size {len(sigma)}, expecting same as number of dimensions {img.shape[0]}. Ignoring', UserWarning) return img return gaussian_blur_stack(img, sigma, border_mode, cval) if isinstance(sigma, (int, float)): sigma = np.repeat(sigma, len(img.shape)) sigma[0] = 0 # Checking for time dimension if len(img.shape) > 4: sigma[-1] = 0 else: # TODO what to expect in the input. if len(sigma) == len(img.shape) - 2: sigma = np.append(sigma, 0) if len(sigma) == len(img.shape) - 1: sigma = np.insert(sigma, 0, 0) # TODO better warning if len(sigma) != len(img.shape): warn(f'GaussianBlur(): wrong sigma tuple, ignoring', UserWarning) return img return gaussian_filter(img, sigma=sigma, mode=border_mode, cval=cval) def gaussian_blur_stack(img, input_sigma, border_mode, cval): sigma = list(np.asarray(input_sigma).copy()) # simple sigma check for channel in sigma: if not isinstance(channel, (float, int, tuple)): warn(f'GaussianBlur(): wrong sigma format, Inside list can be only tuple,float or int. Ignoring', UserWarning) return img # TODO try different techniques for better optimalization. for i in range(len(sigma)): if isinstance(sigma[i], (float, int)): sigma[i] = np.repeat(sigma[i], len(img.shape) - 1) if len(sigma[i]) >= 4: sigma[i][-1] = 0 else: if len(sigma[i]) == len(img.shape) - 2: sigma[i] = np.append(sigma[i], 0) img[i] = gaussian_filter(img[i], sigma=sigma[i], mode=border_mode, cval=cval) return img def gamma_transform(img, gamma): if np.all(img < 0) or np.all(img > 1) : warn(f"Gamma transform: image is not in range [0,1]. continuing without transform", UserWarning) return img else: return np.power(img, gamma)