Compare commits

...

13 Commits

  1. BIN
      Oxygen-Sys-Warning.wav
  2. 161
      app.py
  3. 2
      config.yaml
  4. 132
      core.py
  5. 2
      detector/demo.py
  6. 47
      gpuMonitor.py
  7. 1
      tracker/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/functional.py
  8. 2
      tracker/ltr/models/backbone/resnet.py
  9. 1
      tracker/ltr/models/bbreg/atom_iou_net.py
  10. 2
      tracker/ltr/models/layers/distance.py
  11. 2
      tracker/ltr/models/target_classifier/features.py
  12. 6
      tracker/pytracking/features/augmentation.py
  13. 7
      tracker/pytracking/features/preprocessing.py
  14. 1
      tracker/pytracking/libs/dcf.py
  15. 3
      tracker/pytracking/tracker/dimp/dimp.py
  16. 2
      tracker/pytracking/utils/params.py
  17. BIN
      video_streamer/vision_service.cpython-37m-x86_64-linux-gnu.so

BIN
Oxygen-Sys-Warning.wav

161
app.py

@ -4,9 +4,11 @@ import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from time import sleep from time import sleep
from typing import List from typing import List
import sys
import os import os
import time
import threading # Import threading module
from detector import Detector
# Add the proto directory to the Python path # Add the proto directory to the Python path
proto_dir = os.path.join(os.path.dirname(__file__), 'message_queue', 'proto') proto_dir = os.path.join(os.path.dirname(__file__), 'message_queue', 'proto')
@ -24,6 +26,7 @@ from icecream import ic
from configs import ConfigManager from configs import ConfigManager
from core import Core from core import Core
from tracker import Tracker
from message_queue.Bridge import Bridge from message_queue.Bridge import Bridge
from message_queue.Manager import Manager from message_queue.Manager import Manager
from message_queue.proto.ImageMessage_pb2 import ImageMessage, TrackMode from message_queue.proto.ImageMessage_pb2 import ImageMessage, TrackMode
@ -37,8 +40,6 @@ from message_queue.proto.ConnectStatus_pb2 import ConnectStatus
from video_streamer.gst_video_streamer import GstVideoStreamer from video_streamer.gst_video_streamer import GstVideoStreamer
import cv2 import cv2
import os
import time # Ensure time module is imported
config_manager = ConfigManager('config.yaml') config_manager = ConfigManager('config.yaml')
rtsp_links = config_manager.configs['rtsp_links'].get() rtsp_links = config_manager.configs['rtsp_links'].get()
@ -53,53 +54,112 @@ def handle_camera_status(status: int):
# Helper class to track client connection status # Helper class to track client connection status
class ConnectionTracker: class ConnectionTracker:
def __init__(self): def __init__(self):
self.last_message_time = time.time() # Use time.time() to get the current time
self.timeout = 15 # Timeout in seconds (adjust as needed)
self.last_message_time = time.time()
self.timeout = 15 # 15-second timeout
self.last_client_ip = None # Track last connected client IP
def update_last_message_time(self, client_ip=None):
self.last_message_time = time.time()
if client_ip:
self.last_client_ip = client_ip # Update last known client
def is_client_active(self):
"""Check if the last client is still active within the last 30 seconds."""
return (time.time() - self.last_message_time) < self.timeout
# Create an instance of ConnectionTracker
connection_tracker = ConnectionTracker()
def update_last_message_time(self):
self.last_message_time = time.time() # Update with the current time
def is_client_connected(self):
return (time.time() - self.last_message_time) < self.timeout # Use time.time()
class ConnectionThread(threading.Thread):
def __init__(self):
super().__init__()
self.running = True
# Create an instance of ConnectionTracker
connection_tracker = ConnectionTracker()
def run(self):
while self.running:
def check_client_connection():
if not connection_tracker.is_client_connected():
print("Client disconnected. Searching for another client...")
cl_ip, _ = start_discovery_service(12345)
ic(cl_ip)
# Reinitialize the Manager with the new client addressimport time # Ensure time module is imported
# Helper class to track client connection status
class ConnectionTracker:
def __init__(self):
self.last_message_time = time.time() # Use time.time() to get the current time
self.timeout = 15 # Timeout in seconds (adjust as needed)
sleep(0.5)
def update_last_message_time(self):
self.last_message_time = time.time() # Update with the current time
if connection_tracker.last_client_ip:
def is_client_connected(self):
return (time.time() - self.last_message_time) < self.timeout # Use time.time()
if debug:
print(f"Checking if last client {connection_tracker.last_client_ip} is still available...")
# Create an instance of ConnectionTracker
connection_tracker = ConnectionTracker()
if self.query_last_client(connection_tracker.last_client_ip):
connection_tracker.update_last_message_time()
if debug:
print(f"Last client {connection_tracker.last_client_ip} responded. Continuing...")
continue # Skip discovering a new client
if not connection_tracker.is_client_active():
def check_client_connection():
if not connection_tracker.is_client_connected():
print("Client disconnected. Searching for another client...")
cl_ip, _ = start_discovery_service(12345)
print(cl_ip)
print("Client inactive for 30 seconds. Searching for another client...")
cl_ip, _ = self.start_discovery_service(12345)
print(f"New client found: {cl_ip}")
# Reinitialize the Manager with the new client address # Reinitialize the Manager with the new client address
global manager global manager
manager = Manager(f"tcp://{cl_ip}:5558", f"tcp://{cl_ip}:5557") manager = Manager(f"tcp://{cl_ip}:5558", f"tcp://{cl_ip}:5557")
manager.start(manager_callback) # Restart the Manager and re-register the callback
manager.start(manager_callback)
connection_tracker.update_last_message_time(cl_ip) # Update with new client
def query_last_client(self, client_ip):
"""Send a heartbeat packet to the last known client IP on port 29170 and wait for a response."""
if debug:
print(f"Sending heartbeat to {client_ip}:29170") # Debugging
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(5) # 5-second timeout
heartbeat_port = 29170 # New port for heartbeat
try:
sock.sendto(b"HEARTBEAT", (client_ip, heartbeat_port)) # Send heartbeat message
data, addr = sock.recvfrom(1024) # Wait for response
if debug:
print(f"Received response from {addr}: {data.decode()}") # Debugging
if data.decode() == "HEARTBEAT_ACK":
if debug:
print(f"Client {client_ip} is still responding.")
return True
except socket.timeout:
print(f"Client {client_ip} did not respond. Marking as inactive.")
finally:
sock.close()
return False # Client did not respond
def start_discovery_service(self, port):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('0.0.0.0', port))
print(f"Discovery service listening on port {port}...")
while True:
data, addr = sock.recvfrom(1024)
if data.decode() == "DISCOVER_SERVER":
print(f"Received discovery request from {addr}")
sock.sendto(b"SERVER_RESPONSE", addr)
break
sock.close()
return addr
def stop(self):
self.running = False
connection_tracker.update_last_message_time() # Reset the timer after reconnecting
if __name__ == '__main__': if __name__ == '__main__':
QCoreApplication.setAttribute(Qt.AA_EnableHighDpiScaling) QCoreApplication.setAttribute(Qt.AA_EnableHighDpiScaling)
@ -113,7 +173,9 @@ if __name__ == '__main__':
print(f'{videoStreamer.id} connected') print(f'{videoStreamer.id} connected')
videoStreamers.append(videoStreamer) videoStreamers.append(videoStreamer)
core = Core(videoStreamers)
tracker = Tracker()
detector = Detector(classes=[0, 2, 5, 7])
core = Core(videoStreamers,tracker,detector)
def manager_callback(msg_str): def manager_callback(msg_str):
msg = Message() msg = Message()
@ -156,24 +218,32 @@ if __name__ == '__main__':
print(f"Received discovery request from {addr}") print(f"Received discovery request from {addr}")
sock.sendto(b"SERVER_RESPONSE", addr) sock.sendto(b"SERVER_RESPONSE", addr)
break break
sock.close()
return addr return addr
cl_ip, _ = start_discovery_service(12345) cl_ip, _ = start_discovery_service(12345)
print(cl_ip) print(cl_ip)
global manager
manager = Manager(f"tcp://{cl_ip}:5558", f"tcp://{cl_ip}:5557") manager = Manager(f"tcp://{cl_ip}:5558", f"tcp://{cl_ip}:5557")
manager.start(manager_callback) manager.start(manager_callback)
def gotNewFrame(bboxes, id_, isDetection, ctime): def gotNewFrame(bboxes, id_, isDetection, ctime):
#print(f"Got new frame, bboxes : {bboxes} Id: {id} Is detection {isDetection}")
m = Message() m = Message()
m.msgType = MessageType.MESSAGE_TYPE_IMAGE m.msgType = MessageType.MESSAGE_TYPE_IMAGE
m.image.timestamp = int(ctime.value) m.image.timestamp = int(ctime.value)
for bbox in bboxes: for bbox in bboxes:
# Skip if bbox is None, doesn't have exactly 4 elements, or contains None values
if bbox is None or len(bbox) != 4 or not all(element is not None for element in bbox):
continue
# Add the bounding box to the image
box = m.image.boxes.add() box = m.image.boxes.add()
box.x = bbox[0]
box.y = bbox[1]
box.w = bbox[2]
box.h = bbox[3]
box.x, box.y, box.w, box.h = bbox
m.image.camType = id_ m.image.camType = id_
if isDetection: if isDetection:
m.image.trackMode = TrackMode.TRACK_MODE_DETECT m.image.trackMode = TrackMode.TRACK_MODE_DETECT
@ -193,12 +263,13 @@ if __name__ == '__main__':
core.newFrame.connect(gotNewFrame) core.newFrame.connect(gotNewFrame)
core.coordsUpdated.connect(gotCoords) core.coordsUpdated.connect(gotCoords)
# Set up the QTimer to check client connection every 10 seconds
timer = QTimer()
timer.timeout.connect(check_client_connection)
timer.start(10000) # 10 seconds
# Start the connection thread
connection_thread = ConnectionThread()
connection_thread.start()
try: try:
app.exec_() app.exec_()
except KeyboardInterrupt: except KeyboardInterrupt:
connection_thread.stop()
connection_thread.join()
sys.exit(0) sys.exit(0)

2
config.yaml

@ -6,4 +6,4 @@ rtsp_links: [
thickness: 2 thickness: 2
debug: false
debug: False

132
core.py

@ -1,34 +1,35 @@
#import os
#os.environ['YOLO_VERBOSE'] = "false"
import datetime
import pygame
import cv2
import numpy as np
import torch
from threading import Event, Thread from threading import Event, Thread
from typing import List from typing import List
import numpy as np
from PyQt5.QtCore import QThread, pyqtSlot, pyqtSignal, QUrl, QDir, pyqtProperty from PyQt5.QtCore import QThread, pyqtSlot, pyqtSignal, QUrl, QDir, pyqtProperty
#from icecream import ic
from matplotlib import pyplot as plt
from detector import Detector from detector import Detector
from detector.utils import get_bbox_by_point from detector.utils import get_bbox_by_point
from tracker import Tracker from tracker import Tracker
from video_streamer.videostreamer import VideoStreamer from video_streamer.videostreamer import VideoStreamer
from time import sleep from time import sleep
import time import time
from PyQt5.QtCore import QObject, pyqtSignal from PyQt5.QtCore import QObject, pyqtSignal
import ctypes import ctypes
from ctypes import c_int64 from ctypes import c_int64
showTrack = False
class Core(QThread): class Core(QThread):
newFrame = pyqtSignal(object, int, bool,ctypes.c_int64)
newFrame = pyqtSignal(object, int, bool, ctypes.c_int64)
coordsUpdated = pyqtSignal(int, object, bool) coordsUpdated = pyqtSignal(int, object, bool)
def __init__(self, video_sources: List[VideoStreamer], parent=None):
def __init__(self, video_sources: List[VideoStreamer], tracker=None, detector=None, parent=None):
super(QThread, self).__init__(parent) super(QThread, self).__init__(parent)
self.__detector = Detector(classes=[0, 2, 5, 7])
self.__tracker = Tracker()
self.__detector = detector
self.__tracker = tracker
self.__video_sources = video_sources self.__video_sources = video_sources
self.__processing_source = video_sources[0] self.__processing_source = video_sources[0]
@ -44,7 +45,12 @@ class Core(QThread):
self.__tracking_thread = None self.__tracking_thread = None
self.__processing_id = 0 self.__processing_id = 0
# ic()
self.__frame = None # Frame property for Pygame
@pyqtProperty(np.ndarray)
def frame(self):
return self.__frame
def set_thickness(self, thickness: int): def set_thickness(self, thickness: int):
self.__thickness = thickness self.__thickness = thickness
@ -60,6 +66,7 @@ class Core(QThread):
def __detection(self): def __detection(self):
while self.__is_detecting: while self.__is_detecting:
try: try:
torch.cuda.empty_cache()
source = self.__processing_source source = self.__processing_source
roi = self.__detection_roi roi = self.__detection_roi
frame = source.get_frame() frame = source.get_frame()
@ -71,10 +78,8 @@ class Core(QThread):
bbox = result[1:] bbox = result[1:]
bbox[:2] += roi[:2] bbox[:2] += roi[:2]
global_bboxes.append(bbox) global_bboxes.append(bbox)
# color = (0, 0, 255) if cls == 0 else (80, 127, 255)
# self.__draw_bbox(frame, bbox, color)
self.newFrame.emit(global_bboxes, self.__processing_id, True)
self.newFrame.emit(global_bboxes, self.__processing_id, True, c_int64(int(time.time() * 1e3)))
self.__detection_bboxes = np.array(global_bboxes) self.__detection_bboxes = np.array(global_bboxes)
self.__detection_frame = frame.copy() self.__detection_frame = frame.copy()
sleep(0.03) sleep(0.03)
@ -82,21 +87,94 @@ class Core(QThread):
print(e) print(e)
sleep(0.1) sleep(0.1)
def __tracking(self): def __tracking(self):
source = self.__processing_source source = self.__processing_source
if showTrack:
pygame.init()
# Get actual screen resolution
info = pygame.display.Info()
screen_width, screen_height = info.current_w, info.current_h
screen = pygame.display.set_mode((screen_width, screen_height), pygame.FULLSCREEN)
pygame.display.set_caption('Tracking Frame')
clock = pygame.time.Clock() # Add a clock to control frame rate
while self.__is_tracking: while self.__is_tracking:
if showTrack:
for event in pygame.event.get(): # Prevent freezing by handling events
if event.type == pygame.QUIT:
pygame.quit()
return
ctime = c_int64(int(time.time() * 1000)) # Convert to c_int64 ctime = c_int64(int(time.time() * 1000)) # Convert to c_int64
frame = source.get_frame() frame = source.get_frame()
bbox, success = self.__tracker.update(frame) bbox, success = self.__tracker.update(frame)
center = None
if bbox is not None: if bbox is not None:
center = bbox[:2] + bbox[2:] // 2 center = bbox[:2] + bbox[2:] // 2
self.coordsUpdated.emit(self.__processing_id, center, success) self.coordsUpdated.emit(self.__processing_id, center, success)
self.newFrame.emit([bbox], self.__processing_id, False, ctime) self.newFrame.emit([bbox], self.__processing_id, False, ctime)
sleep(0.01)
else:
self.newFrame.emit([bbox], self.__processing_id, False, ctime)
sleep(0.05)
if showTrack:
x, y, w, h = map(int, bbox)
box_color = (0, 255, 0) if success else (255, 0, 0)
cv2.rectangle(frame, (x, y), (x + w, y + h), box_color, 2)
# Convert OpenCV frame (BGR) to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 2.25
font_color = (255, 255, 0)
thickness = 6
position = (50, 450) # Bottom-left corner of the text in the image
now = datetime.datetime.now()
time_string = now.strftime("%H:%M:%S.%f")[:-3]
# 4. Use cv2.putText() to write time on the image
cv2.putText(frame, time_string, position, font, font_scale, font_color, thickness, cv2.LINE_AA)
cv2.putText(frame, f"{ctime}", (50, 380), font, font_scale, (255,255,255), thickness, cv2.LINE_AA)
# print(ctime)
frame = cv2.flip(frame, 1) # Flip horizontally
# Resize frame while maintaining aspect ratio
frame_height, frame_width, _ = frame.shape
aspect_ratio = frame_width / frame_height
if aspect_ratio > (screen_width / screen_height): # Wider than screen
new_width = screen_width
new_height = int(screen_width / aspect_ratio)
else: # Taller than screen
new_height = screen_height
new_width = int(screen_height * aspect_ratio)
resized_frame = cv2.resize(frame, (new_width, new_height))
# Convert to Pygame surface without unnecessary rotation
frame_surface = pygame.surfarray.make_surface(resized_frame)
# Optional: If rotation is needed, use pygame.transform.rotate()
frame_surface = pygame.transform.rotate(frame_surface, -90) # Example rotation
# Center the frame
x_offset = (screen_width - new_width) // 2
y_offset = (screen_height - new_height) // 2
screen.fill((0, 0, 0)) # Clear screen
screen.blit(frame_surface, (x_offset, y_offset))
pygame.display.flip()
clock.tick(30) # Limit FPS to prevent excessive CPU usage
def start_detect(self, x: int, y: int, w: int, h: int): def start_detect(self, x: int, y: int, w: int, h: int):
self.__detection_roi = [x, y, x + w, y + h] self.__detection_roi = [x, y, x + w, y + h]
@ -110,12 +188,13 @@ class Core(QThread):
def stop_detection(self): def stop_detection(self):
self.__is_detecting = False self.__is_detecting = False
if self.__detection_thread is not None:
if self.__detection_thread is not None and self.__detection_thread.is_alive():
self.__detection_thread.join() self.__detection_thread.join()
self.__detection_thread = None self.__detection_thread = None
def start_track(self, x: int, y: int, w: int = 0, h: int = 0): def start_track(self, x: int, y: int, w: int = 0, h: int = 0):
print(f"start tracking: {x}, {y}, {w}, {h}")
try: try:
self.__is_detecting = False self.__is_detecting = False
self.__is_tracking = False self.__is_tracking = False
@ -140,21 +219,22 @@ class Core(QThread):
if self.__tracking_thread is not None: if self.__tracking_thread is not None:
self.__tracking_thread.join() self.__tracking_thread.join()
self.stop_track()
self.__is_tracking = True self.__is_tracking = True
self.__tracking_thread = Thread(target=self.__tracking) self.__tracking_thread = Thread(target=self.__tracking)
self.__tracking_thread.start() self.__tracking_thread.start()
sleep(0.03) sleep(0.03)
def stop_track(self): def stop_track(self):
if showTrack:
pygame.quit()
print("stop tracking")
self.stop_detection() self.stop_detection()
self.__tracker.stop() self.__tracker.stop()
self.__is_tracking = False self.__is_tracking = False
if self.__tracking_thread is not None:
self.__tracking_thread.join()
self.__tracking_thread = None self.__tracking_thread = None
def __draw_bbox(self, img: np.ndarray, bbox, color): def __draw_bbox(self, img: np.ndarray, bbox, color):
thickness = self.__thickness thickness = self.__thickness
# cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2] + bbox[0], bbox[3] + bbox[1]),
# color, thickness)
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2] + bbox[0], bbox[3] + bbox[1]),
color, thickness)

2
detector/demo.py

@ -7,7 +7,7 @@ from utils import get_bbox_by_point
if __name__ == '__main__': if __name__ == '__main__':
detector = Detector(classes=[0, 2, 5, 7]) detector = Detector(classes=[0, 2, 5, 7])
cap = cv2.VideoCapture(1)
cap = cv2.VideoCapture(0)
display_name = 'detector' display_name = 'detector'
cv2.namedWindow(display_name, cv2.WINDOW_NORMAL) cv2.namedWindow(display_name, cv2.WINDOW_NORMAL)
cv2.setWindowProperty(display_name, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) cv2.setWindowProperty(display_name, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)

47
gpuMonitor.py

@ -0,0 +1,47 @@
import pynvml
import time
from colorama import Fore, Style, init
import os
# Initialize colorama
init(autoreset=True)
def monitor_gpu_ram_usage(interval=2, threshold_gb=2):
pynvml.nvmlInit()
# Initialize NVML
try:
device_count = pynvml.nvmlDeviceGetCount()
print(f"Found {device_count} GPU(s).")
while True:
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
print(f"GPU {i}:")
print(f" Total RAM: {info.total / 1024 ** 2:.2f} MB")
if(info.used / 1024 ** 2 >= 2.5 * 1024 ):
print(Fore.RED + f" Used RAM: {info.used / 1024 ** 2:.2f} MB")
os.system("aplay /home/rog/repos/Tracker/NE-Smart-Tracker/Oxygen-Sys-Warning.wav")
else:
print(f" Used RAM: {info.used / 1024 ** 2:.2f} MB")
print(f" Free RAM: {info.free / 1024 ** 2:.2f} MB")
print(Fore.GREEN + "-" * 30)
print(Fore.GREEN)
time.sleep(interval) # Wait for the specified interval before checking again
except KeyboardInterrupt:
print("Monitoring stopped by user.")
finally:
# Shutdown NVML
pynvml.nvmlShutdown()
if __name__ == "__main__":
monitor_gpu_ram_usage(interval=2, threshold_gb=2) # Check every 2 seconds, threshold is 2 GB

1
tracker/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/functional.py

@ -12,6 +12,7 @@
import torch import torch
import torch.autograd as ag import torch.autograd as ag
__all__ = ['prroi_pool2d'] __all__ = ['prroi_pool2d']

2
tracker/ltr/models/backbone/resnet.py

@ -6,6 +6,7 @@ from torchvision.models.resnet import model_urls
from .base import Backbone from .base import Backbone
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
expansion = 4 expansion = 4
@ -22,6 +23,7 @@ class Bottleneck(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
def forward(self, x): def forward(self, x):
residual = x residual = x

1
tracker/ltr/models/bbreg/atom_iou_net.py

@ -2,6 +2,7 @@ import torch.nn as nn
import torch import torch
from ltr.models.layers.blocks import LinearBlock from ltr.models.layers.blocks import LinearBlock
from ltr.external.PreciseRoIPooling.pytorch.prroi_pool import PrRoIPool2D from ltr.external.PreciseRoIPooling.pytorch.prroi_pool import PrRoIPool2D
torch.cuda.empty_cache()
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):

2
tracker/ltr/models/layers/distance.py

@ -14,12 +14,14 @@ class DistanceMap(nn.Module):
super().__init__() super().__init__()
self.num_bins = num_bins self.num_bins = num_bins
self.bin_displacement = bin_displacement self.bin_displacement = bin_displacement
torch.cuda.empty_cache()
def forward(self, center, output_sz): def forward(self, center, output_sz):
"""Create the distance map. """Create the distance map.
args: args:
center: Torch tensor with (y,x) center position. Dims (batch, 2) center: Torch tensor with (y,x) center position. Dims (batch, 2)
output_sz: Size of output distance map. 2-dimensional tuple.""" output_sz: Size of output distance map. 2-dimensional tuple."""
torch.cuda.empty_cache()
center = center.view(-1,2) center = center.view(-1,2)

2
tracker/ltr/models/target_classifier/features.py

@ -4,9 +4,11 @@ from ltr.models.layers.normalization import InstanceL2Norm
def residual_bottleneck(feature_dim=256, num_blocks=1, l2norm=True, final_conv=False, norm_scale=1.0, out_dim=None, def residual_bottleneck(feature_dim=256, num_blocks=1, l2norm=True, final_conv=False, norm_scale=1.0, out_dim=None,
interp_cat=False, final_relu=False, final_pool=False, input_dim=None, final_stride=1): interp_cat=False, final_relu=False, final_pool=False, input_dim=None, final_stride=1):
"""Construct a network block based on the Bottleneck block used in ResNet.""" """Construct a network block based on the Bottleneck block used in ResNet."""
if out_dim is None: if out_dim is None:
out_dim = feature_dim out_dim = feature_dim
if input_dim is None: if input_dim is None:

6
tracker/pytracking/features/augmentation.py

@ -6,16 +6,18 @@ import cv2 as cv
import random import random
from pytracking.features.preprocessing import numpy_to_torch, torch_to_numpy from pytracking.features.preprocessing import numpy_to_torch, torch_to_numpy
class Transform: class Transform:
"""Base data augmentation transform class.""" """Base data augmentation transform class."""
def __init__(self, output_sz = None, shift = None): def __init__(self, output_sz = None, shift = None):
self.output_sz = output_sz self.output_sz = output_sz
self.shift = (0,0) if shift is None else shift self.shift = (0,0) if shift is None else shift
torch.cuda.empty_cache()
def crop_to_output(self, image): def crop_to_output(self, image):
torch.cuda.empty_cache()
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
imsz = image.shape[2:] imsz = image.shape[2:]
if self.output_sz is None: if self.output_sz is None:
@ -67,6 +69,7 @@ class Rotate(Transform):
super().__init__(output_sz, shift) super().__init__(output_sz, shift)
self.angle = math.pi * angle/180 self.angle = math.pi * angle/180
def __call__(self, image, is_mask=False): def __call__(self, image, is_mask=False):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image)))) return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image))))
@ -90,6 +93,7 @@ class Blur(Transform):
self.filter[0] = self.filter[0].view(1,1,-1,1) / self.filter[0].sum() self.filter[0] = self.filter[0].view(1,1,-1,1) / self.filter[0].sum()
self.filter[1] = self.filter[1].view(1,1,1,-1) / self.filter[1].sum() self.filter[1] = self.filter[1].view(1,1,1,-1) / self.filter[1].sum()
def __call__(self, image, is_mask=False): def __call__(self, image, is_mask=False):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
sz = image.shape[2:] sz = image.shape[2:]

7
tracker/pytracking/features/preprocessing.py

@ -4,10 +4,12 @@ import numpy as np
def numpy_to_torch(a: np.ndarray): def numpy_to_torch(a: np.ndarray):
torch.cuda.empty_cache()
return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)
def torch_to_numpy(a: torch.Tensor): def torch_to_numpy(a: torch.Tensor):
torch.cuda.empty_cache()
return a.squeeze(0).permute(1,2,0).numpy() return a.squeeze(0).permute(1,2,0).numpy()
@ -20,7 +22,7 @@ def sample_patch_transformed(im, pos, scale, image_sz, transforms, is_mask=False
image_sz: Size to resize the image samples to before extraction. image_sz: Size to resize the image samples to before extraction.
transforms: A set of image transforms to apply. transforms: A set of image transforms to apply.
""" """
torch.cuda.empty_cache()
# Get image patche # Get image patche
im_patch, _ = sample_patch(im, pos, scale*image_sz, image_sz, is_mask=is_mask) im_patch, _ = sample_patch(im, pos, scale*image_sz, image_sz, is_mask=is_mask)
@ -39,6 +41,7 @@ def sample_patch_multiscale(im, pos, scales, image_sz, mode: str='replicate', ma
mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major'
max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode
""" """
torch.cuda.empty_cache()
if isinstance(scales, (int, float)): if isinstance(scales, (int, float)):
scales = [scales] scales = [scales]
@ -62,7 +65,7 @@ def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, o
mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major'
max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode
""" """
torch.cuda.empty_cache()
# if mode not in ['replicate', 'inside']: # if mode not in ['replicate', 'inside']:
# raise ValueError('Unknown border mode \'{}\'.'.format(mode)) # raise ValueError('Unknown border mode \'{}\'.'.format(mode))

1
tracker/pytracking/libs/dcf.py

@ -3,6 +3,7 @@ import torch
def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor): def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor):
"""Computes maximum and argmax in the last two dimensions.""" """Computes maximum and argmax in the last two dimensions."""
torch.cuda.empty_cache()
max_val_row, argmax_row = torch.max(a, dim=-2) max_val_row, argmax_row = torch.max(a, dim=-2)
max_val, argmax_col = torch.max(max_val_row, dim=-1) max_val, argmax_col = torch.max(max_val_row, dim=-1)

3
tracker/pytracking/tracker/dimp/dimp.py

@ -65,6 +65,7 @@ class DiMP():
def track(self, image) -> dict: def track(self, image) -> dict:
torch.cuda.empty_cache()
# Convert image # Convert image
im = numpy_to_torch(image) im = numpy_to_torch(image)
@ -213,7 +214,7 @@ class DiMP():
# Compute augmentation size # Compute augmentation size
aug_expansion_factor = self.params.get('augmentation_expansion_factor', None) aug_expansion_factor = self.params.get('augmentation_expansion_factor', None)
ic(self.params.get('augmentation_expansion_factor', None))
aug_expansion_sz = (self.img_sample_sz * aug_expansion_factor).long() aug_expansion_sz = (self.img_sample_sz * aug_expansion_factor).long()

2
tracker/pytracking/utils/params.py

@ -5,7 +5,7 @@ class TrackerParams:
"""Class for tracker parameters.""" """Class for tracker parameters."""
image_sample_size = 18 * 16 image_sample_size = 18 * 16
search_area_scale = 5
search_area_scale = 7
# Learning parameters # Learning parameters
sample_memory_size = 50 sample_memory_size = 50

BIN
video_streamer/vision_service.cpython-37m-x86_64-linux-gnu.so

Loading…
Cancel
Save