Skip to content
Snippets Groups Projects
Commit 835481c0 authored by libo's avatar libo
Browse files

add vot iter

parent b899d80c
No related branches found
No related tags found
No related merge requests found
% error('Tracker not configured! Please edit the tracker_test.m file.'); % Remove this line after proper configuration
% The human readable label for the tracker, used to identify the tracker in reports
% If not set, it will be set to the same value as the identifier.
% It does not have to be unique, but it is best that it is.
tracker_label = ['SiamRPNpp'];
% For Python implementations we have created a handy function that generates the appropritate
% command that will run the python executable and execute the given script that includes your
% tracker implementation.
%
% Please customize the line below by substituting the first argument with the name of the
% script of your tracker (not the .py file but just the name of the script) and also provide the
% path (or multiple paths) where the tracker sources % are found as the elements of the cell
% array (second argument).
setenv('MKL_NUM_THREADS','1');
pysot_root = 'path/to/pysot';
track_build_path = 'path/to/track/build';
tracker_command = generate_python_command('vot_iter.vot_iter', {pysot_root; [track_build_path '/python/lib']})
tracker_interpreter = 'python';
tracker_linkpath = {track_build_path};
% tracker_linkpath = {}; % A cell array of custom library directories used by the tracker executable (optional)
"""
\file vot.py
@brief Python utility functions for VOT integration
@author Luka Cehovin, Alessio Dore
@date 2016
"""
import sys
import copy
import collections
try:
import trax
except ImportError:
raise Exception('TraX support not found. Please add trax module to Python path.')
Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height'])
Point = collections.namedtuple('Point', ['x', 'y'])
Polygon = collections.namedtuple('Polygon', ['points'])
class VOT(object):
""" Base class for Python VOT integration """
def __init__(self, region_format, channels=None):
""" Constructor
Args:
region_format: Region format options
"""
assert(region_format in [trax.Region.RECTANGLE, trax.Region.POLYGON])
if channels is None:
channels = ['color']
elif channels == 'rgbd':
channels = ['color', 'depth']
elif channels == 'rgbt':
channels = ['color', 'ir']
elif channels == 'ir':
channels = ['ir']
else:
raise Exception('Illegal configuration {}.'.format(channels))
self._trax = trax.Server([region_format], [trax.Image.PATH], channels)
request = self._trax.wait()
assert(request.type == 'initialize')
if isinstance(request.region, trax.Polygon):
self._region = Polygon([Point(x[0], x[1]) for x in request.region])
else:
self._region = Rectangle(*request.region.bounds())
self._image = [x.path() for k, x in request.image.items()]
if len(self._image) == 1:
self._image = self._image[0]
self._trax.status(request.region)
def region(self):
"""
Send configuration message to the client and receive the initialization
region and the path of the first image
Returns:
initialization region
"""
return self._region
def report(self, region, confidence = None):
"""
Report the tracking results to the client
Arguments:
region: region for the frame
"""
assert(isinstance(region, Rectangle) or isinstance(region, Polygon))
if isinstance(region, Polygon):
tregion = trax.Polygon.create([(x.x, x.y) for x in region.points])
else:
tregion = trax.Rectangle.create(region.x, region.y, region.width, region.height)
properties = {}
if not confidence is None:
properties['confidence'] = confidence
self._trax.status(tregion, properties)
def frame(self):
"""
Get a frame (image path) from client
Returns:
absolute path of the image
"""
if hasattr(self, "_image"):
image = self._image
del self._image
return image
request = self._trax.wait()
if request.type == 'frame':
image = [x.path() for k, x in request.image.items()]
if len(image) == 1:
return image[0]
return image
else:
return None
def quit(self):
if hasattr(self, '_trax'):
self._trax.quit()
def __del__(self):
self.quit()
import sys
import cv2
import torch
import numpy as np
import os
from os.path import join
from pysot.core.config import cfg
from pysot.models.model_builder import ModelBuilder
from pysot.tracker.tracker_builder import build_tracker
from pysot.utils.bbox import get_axis_aligned_bbox
from pysot.utils.model_load import load_pretrain
from toolkit.datasets import DatasetFactory
from toolkit.utils.region import vot_overlap, vot_float2str
from . import vot
from .vot import Rectangle, Polygon, Point
# modify root
cfg_root = "path/to/expr"
model_file = join(cfg_root, 'model.pth')
cfg_file = join(cfg_root, 'config.yaml')
def warmup(model):
for i in range(10):
model.template(torch.FloatTensor(1,3,127,127).cuda())
def setup_tracker():
cfg.merge_from_file(cfg_file)
model = ModelBuilder()
model = load_pretrain(model, model_file).cuda().eval()
tracker = build_tracker(model)
warmup(model)
return tracker
tracker = setup_tracker()
handle = vot.VOT("polygon")
region = handle.region()
try:
region = np.array([region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1],
region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1]])
except:
region = np.array(region)
cx, cy, w, h = get_axis_aligned_bbox(region)
image_file = handle.frame()
if not image_file:
sys.exit(0)
im = cv2.imread(image_file) # HxWxC
# init
target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
gt_bbox_ = [cx-(w-1)/2, cy-(h-1)/2, w, h]
tracker.init(im, gt_bbox_)
while True:
img_file = handle.frame()
if not img_file:
break
im = cv2.imread(img_file)
outputs = tracker.track(im)
pred_bbox = outputs['bbox']
result = Rectangle(*pred_bbox)
score = outputs['best_score']
if cfg.MASK.MASK:
pred_bbox = outputs['polygon']
result = Polygon(Point(x[0], x[1]) for x in pred_bbox)
handle.report(result, score)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment