You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
26 lines
714 B
26 lines
714 B
from typing import List
|
|
|
|
import numpy as np
|
|
# from icecream import ic
|
|
from ultralytics import YOLO
|
|
|
|
|
|
class Detector:
|
|
def __init__(self, model_name: str = 'yolov8n.pt', classes: List = None):
|
|
self.__model = YOLO(model_name)
|
|
self.__classes = classes
|
|
|
|
def predict(self, img: np.ndarray):
|
|
if self.__classes:
|
|
results = self.__model.predict(img, classes=self.__classes)[0]
|
|
else:
|
|
results = self.__model.predict(img)[0]
|
|
|
|
cls = results.boxes.cls.cpu().numpy()
|
|
xywh = np.array(results.boxes.xywh.cpu().numpy())
|
|
xywh[:, :2] -= xywh[:, 2:]/2
|
|
|
|
results = np.hstack((cls[..., np.newaxis], xywh)).astype(int)
|
|
|
|
return results
|
|
|