Skip to content
Snippets Groups Projects
Commit 23e25314 authored by Lucia D. Hradecka's avatar Lucia D. Hradecka
Browse files

refactor utils.py: group similar function, remove repeated logic

parent 2ed6bf73
No related branches found
No related tags found
1 merge request!12Make version 1.3.2 default
# ============================================================================================= #
# Author: Filip Lux, Lucia Hradecká #
# Copyright: Filip Lux lux.filip@gmail.com #
# Lucia Hradecká lucia.d.hradecka@gmail.com
# #
# MIT License. #
# #
......@@ -23,7 +24,7 @@
# SOFTWARE. #
# ============================================================================================= #
from typing import Sequence, Union
from typing import Sequence, Union, Optional
from ..biovol_typing import TypeSextetFloat, TypeTripletFloat, TypePairFloat, \
TypeSpatioTemporalCoordinate, TypeSpatialCoordinate, TypePairInt, TypeSextetInt
import numpy as np
......@@ -36,106 +37,156 @@ DEBUG = False
def parse_limits(input_limit: Union[float, TypePairFloat, TypeTripletFloat, TypeSextetFloat],
scale: bool = False) -> TypeSextetFloat:
"""
Parse the limits of affine transformation: rotation, scaling, or translation.
Args:
input_limit: transformation limits (type None, float, tuple of 2 floats, tuple of 3 floats, tuple of 6 floats)
scale: a flag (True if computing limits for scaling, False otherwise)
Returns: a tuple of 6 floats representing the limits for all 3 spatial dimensions.
input_limit = None --> return (0., 0., 0., 0., 0., 0.)
input_limit = x : float --> return (1/x, x, 1/x, x, 1/x, x) if scale, else (-x, +x, -x, +x, -x, +x)
input_limit = (a, b) : TypePairFloat --> return (a, b, a, b, a, b)
input_limit = (a, b, c) : TypeTripletFloat --> return (1/a, a, 1/b, b, 1/c, c) if scale, else (-a, +a, -b, +b, -c, +c)
input_limit = ((a, b), (c, d), (e, f)) : TypeTripletFloat --> return (a, b, c, d, e, f)
input_limit = (a, b, c, d, e, f) : TypeSextetFloat --> return (a, b, c, d, e, f)
"""
# input_limit = x : float --> return (1/x, x, 1/x, x, 1/x, x) if scale, else (-x, +x, -x, +x, -x, +x)
if (type(input_limit) is float) or (type(input_limit) is int):
limit_range = parse_helper_affine_limits_1d(input_limit, scale=scale)
return limit_range * 3
# input_limit = None
# returns (0., 0., 0., 0., 0., 0.)
if input_limit is None:
return (0., ) * 6
# input_limit = x : float
# returns (-x, +x, -x, +x, -x, +x)
# if scale, returns (1/x, x, 1/x, x, 1/x, x)
elif (type(input_limit) is float) or (type(input_limit) is int):
range = sorted([input_limit, 1 / input_limit]) if scale else [-input_limit, input_limit]
return tuple((range[0], range[1]) * 3)
# input_limit = (a, b) : TypePairFloat
# returns (a, b, a, b, a, b)
elif len(input_limit) == 2:
a, b = input_limit
return a, b, a, b, a, b
# input_limit = (a, b, c) : TypeTripletFloat
# returns (-a, +a, -b, +b, -c, +c)
elif len(input_limit) == 3:
# input_limit : TypeTripletFloat
# if input_limit = ((a, b), (c, d), (e, f)) --> return (a, b, c, d, e, f)
# elif input_limit = (a, b, c) --> return (-a, +a, -b, +b, -c, +c)
# if scale, return (1/a, a, 1/b, b, 1/c, c)
if len(input_limit) == 3:
res = []
for item in input_limit:
# input_limit = ((a, b), (c, d), (e, f))
# return (a, b, c, d, e, f)
if isinstance(item, Iterable):
for val in item:
res.append(float(val))
# input_limit = (a, b, c)
# return (-a, +a, -b, +b, -c, +c)
# if scale, returns (1/a, a, 1/b, b, 1/c, c)
else:
range = sorted([item, 1 / item]) if scale else [-item, item]
res.append(float(range[0]))
res.append(float(range[1]))
limit_range = parse_helper_affine_limits_1d(item, scale=scale)
res.append(limit_range[0])
res.append(limit_range[1])
return tuple(res)
return parse_helper_sextet_common_cases(input_limit, return_float=True)
# input_limit = (a, b, c, d, e, f)
# returns (a, b, c, d, e, f)
elif len(input_limit) == 6:
return input_limit
def parse_helper_affine_limits_1d(input_limit: float, scale: bool) -> tuple:
"""
Create a 2-tuple of transformation limits for a single spatial axis.
Returns: (1/x, x) if scale=True, (-x, +x) otherwise
"""
return tuple(sorted([input_limit, 1 / input_limit])) if scale else (-input_limit, input_limit)
def parse_pads(pad_size: Union[int, TypePairInt, TypeSextetInt]) -> TypeSextetInt:
"""
Parse the padding argument.
# pad_size = None
# returns (0, 0, 0, 0, 0, 0)
if pad_size is None:
return 0, 0, 0, 0, 0, 0
Args:
pad_size: padding size (type None, int, tuple of 2 ints, tuple of 6 ints)
Returns: a tuple of 6 ints representing padding for all 3 spatial dimensions.
pad_size = None --> return (0, 0, 0, 0, 0, 0)
pad_size = x : int --> return (x, x, x, x, x, x)
input_limit = (a, b) : TypePairInt --> return (a, b, a, b, a, b)
input_limit = (a, b, c, d, e, f) --> return (a, b, c, d, e, f)
"""
# pad_size = x : int
# returns (x, x, x, x, x, x)
elif type(pad_size) is int:
if type(pad_size) is int:
return tuple((pad_size,) * 6)
# input_limit = (a, b) : TypePairInt
# returns (a, b, a, b, a, b)
elif len(pad_size) == 2:
a, b = pad_size
return a, b, a, b, a, b
return parse_helper_sextet_common_cases(pad_size, return_float=False)
def parse_helper_sextet_common_cases(arg: Optional[tuple], return_float=False):
"""
A helper function for argument parsing functions.
Takes care of the common cases when type(arg) is None, 2-tuple, or 6-tuple.
"""
if arg is None:
elem = 0. if return_float else 0
return (elem,) * 6
# input_limit = (a, b, c, d, e, f)
# returns (a, b, c, d, e, f)
elif len(pad_size) == 6:
return pad_size
elif len(arg) == 2:
return arg * 3
elif len(arg) == 6:
return arg
def parse_coefs(coefs: Union[float, tuple],
identity_element: float = 1,
d4: bool = False) -> tuple:
# input_limit = None
# return (ie, ie, ie)
def parse_coefs(coefs: Union[float, tuple], identity_element: float = 1, dim4: bool = False) -> tuple:
"""
Parse the coefficients of affine transformation: rotation, scaling, or translation.
Args:
coefs: transformation coefficients
identity_element: identity element (e.g. 1 for scaling, 0 for translation)
dim4: a flag (True if time-lapse data, False otherwise)
Returns: a tuple of 3 floats representing the transformation parameters for all 3 spatial dimensions.
"""
# input_limit = None --> return (ie, ie, ie)
if coefs is None:
return tuple((identity_element, ) * 3)
return (identity_element,) * 3
# return (a, a, a)
elif isinstance(coefs, (int, float)):
return coefs, coefs, coefs
# return (a, b, c)
elif len(coefs) == 3:
return coefs
# return (a, b, c, d) for time-lapse (4D) data
elif d4 and len(coefs) == 4:
return (coefs,) * 3
# return (a, b, c) for 3D data or (a, b, c, d) for time-lapse (4D) data
elif (len(coefs) == 3) or (dim4 and len(coefs) == 4):
return coefs
def get_image_center(shape: Union[TypeSpatioTemporalCoordinate, TypeSpatialCoordinate],
spacing: TypeTripletFloat = (1., 1., 1.),
lps: bool = False) -> TypeTripletFloat:
def get_first_img_keyword(targets: dict = None):
"""
Get the first 'image'-type keyword from the targets dictionary.
"""
center = (np.array(shape)[:3] - 1) / 2
if (targets is not None) and isinstance(targets, dict):
return targets.get('img_keywords')[0]
return 'image' # <-- best effort, if we don't have concrete naming in the `targets` dict
if lps:
center = ras_to_lps(center)
return center * np.array(spacing)
def get_spatio_temporal_domain_limit(sample: dict, targets: dict = None) -> TypeSpatioTemporalCoordinate:
"""
Returns a vector of spatio-temporal coordinates of length 4.
The vector limits the domain of the image.
Args:
sample: dictionary with data
targets: dictionary with targets
"""
shape = list(sample[get_first_img_keyword(targets)].shape)
if len(shape) == 3:
# 3D image without channels and the time axis
limit = shape + [1]
elif len(shape) == 4:
# 3D image with channels, without the time axis
limit = shape[1:] + [1]
elif len(shape) == 5:
# 3D image with channels and the time axis
limit = shape[1:5]
assert len(limit) == 4
return tuple(limit)
def to_spatio_temporal(shape: tuple) -> TypeSpatioTemporalCoordinate:
"""
Return spatio-temporal shape given the input shape.
"""
shape = list(shape)
if len(shape) == 3:
......@@ -145,38 +196,31 @@ def to_spatio_temporal(shape: tuple) -> TypeSpatioTemporalCoordinate:
return tuple(shape)
def to_tuple(param, low=None, bias=None):
def to_tuple(param: Union[int, float, Iterable]):
"""Convert input argument to min-max tuple
Args:
param (scalar, tuple or list of 2+ elements): Input value.
If value is scalar, return value would be (offset - value, offset + value).
If value is tuple, return value would be value + offset (broadcasted).
low: Second element of tuple can be passed as optional argument
bias: An offset factor added to each element
param (scalar or Iterable): Input value.
If scalar, the return value is (-value, +value). Otherwise, convert the Iterable to tuple.
"""
if low is not None and bias is not None:
raise ValueError("Arguments low and bias are mutually exclusive")
if param is None:
return param
if isinstance(param, (int, float)):
if low is None:
param = -param, +param
else:
param = (low, param) if low < param else (param, low)
elif isinstance(param, Sequence):
param = tuple(param)
else:
raise ValueError("Argument param must be either scalar (int, float) or tuple")
if bias is not None:
return tuple(bias + x for x in param)
return -param, +param
return tuple(param)
def get_image_center(shape: Union[TypeSpatioTemporalCoordinate, TypeSpatialCoordinate],
spacing: TypeTripletFloat = (1., 1., 1.), lps: bool = False) -> TypeTripletFloat:
center = (np.array(shape)[:3] - 1) / 2.0
if lps:
center = ras_to_lps(center)
return center * np.array(spacing)
# Simple ITK uses LPS coordinates format
def ras_to_lps(triplet: Sequence[float]):
return np.array((-1, -1, 1), dtype=float) * np.asarray(triplet)
......@@ -221,9 +265,21 @@ def sitk_to_np(sitk_img: sitk.Image,
return img
def validate_bbox(new_bbox: tuple,
old_bbox: tuple,
ratio: float = 0.5) -> bool:
def is_included(shape: Union[TypeSpatialCoordinate, TypeSpatioTemporalCoordinate], coo):
coo_arr = np.array(coo) + 0.5
shape_arr = np.array(shape[:3])
assert len(shape_arr) == len(coo_arr), f'shape: {shape_arr} coo: {coo_arr}'
res = all(coo_arr >= 0) and (coo_arr < shape_arr).all()
if DEBUG:
print('IS INCLUDED', shape, coo, res)
return res
def validate_bbox(new_bbox: tuple, old_bbox: tuple, ratio: float = 0.5) -> bool:
assert len(new_bbox) == len(old_bbox)
......@@ -247,49 +303,5 @@ def get_bbox_size(bbox: tuple) -> float:
return volume
def get_first_img_keyword(targets: dict = None):
if (targets is not None) and isinstance(targets, dict):
return targets.get('img_keywords')[0]
return 'image' # <-- best effort, if we don't have concrete naming in the `targets` dict
def get_spatio_temporal_domain_limit(sample: dict, targets: dict = None) -> TypeSpatioTemporalCoordinate:
"""
Returns vector of spatio-temporal coordinates of length 4.
The vector limits a domain of the image.
Args:
sample: dictionary with data
targets: dictionary with targets
Returns:
"""
shape = list(sample[get_first_img_keyword(targets)].shape)
if len(shape) == 3:
limit = shape + [1]
elif len(shape) == 4:
limit = shape[1:] + [1]
elif len(shape) == 5:
limit = shape[1:5]
assert len(limit) == 4
return tuple(limit)
def is_included(shape: Union[TypeSpatialCoordinate, TypeSpatioTemporalCoordinate], coo):
coo_arr = np.array(coo) + 0.5
shape_arr = np.array(shape[:3])
assert len(shape_arr) == len(coo_arr), f'shape: {shape_arr} coo: {coo_arr}'
res = all(coo_arr >= 0) and (coo_arr < shape_arr).all()
if DEBUG:
print('IS INCLUDED', shape, coo, res)
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment