Skip to content
Snippets Groups Projects
test_transforms.py 21.6 KiB
Newer Older
Filip Lux's avatar
Filip Lux committed
# ============================================================================================= #
#  Author:       Filip Lux                                                                      #
#  Copyright:    Filip Lux          : lux.filip@gmail.com                                       #
#                                                                                               #
#  MIT License.                                                                                 #
#                                                                                               #
#  Permission is hereby granted, free of charge, to any person obtaining a copy                 #
#  of this software and associated documentation files (the "Software"), to deal                #
#  in the Software without restriction, including without limitation the rights                 #
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell                    #
#  copies of the Software, and to permit persons to whom the Software is                        #
#  furnished to do so, subject to the following conditions:                                     #
#                                                                                               #
#  The above copyright notice and this permission notice shall be included in all               #
#  copies or substantial portions of the Software.                                              #
#                                                                                               #
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR                   #
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,                     #
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE                  #
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER                       #
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,                #
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE                #
#  SOFTWARE.                                                                                    #
# ============================================================================================= #


Lucia Hradecká's avatar
Lucia Hradecká committed
import unittest
Filip Lux's avatar
Filip Lux committed
from bio_volumentations.augmentations.transforms import (
    GaussianNoise, PoissonNoise, Resize, Pad, Scale, Flip, CenterCrop, AffineTransform,
    RandomScale, RandomRotate90, RandomFlip, RandomCrop, RandomAffineTransform, RandomGamma,
    NormalizeMeanStd, GaussianBlur, Normalize, HistogramEqualization, RandomBrightnessContrast)
from bio_volumentations.core.composition import Compose
Lucia Hradecká's avatar
Lucia Hradecká committed
import numpy as np

DEBUG = False
Filip Lux's avatar
Filip Lux committed
class TestScale(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(Scale, (31, 32, 33), params={'scales': 1.5})
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Lucia Hradecká's avatar
Lucia Hradecká committed

Filip Lux's avatar
Filip Lux committed
        tests = get_shape_tests(Scale, (31, 32, 33), params={'scales': 0.8})
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Lucia Hradecká's avatar
Lucia Hradecká committed

    def test_keypoints(self):

        tests = get_keypoints_tests(Scale, params={'scales': 1.5})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.1, msg)

        tests = get_keypoints_tests(Scale, params={'scales': 0.8})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.1, msg)


def get_keypoints_tests(transform,
                        in_shape: tuple = (32, 33, 34),
                        params: dict = {}):

    w, h, d = in_shape

    img = np.zeros((4, w, h, d), np.float32)
    mask = np.zeros((w, h, d), np.int32)
    keypoints = []

    lbd = 3

    for _ in range(15):
        w1, h1, d1 = np.random.randint(lbd, w - lbd), \
                     np.random.randint(lbd, h - lbd), \
                     np.random.randint(lbd, d - lbd)
        img[:, w1 - lbd:w1 + lbd, h1 - lbd:h1 + lbd, d1 - lbd:d1 + lbd] = 10.
        mask[w1 - lbd:w1 + lbd, h1 - lbd:h1 + lbd, d1 - lbd:d1 + lbd] = 10.
        keypoints.append((w1-0., h1-0., d1-0.))

    sample = {'image': img,
              'mask': mask,
              'keypoints': keypoints}

    tr = Compose([transform(**params, p=1)])
    sample_transformed = tr(**sample)

    keypoints_transformed = sample_transformed['keypoints']
    if DEBUG:
        print('KEYPOINTS', transform, keypoints)
        print('KEYPOINTS TRANSFORMED', transform, keypoints_transformed)

    tests = []
    for k in keypoints_transformed:
        coos = (np.array(k) + .5).astype(int)
        tests.append((sample_transformed['image'][0, coos[0], coos[1], coos[2]], 10.,
                      f'mask, {k} {coos} {transform} {params}'))
        tests.append((sample_transformed['mask'][coos[0], coos[1], coos[2]], 10.,
                      f'img {k} {coos} {transform}, {params}'))

    return tests


def get_shape_tests(transform,
                    in_shape: tuple,
                    params={}):
    """
    Iterates over all the possibilities, hot the array can passed throught the transform
    Args:
        transform: biovol transform,
        in_shape: spatial dimension of the input image
        params: optional, params of the biovol transform

    Returns:
        list of outputs and expected shapes

    """

    w, h, d = in_shape
    w_, h_, d_ = params['shape'] if 'shape' in params.keys() else (w, h, d)

    res = []
    tr = Compose([transform(**params, p=1)])

    # img (W, H, D), mask (W, H, D)
    img = np.ones((w, h, d), dtype=np.float32)
    mask = np.ones((w, h, d), dtype=np.int32)
    fmask = np.ones((w, h, d), dtype=np.float32)
    #print(img.dtype, mask.dtype, fmask.dtype)
    tr_img = tr(image=img, mask=mask, float_mask=fmask)
    #print(tr_img['image'].dtype, tr_img['mask'].dtype, tr_img['float_mask'].dtype)
    res.append((tr_img['image'], (1, w_, h_, d_), np.float32))
    res.append((tr_img['mask'], (w_, h_, d_), np.int32))
    res.append((tr_img['float_mask'], (w_, h_, d_), np.float32))

    # img (C, W, H, D), mask (W, H, D)
    img = np.ones((4, w, h, d), dtype=np.single)
    mask = np.ones((w, h, d), dtype=int)
    fmask = np.ones((w, h, d), dtype=np.single)
    tr_img = tr(image=img, mask=mask, float_mask=fmask)
    res.append((tr_img['image'], (4, w_, h_, d_), np.float32))
    res.append((tr_img['mask'], (w_, h_, d_), np.int32))
    res.append((tr_img['float_mask'], (w_, h_, d_), np.float32))

    # img (C, W, H, D, T), mask (W, H, D, T)
    img = np.ones((4, w, h, d, 5), dtype=np.single)
    mask = np.ones((w, h, d, 5), dtype=int)
    fmask = np.ones((w, h, d, 5), dtype=np.single)
    tr_img = tr(image=img, mask=mask, float_mask=fmask)
    res.append((tr_img['image'], (4, w_, h_, d_, 5), np.float32))
    res.append((tr_img['mask'], (w_, h_, d_, 5), np.int32))
    res.append((tr_img['float_mask'], (w_, h_, d_, 5), np.float32))

    return res

Lucia Hradecká's avatar
Lucia Hradecká committed

Filip Lux's avatar
Filip Lux committed
class TestRandomScale(unittest.TestCase):
    def test_shape(self):
Lucia Hradecká's avatar
Lucia Hradecká committed

Filip Lux's avatar
Filip Lux committed
        limits = [0.2,
                  (0.8, 1.2),
                  (0.2, 0.3, 0.1),
                  (0.8, 1.2, 0.9, 1.1, 0.7, 1.)]
Lucia Hradecká's avatar
Lucia Hradecká committed

Filip Lux's avatar
Filip Lux committed
        for scaling_limit in limits:
            tests = get_shape_tests(RandomScale,
                                    in_shape=(31, 32, 33),
Filip Lux's avatar
Filip Lux committed
                                    params={'scaling_limit': scaling_limit})
Filip Lux's avatar
Filip Lux committed
            for tr_img, expected_shape, data_type in tests:
                self.assertTupleEqual(tr_img.shape, expected_shape)
                self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

    def test_keypoints(self):

        limits = [0.2,
                  (0.8, 1.2),
                  (0.2, 0.3, 0.1),
                  (0.8, 1.2, 0.9, 1.1, 0.7, 1.)]

        for scaling_limit in limits:
            tests = get_keypoints_tests(RandomScale,
                                        in_shape=(61, 62, 63),
                                        params={'scaling_limit': scaling_limit})
            for value, expected_value, msg in tests:
                self.assertGreater(value, expected_value * 0.5, msg)

Filip Lux's avatar
Filip Lux committed

class TestRandomRotate90(unittest.TestCase):
    def test_shape(self):

        axes_list = [None,
                     [1],
                     [1, 2],
                     [1, 2, 3]]

        for axes in axes_list:
            tests = get_shape_tests(RandomRotate90, (30, 30, 30),
                                    params={'axes': axes})
Filip Lux's avatar
Filip Lux committed
            for tr_img, expected_shape, data_type in tests:
                self.assertTupleEqual(tr_img.shape, expected_shape)
                self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestFlip(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(Flip, (31, 32, 33))
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestRandomFlip(unittest.TestCase):
    def test_shape(self):

        axes_list = [None,
                     [],
                     [1],
                     [1, 2],
                     [1, 2, 3]]

        for axes in axes_list:
            tests = get_shape_tests(RandomFlip, (30, 30, 30),
                                    params={'axes_to_choose': axes})
Filip Lux's avatar
Filip Lux committed
            for tr_img, expected_shape, data_type in tests:
                self.assertTupleEqual(tr_img.shape, expected_shape)
                self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestCenterCrop(unittest.TestCase):
    def test_inflate(self):
        in_shape = (32, 31, 30)
        shape_tests = get_shape_tests(CenterCrop, in_shape, {'shape': (40, 41, 42)})

Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in shape_tests:
Filip Lux's avatar
Filip Lux committed
            self.assertTupleEqual(tr_img.shape, expected_shape)
Filip Lux's avatar
Filip Lux committed
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

    def test_deflate(self):
        in_shape = (32, 31, 30)
        shape_tests = get_shape_tests(CenterCrop, in_shape, {'shape': (20, 21, 22)})

Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in shape_tests:
Filip Lux's avatar
Filip Lux committed
            self.assertTupleEqual(tr_img.shape, expected_shape)
Filip Lux's avatar
Filip Lux committed
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


Filip Lux's avatar
Filip Lux committed
class TestRandomCrop(unittest.TestCase):
    def test_inflate(self):
        in_shape = (32, 31, 30)
        shape_tests = get_shape_tests(RandomCrop, in_shape, {'shape': (40, 41, 42)})

Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in shape_tests:
Filip Lux's avatar
Filip Lux committed
            self.assertTupleEqual(tr_img.shape, expected_shape)
Filip Lux's avatar
Filip Lux committed
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

    def test_deflate(self):
        in_shape = (32, 31, 30)
        shape_tests = get_shape_tests(RandomCrop, in_shape, {'shape': (20, 21, 22)})

Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in shape_tests:
Filip Lux's avatar
Filip Lux committed
            self.assertTupleEqual(tr_img.shape, expected_shape)
Filip Lux's avatar
Filip Lux committed
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestResize(unittest.TestCase):
    def test_inflate(self):
        in_shape = (32, 31, 30)
        shape_tests = get_shape_tests(Resize, in_shape, {'shape': (40, 41, 42)})

Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in shape_tests:
Filip Lux's avatar
Filip Lux committed
            self.assertTupleEqual(tr_img.shape, expected_shape)
Filip Lux's avatar
Filip Lux committed
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

    def test_deflate(self):
        in_shape = (32, 31, 30)
        shape_tests = get_shape_tests(Resize, in_shape, {'shape': (20, 21, 22)})

Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in shape_tests:
Filip Lux's avatar
Filip Lux committed
            self.assertTupleEqual(tr_img.shape, expected_shape)
Filip Lux's avatar
Filip Lux committed
            self.assertEqual(tr_img.dtype, data_type)
Lucia Hradecká's avatar
Lucia Hradecká committed

    def test_keypoints(self):
        in_shape = (32, 31, 30)
        tests = get_keypoints_tests(Resize, in_shape, params={'shape': (40, 41, 42)})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)

        tests = get_keypoints_tests(Resize, in_shape, params={'shape': (20, 21, 22)})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)

Lucia Hradecká's avatar
Lucia Hradecká committed

class TestPad(unittest.TestCase):
    def test_1(self):
        tr = Compose([Pad(2)])

        img = np.empty((30, 30, 30))
        tr_img = tr(image=img)['image']
        self.assertTupleEqual(tr_img.shape, (1, 34, 34, 34))

        img = np.empty((1, 30, 30, 30))
        tr_img = tr(image=img)['image']
        self.assertTupleEqual(tr_img.shape, (1, 34, 34, 34))

        img = np.empty((4, 30, 30, 30))
        tr_img = tr(image=img)['image']
        self.assertTupleEqual(tr_img.shape, (4, 34, 34, 34))

        img = np.empty((4, 30, 30, 30, 5))
        tr_img = tr(image=img)['image']
        self.assertTupleEqual(tr_img.shape, (4, 34, 34, 34, 5))

    def test_keypoints(self):
        in_shape = (32, 31, 30)
        tests = get_keypoints_tests(Pad, in_shape, params={'pad_size': (5, 8)})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)

        tests = get_keypoints_tests(Pad, in_shape, params={'pad_size': 4})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)

        tests = get_keypoints_tests(Pad, in_shape, params={'pad_size': (3, 4, 5, 6, 7, 8)})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)

Lucia Hradecká's avatar
Lucia Hradecká committed

Filip Lux's avatar
Filip Lux committed
class TestRandomAffineTransform(unittest.TestCase):
    def test_shape(self):

        angle_limits = [10,
                        (-20, 20),
                        (12, 30, 0),
                        (-20, 20, -180, 180, 0, 0)]

        for angle_limit in angle_limits:
            tests = get_shape_tests(RandomAffineTransform, (31, 32, 33),
                                    params={'angle_limit': angle_limit})
Filip Lux's avatar
Filip Lux committed
            for tr_img, expected_shape, data_type in tests:
                self.assertTupleEqual(tr_img.shape, expected_shape)
                self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

        scale_limits = [0.2,
                        (0.8, 1.2),
                        (0.2, 0.3, 0.1),
                        (0.8, 1.2, 0.9, 1.1, 0.7, 1.)]

        for scale_limit in scale_limits:
            tests = get_shape_tests(RandomAffineTransform, (31, 32, 33),
                                    params={'scaling_limit': scale_limit})
Filip Lux's avatar
Filip Lux committed
            for tr_img, expected_shape, data_type in tests:
                self.assertTupleEqual(tr_img.shape, expected_shape)
                self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

        translation_limits = [10,
                              (0, 12),
                              (3, 5, 10),
                              (-3, 3, -5, 5, 0, 0)]

        for translation in translation_limits:
            tests = get_shape_tests(RandomAffineTransform, (31, 32, 33),
                                    params={'translation_limit': translation})
Filip Lux's avatar
Filip Lux committed
            for tr_img, expected_shape, data_type in tests:
                self.assertTupleEqual(tr_img.shape, expected_shape)
                self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

    def test_keypoints(self):

        in_shape = (61, 62, 63)

        angle_limits = [10,
                        (-20, 20),
                        (12, 30, 0),
                        (-20, 20, -180, 180, 0, 0)]

        for angle_limit in angle_limits:
            tests = get_keypoints_tests(RandomAffineTransform,
                                        in_shape=in_shape,
                                        params={'angle_limit': angle_limit})

            for value, expected_value, msg in tests:
                self.assertGreater(value, expected_value * 0.1, msg)

        scale_limits = [0.2,
                        (0.8, 1.2),
                        (0.2, 0.3, 0.1),
                        (0.8, 1.2, 0.9, 1.1, 0.7, 1.)]

        for scale_limit in scale_limits:
            tests = get_keypoints_tests(RandomAffineTransform,
                                        in_shape=in_shape,
                                        params={'scaling_limit': scale_limit})

            for value, expected_value, msg in tests:
                self.assertGreater(value, expected_value * 0.5, msg)

        translation_limits = [10,
                              (0, 12),
                              (3, 5, 10),
                              (-3, 3, -5, 5, 0, 0)]

        for translation in translation_limits:
            tests = get_keypoints_tests(RandomAffineTransform,
                                        in_shape=in_shape,
                                        params={'translation_limit': translation})

            for value, expected_value, msg in tests:
                self.assertGreater(value, expected_value * 0.2, msg)


Filip Lux's avatar
Filip Lux committed

class TestAffineTransform(unittest.TestCase):
    def test_shape(self):

        scale = (1.2, 0.8, 1)
        translation = (0, 1, -40)
        angles = (-20, 0, -0.5)

        tests = get_shape_tests(AffineTransform, (31, 32, 33),
                                params={'translation': translation})
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

        tests = get_shape_tests(AffineTransform, (31, 32, 33),
                                params={'scale': scale})
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

        tests = get_shape_tests(AffineTransform, (31, 32, 33),
                                params={'angles': angles})
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed

    def test_keypoints(self):

        scale = (1.2, 0.8, 1)
        translation = (0, 1, -40)
        angles = (-20, 0, -0.5)
Filip Lux's avatar
Filip Lux committed

        tests = get_keypoints_tests(AffineTransform,
                                    in_shape=(61, 62, 63),
                                    params={'scale': scale})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)
        tests = get_keypoints_tests(AffineTransform,
                                    in_shape=(61, 62, 63),
                                    params={'translation': translation})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)

        tests = get_keypoints_tests(AffineTransform,
                                    in_shape=(61, 62, 63),
                                    params={'angles': angles})
        for value, expected_value, msg in tests:
            self.assertGreater(value, expected_value * 0.5, msg)

# ImageTransforms
Filip Lux's avatar
Filip Lux committed
class TestNormalizeMeanStd(unittest.TestCase):
    def test_shape(self):

        mean = 1.2
        std = 2
        tests = get_shape_tests(NormalizeMeanStd, (31, 32, 33),
                                params={'mean': mean,
                                        'std': std})
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)


class TestGaussianNoise(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(GaussianNoise, (31, 32, 33))
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)


class TestPoissonNoise(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(PoissonNoise, (31, 32, 33))
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
Filip Lux's avatar
Filip Lux committed
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestGaussianBlur(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(GaussianBlur, (31, 32, 33))
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestRandomGamma(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(RandomGamma, (31, 32, 33))
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestRandomBrightnessContrast(unittest.TestCase):
    def test_shape(self):

        brightness_list = [3, (2, 5)]
        contrast_list = [0.5, (.7, 1.1)]

        for brightness in brightness_list:
            for contrast in contrast_list:
                tests = get_shape_tests(RandomBrightnessContrast, (30, 31, 32),
                                        params={'brightness_limit': brightness,
                                                'contrast_limit': contrast})
Filip Lux's avatar
Filip Lux committed
                for tr_img, expected_shape, data_type in tests:
                    self.assertTupleEqual(tr_img.shape, expected_shape)
                    self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestHistogramEqualization(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(HistogramEqualization, (31, 32, 33))
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


class TestNormalize(unittest.TestCase):
    def test_shape(self):
        tests = get_shape_tests(Normalize, (31, 32, 33))
Filip Lux's avatar
Filip Lux committed
        for tr_img, expected_shape, data_type in tests:
            self.assertTupleEqual(tr_img.shape, expected_shape)
            self.assertEqual(tr_img.dtype, data_type)
Filip Lux's avatar
Filip Lux committed


Lucia Hradecká's avatar
Lucia Hradecká committed
if __name__ == '__main__':
    unittest.main()