Skip to content
Snippets Groups Projects
gen_json.py 2.76 KiB
Newer Older
# Author: Kristyna Janku, Snippets creation and structure inspired by gen_json.py for VID dataset

import os
import glob
import json
import cv2

video_list = ['01', '02']

snippets = dict()
num_videos = 0

for video in video_list:
    num_videos += 1
    id_set = []
    id_frames = {}
    snippets[video] = dict()

    # Load numbers of all frames from filenames
    frames = sorted([''.join(filter(str.isdigit, frame)) for frame in os.listdir(os.path.join('data', video)) if frame.endswith('.tif')])

    # Load tracking file and get information about objects and frames they appear in
    with open(os.path.join(data_path, video + '_GT', 'TRA', 'man_track.txt'), 'r') as track_file:
        for line in track_file:
            parts = line.split(' ')
            obj_id = parts[0]
            id_set.append(int(obj_id))
            id_frames[int(obj_id)] = list(range(int(parts[1]), (int(parts[2]) + 1)))
            snippets[video][obj_id] = dict()

    for frame in frames:
        # Load segmented image, gold truth if available, silver truth otherwise
        seg = glob.glob(os.path.join(data_path, video + '_GT', 'SEG', '*' + frame + '*'))
            seg = glob.glob(os.path.join(data_path, video + '_ST', 'SEG', '*' + frame + '*'))
        seg_image = cv2.imread(seg[0], cv2.IMREAD_ANYDEPTH)

        for obj_id in range(1, seg_image.max() + 1):
            if obj_id in id_set and int(frame) in id_frames[obj_id]:
                # Mask only pixels with current object, multiply by 1 to change True/False to 1/0
                mask = (seg_image == obj_id) * 1
                # Convert to 8-bit image to enable further transformations
                mask = cv2.convertScaleAbs(mask)
                # Some ST segmentation data was inconsistent with GT tracking data (e.g. dividing cells),
                # so checking if object really is in segmented image
                if mask.max() == 0:
                    print('Video: ' + str(num_videos) + ' Frame: ' + frame + ' Missing object id in seg: ' + str(obj_id))
                else:
                    # Get bounding box of object from contour
                    cnt = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
                    x, y, w, h = cv2.boundingRect(cnt[0])
                    snippets[video][str(obj_id)][frame] = [x, y, x+w, y+h]

    print('Video number: {:d} Snippets count: {:d}'.format(num_videos, len(snippets[video])))

# Divide into training and validation data
train = {k: v for (k, v) in snippets.items() if '01' in k}
val = {k: v for (k, v) in snippets.items() if '02' in k}

# Save to json files
json.dump(train, open('train.json', 'w'), indent=4, sort_keys=True)
json.dump(val, open('val.json', 'w'), indent=4, sort_keys=True)
print('All videos done!')