You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

26 lines
714 B

from typing import List
import numpy as np
# from icecream import ic
from ultralytics import YOLO
class Detector:
def __init__(self, model_name: str = 'yolov8n.pt', classes: List = None):
self.__model = YOLO(model_name)
self.__classes = classes
def predict(self, img: np.ndarray):
if self.__classes:
results = self.__model.predict(img, classes=self.__classes)[0]
else:
results = self.__model.predict(img)[0]
cls = results.boxes.cls.cpu().numpy()
xywh = np.array(results.boxes.xywh.cpu().numpy())
xywh[:, :2] -= xywh[:, 2:]/2
results = np.hstack((cls[..., np.newaxis], xywh)).astype(int)
return results