Commit f62db7d7 authored by Lucia D. Hradecka's avatar Lucia D. Hradecka
Browse files

[fix] RandomRotate90 (finished)

1. fixed for all target types
2. users can input multiple factors
3. improved code readability
parent febbc9d1
Loading
Loading
Loading
Loading
+10 −5
Original line number Diff line number Diff line
@@ -265,21 +265,26 @@ def transpose_keypoints(keypoints, ax1, ax2):


def rot90_keypoints(keypoints, factor, axes, img_shape):
    """
    Rotates keypoints by 0, 90, 180, and 270 degrees.
    Returns rotated keypoints and the shape of the rotated image domain.
    """
    img_shape = np.array(img_shape)

    if factor == 1:
        keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)
        keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
        img_shape[axes[0] - 1], img_shape[axes[1] - 1] = img_shape[axes[1] - 1], img_shape[axes[0] - 1]

    elif factor == 2:
        keypoints = flip_keypoints(keypoints, axes, img_shape)

    elif factor == 3:
        keypoints = transpose_keypoints(keypoints, axes[0], axes[1])
        img_shape_transposed = np.array(img_shape)
        img_shape_transposed[axes[0] - 1], img_shape_transposed[axes[1] - 1] = \
            img_shape_transposed[axes[1] - 1], img_shape_transposed[axes[0] - 1]
        keypoints = flip_keypoints(keypoints, [axes[1]], img_shape_transposed)
        img_shape[axes[0] - 1], img_shape[axes[1] - 1] = img_shape[axes[1] - 1], img_shape[axes[0] - 1]
        keypoints = flip_keypoints(keypoints, [axes[1]], img_shape)

    return keypoints
    return keypoints, img_shape


def pad_keypoints(keypoints, pad_size):
+15 −13
Original line number Diff line number Diff line
@@ -910,14 +910,17 @@ class RandomRotate90(DualTransform):
        Targets:
            image, mask, float mask, key points, bounding boxes
    """
    def __init__(self, axes: List[int] = None, shuffle_axis: bool = False, factor: Optional[int] = None,
    def __init__(self, axes: List[int] = None, shuffle_axis: bool = False, factor: Optional[int | List[int]] = None,
                 always_apply: bool = False, p: float = 0.5):
        super().__init__(always_apply, p)

        self.axes = axes
        # Rotate around all spatial axes if not specified by the user:
        if self.axes is None:
            self.axes = [1, 2, 3]
            self.axes = [1, 2, 3]  # rotate around all spatial axes if not specified by the user

        if (factor is not None) and (not isinstance(factor, int)) and (len(self.axes) != len(factor)):
            raise ValueError(f'Lengths of the "axis" and "factor" arguments are not equal: '
                             f'{len(self.axes)} vs {len(factor)}')

        # Create all combinations for rotating
        axes_to_rotate = {1: (3, 2), 2: (1, 3), 3: (2, 1)}
@@ -928,11 +931,10 @@ class RandomRotate90(DualTransform):

        self.shuffle_axis = shuffle_axis
        self.factor = factor
        self.last_factor = None

    def apply(self, img, **params):
        for factor, axes in zip(params['factor'], params['rotation_around']):
            img = np.rot90(img, factor, axes=axes)
        for rot, factor in zip(params['rotation_around'], params['factor']):
            img = np.rot90(img, factor, axes=rot)
        return img

    def apply_to_mask(self, mask, **params):
@@ -941,11 +943,10 @@ class RandomRotate90(DualTransform):
        return mask

    def apply_to_keypoints(self, keypoints, **params):
        img_shape = params['img_shape']
        for rot, factor in zip(params['rotation_around'], params['factor']):
            keypoints = F.rot90_keypoints(keypoints,
                                          factor=factor,
                                          axes=(rot[0], rot[1]),
                                          img_shape=params['img_shape'])
            keypoints, img_shape = F.rot90_keypoints(keypoints, factor=factor, axes=(rot[0], rot[1]),
                                                     img_shape=img_shape)
        return keypoints

    def get_params(self, targets, **data):
@@ -957,9 +958,10 @@ class RandomRotate90(DualTransform):
        # If not specified, choose the angle to rotate
        if self.factor is None:
            factor = list(randint(0, 3, size=len(self.axes)))
        else:
        elif isinstance(self.factor, int) or isinstance(self.factor, float):
            factor = [self.factor] * len(self.axes)
        self.last_factor = factor
        else:
            factor = self.factor

        img_shape = get_spatial_shape_from_image(data, targets)

@@ -968,7 +970,7 @@ class RandomRotate90(DualTransform):
                'img_shape': img_shape}

    def __repr__(self):
        return f'RandomRotate90(axes={self.axes}, shuffle_axis={self.shuffle_axis}, factor={self.last_factor}, ' \
        return f'RandomRotate90(axes={self.axes}, shuffle_axis={self.shuffle_axis}, factor={self.factor}, ' \
               f'always_apply={self.always_apply}, p={self.p})'


+2 −2
Original line number Diff line number Diff line
@@ -240,9 +240,9 @@ class TestRandomRotate90(unittest.TestCase):
        axes_list = [[1], [2], [3],
                     [1, 2],
                     [1, 2, 3], None,
                     [1, 2, 3, 2, 3, 1, 3]]  # TODO sometimes fails for [1, 2, 3, 2, 3, 1, 3]
                     [1, 2, 3, 2, 3, 1, 3]]

        for _ in range(200):
        for _ in range(32):
            for axes in axes_list:
                tests = get_keypoints_tests(RandomRotate90, params={'axes': axes})
                for value, expected_value, msg in tests: