Source code for gbvision.utils.continuity.continues_shape_wrapper

from typing import List, Dict, Callable, Union, Optional

from gbvision.constants.types import Frame, Shape
from gbvision.utils.continuity import ContinuesCircle
from gbvision.utils.continuity.continues_rect import ContinuesRect
from gbvision.utils.continuity import ContinuesRotatedRect
from gbvision.utils.continuity import ContinuesShape
from gbvision.utils.tracker import Tracker


[docs]class ContinuesShapeWrapper: """ An object that tracks several shapes in a frame using continuity :param shapes: A list of shapes to track using continuity (must be of the same shape) :param frame: The frame from which the shapes were found :param finding_pipeline: A function that finds the shapes in a given frame and returns a list of them (order irrelevant) :param shape_type: The type of the shape, can be either 'CIRCLE', 'RECT', or 'ROTATED_RECT', default is 'RECT', can also be a class that inherits from ContinuesShape :param tracker_type: The type of the trackers to use, default is 'EMPTY' :param shape_lifespan: The maximum amount of frames for a shape to not be found until it is considered lost :param track_new: Indicates whether to track new shapes that were un-tracked so far or ignore them, default is False (ignore) :param args: Additional arguments for continues shape constructor :param kwargs: Additional keyword arguments for continues shape constructor """ SHAPE_TYPE_CIRCLE = 'CIRCLE' SHAPE_TYPE_RECT = 'RECT' SHAPE_TYPE_ROTATED_RECT = 'ROTATED_RECT' _CONTINUES_SHAPE_TYPES = { SHAPE_TYPE_CIRCLE: ContinuesCircle, SHAPE_TYPE_RECT: ContinuesRect, SHAPE_TYPE_ROTATED_RECT: ContinuesRotatedRect } def __init__(self, shapes: List[Shape], frame: Frame, finding_pipeline: Callable[[Frame], List[Shape]], shape_type: Union[str, type] = SHAPE_TYPE_RECT, tracker_type: str = Tracker.TRACKER_TYPE_EMPTY, shape_lifespan: Optional[int] = None, track_new: bool = False, *args, **kwargs): if shape_type in self._CONTINUES_SHAPE_TYPES: self.shape_type = self._CONTINUES_SHAPE_TYPES[shape_type] else: self.shape_type = shape_type self.tracker_type = tracker_type self.shape_lifespan = shape_lifespan self.finding_pipeline = finding_pipeline self.shapes: Dict[int, ContinuesShape] = {} self.track_new = track_new self.__args = args self.__kwargs = kwargs for i, shape in enumerate(shapes): self.shapes[i] = self.__create_continues_shape(shape, frame) self.__idx = len(shapes) def __create_continues_shape(self, shape, frame) -> ContinuesShape: return self.shape_type(shape, frame, Tracker(self.tracker_type), *self.__args, **self.__kwargs)
[docs] def find_shapes(self, frame: Frame) -> Dict[int, Shape]: """ Finds all shapes in the frame, them performs a continues shape operations on them and return the result as a dict where the keys are unique ids and the values are the shapes if a shape was lost it removes it from the tracked shapes list if a new shape was found and the track_new field is set to True it adds it to the tracked shapes list :param frame: The frame to search in :return: A dict mapping from unique ids to shapes, based on continuity """ shapes = self.finding_pipeline(frame) result = {} to_delete = [] for i in self.shapes: cont_shape = self.shapes[i] if cont_shape.is_lost(self.shape_lifespan): to_delete.append(i) continue found = False for j, shape in enumerate(shapes): if cont_shape.update(shape, frame): found = True del shapes[j] break if not found: cont_shape.update_forced(frame) result[i] = cont_shape.get() for i in to_delete: del self.shapes[i] if self.track_new: for shape in shapes: self.shapes[self.__idx] = self.__create_continues_shape(shape, frame) result[self.__idx] = self.shapes[self.__idx].get() self.__idx += 1 return result
[docs] def get_shapes(self) -> Dict[int, Shape]: """ Returns the current location of the shapes based on continuity NOTE! this will be applied to the last frame given to the find_shapes method, only use this method if you need to get the shapes twice in an iteration :return: A dict mapping from unique ids to shapes """ result = {} for i in self.shapes: result[i] = self.shapes[i].get() return result
[docs] def get_shapes_as_list(self) -> List[Shape]: """ Gets all the shapes as a list instead of a dictionary :return: A list of all the tracked shapes (sorted by unique id's) """ return list(self.get_shapes().values())