Skip to content
Snippets Groups Projects
Commit a73f5bb6 authored by Lucia D. Hradecka's avatar Lucia D. Hradecka
Browse files

refactor functional.py: remove repeated logic

parent ac5ceb38
No related branches found
No related tags found
1 merge request!12Make version 1.3.2 default
......@@ -269,25 +269,6 @@ def rot90_keypoints(keypoints, factor, axes, img_shape):
return keypoints
def pad(img, pad_width, border_mode, cval, mask=True):
if not mask:
pad_width = [(0, 0)] + pad_width
if len(img.shape) > len(pad_width):
pad_width = pad_width + [(0, 0)]
assert len(img.shape) == len(pad_width)
if border_mode == "constant":
return np.pad(img, pad_width, border_mode, constant_values=cval)
if border_mode == "linear_ramp":
return np.pad(img, pad_width, border_mode, end_values=cval)
result = np.pad(img, pad_width, border_mode)
return result
def pad_keypoints(keypoints, pad_size):
a, b, c, d, e, f = pad_size
......@@ -307,9 +288,11 @@ def pad_pixels(img, input_pad_width: TypeSextetInt, border_mode, cval, mask=Fals
pad_width = [(0, 0)] + pad_width
# zeroes for temporal dimension
if len(img.shape) == 5:
if len(img.shape) == 5: # if len(img.shape) > len(pad_width):
pad_width = pad_width + [(0, 0)]
assert len(img.shape) == len(pad_width)
if border_mode == "constant":
return np.pad(img, pad_width, border_mode, constant_values=cval)
if border_mode == "linear_ramp":
......@@ -323,22 +306,23 @@ def get_spatial_shape(array: np.array, mask: bool) -> TypeSpatialShape:
# Used in crop()
def get_pad_dims(spatial_shape: TypeSpatialShape, crop_shape: TypeSpatialShape):
pad_dims = []
def get_pad_dims(spatial_shape: TypeSpatialShape, crop_shape: TypeSpatialShape) -> TypeSextetInt:
pad_dims = [0] * 6
for i in range(3):
i_dim, c_dim = spatial_shape[i], crop_shape[i]
current_pad_dims = (0, 0)
if i_dim < c_dim:
pad_size = c_dim - i_dim
if pad_size % 2 != 0:
pad_dims.append((int(pad_size // 2 + 1), int(pad_size // 2)))
current_pad_dims = (int(pad_size // 2 + 1), int(pad_size // 2))
else:
pad_dims.append((int(pad_size // 2), int(pad_size // 2)))
else:
pad_dims.append((0, 0))
return pad_dims
current_pad_dims = (int(pad_size // 2), int(pad_size // 2))
pad_dims[i * 2:(i + 1) * 2] = current_pad_dims
return tuple(pad_dims)
# Too similar to the random_crop. Could be made into one function
def crop(input_array: np.array,
crop_shape: TypeSpatialShape,
crop_position: TypeSpatialShape,
......@@ -352,7 +336,7 @@ def crop(input_array: np.array,
UserWarning)
# pad
input_array = pad(input_array, pad_dims, border_mode, cval, mask)
input_array = pad_pixels(input_array, pad_dims, border_mode, cval, mask)
# test
input_spatial_shape = get_spatial_shape(input_array, mask)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment