import typing
from typing import Callable, Tuple, Optional
import cv2
from gbvision.constants.types import Frame, Rect, TrackerType
_major_ver, _minor_ver, _subminor_ver = cv2.__version__.split('.')
class _EmptyTracker:
def __init__(self):
self.__rect = None
def init(self, frame: Frame, rect: Rect):
self.__rect = rect
return True
def update(self, frame: Frame) -> Tuple[bool, Rect]:
return True, self.__rect
def _get_tracker_create(tracker_name: str) -> Optional[Callable[[], TrackerType]]:
if tracker_name == 'EMPTY':
return _EmptyTracker
try:
if int(_major_ver) < 4 and int(_minor_ver) < 3:
return lambda: cv2.cv2.Tracker_create(tracker_name.upper())
attr_full_name = f'Tracker{tracker_name}_create'
if hasattr(cv2, attr_full_name):
return cv2.__dict__[attr_full_name]
return cv2.legacy.__dict__[attr_full_name]
except AttributeError:
return None
[docs]class Tracker:
"""
A tracker object that tracks a rectangle in a video using an opencv tracking algorithm
:param tracker_type: Tracker algorithm taken from this list: BOOSTING, MIL, KCF, TLD, MEDIANFLOW,
GOTURN, MOSSE, CSRT, EMPTY. (Default is EMPTY)
"""
TRACKER_TYPE_BOOSTING = 'BOOSTING'
TRACKER_TYPE_MIL = 'MIL'
TRACKER_TYPE_KCF = 'KCF'
TRACKER_TYPE_TLD = 'TLD'
TRACKER_TYPE_MEDIANFLOW = 'MEDIANFLOW'
TRACKER_TYPE_GOTURN = 'GOTURN'
TRACKER_TYPE_MOSSE = 'MOSSE'
TRACKER_TYPE_CSRT = 'CSRT'
TRACKER_TYPE_EMPTY = 'EMPTY'
_TRACKER_ALGORITHMS = {
tracker_name: _get_tracker_create(tracker_type)
for tracker_name, tracker_type in
[(TRACKER_TYPE_BOOSTING, 'Boosting'),
(TRACKER_TYPE_MIL, 'MIL'),
(TRACKER_TYPE_KCF, 'KCF'),
(TRACKER_TYPE_TLD, 'TLD'),
(TRACKER_TYPE_MEDIANFLOW, 'MedianFlow'),
(TRACKER_TYPE_GOTURN, 'GOTURN'),
(TRACKER_TYPE_MOSSE, 'MOSSE'),
(TRACKER_TYPE_CSRT, 'CSRT'),
(TRACKER_TYPE_EMPTY, 'EMPTY')]
}
def __init__(self, tracker_type: str = TRACKER_TYPE_EMPTY):
tracker_type = tracker_type.upper()
assert tracker_type in self._TRACKER_ALGORITHMS, f'Unknown tracker type: {tracker_type}'
assert self._TRACKER_ALGORITHMS[tracker_type] is not None,\
f'Your version of OpenCV has no support for tracker type {tracker_type}'
self.tracker = self._TRACKER_ALGORITHMS[tracker_type]()
self.tracker_type = tracker_type
[docs] def init(self, frame: Frame, rect: Rect) -> bool:
"""
Initialize the tracker
:param frame: The frame
:param rect: Given rectangle
:return: True if initialization went successfully, False otherwise
"""
return self.tracker.init(frame, typing.cast(Rect, tuple([int(max(x, 0)) for x in rect])))
[docs] def update(self, frame: Frame) -> Rect:
"""
Get the rect location in new frame
:param frame: The frame
:return: The location of the rect in new frame
"""
return self.tracker.update(frame)[1]