You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
197 lines
9.2 KiB
197 lines
9.2 KiB
#pragma once
|
|
|
|
#include <torch/torch.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <map>
|
|
#include <optional>
|
|
|
|
// Forward declare model classes if headers are not included yet to avoid circular dependencies
|
|
// Or include them if they are fundamental. For now, let's assume they will be included.
|
|
#include "resnet/resnet.h"
|
|
#include "classifier/classifier.h"
|
|
#include "bb_regressor/bb_regressor.h"
|
|
|
|
namespace cimp {
|
|
|
|
struct DiMPTrackerParams {
|
|
// --- Device ---
|
|
// torch::Device device = torch::kCUDA; // Will be set by DiMPTracker constructor
|
|
|
|
// --- Input / Preprocessing ---
|
|
torch::IntArrayRef image_sample_size = {288, 288}; // Target size of the cropped image sample
|
|
std::string border_mode = "replicate"; // Border mode for patch extraction
|
|
double patch_max_scale_change = 1.5; // Max scale change for multiscale sampling
|
|
|
|
// --- Target Model ---
|
|
double search_area_scale = 5.0; // Scale factor for the search area relative to the target size
|
|
double target_inside_ratio = 0.2; // Ratio for keeping target inside image boundaries
|
|
|
|
// --- Classifier ---
|
|
// Augmentation parameters (can be a sub-struct if complex)
|
|
struct AugmentationParams {
|
|
double augmentation_expansion_factor = 2.0;
|
|
double random_shift_factor = 0.0; // Typically 0 for DiMP, but can be non-zero
|
|
std::vector<double> relativeshift = {0.0, 0.0}; // Example, usually more shifts
|
|
std::vector<double> blur = {}; // Sigmas for Gaussian blur
|
|
std::vector<double> rotate = {}; // Angles for rotation
|
|
struct DropoutAug {
|
|
int num = 0; // Number of dropout samples
|
|
float prob = 0.0f; // Dropout probability
|
|
} dropout;
|
|
} augmentation;
|
|
|
|
bool use_augmentation = true;
|
|
int sample_memory_size = 50; // For classifier's target_boxes memory
|
|
int net_opt_iter = 10; // Optimizer iterations for filter learning
|
|
|
|
// --- IoU Net (BB Regressor) ---
|
|
bool use_iou_net = true;
|
|
double box_jitter_pos = 0.1; // Jitter for proposal generation (relative to square_box_sz)
|
|
double box_jitter_sz = 0.1; // Jitter for proposal generation
|
|
int box_refinement_iter = 5; // Iterations for box optimization
|
|
double box_refinement_step_length = 1.0;
|
|
double box_refinement_step_decay = 1.0;
|
|
double maximal_aspect_ratio = 5.0;
|
|
int iounet_k = 5; // Number of top proposals to average for final box
|
|
|
|
// --- Localization ---
|
|
double target_not_found_threshold = 0.25; // Threshold to consider target lost
|
|
double target_neighborhood_scale = 2.2; // Scale for masking neighborhood around max score
|
|
bool update_scale_when_uncertain = true;
|
|
|
|
// TODO: Add other parameters from DiMP Python code as needed
|
|
// e.g. feature_stride, kernel_size (these might be derived from network)
|
|
};
|
|
|
|
class DiMPTracker {
|
|
public:
|
|
DiMPTracker(const DiMPTrackerParams& params,
|
|
const std::string& resnet_weights_dir,
|
|
const std::string& classifier_weights_dir,
|
|
const std::string& bbregressor_weights_dir,
|
|
torch::Device device);
|
|
|
|
// Initialize the tracker with the first frame and bounding box
|
|
// image: HWC, uint8 tensor or cv::Mat (needs conversion)
|
|
// initial_bbox_xywh: [x, y, w, h] tensor for the target in the first frame
|
|
void initialize(const torch::Tensor& image_tensor_hwc_uchar, const torch::Tensor& initial_bbox_xywh);
|
|
|
|
// Track the target in subsequent frames
|
|
// image: HWC, uint8 tensor or cv::Mat
|
|
// Returns: [x, y, w, h] tensor for the predicted bounding box
|
|
torch::Tensor track(const torch::Tensor& image_tensor_hwc_uchar);
|
|
|
|
private:
|
|
// --- Core Models ---
|
|
cimp::resnet::ResNet resnet_model_;
|
|
Classifier classifier_model_; // Classifier is in global namespace
|
|
BBRegressor bbregressor_model_; // BBRegressor is in global namespace
|
|
|
|
// --- Parameters & Device ---
|
|
DiMPTrackerParams params_;
|
|
torch::Device device_;
|
|
|
|
// --- Tracker State ---
|
|
torch::Tensor pos_; // Target position (y_center, x_center) in image coordinates
|
|
torch::Tensor target_sz_; // Target size (height, width) in image coordinates
|
|
torch::Tensor image_sz_; // Current image size (height, width)
|
|
double target_scale_; // Current scale factor of the target
|
|
torch::Tensor base_target_sz_; // Target size at scale 1.0
|
|
torch::Tensor img_sample_sz_; // Size of the image sample patch (e.g., {288, 288})
|
|
torch::Tensor img_support_sz_; // Usually same as img_sample_sz_
|
|
|
|
torch::Tensor init_sample_pos_; // Position used for generating initial samples
|
|
double init_sample_scale_; // Scale used for generating initial samples
|
|
|
|
// Learned components
|
|
torch::Tensor target_filter_; // Learned DiMP classification filter: [num_filters, C, H, W]
|
|
std::vector<torch::Tensor> iou_modulation_; // Learned IoU modulation vectors: list of [1, C, 1, 1]
|
|
|
|
// Feature/Kernel sizes (often derived during initialization)
|
|
torch::Tensor feature_sz_; // Size of the classification feature map (e.g., {18, 18})
|
|
torch::Tensor kernel_size_; // Size of the classification filter (e.g., {4, 4})
|
|
// torch::Tensor output_sz_; // output_sz = feature_sz + (kernel_size + 1)%2
|
|
|
|
// Augmentation transforms (might be more complex in C++)
|
|
// For now, logic will be in generate_init_samples
|
|
// std::vector<std::function<torch::Tensor(torch::Tensor)>> transforms_;
|
|
|
|
// Stored target boxes for classifier training
|
|
torch::Tensor stored_target_boxes_; // [memory_size, 4]
|
|
|
|
// --- Helper Methods (to be implemented in .cpp) ---
|
|
torch::Tensor convert_image_to_tensor_chw_float(const torch::Tensor& image_hwc_uchar);
|
|
|
|
std::pair<std::vector<torch::Tensor>, torch::Tensor> generate_init_samples(const torch::Tensor& image_chw_float);
|
|
|
|
void init_classifier_internal(const std::vector<torch::Tensor>& init_backbone_feat_list, const torch::Tensor& init_target_boxes_aug);
|
|
void init_iou_net_internal(const std::vector<torch::Tensor>& init_backbone_feat_list, const torch::Tensor& initial_bbox_for_iou);
|
|
|
|
std::pair<std::map<std::string, torch::Tensor>, torch::Tensor> extract_backbone_features(
|
|
const torch::Tensor& image_chw_float,
|
|
const torch::Tensor& pos,
|
|
const torch::Tensor& scales, // vector of scales
|
|
const torch::IntArrayRef& sample_sz);
|
|
|
|
torch::Tensor get_classification_features(const std::map<std::string, torch::Tensor>& backbone_feat);
|
|
std::vector<torch::Tensor> get_iou_backbone_features(const std::map<std::string, torch::Tensor>& backbone_feat);
|
|
std::vector<torch::Tensor> get_iou_features(const std::map<std::string, torch::Tensor>& backbone_feat);
|
|
|
|
|
|
std::pair<torch::Tensor, torch::Tensor> get_sample_location(const torch::Tensor& sample_coords_xyxy);
|
|
torch::Tensor get_centered_sample_pos();
|
|
|
|
torch::Tensor classify_target(const torch::Tensor& test_x_clf_feat);
|
|
|
|
struct LocalizationResult {
|
|
torch::Tensor translation_vec_yx; // y, x displacement
|
|
int64_t scale_idx;
|
|
torch::Tensor scores_peak_map; // The score map from the peak scale
|
|
std::string flag; // "normal", "not_found", "uncertain"
|
|
};
|
|
LocalizationResult localize_target(const torch::Tensor& scores_raw,
|
|
const torch::Tensor& sample_pos_yx,
|
|
const torch::Tensor& sample_scales);
|
|
LocalizationResult localize_advanced(const torch::Tensor& scores_scaled,
|
|
const torch::Tensor& sample_pos_yx,
|
|
const torch::Tensor& sample_scales);
|
|
|
|
|
|
void update_state(const torch::Tensor& new_pos_yx);
|
|
torch::Tensor get_iounet_box(const torch::Tensor& pos_yx, const torch::Tensor& sz_hw,
|
|
const torch::Tensor& sample_pos_yx, double sample_scale);
|
|
|
|
void refine_target_box(const std::map<std::string, torch::Tensor>& backbone_feat,
|
|
const torch::Tensor& sample_pos_yx,
|
|
double sample_scale,
|
|
int64_t scale_idx,
|
|
bool update_scale_flag);
|
|
|
|
std::pair<torch::Tensor, torch::Tensor> optimize_boxes_default(
|
|
const std::vector<torch::Tensor>& iou_features,
|
|
const torch::Tensor& init_boxes_xywh); // proposals_xywh
|
|
|
|
// Image processing / patch sampling helpers
|
|
std::pair<torch::Tensor, torch::Tensor> sample_patch_multiscale_affine(
|
|
const torch::Tensor& im_chw_float,
|
|
const torch::Tensor& pos_yx,
|
|
const torch::Tensor& scales, // 1D tensor of scales
|
|
const torch::IntArrayRef& output_sz_hw,
|
|
const std::string& border_mode = "replicate",
|
|
std::optional<double> max_scale_change = std::nullopt);
|
|
|
|
std::pair<torch::Tensor, torch::Tensor> sample_patch_transformed_affine(
|
|
const torch::Tensor& im_chw_float,
|
|
const torch::Tensor& pos_yx,
|
|
double scale,
|
|
const torch::IntArrayRef&aug_expansion_sz_hw, // Size of patch to extract before transform
|
|
const std::vector<torch::Tensor>& affine_matrices, // One 2x3 affine matrix per transform
|
|
const torch::IntArrayRef& out_sz_hw // Final output size after transform
|
|
);
|
|
|
|
// Augmentation helpers
|
|
// ...
|
|
};
|
|
|
|
} // namespace cimp
|