Newer
Older
# ============================================================================================= #
# Lucia Hradecká lucia.d.hradecka@gmail.com
# #
# MIT License. #
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy #
# of this software and associated documentation files (the "Software"), to deal #
# in the Software without restriction, including without limitation the rights #
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell #
# copies of the Software, and to permit persons to whom the Software is #
# furnished to do so, subject to the following conditions: #
# #
# The above copyright notice and this permission notice shall be included in all #
# copies or substantial portions of the Software. #
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR #
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, #
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, #
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
# SOFTWARE. #
# ============================================================================================= #
from typing import Sequence, Union, Optional
from ..biovol_typing import TypeSextetFloat, TypeTripletFloat, TypePairFloat, \
TypeSpatioTemporalCoordinate, TypeSpatialCoordinate, TypePairInt, TypeSextetInt
import numpy as np
import SimpleITK as sitk
from collections.abc import Iterable
def parse_limits(input_limit: Union[float, TypePairFloat, TypeTripletFloat, TypeSextetFloat],
scale: bool = False) -> TypeSextetFloat:
"""
Parse the limits of affine transformation: rotation, scaling, or translation.
Args:
input_limit: transformation limits (type None, float, tuple of 2 floats, tuple of 3 floats, tuple of 6 floats)
scale: a flag (True if computing limits for scaling, False otherwise)
Returns: a tuple of 6 floats representing the limits for all 3 spatial dimensions.
input_limit = None --> return (0., 0., 0., 0., 0., 0.)
input_limit = x : float --> return (1/x, x, 1/x, x, 1/x, x) if scale, else (-x, +x, -x, +x, -x, +x)
input_limit = (a, b) : TypePairFloat --> return (a, b, a, b, a, b)
input_limit = (a, b, c) : TypeTripletFloat --> return (1/a, a, 1/b, b, 1/c, c) if scale, else (-a, +a, -b, +b, -c, +c)
input_limit = ((a, b), (c, d), (e, f)) : TypeTripletFloat --> return (a, b, c, d, e, f)
input_limit = (a, b, c, d, e, f) : TypeSextetFloat --> return (a, b, c, d, e, f)
"""
# input_limit = x : float --> return (1/x, x, 1/x, x, 1/x, x) if scale, else (-x, +x, -x, +x, -x, +x)
if (type(input_limit) is float) or (type(input_limit) is int):
limit_range = parse_helper_affine_limits_1d(input_limit, scale=scale)
return limit_range * 3
# input_limit : TypeTripletFloat
# if input_limit = ((a, b), (c, d), (e, f)) --> return (a, b, c, d, e, f)
# elif input_limit = (a, b, c) --> return (-a, +a, -b, +b, -c, +c)
# if scale, return (1/a, a, 1/b, b, 1/c, c)
if len(input_limit) == 3:
res = []
for item in input_limit:
if isinstance(item, Iterable):
for val in item:
res.append(float(val))
else:
limit_range = parse_helper_affine_limits_1d(item, scale=scale)
res.append(limit_range[0])
res.append(limit_range[1])
return parse_helper_sextet_common_cases(input_limit, return_float=True)
def parse_helper_affine_limits_1d(input_limit: float, scale: bool) -> tuple:
"""
Create a 2-tuple of transformation limits for a single spatial axis.
Returns: (1/x, x) if scale=True, (-x, +x) otherwise
"""
return tuple(sorted([input_limit, 1 / input_limit])) if scale else (-input_limit, input_limit)
def parse_pads(pad_size: Union[int, TypePairInt, TypeSextetInt]) -> TypeSextetInt:
"""
Parse the padding argument.
Args:
pad_size: padding size (type None, int, tuple of 2 ints, tuple of 6 ints)
Returns: a tuple of 6 ints representing padding for all 3 spatial dimensions.
pad_size = None --> return (0, 0, 0, 0, 0, 0)
pad_size = x : int --> return (x, x, x, x, x, x)
input_limit = (a, b) : TypePairInt --> return (a, b, a, b, a, b)
input_limit = (a, b, c, d, e, f) --> return (a, b, c, d, e, f)
"""
if type(pad_size) is int:
return parse_helper_sextet_common_cases(pad_size, return_float=False)
def parse_helper_sextet_common_cases(arg: Optional[tuple], return_float=False):
"""
A helper function for argument parsing functions.
Takes care of the common cases when type(arg) is None, 2-tuple, or 6-tuple.
"""
if arg is None:
elem = 0. if return_float else 0
return (elem,) * 6
elif len(arg) == 2:
return arg * 3
elif len(arg) == 6:
return arg
def parse_coefs(coefs: Union[float, tuple], identity_element: float = 1, dim4: bool = False) -> tuple:
"""
Parse the coefficients of affine transformation: rotation, scaling, or translation.
Args:
coefs: transformation coefficients
identity_element: identity element (e.g. 1 for scaling, 0 for translation)
dim4: a flag (True if time-lapse data, False otherwise)
Returns: a tuple of 3 floats representing the transformation parameters for all 3 spatial dimensions.
"""
# input_limit = None --> return (ie, ie, ie)
return (identity_element,) * 3
return (coefs,) * 3
# return (a, b, c) for 3D data or (a, b, c, d) for time-lapse (4D) data
elif (len(coefs) == 3) or (dim4 and len(coefs) == 4):
def get_first_img_keyword(targets: dict = None):
"""
Get the first 'image'-type keyword from the targets dictionary.
"""
if (targets is not None) and isinstance(targets, dict):
return targets.get('img_keywords')[0]
return 'image' # <-- best effort, if we don't have concrete naming in the `targets` dict
def get_spatio_temporal_domain_limit(sample: dict, targets: dict = None) -> TypeSpatioTemporalCoordinate:
"""
Returns a vector of spatio-temporal coordinates of length 4.
The vector limits the domain of the image.
Args:
sample: dictionary with data
targets: dictionary with targets
"""
shape = list(sample[get_first_img_keyword(targets)].shape)
if len(shape) == 3:
# 3D image without channels and the time axis
limit = shape + [1]
elif len(shape) == 4:
# 3D image with channels, without the time axis
limit = shape[1:] + [1]
elif len(shape) == 5:
# 3D image with channels and the time axis
limit = shape[1:5]
assert len(limit) == 4
return tuple(limit)
def to_spatio_temporal(shape: tuple) -> TypeSpatioTemporalCoordinate:
"""
Return spatio-temporal shape given the input shape.
"""
shape = list(shape)
if len(shape) == 3:
shape.append(0)
assert len(shape) == 4
return tuple(shape)
def to_tuple(param: Union[int, float, Iterable]):
param (scalar or Iterable): Input value.
If scalar, the return value is (-value, +value). Otherwise, convert the Iterable to tuple.
"""
if param is None:
return param
if isinstance(param, (int, float)):
return -param, +param
def get_image_center(shape: Union[TypeSpatioTemporalCoordinate, TypeSpatialCoordinate],
spacing: TypeTripletFloat = (1., 1., 1.), lps: bool = False) -> TypeTripletFloat:
center = (np.array(shape)[:3] - 1) / 2.0
if lps:
center = ras_to_lps(center)
return center * np.array(spacing)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# Simple ITK uses LPS coordinates format
def ras_to_lps(triplet: Sequence[float]):
return np.array((-1, -1, 1), dtype=float) * np.asarray(triplet)
def np_to_sitk(img: np.array) -> sitk.Image:
# image in format (c, s1, s2, s3, [t])
assert len(img.shape) == 5
channels, w, h, d, frames = img.shape
sample = np.moveaxis(img, 0, 3)
sample = sample.reshape((w, h, d, channels * frames))
# TODO: rather swap axis of parameters than data
sample = np.swapaxes(sample, 0, 2)
return sitk.GetImageFromArray(sample)
def sitk_to_np(sitk_img: sitk.Image,
channels,
frames=1) -> np.array:
# shape (d, w, h, c*f)
img = sitk.GetArrayFromImage(sitk_img)
if len(img.shape) == 3:
img = np.expand_dims(img, 3)
assert channels * frames == img.shape[-1], (f'Number of channels ({channels}) and frames ({frames})'
f'does not correspond to the sitk vector size {img.shape[-1]}')
# split channels and frames
w, h, d = img.shape[:3]
img = img.reshape((w, h, d, channels, frames))
img = np.swapaxes(img, 0, 2)
img = np.moveaxis(img, 3, 0)
# shape (c, w, h, d, f)
return img
def is_included(shape: Union[TypeSpatialCoordinate, TypeSpatioTemporalCoordinate], coo):
coo_arr = np.array(coo) + 0.5
shape_arr = np.array(shape[:3])
assert len(shape_arr) == len(coo_arr), f'shape: {shape_arr} coo: {coo_arr}'
res = all(coo_arr >= 0) and (coo_arr < shape_arr).all()
if DEBUG:
print('IS INCLUDED', shape, coo, res)
return res
def validate_bbox(new_bbox: tuple, old_bbox: tuple, ratio: float = 0.5) -> bool:
assert len(new_bbox) == len(old_bbox)
old_size = get_bbox_size(old_bbox)
new_size = get_bbox_size(new_bbox)
return old_size / new_size >= ratio
def get_bbox_size(bbox: tuple) -> float:
assert len(bbox) % 2 == 0
dims = np.reshape(np.array(bbox), (-1, 2))
volume = 1.
for v_min, v_max in dims:
assert v_max >= v_min, f'The definition of bbox is invalid {bbox}.'
volume *= v_max - v_min
return volume