Skip to content
Snippets Groups Projects
Commit 2ed6bf73 authored by Lucia D. Hradecka's avatar Lucia D. Hradecka
Browse files

group similar transformations

parent c43087fc
No related branches found
No related tags found
1 merge request!12Make version 1.3.2 default
...@@ -40,12 +40,10 @@ ...@@ -40,12 +40,10 @@
# ============================================================================================= # # ============================================================================================= #
import numpy as np import numpy as np
from functools import wraps
import skimage.transform as skt import skimage.transform as skt
from skimage.exposure import equalize_hist from skimage.exposure import equalize_hist
from scipy.ndimage import zoom, gaussian_filter from scipy.ndimage import gaussian_filter
from warnings import warn from warnings import warn
from typing import Union
from ..biovol_typing import TypeTripletFloat, TypeSpatioTemporalCoordinate, TypeSextetInt, TypeSpatialShape from ..biovol_typing import TypeTripletFloat, TypeSpatioTemporalCoordinate, TypeSextetInt, TypeSpatialShape
from .spatial_functional import get_affine_transform, apply_sitk_transform from .spatial_functional import get_affine_transform, apply_sitk_transform
...@@ -88,262 +86,12 @@ But for parameters use primarily ints. ...@@ -88,262 +86,12 @@ But for parameters use primarily ints.
""" """
def preserve_shape(func):
"""
Preserve shape of the image
"""
@wraps(func)
def wrapped_function(img, *args, **kwargs):
shape = img.shape
result = func(img, *args, **kwargs)
result = result.reshape(shape)
return result
return wrapped_function
def get_center_crop_coords(img_shape, crop_shape):
froms = (img_shape - crop_shape) // 2
tos = froms + crop_shape
return froms, tos
# Too similar to the random_crop. Could be made into one function
def crop(input_array: np.array,
crop_shape: TypeSpatialShape,
crop_position: TypeSpatialShape,
pad_dims,
border_mode, cval, mask):
input_spatial_shape = get_spatial_shape(input_array, mask)
if np.any(input_spatial_shape < crop_shape):
warn(f'F.crop(): Input size {input_spatial_shape} smaller than crop size {crop_shape}, pad by {border_mode}.',
UserWarning)
# pad
input_array = pad(input_array, pad_dims, border_mode, cval, mask)
# test
input_spatial_shape = get_spatial_shape(input_array, mask)
assert np.all(input_spatial_shape >= crop_shape)
x1, y1, z1 = crop_position
x2, y2, z2 = np.array(crop_position) + np.array(crop_shape)
if mask:
result = input_array[x1:x2, y1:y2, z1:z2]
assert np.all(result.shape[:3] == crop_shape), f'{result.shape} {crop_shape} {mask} {crop_position}'
else:
result = input_array[:, x1:x2, y1:y2, z1:z2]
assert np.all(result.shape[1:4] == crop_shape)
return result
def crop_keypoints(keypoints,
crop_shape: TypeSpatialShape,
crop_position: TypeSpatialShape,
pad_dims,
keep_all: bool):
(px, _), (py, _), (pz, _) = pad_dims
pad = np.array((px, py, pz))
res = []
for keypoint in keypoints:
k = keypoint[:3] - crop_position + pad
if keep_all or (np.all(k >= 0) and np.all((k + .5) < crop_shape)):
res.append(k)
return res
def get_spatial_shape(array: np.array, mask: bool) -> TypeSpatialShape:
return np.array(array.shape)[:3] if mask else np.array(array.shape)[1:4]
def get_pad_dims(spatial_shape: TypeSpatialShape, crop_shape: TypeSpatialShape):
pad_dims = []
for i in range(3):
i_dim, c_dim = spatial_shape[i], crop_shape[i]
if i_dim < c_dim:
pad_size = c_dim - i_dim
if pad_size % 2 != 0:
pad_dims.append((int(pad_size // 2 + 1), int(pad_size // 2)))
else:
pad_dims.append((int(pad_size // 2), int(pad_size // 2)))
else:
pad_dims.append((0, 0))
return pad_dims
def pad(img, pad_width, border_mode, cval, mask=True):
if not mask:
pad_width = [(0, 0)] + pad_width
if len(img.shape) > len(pad_width):
pad_width = pad_width + [(0, 0)]
assert len(img.shape) == len(pad_width)
if border_mode == "constant":
return np.pad(img, pad_width, border_mode, constant_values=cval)
if border_mode == "linear_ramp":
return np.pad(img, pad_width, border_mode, end_values=cval)
result = np.pad(img, pad_width, border_mode)
return result
def pad_keypoints(keypoints, pad_size):
a, b, c, d, e, f = pad_size
res = []
for coo in keypoints:
padding = np.array((a, c, e)) if len(coo) == 3 else np.array((a, c, e, 0))
res.append(coo + padding)
return res
def flip_keypoints(keypoints, axes, img_shape):
# all values in axes are in [1, 2, 3]
assert np.all(np.array([ax in [1, 2, 3] for ax in axes])), f'{axes} does not contain values from [1, 2, 3]'
mult, add = np.ones(3, int), np.zeros(3, int)
for ax in axes:
mult[ax-1] = -1
add[ax-1] = img_shape[ax-1] - 1
res = []
for k in keypoints:
flipped = list(np.array(k[:3]) * mult + add)
if len(k) == 4:
flipped.append(k[-1])
res.append(tuple(flipped))
return res
def rot90_keypoints(keypoints, factor, axes, img_shape):
if factor == 1:
keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)
keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
elif factor == 2:
keypoints = flip_keypoints(keypoints, axes, img_shape)
elif factor == 3:
keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)
return keypoints
def transpose_keypoints(keypoints, ax1, ax2):
# all values in axes are in [1, 2, 3]
assert (ax1 in [1, 2, 3]) and (ax2 in [1, 2, 3]), f'[{ax1} {ax2}] does not contain values from [1, 2, 3]'
res = []
for k in keypoints:
k = list(k)
k[ax1-1], k[ax2-1] = k[ax2-1], k[ax1-1]
res.append(tuple(k))
return res
def pad_pixels(img, input_pad_width: TypeSextetInt, border_mode, cval, mask=False):
a, b, c, d, e, f = input_pad_width
pad_width = [(a, b), (c, d), (e, f)]
# zeroes for channel dimension
if not mask:
pad_width = [(0, 0)] + pad_width
# zeroes for temporal dimension
if len(img.shape) == 5:
pad_width = pad_width + [(0, 0)]
if border_mode == "constant":
return np.pad(img, pad_width, border_mode, constant_values=cval)
if border_mode == "linear_ramp":
return np.pad(img, pad_width, border_mode, end_values=cval)
return np.pad(img, pad_width, border_mode)
def normalize_mean_std(img, mean, denominator):
if len(mean.shape) == 0:
mean = mean[..., None]
if len(denominator.shape) == 0:
denominator = denominator[..., None]
new_axis = [i + 1 for i in range(len(img.shape) - 1)]
img -= np.expand_dims(mean, axis=new_axis)
img *= np.expand_dims(denominator, axis=new_axis)
return img
# formula taken from
# https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation
def normalize_channel(img, mean, std):
return (img - img.mean()) * (std / img.std()) + mean
def value_to_list(value, length):
if isinstance(value, (float, int)):
return [value for _ in range(length)]
else:
return value
def correct_length_list(list_to_check, length, value_to_fill=1, list_name="###Default###"):
if len(list_to_check) < length:
warn(f"{list_name} have elements {len(list_to_check)}, should be {length} appending {value_to_fill} " +
"till length matches", UserWarning)
for i in range(length - len(list_to_check)):
list_to_check = list_to_check + [value_to_fill]
if len(list_to_check) > length:
warn(f"{list_name} have elements {len(list_to_check)}, should be {length} removing elements from behind " +
" till length matches", UserWarning)
list_to_check = [list_to_check[i] for i in range(length)]
return list_to_check
def normalize(img, input_mean, input_std):
mean = value_to_list(input_mean, img.shape[0])
std = value_to_list(input_std, img.shape[0])
mean = correct_length_list(mean, img.shape[0], value_to_fill=0, list_name="mean")
std = correct_length_list(std, img.shape[0], value_to_fill=1, list_name="std")
for i in range(img.shape[0]):
img[i] = normalize_channel(img[i], mean[i], std[i])
return img
def gaussian_noise(img, mean, sigma):
img = img.astype("float32")
noise = np.random.normal(mean, sigma, img.shape).astype(np.float32)
return img + noise
def poisson_noise(img, peak):
img = img.astype("float32")
return img + np.random.poisson(img).astype(np.float32)
# TODO parameter # TODO parameter
# Anti-aliasing - gaussian filter to smooth. using automatically when downsampling, except when integer # Anti-aliasing - gaussian filter to smooth. using automatically when downsampling, except when integer
# and interpolation is 0. (so mask) # and interpolation is 0. (so mask)
# float mask - how, for now no gaussian filter. # float mask - how, for now no gaussian filter.
def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, mask=False, def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, mask=False,
anti_aliasing_downsample=True): anti_aliasing_downsample=True):
# TODO: random fix, check if it is correct # TODO: random fix, check if it is correct
new_shape = list(input_new_shape)[:-1] new_shape = list(input_new_shape)[:-1]
...@@ -359,7 +107,7 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, ...@@ -359,7 +107,7 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0,
# too many or few dimensions of new_shape # too many or few dimensions of new_shape
if len(new_shape) < len(img.shape) - 1 or len(new_shape) > len(img.shape): if len(new_shape) < len(img.shape) - 1 or len(new_shape) > len(img.shape):
warn(f"Resize(): wrong parameter shape: {new_shape}," + warn(f"Resize(): wrong parameter shape: {new_shape}," +
f"expecting something with dimensions of {img.shape } or {img.shape[0:-1] }, " + f"expecting something with dimensions of {img.shape} or {img.shape[0:-1]}, " +
"continuing without resizing ", UserWarning) "continuing without resizing ", UserWarning)
return img return img
# Adding time dimension # Adding time dimension
...@@ -368,7 +116,7 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, ...@@ -368,7 +116,7 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0,
else: else:
if len(new_shape) < len(img.shape[1:]) - 1 or len(new_shape) > len(img.shape[1:]): if len(new_shape) < len(img.shape[1:]) - 1 or len(new_shape) > len(img.shape[1:]):
warn(f"Resize(): wrong dimensions of shape: {new_shape}," + warn(f"Resize(): wrong dimensions of shape: {new_shape}," +
f"expecting something with dimensions of {img.shape[1:] } or {img.shape[1:-1] }, continuing " + f"expecting something with dimensions of {img.shape[1:]} or {img.shape[1:-1]}, continuing " +
"without resizing ", UserWarning) "without resizing ", UserWarning)
return img return img
# adding time dimension # adding time dimension
...@@ -387,10 +135,10 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, ...@@ -387,10 +135,10 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0,
anti_aliasing=anti_aliasing anti_aliasing=anti_aliasing
) )
return new_img return new_img
if anti_aliasing_downsample and np.any(np.array(img.shape[1:]) < np.array(new_shape)): if anti_aliasing_downsample and np.any(np.array(img.shape[1:]) < np.array(new_shape)):
anti_aliasing = True anti_aliasing = True
data = [] data = []
for i in range(img.shape[0]): for i in range(img.shape[0]):
subimg = img[i].copy() subimg = img[i].copy()
...@@ -405,14 +153,13 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0, ...@@ -405,14 +153,13 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0,
) )
data.append(d0.copy()) data.append(d0.copy())
new_img = np.stack(data, axis=0) new_img = np.stack(data, axis=0)
return new_img return new_img
def resize_keypoints(keypoints, def resize_keypoints(keypoints,
domain_limit: TypeSpatioTemporalCoordinate, domain_limit: TypeSpatioTemporalCoordinate,
new_shape: TypeSpatioTemporalCoordinate): new_shape: TypeSpatioTemporalCoordinate):
assert len(domain_limit) == len(new_shape) == 4 assert len(domain_limit) == len(new_shape) == 4
# for each dim compute ratio # for each dim compute ratio
...@@ -422,62 +169,12 @@ def resize_keypoints(keypoints, ...@@ -422,62 +169,12 @@ def resize_keypoints(keypoints,
return [keypoint * ratio for keypoint in keypoints] return [keypoint * ratio for keypoint in keypoints]
# TODO compare with skt.rescale, new version got channel_axis
def scale(img, input_scale_factor, interpolation=0, border_mode='reflect', cval=0, mask=True):
scale_factor = input_scale_factor
# check for zero or negative numbers
if isinstance(scale_factor, (int, float)):
if scale_factor <= 0:
warn(f"RandomScale()/Scale(): scale_factor: {len(scale_factor)} is zero or negative number" +
f" continuing without scaling ", UserWarning)
return img
else:
for dimension in scale_factor:
if dimension <= 0:
warn(f"RandomScale()/Scale(): scale_factor: {len(scale_factor)} contains zero or negative number " +
"continuing without scaling ", UserWarning)
return img
img_shape = img.shape
if scale_factor is None:
return img
if isinstance(scale_factor, (list, tuple)):
scale_factor = np.array(scale_factor)
if not mask:
img_shape = img_shape[1:]
# TODO, maybe user wants to add shape for only spatial dimensions
if len(img_shape) != len(scale_factor) and len(img_shape) - 1 != len(scale_factor):
warn(f"RandomScale()/Scale(): Wrong dimension of scaling factor list: {len(scale_factor)}," +
f"expecting {len(img_shape)} or {len(img_shape[:-1]) }, continuing without scaling ", UserWarning)
return img
elif len(img_shape) - 1 == len(scale_factor):
scale_factor = np.append(scale_factor, 1)
else:
scale_factor = [scale_factor for _ in range(len(img_shape) - 1)]
if mask:
scale_factor.append(scale_factor[0])
# Not scaling time dimensions
if len(scale_factor) == 4:
scale_factor[-1] = 1
if mask:
return zoom(img, scale_factor, order=interpolation, mode=border_mode, cval=cval)
data = []
for i in range(img.shape[0]):
subimg = img[i].copy()
d0 = zoom(subimg, scale_factor, order=interpolation, mode=border_mode, cval=cval)
data.append(d0.copy())
new_img = np.stack(data, axis=0)
return new_img
''' '''
#TODO maybe add parameter for order of rotations #TODO maybe add parameter for order of rotations
#LIMIT dimensions #LIMIT dimensions
def affine_transform(img, input_x_angle, input_y_angle, input_z_angle, translantion, interpolation = 1, border_mode = 'constant', def affine_transform(img, input_x_angle, input_y_angle, input_z_angle, translantion, interpolation = 1, border_mode = 'constant',
value = 0, input_scaling_coef = None, scale_back = True, mask = False ): value = 0, input_scaling_coef = None, scale_back = True, mask = False ):
if mask: if mask:
img = img[np.newaxis, :] img = img[np.newaxis, :]
x_angle, y_angle, z_angle = [np.pi * i / 180 for i in [input_x_angle, input_y_angle, input_z_angle]] x_angle, y_angle, z_angle = [np.pi * i / 180 for i in [input_x_angle, input_y_angle, input_z_angle]]
...@@ -509,7 +206,7 @@ def affine_transform(img, input_x_angle, input_y_angle, input_z_angle, translant ...@@ -509,7 +206,7 @@ def affine_transform(img, input_x_angle, input_y_angle, input_z_angle, translant
for i in range(len(translantion)): for i in range(len(translantion)):
offset[i + 1] -= translantion[i] offset[i + 1] -= translantion[i]
img = sci.affine_transform(img, inverse_affine_matrix, offset, order=interpolation, mode=border_mode, cval= value) img = sci.affine_transform(img, inverse_affine_matrix, offset, order=interpolation, mode=border_mode, cval= value)
if mask: if mask:
img = img[0] img = img[0]
return img return img
...@@ -581,84 +278,183 @@ def affine_keypoints(keypoints: list, ...@@ -581,84 +278,183 @@ def affine_keypoints(keypoints: list,
return res return res
# TO REMOVE # Used in rot90_keypoints
def rotation_matrix_calculation(dim, x_angle, y_angle, z_angle): def flip_keypoints(keypoints, axes, img_shape):
rot_matrix = np.identity(dim).astype(np.float32)
rot_matrix = rot_matrix @ rot_x(x_angle, dim)
rot_matrix = rot_matrix @ rot_y(y_angle, dim)
rot_matrix = rot_matrix @ rot_z(z_angle, dim)
return rot_matrix
def rot_x(angle, dim):
if dim == 4:
rotation_x = np.array([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, np.cos(angle), -np.sin(angle)],
[0, 0, np.sin(angle), np.cos(angle)]])
if dim == 5:
rotation_x = np.array([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, np.cos(angle), -np.sin(angle), 0],
[0, 0, np.sin(angle), np.cos(angle), 0],
[0, 0, 0, 0, 1]])
return rotation_x
def rot_y(angle, dim):
if dim == 4:
rotation_y = np.array([[1, 0, 0, 0],
[0, np.cos(angle), 0, np.sin(angle)],
[0, 0, 1, 0],
[0, -np.sin(angle), 0, np.cos(angle)]])
if dim == 5:
rotation_y = np.array([[1, 0, 0, 0, 0],
[0, np.cos(angle), 0, np.sin(angle), 0],
[0, 0, 1, 0, 0],
[0, -np.sin(angle), 0, np.cos(angle), 0],
[0, 0, 0, 0, 1]])
return rotation_y
def rot_z(angle, dim):
if dim == 4:
rotation_z = np.array([[1, 0, 0, 0],
[0, np.cos(angle), -np.sin(angle), 0],
[0, np.sin(angle), np.cos(angle), 0],
[0, 0, 0, 1]])
if dim == 5:
rotation_z = np.array([[1, 0, 0, 0, 0],
[0, np.cos(angle), -np.sin(angle), 0, 0],
[0, np.sin(angle), np.cos(angle), 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]])
return rotation_z
# all values in axes are in [1, 2, 3]
assert np.all(np.array([ax in [1, 2, 3] for ax in axes])), f'{axes} does not contain values from [1, 2, 3]'
# TODO clipped tag may be important for types other that float32, but tags are from fork and not tested mult, add = np.ones(3, int), np.zeros(3, int)
# @clipped for ax in axes:
def brightness_contrast_adjust(img, alpha=1, beta=0): mult[ax-1] = -1
if alpha != 1: add[ax-1] = img_shape[ax-1] - 1
img *= alpha
if beta != 0: res = []
img += beta for k in keypoints:
return img flipped = list(np.array(k[:3]) * mult + add)
if len(k) == 4:
flipped.append(k[-1])
res.append(tuple(flipped))
return res
def histogram_equalization(img, bins): # Used in rot90_keypoints
for i in range(img.shape[0]): def transpose_keypoints(keypoints, ax1, ax2):
img[i] = equalize_hist(img[i], bins)
return img # all values in axes are in [1, 2, 3]
assert (ax1 in [1, 2, 3]) and (ax2 in [1, 2, 3]), f'[{ax1} {ax2}] does not contain values from [1, 2, 3]'
res = []
for k in keypoints:
k = list(k)
k[ax1-1], k[ax2-1] = k[ax2-1], k[ax1-1]
res.append(tuple(k))
return res
def rot90_keypoints(keypoints, factor, axes, img_shape):
if factor == 1:
keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)
keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
elif factor == 2:
keypoints = flip_keypoints(keypoints, axes, img_shape)
elif factor == 3:
keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)
return keypoints
def pad(img, pad_width, border_mode, cval, mask=True):
if not mask:
pad_width = [(0, 0)] + pad_width
if len(img.shape) > len(pad_width):
pad_width = pad_width + [(0, 0)]
assert len(img.shape) == len(pad_width)
if border_mode == "constant":
return np.pad(img, pad_width, border_mode, constant_values=cval)
if border_mode == "linear_ramp":
return np.pad(img, pad_width, border_mode, end_values=cval)
result = np.pad(img, pad_width, border_mode)
return result
def pad_keypoints(keypoints, pad_size):
a, b, c, d, e, f = pad_size
res = []
for coo in keypoints:
padding = np.array((a, c, e)) if len(coo) == 3 else np.array((a, c, e, 0))
res.append(coo + padding)
return res
def pad_pixels(img, input_pad_width: TypeSextetInt, border_mode, cval, mask=False):
a, b, c, d, e, f = input_pad_width
pad_width = [(a, b), (c, d), (e, f)]
# zeroes for channel dimension
if not mask:
pad_width = [(0, 0)] + pad_width
# zeroes for temporal dimension
if len(img.shape) == 5:
pad_width = pad_width + [(0, 0)]
if border_mode == "constant":
return np.pad(img, pad_width, border_mode, constant_values=cval)
if border_mode == "linear_ramp":
return np.pad(img, pad_width, border_mode, end_values=cval)
return np.pad(img, pad_width, border_mode)
# Used in crop()
def get_spatial_shape(array: np.array, mask: bool) -> TypeSpatialShape:
return np.array(array.shape)[:3] if mask else np.array(array.shape)[1:4]
# Used in crop()
def get_pad_dims(spatial_shape: TypeSpatialShape, crop_shape: TypeSpatialShape):
pad_dims = []
for i in range(3):
i_dim, c_dim = spatial_shape[i], crop_shape[i]
if i_dim < c_dim:
pad_size = c_dim - i_dim
if pad_size % 2 != 0:
pad_dims.append((int(pad_size // 2 + 1), int(pad_size // 2)))
else:
pad_dims.append((int(pad_size // 2), int(pad_size // 2)))
else:
pad_dims.append((0, 0))
return pad_dims
# Too similar to the random_crop. Could be made into one function
def crop(input_array: np.array,
crop_shape: TypeSpatialShape,
crop_position: TypeSpatialShape,
pad_dims,
border_mode, cval, mask):
input_spatial_shape = get_spatial_shape(input_array, mask)
if np.any(input_spatial_shape < crop_shape):
warn(f'F.crop(): Input size {input_spatial_shape} smaller than crop size {crop_shape}, pad by {border_mode}.',
UserWarning)
# pad
input_array = pad(input_array, pad_dims, border_mode, cval, mask)
# test
input_spatial_shape = get_spatial_shape(input_array, mask)
assert np.all(input_spatial_shape >= crop_shape)
x1, y1, z1 = crop_position
x2, y2, z2 = np.array(crop_position) + np.array(crop_shape)
if mask:
result = input_array[x1:x2, y1:y2, z1:z2]
assert np.all(result.shape[:3] == crop_shape), f'{result.shape} {crop_shape} {mask} {crop_position}'
else:
result = input_array[:, x1:x2, y1:y2, z1:z2]
assert np.all(result.shape[1:4] == crop_shape)
return result
def crop_keypoints(keypoints,
crop_shape: TypeSpatialShape,
crop_position: TypeSpatialShape,
pad_dims,
keep_all: bool):
(px, _), (py, _), (pz, _) = pad_dims
pad = np.array((px, py, pz))
res = []
for keypoint in keypoints:
k = keypoint[:3] - crop_position + pad
if keep_all or (np.all(k >= 0) and np.all((k + .5) < crop_shape)):
res.append(k)
return res
def gaussian_blur(img, input_sigma, border_mode, cval): def gaussian_blur(img, input_sigma, border_mode, cval):
sigma = input_sigma sigma = input_sigma
if isinstance(sigma, list): if isinstance(sigma, list):
if img.shape[0] != len(sigma): if img.shape[0] != len(sigma):
warn(f'GaussianBlur(): wrong list size {len(sigma)}, expecting same as number of dimensions {img.shape[0]}. Ignoring', UserWarning) warn(
f'GaussianBlur(): wrong list size {len(sigma)}, expecting same as number of dimensions {img.shape[0]}. Ignoring',
UserWarning)
return img return img
return gaussian_blur_stack(img, sigma, border_mode, cval) return gaussian_blur_stack(img, sigma, border_mode, cval)
...@@ -671,15 +467,15 @@ def gaussian_blur(img, input_sigma, border_mode, cval): ...@@ -671,15 +467,15 @@ def gaussian_blur(img, input_sigma, border_mode, cval):
else: else:
# TODO what to expect in the input. # TODO what to expect in the input.
if len(sigma) == len(img.shape) - 2: if len(sigma) == len(img.shape) - 2:
sigma = np.append(sigma, 0) sigma = np.append(sigma, 0)
if len(sigma) == len(img.shape) - 1: if len(sigma) == len(img.shape) - 1:
sigma = np.insert(sigma, 0, 0) sigma = np.insert(sigma, 0, 0)
# TODO better warning # TODO better warning
if len(sigma) != len(img.shape): if len(sigma) != len(img.shape):
warn(f'GaussianBlur(): wrong sigma tuple, ignoring', UserWarning) warn(f'GaussianBlur(): wrong sigma tuple, ignoring', UserWarning)
return img return img
return gaussian_filter(img, sigma=sigma, mode=border_mode, cval=cval) return gaussian_filter(img, sigma=sigma, mode=border_mode, cval=cval)
def gaussian_blur_stack(img, input_sigma, border_mode, cval): def gaussian_blur_stack(img, input_sigma, border_mode, cval):
sigma = list(np.asarray(input_sigma).copy()) sigma = list(np.asarray(input_sigma).copy())
...@@ -689,7 +485,7 @@ def gaussian_blur_stack(img, input_sigma, border_mode, cval): ...@@ -689,7 +485,7 @@ def gaussian_blur_stack(img, input_sigma, border_mode, cval):
warn(f'GaussianBlur(): wrong sigma format, Inside list can be only tuple,float or int. Ignoring', warn(f'GaussianBlur(): wrong sigma format, Inside list can be only tuple,float or int. Ignoring',
UserWarning) UserWarning)
return img return img
# TODO try different techniques for better optimalization. # TODO try different techniques for better optimalization.
for i in range(len(sigma)): for i in range(len(sigma)):
if isinstance(sigma[i], (float, int)): if isinstance(sigma[i], (float, int)):
...@@ -703,6 +499,16 @@ def gaussian_blur_stack(img, input_sigma, border_mode, cval): ...@@ -703,6 +499,16 @@ def gaussian_blur_stack(img, input_sigma, border_mode, cval):
return img return img
# TODO clipped tag may be important for types other that float32, but tags are from fork and not tested
# @clipped
def brightness_contrast_adjust(img, alpha=1, beta=0):
if alpha != 1:
img *= alpha
if beta != 0:
img += beta
return img
def gamma_transform(img, gamma): def gamma_transform(img, gamma):
if np.all(img < 0) or np.all(img > 1) : if np.all(img < 0) or np.all(img > 1) :
warn(f"Gamma transform: image is not in range [0,1]. continuing without transform", UserWarning) warn(f"Gamma transform: image is not in range [0,1]. continuing without transform", UserWarning)
...@@ -710,3 +516,68 @@ def gamma_transform(img, gamma): ...@@ -710,3 +516,68 @@ def gamma_transform(img, gamma):
else: else:
return np.power(img, gamma) return np.power(img, gamma)
def histogram_equalization(img, bins):
for i in range(img.shape[0]):
img[i] = equalize_hist(img[i], bins)
return img
def gaussian_noise(img, mean, sigma):
img = img.astype("float32")
noise = np.random.normal(mean, sigma, img.shape).astype(np.float32)
return img + noise
def poisson_noise(img, peak):
img = img.astype("float32")
return img + np.random.poisson(img).astype(np.float32)
def value_to_list(value, length):
if isinstance(value, (float, int)):
return [value for _ in range(length)]
else:
return value
def correct_length_list(list_to_check, length, value_to_fill=1, list_name="###Default###"):
if len(list_to_check) < length:
warn(f"{list_name} have elements {len(list_to_check)}, should be {length} appending {value_to_fill} " +
"till length matches", UserWarning)
for i in range(length - len(list_to_check)):
list_to_check = list_to_check + [value_to_fill]
if len(list_to_check) > length:
warn(f"{list_name} have elements {len(list_to_check)}, should be {length} removing elements from behind " +
" till length matches", UserWarning)
list_to_check = [list_to_check[i] for i in range(length)]
return list_to_check
# formula taken from
# https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation
def normalize_channel(img, mean, std):
return (img - img.mean()) * (std / img.std()) + mean
def normalize(img, input_mean, input_std):
mean = value_to_list(input_mean, img.shape[0])
std = value_to_list(input_std, img.shape[0])
mean = correct_length_list(mean, img.shape[0], value_to_fill=0, list_name="mean")
std = correct_length_list(std, img.shape[0], value_to_fill=1, list_name="std")
for i in range(img.shape[0]):
img[i] = normalize_channel(img[i], mean[i], std[i])
return img
def normalize_mean_std(img, mean, denominator):
if len(mean.shape) == 0:
mean = mean[..., None]
if len(denominator.shape) == 0:
denominator = denominator[..., None]
new_axis = [i + 1 for i in range(len(img.shape) - 1)]
img -= np.expand_dims(mean, axis=new_axis)
img *= np.expand_dims(denominator, axis=new_axis)
return img
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment