Skip to content
Snippets Groups Projects
Commit d19cdb5a authored by Lucia D. Hradecka's avatar Lucia D. Hradecka
Browse files

add tests of invalid input samples

parent 41a23f93
No related branches found
No related tags found
1 merge request!12Make version 1.3.2 default
......@@ -783,5 +783,91 @@ class TestInputArgs(unittest.TestCase):
# tr = Compose([Normalize(2, [3, 4])])
class TestInvalidInput(unittest.TestCase):
def invalid_range_check(self, transform, sample=None, **params):
if sample is None:
img_shape = (4, 120, 120, 120)
img = np.ones(img_shape, dtype=np.float64)
mask = np.ones(img_shape[1:], dtype=np.int64)
fmask = np.ones(img_shape[1:], dtype=np.float64)
else:
img, mask, fmask = sample
tr = Compose([transform(p=1, **params)])
tr_img = tr(image=img, mask=mask, float_mask=fmask)
# some checks - we just need to make sure that the computation did not fail
self.assertTrue(np.issubdtype(tr_img['image'].dtype, np.floating))
self.assertTrue(np.issubdtype(tr_img['mask'].dtype, np.integer))
self.assertTrue(np.issubdtype(tr_img['float_mask'].dtype, np.floating))
def test_invalid_range_crop(self):
self.invalid_range_check(RandomCrop, shape=(10, 10, 10))
self.invalid_range_check(CenterCrop, shape=(10, 10, 10))
def test_invalid_range_scale(self):
self.invalid_range_check(Scale, scales=0.5)
def test_invalid_range_gamma(self):
img_shape = (4, 120, 120, 120)
img = np.ones(img_shape, dtype=np.float64) * 2
mask = np.ones(img_shape[1:], dtype=np.int64)
fmask = np.ones(img_shape[1:], dtype=np.float64)
self.invalid_range_check(RandomGamma, sample=(img, mask, fmask))
def test_invalid_range_gaussian_blur(self):
self.invalid_range_check(GaussianBlur)
def test_invalid_range_normalize(self):
img_shape = (4, 120, 120, 120)
img = np.ones(img_shape, dtype=np.float64)
mask = np.ones(img_shape[1:], dtype=np.int64)
fmask = np.ones(img_shape[1:], dtype=np.float64)
self.invalid_range_check(RandomGamma, sample=(img, mask, fmask))
def invalid_dtype_check(self, transform, **params):
img_shape = (4, 120, 120, 120)
img = np.ones(img_shape, dtype=int)
mask = np.ones(img_shape[1:], dtype=float)
fmask = np.ones(img_shape[1:], dtype=int)
tr = Compose([transform(p=1, **params)])
tr_img = tr(image=img, mask=mask, float_mask=fmask)
self.assertTrue(np.issubdtype(tr_img['image'].dtype, np.floating))
self.assertTrue(np.issubdtype(tr_img['mask'].dtype, np.integer))
self.assertTrue(np.issubdtype(tr_img['float_mask'].dtype, np.floating))
def test_invalid_dtype_crop(self):
self.invalid_dtype_check(RandomCrop, shape=(10, 10, 10))
self.invalid_dtype_check(CenterCrop, shape=(10, 10, 10))
def test_invalid_dtype_scale(self):
self.invalid_dtype_check(Scale, scales=0.5)
def test_invalid_dtype_gamma(self):
self.invalid_dtype_check(RandomGamma)
def test_invalid_dtype_gaussian_blur(self):
self.invalid_dtype_check(RandomGamma)
def test_invalid_size_crop(self):
img_shape = (4, 120, 120, 120)
img = np.ones(img_shape, dtype=np.float64)
mask = np.ones(img_shape[1:], dtype=np.int64)
fmask = np.ones(img_shape[1:], dtype=np.float64)
tr = Compose([CenterCrop(shape=(140, 120, 100), p=1)])
tr_img = tr(image=img, mask=mask, float_mask=fmask)
# some checks - we just need to make sure that the computation did not fail
self.assertTrue(np.issubdtype(tr_img['image'].dtype, np.floating))
self.assertTrue(np.issubdtype(tr_img['mask'].dtype, np.integer))
self.assertTrue(np.issubdtype(tr_img['float_mask'].dtype, np.floating))
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment