Commit 5199e2bc authored by Filip Lux's avatar Filip Lux
Browse files

3D watershed

parent aaf050b9
Loading
Loading
Loading
Loading
+19 −4
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ from .config import Config
from .preprocessing import get_pixel_weight_function, preprocess_gt, smooth_components
from datetime import datetime
import random
import tifffile as tiff


from .eidg import random_transform
@@ -341,7 +342,7 @@ class Dataset:
        self.input_depth = config.INPUT_DEPTH

        #  get flists
        self.img_sequences = self._get_img_sequences()
        self.img_sequences = self._get_source_seq()
        self.flist_img = None
        self.flist_gt = None
        self.flist_mask = None
@@ -461,7 +462,7 @@ class Dataset:
        assert len(img_flist) > 0, f'no samples were found\nIMG path: {img_path}'
        return img_flist, gt_flist, mask_flist

    def _get_img_sequences(self):
    def _get_source_seq(self):
        if self.mode == 1:
            # list all the possible sequence directories
            # exclude validation_seq
@@ -477,8 +478,7 @@ class Dataset:
            return [self.seq]

    def update_flists(self):
        self.flist_img, self.flist_gt, self.flist_mask = \
            self._get_flists()
        self.flist_img, self.flist_gt, self.flist_mask = self._get_flists()

        if self.mode == 1:
            if self.validation_sequence is None:
@@ -542,3 +542,18 @@ class Dataset:
        img_path = self.flist_img[0][0]
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        return img.shape

    def get_z_depth(self):
        sequences = self.img_sequences

        for seq in sequences:
            img_path = os.path.join(self.config.DATA_PATH, self.name, seq)

            assert os.path.isdir(img_path), img_path
            files = os.listdir(img_path)
            assert len(files) > 0
            img_stack = tiff.imread(os.path.join(img_path, files[0]))
            if len(img_stack.shape) == 2:
                return 1
            else:
                return img_stack.shape[0]
+87 −25
Original line number Diff line number Diff line
@@ -16,7 +16,9 @@ from .models import UNetModel
from .postprocessing import \
    postprocess_markers as pm,\
    postprocess_markers_09,\
    postprocess_foreground
    postprocess_foreground, \
    postprocess_markers_3D, \
    postprocess_foreground_3D
from .utils import \
    get_formatted_shape,\
    overlay_labels,\
@@ -139,12 +141,12 @@ class DeepWater:
            if self.tracking:
                print('Creates tracking...\n')
                create_tracking(self.out_path, self.out_path, self.track_threshold)
            if self.display:
            if self.display and ('3D' not in self.name):
                """ create new path for vizualization """
                self._store_visualisations()

            if '3D' in self.name:
                self._merge_results()
            # if '3D' in self.name:
            #     self._merge_results()

            gt_path = os.path.join(self.config.DATA_PATH, self.name, f'{self.seq}_GT/SEG')
            if not os.path.isdir(gt_path):
@@ -202,22 +204,45 @@ class DeepWater:

        batch_ids = self._get_batch_ids(n_batches)

        z_depth = self.dataset.get_z_depth()
        assert z_depth > 0

        m, n = self.dataset.get_original_size()
        
        marker_prediction = np.zeros((0, m, n))
        foreground_prediction = np.zeros((0, m, n))

        for batch_index in tqdm(range(n_batches)):
            marker_prediction = marker_model.predict_dataset(self.dataset,
                                                             batch_index=batch_index)[..., -1]
            foreground_prediction = foreground_model.predict_dataset(self.dataset,
                                                                     batch_index=batch_index)[..., -1]
            n_samples = len(foreground_prediction)

            ids = batch_ids[batch_index]

            for i in range(n_samples):
                marker_image = marker_prediction[i, ...]
                foreground_image = foreground_prediction[i, ...]
            # get new batch
            new_marker_pred = marker_model.predict_dataset(self.dataset, batch_index=batch_index)[..., -1]
            new_for_pred = foreground_model.predict_dataset(self.dataset, batch_index=batch_index)[..., -1]

            assert new_marker_pred.shape[1:] == (m, n)
            assert new_for_pred.shape[1:] == (m, n)

            # append to existing
            marker_prediction = np.concatenate([marker_prediction, new_marker_pred], axis=0)
            foreground_prediction = np.concatenate([foreground_prediction, new_for_pred], axis=0)

            tot_samples = marker_prediction.shape[0]
            n_imgs = tot_samples // z_depth

            for i in range(n_imgs):
                marker_image, marker_prediction = np.split(marker_prediction, [z_depth])
                foreground_image, foreground_prediction = np.split(foreground_prediction, [z_depth])

                assert foreground_image.shape == (z_depth, m, n), foreground_image.shape
                assert marker_image.shape == (z_depth, m, n), marker_image.shape

                marker_image = (marker_image * 255).astype(np.uint8)
                foreground_image = (foreground_image * 255 * 255).astype(np.uint16)

                if z_depth == 1:
                    marker_image = np.squeeze(marker_image)
                    foreground_image = np.squeeze(foreground_image)
                    _, marker_function = postprocess_markers(marker_image,
                                                             threshold=self.config.THR_MARKERS,
                                                             c=self.config.MARKER_DIAMETER,
@@ -226,6 +251,7 @@ class DeepWater:

                    foreground = postprocess_foreground(foreground_image,
                                                        threshold=self.config.THR_FOREGROUND)
                    # impose markers to foreground
                    foreground = np.maximum(foreground, (marker_function > 0) * 255)
                    segmentation_function = self._get_segmentation_function(foreground_image, foreground)

@@ -233,6 +259,21 @@ class DeepWater:
                    labels = remove_edge_cells(labels, self.border)

                    self._store_results(ids[i], labels, marker_image, foreground_image, marker_function, foreground)
                else:
                    marker_function = postprocess_markers_3D(marker_image,
                                                             threshold=self.config.THR_MARKERS,
                                                             c=self.config.MARKER_DIAMETER)
                    foreground = postprocess_foreground_3D(foreground_image,
                                                           threshold=self.config.THR_FOREGROUND)

                    labels = watershed(-foreground_image, marker_function, mask=foreground)
                    # labels = remove_edge_cells(labels, self.border)

                    self._store_results_3D(ids[i], labels, marker_image, foreground_image, marker_function, foreground)





    def _get_batch_ids(self, n_batches):
        img_flist, _, _ = self._get_flists(self.dataset)
@@ -305,6 +346,27 @@ class DeepWater:
            cv2.imwrite(f'{self.viz_path}/c{index}.tif',
                        (foreground_image / 255).astype(np.uint8))

    def _store_results_3D(self, index, labels, marker_image, foreground_image, marker_function, foreground):
        # store result
        index = index.split('_')[0]
        tifffile.imwrite(os.path.join(self.out_path, f'mask{index}.tif'),
                         labels.astype(np.uint16),
                         photometric='minisblack')

        if self.display:
            ws_functions = np.concatenate(((marker_function>0)*255, foreground), axis=2)
            tifffile.imwrite(f'{self.viz_path}/ws_functions{index}.tif',
                             ws_functions.astype(np.uint8),
                             photometric='minisblack')

            tifffile.imwrite(f'{self.viz_path}/m{index}.tif',
                             marker_image,
                             photometric='minisblack')
            tifffile.imwrite(f'{self.viz_path}/c{index}.tif',
                             (foreground_image / 255).astype(np.uint8),
                             photometric='minisblack')


    def _get_segmentation_function(self, foreground_image, foreground):
        if self.version == 1.0:
            # imposing markers into segmenation function
@@ -434,7 +496,7 @@ class DeepWater:

            seq = find_sequences(self.config.IMG_PATH)
            assert len(seq) > 0, f'there are no image sequences in {self.config.IMG_PATH}'
            print(seq)
            # print(seq)

            for s in seq:
                if self.input_suffix:
+1 −2
Original line number Diff line number Diff line
@@ -227,16 +227,15 @@ class UNetModel():
                       workers=5)

    def predict_dataset(self, dataset: Dataset, batch_index=None):
        # with a final crop to original image size
        generator = dataset.get_img_generator()
        m, n = dataset.get_original_size()

        if batch_index is None:

            prediction = self.model.predict_generator(generator,
                                                      verbose=True,
                                                      use_multiprocessing=True,
                                                      workers=10)

        else:
            x = generator[batch_index]
            prediction = self.model.predict(x)
+9 −7
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import cv2
import numpy as np
from scipy.ndimage.morphology import binary_fill_holes
from skimage.segmentation import watershed
from skimage.morphology import ball, binary_opening
from skimage.morphology import disk, binary_opening
from skimage.measure import label


@@ -75,15 +75,16 @@ def postprocess_markers_3D(img,
    threshold == tm
    """

    kernel = ball(c)
    markers = binary_opening(img, kernel)
    glob_f = (markers > threshold).astype(np.uint8)
    kernel = np.expand_dims(disk(c), axis=0)
    # print('opening', kernel.shape, img.shape)
    markers = binary_opening((img > threshold), kernel).astype(np.uint8)

    # label connected components
    labels = label(glob_f)
    labels = label(markers)
    # print('Marker threshold:', threshold, c, 'sum:', np.sum(labels>0))

    # print(threshold, c, circular, h)
    return np.unique(labels), labels
    return labels


# postprocess markers
@@ -131,7 +132,8 @@ def postprocess_foreground(b, threshold=230):
# postprocess foreground
def postprocess_foreground_3D(b, threshold=230):
    # tresholding
    bt = (b > int(threshold)) * 255
    bt = (b//255 > int(threshold)) * 255
    # print('Foreground threshold:', threshold, 'sum:', np.sum(bt>0))

    return bt

+29 −13
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@ from tqdm import tqdm
from time import sleep



class Normalizer:
    def __init__(self, normalization: str, uneven_illumination: bool = False):
        self.normalfce = get_normal_fce(normalization)
@@ -176,19 +175,22 @@ def create_tracking(path, output_path, threshold=0.15):
    names = [name for name in names if '.tif' in name and 'mask' in name]
    names.sort()

    img = cv2.imread(os.path.join(path, names[0]), cv2.IMREAD_ANYDEPTH)
    mi, ni = img.shape
    img = tiff.imread(os.path.join(path, names[0]))
    shape = img.shape

    print('Relabelling the segmentation masks.')
    records = {}

    old = np.zeros((mi, ni))
    old = np.zeros(img.shape)
    index = 1
    n_images = len(names)

    for i, name in enumerate(tqdm(names)):
        result = np.zeros((mi, ni), np.uint16)
        result = np.zeros(img.shape, np.uint16)

        img_path = os.path.join(path, name)
        img = tiff.imread(img_path)

        img = cv2.imread(os.path.join(path, name), cv2.IMREAD_ANYDEPTH)
        assert img.shape == shape, f'{img_path} -> {img.shape}'

        labels = np.unique(img)[1:]

@@ -197,6 +199,8 @@ def create_tracking(path, output_path, threshold=0.15):
        for label in labels:
            mask = (img == label) * 1

            assert mask.shape == shape

            mask_size = np.sum(mask)
            overlap = mask * old
            candidates = np.unique(overlap)[1:]
@@ -232,6 +236,8 @@ def create_tracking(path, output_path, threshold=0.15):
                        m_mask = (result == max_candidate) * 1
                        result = result - m_mask * max_candidate + m_mask * index

                        assert result.shape == shape

                        records[index] = [i, i, max_candidate.astype(np.uint16)]
                        index += 1

@@ -243,11 +249,13 @@ def create_tracking(path, output_path, threshold=0.15):
                # update of used parent cells
                parent_cells.append(max_candidate)
        # store result
        cv2.imwrite(os.path.join(output_path, name), result.astype(np.uint16))
        # print('result shape', result.shape)
        tiff.imwrite(os.path.join(output_path, name), result.astype(np.uint16))
        old = result

    # store tracking
    print('Generating the tracking file.')
    print('tracking results:', records)
    with open(os.path.join(output_path, 'res_track.txt'), "w") as file:
        for key in records.keys():
            file.write('{} {} {} {}\n'.format(key, records[key][0], records[key][1], records[key][2]))
@@ -262,17 +270,23 @@ def remove_edge_cells(label_img, border=20):

def get_edge_indexes(label_img, border=20):
    mask = np.ones(label_img.shape)
    if len(mask.shape) == 2:
        mi, ni = mask.shape
        mask[border:mi - border, border:ni - border] = 0
    elif len(mask.shape) == 3:
        mi, ni, di = mask.shape
        mask[border:mi - border, border:ni - border, border:di - border] = 0
    else:
        assert False
    border_cells = mask * label_img
    indexes = (np.unique(border_cells))
    indexes = np.delete(np.unique(border_cells), [0])

    result = []

    # get only cells with center inside the mask
    for index in indexes:
        cell_size = sum(sum(label_img == index))
        gap_size = sum(sum(border_cells == index))
        cell_size = np.sum(label_img == index)
        gap_size = np.sum(border_cells == index)
        if cell_size * 0.5 < gap_size:
            result.append(index)

@@ -537,3 +551,5 @@ def unfold_3D(path, markers=False, filter_empty=False):
            if not filter_empty or np.sum(img) > 0:
                cv2.imwrite(file_path, img)

if __name__ == '__main__':
    create_tracking('/home/xlux/PROJECTS/deepwater/tmp/test_A549', '/home/xlux/PROJECTS/deepwater/tmp/test_A549')