# ============================================================================================= # # Author: Filip Lux, Lucia Hradecká # # Copyright: Filip Lux lux.filip@gmail.com # # Lucia Hradecká lucia.d.hradecka@gmail.com # # # # MIT License. # # # # Permission is hereby granted, free of charge, to any person obtaining a copy # # of this software and associated documentation files (the "Software"), to deal # # in the Software without restriction, including without limitation the rights # # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # # copies of the Software, and to permit persons to whom the Software is # # furnished to do so, subject to the following conditions: # # # # The above copyright notice and this permission notice shall be included in all # # copies or substantial portions of the Software. # # # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # # SOFTWARE. # # ============================================================================================= # from typing import Union, Optional from collections.abc import Iterable import numpy as np from src.biovol_typing import TypeSextetFloat, TypeTripletFloat, TypePairFloat, \ TypeSpatioTemporalCoordinate, TypeSpatialCoordinate, TypePairInt, TypeSextetInt from src.random_utils import uniform DEBUG = False def get_nonchannel_axes(array): """ Return the non-channel axis indices for a given image. """ return tuple(range(1, array.ndim)) def atleast_kd(array, k): """ Add singleton dimensions to the input array s.t. the new shape is at least k-dimensional. """ array = np.asarray(array) new_shape = array.shape + (1,) * (k - array.ndim) return array.reshape(new_shape) def get_sigma_axiswise(min_sigma, max_sigma): """ Randomly choose a single sigma for all axes and channels (if max_sigma is int or float) or a sigma for each axis (except the channel axis). """ sigma = uniform(min_sigma, max_sigma) if isinstance(max_sigma, tuple): # If tuple on input, we must return a tuple (not np.ndarray) sigma = tuple(sigma) return sigma def get_spatial_shape_from_image(data, targets): # Image is always [C, D, H, W] or [C, D, H, W, T] return np.array(data[get_first_img_keyword(targets)].shape[1:4]) def parse_limits(input_limit: Union[float, TypePairFloat, TypeTripletFloat, TypeSextetFloat], scale: bool = False) -> TypeSextetFloat: """ Parse the limits of affine transformation: rotation, scaling, or translation. Args: input_limit: transformation limits (type None, float, tuple of 2 floats, tuple of 3 floats, tuple of 6 floats) scale: a flag (True if computing limits for scaling, False otherwise) Returns: a tuple of 6 floats representing the limits for all 3 spatial dimensions. input_limit = None --> return (0., 0., 0., 0., 0., 0.) input_limit = x : float --> return (1/x, x, 1/x, x, 1/x, x) if scale, else (-x, +x, -x, +x, -x, +x) input_limit = (a, b) : TypePairFloat --> return (a, b, a, b, a, b) input_limit = (a, b, c) : TypeTripletFloat --> return (1/a, a, 1/b, b, 1/c, c) if scale, else (-a, +a, -b, +b, -c, +c) input_limit = ((a, b), (c, d), (e, f)) : TypeTripletFloat --> return (a, b, c, d, e, f) input_limit = (a, b, c, d, e, f) : TypeSextetFloat --> return (a, b, c, d, e, f) """ # input_limit = x : float --> return (1/x, x, 1/x, x, 1/x, x) if scale, else (-x, +x, -x, +x, -x, +x) if isinstance(input_limit, float) or isinstance(input_limit, int): limit_range = parse_helper_affine_limits_1d(input_limit, scale=scale) # get (1/x, x) or (-x, +x) return limit_range * 3 # copy the tuple for each spatial axis # input_limit : TypeTripletFloat # if input_limit = ((a, b), (c, d), (e, f)) --> return (a, b, c, d, e, f) # elif input_limit = (a, b, c) --> return (-a, +a, -b, +b, -c, +c) # if scale, return (1/a, a, 1/b, b, 1/c, c) if len(input_limit) == 3: res = [] for item in input_limit: # for each spatial axis if isinstance(item, Iterable): # we already have a tuple -> add it to the result res.extend(item) else: # we need to create a tuple limit_range = parse_helper_affine_limits_1d(item, scale=scale) # get (1/x, x) or (-x, +x) res.append(limit_range[0]) res.append(limit_range[1]) return tuple(res) return parse_helper_sextet_common_cases(input_limit, return_float=True) def parse_helper_affine_limits_1d(input_limit: float, scale: bool) -> tuple: """ Create a 2-tuple of transformation limits for a single spatial axis. Returns: (1/x, x) if scale=True, (-x, +x) otherwise """ return tuple(sorted([input_limit, 1 / input_limit])) if scale else (-input_limit, input_limit) def parse_pads(pad_size: Union[int, TypePairInt, TypeSextetInt]) -> TypeSextetInt: """ Parse the padding argument. Args: pad_size: padding size (type None, int, tuple of 2 ints, tuple of 6 ints) Returns: a tuple of 6 ints representing padding for all 3 spatial dimensions. pad_size = None --> return (0, 0, 0, 0, 0, 0) pad_size = x : int --> return (x, x, x, x, x, x) input_limit = (a, b) : TypePairInt --> return (a, b, a, b, a, b) input_limit = (a, b, c, d, e, f) --> return (a, b, c, d, e, f) """ if isinstance(pad_size, int): return tuple((pad_size,) * 6) return parse_helper_sextet_common_cases(pad_size, return_float=False) def parse_helper_sextet_common_cases(arg: Optional[tuple], return_float=False): """ A helper function for argument parsing functions. Takes care of the common cases when type(arg) is None, 2-tuple, or 6-tuple. """ if arg is None: elem = 0. if return_float else 0 return (elem,) * 6 elif len(arg) == 2: return arg * 3 elif len(arg) == 6: return arg def parse_coefs(coefs: Union[float, tuple], identity_element: float = 1, dim4: bool = False) -> tuple: """ Parse the coefficients of affine transformation: rotation, scaling, or translation. Args: coefs: transformation coefficients identity_element: identity element (e.g. 1 for scaling, 0 for translation) dim4: a flag (True if time-lapse data, False otherwise) Returns: a tuple of 3 floats representing the transformation parameters for all 3 spatial dimensions. """ # input_limit = None --> return (ie, ie, ie) if coefs is None: return (identity_element,) * 3 # return (a, a, a) elif isinstance(coefs, (int, float)): return (coefs,) * 3 # return (a, b, c) for 3D data or (a, b, c, d) for time-lapse (4D) data elif (len(coefs) == 3) or (dim4 and len(coefs) == 4): return coefs def get_first_img_keyword(targets: dict = None): """ Get the first 'image'-type keyword from the targets dictionary. """ if (targets is not None) and isinstance(targets, dict): return targets.get('img_keywords')[0] return 'image' # <-- best effort, if we don't have concrete naming in the `targets` dict def get_spatio_temporal_domain_limit(sample: dict, targets: dict = None) -> TypeSpatioTemporalCoordinate: """ Returns a vector of spatio-temporal coordinates of length 4. The vector limits the domain of the image. Args: sample: dictionary with data targets: dictionary with targets """ shape = list(sample[get_first_img_keyword(targets)].shape) if len(shape) == 3: # 3D image without channels and the time axis limit = shape + [1] elif len(shape) == 4: # 3D image with channels, without the time axis limit = shape[1:] + [1] elif len(shape) == 5: # 3D image with channels and the time axis limit = shape[1:5] assert len(limit) == 4 return tuple(limit) def to_spatio_temporal(shape: tuple) -> TypeSpatioTemporalCoordinate: """ Return spatio-temporal shape given the input shape (without the channel dimension). """ shape = list(shape) if len(shape) == 3: shape.append(0) assert len(shape) == 4 return tuple(shape) def to_tuple(param: Union[int, float, Iterable]): """Convert input argument to min-max tuple Args: param (scalar or Iterable): Input value. If scalar, the return value is (-value, +value). Otherwise, convert the Iterable to tuple. """ if param is None: return param if isinstance(param, (int, float)): return -param, +param return tuple(param) def is_included(shape: Union[TypeSpatialCoordinate, TypeSpatioTemporalCoordinate], coo): coo_arr = np.array(coo) + 0.5 shape_arr = np.array(shape[:3]) # ignore the time dimension assert len(shape_arr) == len(coo_arr), f'shape: {shape_arr} coo: {coo_arr}' res = all(coo_arr >= 0) and (coo_arr < shape_arr).all() if DEBUG: print('IS INCLUDED', shape, coo, res) return res def validate_bbox(new_bbox: tuple, old_bbox: tuple, ratio: float = 0.5) -> bool: assert len(new_bbox) == len(old_bbox) old_size = get_bbox_size(old_bbox) new_size = get_bbox_size(new_bbox) return old_size / new_size >= ratio def get_bbox_size(bbox: tuple) -> float: assert len(bbox) % 2 == 0 dims = np.reshape(np.array(bbox), (-1, 2)) volume = 1. for v_min, v_max in dims: assert v_max >= v_min, f'The definition of bbox is invalid {bbox}.' volume *= v_max - v_min return volume