Skip to content
Snippets Groups Projects
Commit 41a23f93 authored by Lucia D. Hradecka's avatar Lucia D. Hradecka
Browse files

replace explicit loops with numpy vectorised operations

parent d5ca1ee0
No related branches found
No related tags found
1 merge request!12Make version 1.3.2 default
......@@ -46,7 +46,7 @@ from scipy.ndimage import gaussian_filter
from warnings import warn
from .sitk_utils import get_affine_transform, apply_sitk_transform
from .utils import is_included
from .utils import is_included, get_nonchannel_axes, atleast_kd
from src.biovol_typing import TypeTripletFloat, TypeSpatioTemporalCoordinate, TypeSextetInt, TypeSpatialShape
from src.random_utils import normal, poisson
......@@ -81,11 +81,10 @@ def resize(img, input_new_shape, interpolation=1, border_mode='reflect', cval=0,
new_shape = list(input_new_shape)[:-1]
# Zero or negative check
for dimension in new_shape:
if dimension <= 0:
warn(f'Resize(): shape: {new_shape} contains zero or negative number, continuing without Resize.',
UserWarning)
return img
if np.any(np.asarray(new_shape) <= 0):
warn(f'Resize(): shape: {new_shape} contains zero or negative number, continuing without Resize.',
UserWarning)
return img
# shape check
if mask:
......@@ -151,7 +150,7 @@ def resize_keypoints(keypoints,
ratio = np.array(new_shape[:3]) / np.array(domain_limit[:3])
# (we suppose here that length of keypoint is 3)
return [keypoint * ratio for keypoint in keypoints]
return list(map(tuple, np.asarray(keypoints) * ratio))
def affine(img: np.array,
......@@ -224,18 +223,18 @@ def flip_keypoints(keypoints, axes, img_shape):
# all values in axes are in [1, 2, 3]
assert np.all(np.array([ax in [1, 2, 3] for ax in axes])), f'{axes} does not contain values from [1, 2, 3]'
mult, add = np.ones(3, int), np.zeros(3, int)
keys = np.asarray(keypoints)
ndim = keys.shape[1]
mult = np.ones(ndim, int)
add = np.zeros(ndim, int)
for ax in axes:
mult[ax - 1] = -1
add[ax - 1] = img_shape[ax - 1] - 1
res = []
for k in keypoints:
flipped = list(np.array(k[:3]) * mult + add)
if len(k) == 4:
flipped.append(k[-1])
res.append(tuple(flipped))
return res
keys = keys * mult + add
return list(map(tuple, keys))
# Used in rot90_keypoints
......@@ -243,12 +242,13 @@ def transpose_keypoints(keypoints, ax1, ax2):
# all values in axes are in [1, 2, 3]
assert (ax1 in [1, 2, 3]) and (ax2 in [1, 2, 3]), f'[{ax1} {ax2}] does not contain values from [1, 2, 3]'
res = []
for k in keypoints:
k = list(k)
k[ax1 - 1], k[ax2 - 1] = k[ax2 - 1], k[ax1 - 1]
res.append(tuple(k))
return res
axis1 = ax1 - 1
axis2 = ax2 - 1
keys = np.asarray(keypoints)
keys[:, axis1], keys[:, axis2] = keys[:, axis2], keys[:, axis1].copy()
# Return a list of tuples
return list(map(tuple, keys))
def rot90_keypoints(keypoints, factor, axes, img_shape):
......@@ -269,11 +269,11 @@ def rot90_keypoints(keypoints, factor, axes, img_shape):
def pad_keypoints(keypoints, pad_size):
a, b, c, d, e, f = pad_size
res = []
for coo in keypoints:
padding = np.array((a, c, e)) if len(coo) == 3 else np.array((a, c, e, 0)) # we only need the 'before' pad size
res.append(coo + padding)
return res
keys = np.asarray(keypoints)
padding = np.asarray((a, c, e) if keys.shape[1] == 3 else (a, c, e, 0)) # we only need the 'before' pad size
# Return a list of tuples
return list(map(tuple, keys + padding))
def pad_pixels(img, input_pad_width: TypeSextetInt, border_mode, cval, mask=False):
......@@ -359,16 +359,20 @@ def crop_keypoints(keypoints,
crop_position: TypeSpatialShape,
pad_dims,
keep_all: bool):
# Get padding information
px, _, py, _, pz, _ = pad_dims # we only need the 'before' padding size
pad = np.array((px, py, pz))
pad = np.asarray((px, py, pz))
res = []
for keypoint in keypoints:
k = keypoint[:3] - crop_position + pad
if keep_all or (np.all(k >= 0) and np.all((k + .5) < crop_shape)):
res.append(k)
# Compute new keypoint positions
keys = np.asarray(keypoints)[:, :3] - np.asarray(crop_position) + pad # ignore the time dimension of keypoints
return res
# Filter the keypoints
if not keep_all:
mask = (keys >= 0) & (keys + .5 < np.asarray(crop_shape))
keys = keys[np.sum(mask, axis=1) == 3, :]
# Return a list of tuples
return list(map(tuple, keys))
def gaussian_blur(img, input_sigma, border_mode, cval):
......@@ -447,7 +451,7 @@ def gamma_transform(img, gamma):
def histogram_equalization(img, bins):
for i in range(img.shape[0]):
for i in range(img.shape[0]): # for each channel
img[i] = equalize_hist(img[i], bins)
return img
......@@ -465,9 +469,9 @@ def poisson_noise(img, peak):
def value_to_list(value, length):
if isinstance(value, (float, int)):
return [value for _ in range(length)]
return [value] * length
else:
return value
return value # TODO: maybe return list(value)?
def correct_length_list(list_to_check, length, value_to_fill=1, list_name='###Default###'):
......@@ -486,32 +490,34 @@ def correct_length_list(list_to_check, length, value_to_fill=1, list_name='###De
return list_to_check
def normalize_channel(img, mean, std):
def normalize(img, input_mean, input_std):
"""
Normalize a single-channel image to have the desired mean and variance values.
Used formula from: https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation
Normalize a multi-channel image to have the desired mean and standard deviation values.
Formula from: https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation
"""
return (img - img.mean()) * (std / img.std()) + mean
def normalize(img, input_mean, input_std):
mean = value_to_list(input_mean, img.shape[0])
std = value_to_list(input_std, img.shape[0])
mean = correct_length_list(mean, img.shape[0], value_to_fill=0, list_name='mean')
std = correct_length_list(std, img.shape[0], value_to_fill=1, list_name='std')
for i in range(img.shape[0]): # for each channel
img[i] = normalize_channel(img[i], mean[i], std[i])
return img
mean = atleast_kd(mean, img.ndim)
std = atleast_kd(std, img.ndim)
img_mean = atleast_kd(img.mean(axis=get_nonchannel_axes(img)), img.ndim)
img_std = atleast_kd(img.std(axis=get_nonchannel_axes(img)), img.ndim)
if np.any(np.isclose(img_std, 0)):
warn(f'Normalize(): standard deviation of at least one input channel is 0. Skipping this transformation.',
UserWarning)
return img
img = (img - img_mean) * (std / img_std) + mean
return img.astype(img.dtype)
def normalize_mean_std(img, mean, denominator):
if len(mean.shape) == 0:
mean = mean[..., None]
if len(denominator.shape) == 0:
denominator = denominator[..., None]
new_axis = [i + 1 for i in range(len(img.shape) - 1)]
img -= np.expand_dims(mean, axis=new_axis)
img *= np.expand_dims(denominator, axis=new_axis)
img -= atleast_kd(mean, k=img.ndim)
img *= atleast_kd(denominator, k=img.ndim)
return img
......@@ -1802,6 +1802,10 @@ class NormalizeMeanStd(ImageOnlyTransform):
# Compute the formula denominator once as it is computationally expensive:
self.denominator = np.reciprocal(self.std, dtype=np.float32)
if len(self.mean.shape) == 0: # shapes of self.mean and self.denominator are the same
self.mean = self.mean[..., None]
self.denominator = self.denominator[..., None]
def apply(self, image, **params):
return F.normalize_mean_std(image, self.mean, self.denominator)
......
......@@ -36,6 +36,22 @@ from src.random_utils import uniform
DEBUG = False
def get_nonchannel_axes(array):
"""
Return the non-channel axis indices for a given image.
"""
return tuple(range(1, array.ndim))
def atleast_kd(array, k):
"""
Add singleton dimensions to the input array s.t. the new shape is at least k-dimensional.
"""
array = np.asarray(array)
new_shape = array.shape + (1,) * (k - array.ndim)
return array.reshape(new_shape)
def get_sigma_axiswise(min_sigma, max_sigma):
"""
Randomly choose a single sigma for all axes and channels (if max_sigma is int or float)
......@@ -88,8 +104,7 @@ def parse_limits(input_limit: Union[float, TypePairFloat, TypeTripletFloat, Type
for item in input_limit: # for each spatial axis
if isinstance(item, Iterable):
# we already have a tuple -> add it to the result
for val in item:
res.append(float(val))
res.extend(item)
else:
# we need to create a tuple
limit_range = parse_helper_affine_limits_1d(item, scale=scale) # get (1/x, x) or (-x, +x)
......
import unittest
class MyTestTemplate(unittest.TestCase):
def test_something(self):
self.assertEqual(True, True) # add assertion here
import time
import numpy as np
from src.augmentations import functional, atleast_kd, get_nonchannel_axes
class TestNormalization(unittest.TestCase):
def normalize_cycle(self, img, mean, std):
# def normalize_channel(img, mean, std):
# return (img - img.mean()) * (std / img.std()) + mean
# for-cycle implementation
for i in range(img.shape[0]):
img[i] = (img[i] - img[i].mean()) * (std[i] / img[i].std()) + mean[i]
# img[i] = normalize_channel(img[i], mean[i], std[i])
return img
def normalize_vect(self, img, mean, std):
mean = atleast_kd(mean, img.ndim)
std = atleast_kd(std, img.ndim)
img_mean = atleast_kd(img.mean(axis=get_nonchannel_axes(img)), img.ndim)
img_std = atleast_kd(img.std(axis=get_nonchannel_axes(img)), img.ndim)
return (img - img_mean) * (std / img_std) + mean
def test_normalize_cycle_imlp(self):
img = np.random.random(size=(3, 5, 6, 7))
mean = [0, 0, 1]
std = [1, 2, 1]
img = self.normalize_cycle(img, mean, std)
# verify it
for i in range(img.shape[0]):
self.assertAlmostEqual(img[i].mean(), mean[i])
self.assertAlmostEqual(img[i].std(), std[i])
def test_normalize_vect_imlp(self):
img = np.random.random(size=(3, 5, 6, 7))
mean = [0, 0, 1]
std = [1, 2, 1]
# numpy-vectorised implementation
img = self.normalize_vect(img, mean, std)
# verify it
for i in range(img.shape[0]):
self.assertAlmostEqual(img[i].mean(), mean[i])
self.assertAlmostEqual(img[i].std(), std[i])
def test_runtime(self):
n = 30
img_size = (3, 256, 256, 256)
img_size = (3, 130, 130, 100, 4)
# img_size = (8, 200, 200, 200)
time_cycle = self.measure_exec_time(self.normalize_cycle, n=n, img_size=img_size)
print(f'Runtime (for cycle implementation): {time_cycle}')
time_vect = self.measure_exec_time(self.normalize_vect, n=n, img_size=img_size)
print(f'Runtime (vectorised implementation): {time_vect}')
def measure_exec_time(self, fn, n=100, img_size=(3, 256, 256, 256)):
total_time = 0
for i in range(n):
img = np.random.random(size=img_size)
mean = [0, 0, 1] + [-1] * (img_size[0] - 3)
std = [1, 2, 1] + [3] * (img_size[0] - 3)
start = time.time()
res = fn(img, mean, std)
res2 = res.shape
end = time.time()
total_time += (end - start)
return total_time / n
def test_normalize_fn_1(self):
img = np.random.random(size=(3, 5, 6, 7))
mean = [0, 0, 1]
std = [1, 2, 1]
res = functional.normalize(img, mean, std)
for i in range(res.shape[0]):
self.assertAlmostEqual(res[i].mean(), mean[i])
self.assertAlmostEqual(res[i].std(), std[i])
def test_normalize_fn_2(self):
img = np.random.random(size=(3, 5, 6, 7, 4))
mean = [0, 0, 1]
std = [1, 2, 2]
res = functional.normalize(img, mean, std)
for i in range(res.shape[0]):
self.assertAlmostEqual(res[i].mean(), mean[i])
self.assertAlmostEqual(res[i].std(), std[i])
def test_normalize_fn_3(self):
img = np.random.random(size=(1, 5, 6, 7))
mean = [1]
std = [2]
res = functional.normalize(img, mean, std)
for i in range(res.shape[0]):
self.assertAlmostEqual(res[i].mean(), mean[i])
self.assertAlmostEqual(res[i].std(), std[i])
class TestGaussianBlur(unittest.TestCase):
def gaussian_blur_vect(self, img, sigma, border_mode, cval):
from skimage.filters import gaussian
# If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes.
# When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes.
# The ith entry in any of these tuples corresponds to the ith entry in axes.
# return functional.gaussian_filter(img, sigma=sigma, mode=border_mode, cval=cval, axes=[0]) # compute
return gaussian(img, sigma=sigma, channel_axis=0, preserve_range=True) # compute
def test_gaussian_blur_fn_1(self):
img = np.random.random(size=(3, 100, 101, 120))
sigma = [2, 1, 1.5]
res = self.gaussian_blur_vect(img, sigma, 'reflect', 0)
res1 = functional.gaussian_blur_stack(img, sigma, 'reflect', 0)
self.assertTrue(np.allclose(res, res1))
class TestCropPadEtc(unittest.TestCase):
def test_keypoints_crop_1(self):
keypoints = [(1, 2, 3), (2, 2, 2), (10, 2, 1), (20, 23, 20), (4, 5, 0)]
out_shape = np.asarray((5, 5, 5))
corner_position = np.asarray((1, 0, 1))
pad = np.asarray((0, 0, 0))
# res = self.crop_keypoints_cycle(keypoints, corner_position, pad, False, out_shape)
res = self.crop_keypoints_vect(keypoints, corner_position, pad, False, out_shape)
self.assertEqual(len(res), 2)
self.assertTupleEqual(tuple(res[0]), (0, 2, 2))
self.assertTupleEqual(tuple(res[1]), (1, 2, 1))
# res = self.crop_keypoints_cycle(keypoints, corner_position, pad, True, out_shape)
res = self.crop_keypoints_vect(keypoints, corner_position, pad, True, out_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (0, 2, 2))
self.assertTupleEqual(tuple(res[1]), (1, 2, 1))
self.assertTupleEqual(tuple(res[2]), (9, 2, 0))
self.assertTupleEqual(tuple(res[3]), (19, 23, 19))
self.assertTupleEqual(tuple(res[4]), (3, 5, -1))
def test_keypoints_crop_2(self):
keypoints = [(1, 2, 3, 0), (2, 2, 2, 1), (10, 2, 1, 3), (20, 23, 20, 1), (4, 5, 0, 0)]
out_shape = np.asarray((5, 5, 5))
corner_position = np.asarray((1, 0, 1))
pad = np.asarray((0, 0, 0))
# res = self.crop_keypoints_cycle(keypoints, corner_position, pad, False, out_shape)
res = self.crop_keypoints_vect(keypoints, corner_position, pad, False, out_shape)
self.assertEqual(len(res), 2)
self.assertTupleEqual(tuple(res[0]), (0, 2, 2))
self.assertTupleEqual(tuple(res[1]), (1, 2, 1))
# res = self.crop_keypoints_cycle(keypoints, corner_position, pad, True, out_shape)
res = self.crop_keypoints_vect(keypoints, corner_position, pad, True, out_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (0, 2, 2))
self.assertTupleEqual(tuple(res[1]), (1, 2, 1))
self.assertTupleEqual(tuple(res[2]), (9, 2, 0))
self.assertTupleEqual(tuple(res[3]), (19, 23, 19))
self.assertTupleEqual(tuple(res[4]), (3, 5, -1))
def crop_keypoints_cycle(self, keypoints, crop_position, pad, keep_all, crop_shape):
res = []
for keypoint in keypoints: # keypoints = list of tuples of floats
k = keypoint[:3] - crop_position + pad # ignore time-dim keypoint position
if keep_all or (np.all(k >= 0) and np.all((k + .5) < crop_shape)):
res.append(k)
return res
def crop_keypoints_vect(self, keypoints, crop_position, pad, keep_all, crop_shape):
keys = np.asarray(keypoints)[:, :3] - crop_position + pad
if keep_all:
return keys
mask = (keys >= 0) & (keys+.5 < crop_shape)
res = keys[np.sum(mask, axis=1) == 3, :]
return res
def pad_keypoints_cycle(self, keypoints, pad):
res = []
for coo in keypoints:
padding = np.array(pad) if len(coo) == 3 else np.array(pad + (0,))
res.append(coo + padding)
return res
def pad_keypoints_vect(self, keypoints, pad):
keys = np.asarray(keypoints)
padding = np.asarray(pad if keys.shape[1] == 3 else pad+(0,)) # we only need the 'before' pad size
return keys + padding
def test_keypoints_pad_1(self):
keypoints = [(1, 2, 3), (2, 2, 2), (10, 2, 1), (20, 23, 20), (4, 5, 0)]
pad = (0, 1, 3)
# res = self.pad_keypoints_cycle(keypoints, pad)
res = self.pad_keypoints_vect(keypoints, pad)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (1, 3, 6))
self.assertTupleEqual(tuple(res[1]), (2, 3, 5))
self.assertTupleEqual(tuple(res[2]), (10, 3, 4))
self.assertTupleEqual(tuple(res[3]), (20, 24, 23))
self.assertTupleEqual(tuple(res[4]), (4, 6, 3))
def test_keypoints_pad_2(self):
keypoints = [(1, 2, 3, 0), (2, 2, 2, 1), (10, 2, 1, 4), (20, 23, 20, 2), (4, 5, 0, 1)]
pad = (0, 1, 3)
# res = self.pad_keypoints_cycle(keypoints, pad)
res = self.pad_keypoints_vect(keypoints, pad)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (1, 3, 6, 0))
self.assertTupleEqual(tuple(res[1]), (2, 3, 5, 1))
self.assertTupleEqual(tuple(res[2]), (10, 3, 4, 4))
self.assertTupleEqual(tuple(res[3]), (20, 24, 23, 2))
self.assertTupleEqual(tuple(res[4]), (4, 6, 3, 1))
def transpose_keypoints_cycle(self, keypoints, ax1, ax2):
res = []
for k in keypoints:
k = list(k)
k[ax1 - 1], k[ax2 - 1] = k[ax2 - 1], k[ax1 - 1]
res.append(tuple(k))
return res
def transpose_keypoints_vect(self, keypoints, ax1, ax2):
keys = np.asarray(keypoints)
ax1, ax2 = ax1-1, ax2-1
keys[:, ax1], keys[:, ax2] = keys[:, ax2], keys[:, ax1].copy()
return keys
def test_keypoints_transpose_1(self):
keypoints = [(1, 2, 3), (2, 2, 2), (10, 2, 1), (20, 23, 20), (4, 5, 0)]
# res = self.transpose_keypoints_cycle(keypoints, 1, 2)
res = self.transpose_keypoints_vect(keypoints, 1, 2)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (2, 1, 3))
self.assertTupleEqual(tuple(res[4]), (5, 4, 0))
# res = self.transpose_keypoints_cycle(keypoints, 1, 3)
res = self.transpose_keypoints_vect(keypoints, 1, 3)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (3, 2, 1))
self.assertTupleEqual(tuple(res[4]), (0, 5, 4))
# res = self.transpose_keypoints_cycle(keypoints, 3, 2)
res = self.transpose_keypoints_vect(keypoints, 3, 2)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (1, 3, 2))
self.assertTupleEqual(tuple(res[4]), (4, 0, 5))
def test_keypoints_transpose_2(self):
keypoints = [(1, 2, 3, 1), (2, 2, 2, 1), (10, 2, 1, 1), (20, 23, 20, 1), (4, 5, 0, 1)]
# res = self.transpose_keypoints_cycle(keypoints, 1, 2)
res = self.transpose_keypoints_vect(keypoints, 1, 2)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (2, 1, 3, 1))
self.assertTupleEqual(tuple(res[4]), (5, 4, 0, 1))
# res = self.transpose_keypoints_cycle(keypoints, 1, 3)
res = self.transpose_keypoints_vect(keypoints, 1, 3)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (3, 2, 1, 1))
self.assertTupleEqual(tuple(res[4]), (0, 5, 4, 1))
# res = self.transpose_keypoints_cycle(keypoints, 3, 2)
res = self.transpose_keypoints_vect(keypoints, 3, 2)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (1, 3, 2, 1))
self.assertTupleEqual(tuple(res[4]), (4, 0, 5, 1))
def flip_keypoints_cycle(self, keypoints, axes, img_shape):
mult, add = np.ones(3, int), np.zeros(3, int)
for ax in axes:
mult[ax - 1] = -1
add[ax - 1] = img_shape[ax - 1] - 1
res = []
for k in keypoints:
flipped = list(np.array(k[:3]) * mult + add)
if len(k) == 4:
flipped.append(k[-1])
res.append(tuple(flipped))
return res
def flip_keypoints_vect(self, keypoints, axes, img_shape):
keys = np.asarray(keypoints)
ndim = keys.shape[1]
mult, add = np.ones(ndim, int), np.zeros(ndim, int)
for ax in axes:
mult[ax - 1] = -1
add[ax - 1] = img_shape[ax - 1] - 1
flipped = keys * mult + add
return flipped
def test_keypoints_flip_1(self):
keypoints = [(1, 2, 3), (2, 2, 2), (10, 2, 1), (20, 23, 20), (4, 5, 0)]
img_shape = (25, 25, 25)
axes = []
# res = self.flip_keypoints_cycle(keypoints, axes, img_shape)
res = self.flip_keypoints_vect(keypoints, axes, img_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (1, 2, 3))
self.assertTupleEqual(tuple(res[4]), (4, 5, 0))
axes = [1]
# res = self.flip_keypoints_cycle(keypoints, axes, img_shape)
res = self.flip_keypoints_vect(keypoints, axes, img_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (23, 2, 3))
self.assertTupleEqual(tuple(res[4]), (20, 5, 0))
axes = [1, 3]
# res = self.flip_keypoints_cycle(keypoints, axes, img_shape)
res = self.flip_keypoints_vect(keypoints, axes, img_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (23, 2, 21))
self.assertTupleEqual(tuple(res[4]), (20, 5, 24))
def test_keypoints_flip_2(self):
keypoints = [(1, 2, 3, 0), (2, 2, 2, 1), (10, 2, 1, 3), (20, 23, 20, 1), (4, 5, 0, 1)]
img_shape = (25, 25, 25)
axes = []
# res = self.flip_keypoints_cycle(keypoints, axes, img_shape)
res = self.flip_keypoints_vect(keypoints, axes, img_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (1, 2, 3, 0))
self.assertTupleEqual(tuple(res[4]), (4, 5, 0, 1))
axes = [1]
# res = self.flip_keypoints_cycle(keypoints, axes, img_shape)
res = self.flip_keypoints_vect(keypoints, axes, img_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (23, 2, 3, 0))
self.assertTupleEqual(tuple(res[4]), (20, 5, 0, 1))
axes = [1, 3]
# res = self.flip_keypoints_cycle(keypoints, axes, img_shape)
res = self.flip_keypoints_vect(keypoints, axes, img_shape)
self.assertEqual(len(res), 5)
self.assertTupleEqual(tuple(res[0]), (23, 2, 21, 0))
self.assertTupleEqual(tuple(res[4]), (20, 5, 24, 1))
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment