You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

197 lines
9.2 KiB

#pragma once
#include <torch/torch.h>
#include <string>
#include <vector>
#include <map>
#include <optional>
// Forward declare model classes if headers are not included yet to avoid circular dependencies
// Or include them if they are fundamental. For now, let's assume they will be included.
#include "resnet/resnet.h"
#include "classifier/classifier.h"
#include "bb_regressor/bb_regressor.h"
namespace cimp {
struct DiMPTrackerParams {
// --- Device ---
// torch::Device device = torch::kCUDA; // Will be set by DiMPTracker constructor
// --- Input / Preprocessing ---
torch::IntArrayRef image_sample_size = {288, 288}; // Target size of the cropped image sample
std::string border_mode = "replicate"; // Border mode for patch extraction
double patch_max_scale_change = 1.5; // Max scale change for multiscale sampling
// --- Target Model ---
double search_area_scale = 5.0; // Scale factor for the search area relative to the target size
double target_inside_ratio = 0.2; // Ratio for keeping target inside image boundaries
// --- Classifier ---
// Augmentation parameters (can be a sub-struct if complex)
struct AugmentationParams {
double augmentation_expansion_factor = 2.0;
double random_shift_factor = 0.0; // Typically 0 for DiMP, but can be non-zero
std::vector<double> relativeshift = {0.0, 0.0}; // Example, usually more shifts
std::vector<double> blur = {}; // Sigmas for Gaussian blur
std::vector<double> rotate = {}; // Angles for rotation
struct DropoutAug {
int num = 0; // Number of dropout samples
float prob = 0.0f; // Dropout probability
} dropout;
} augmentation;
bool use_augmentation = true;
int sample_memory_size = 50; // For classifier's target_boxes memory
int net_opt_iter = 10; // Optimizer iterations for filter learning
// --- IoU Net (BB Regressor) ---
bool use_iou_net = true;
double box_jitter_pos = 0.1; // Jitter for proposal generation (relative to square_box_sz)
double box_jitter_sz = 0.1; // Jitter for proposal generation
int box_refinement_iter = 5; // Iterations for box optimization
double box_refinement_step_length = 1.0;
double box_refinement_step_decay = 1.0;
double maximal_aspect_ratio = 5.0;
int iounet_k = 5; // Number of top proposals to average for final box
// --- Localization ---
double target_not_found_threshold = 0.25; // Threshold to consider target lost
double target_neighborhood_scale = 2.2; // Scale for masking neighborhood around max score
bool update_scale_when_uncertain = true;
// TODO: Add other parameters from DiMP Python code as needed
// e.g. feature_stride, kernel_size (these might be derived from network)
};
class DiMPTracker {
public:
DiMPTracker(const DiMPTrackerParams& params,
const std::string& resnet_weights_dir,
const std::string& classifier_weights_dir,
const std::string& bbregressor_weights_dir,
torch::Device device);
// Initialize the tracker with the first frame and bounding box
// image: HWC, uint8 tensor or cv::Mat (needs conversion)
// initial_bbox_xywh: [x, y, w, h] tensor for the target in the first frame
void initialize(const torch::Tensor& image_tensor_hwc_uchar, const torch::Tensor& initial_bbox_xywh);
// Track the target in subsequent frames
// image: HWC, uint8 tensor or cv::Mat
// Returns: [x, y, w, h] tensor for the predicted bounding box
torch::Tensor track(const torch::Tensor& image_tensor_hwc_uchar);
private:
// --- Core Models ---
cimp::resnet::ResNet resnet_model_;
Classifier classifier_model_; // Classifier is in global namespace
BBRegressor bbregressor_model_; // BBRegressor is in global namespace
// --- Parameters & Device ---
DiMPTrackerParams params_;
torch::Device device_;
// --- Tracker State ---
torch::Tensor pos_; // Target position (y_center, x_center) in image coordinates
torch::Tensor target_sz_; // Target size (height, width) in image coordinates
torch::Tensor image_sz_; // Current image size (height, width)
double target_scale_; // Current scale factor of the target
torch::Tensor base_target_sz_; // Target size at scale 1.0
torch::Tensor img_sample_sz_; // Size of the image sample patch (e.g., {288, 288})
torch::Tensor img_support_sz_; // Usually same as img_sample_sz_
torch::Tensor init_sample_pos_; // Position used for generating initial samples
double init_sample_scale_; // Scale used for generating initial samples
// Learned components
torch::Tensor target_filter_; // Learned DiMP classification filter: [num_filters, C, H, W]
std::vector<torch::Tensor> iou_modulation_; // Learned IoU modulation vectors: list of [1, C, 1, 1]
// Feature/Kernel sizes (often derived during initialization)
torch::Tensor feature_sz_; // Size of the classification feature map (e.g., {18, 18})
torch::Tensor kernel_size_; // Size of the classification filter (e.g., {4, 4})
// torch::Tensor output_sz_; // output_sz = feature_sz + (kernel_size + 1)%2
// Augmentation transforms (might be more complex in C++)
// For now, logic will be in generate_init_samples
// std::vector<std::function<torch::Tensor(torch::Tensor)>> transforms_;
// Stored target boxes for classifier training
torch::Tensor stored_target_boxes_; // [memory_size, 4]
// --- Helper Methods (to be implemented in .cpp) ---
torch::Tensor convert_image_to_tensor_chw_float(const torch::Tensor& image_hwc_uchar);
std::pair<std::vector<torch::Tensor>, torch::Tensor> generate_init_samples(const torch::Tensor& image_chw_float);
void init_classifier_internal(const std::vector<torch::Tensor>& init_backbone_feat_list, const torch::Tensor& init_target_boxes_aug);
void init_iou_net_internal(const std::vector<torch::Tensor>& init_backbone_feat_list, const torch::Tensor& initial_bbox_for_iou);
std::pair<std::map<std::string, torch::Tensor>, torch::Tensor> extract_backbone_features(
const torch::Tensor& image_chw_float,
const torch::Tensor& pos,
const torch::Tensor& scales, // vector of scales
const torch::IntArrayRef& sample_sz);
torch::Tensor get_classification_features(const std::map<std::string, torch::Tensor>& backbone_feat);
std::vector<torch::Tensor> get_iou_backbone_features(const std::map<std::string, torch::Tensor>& backbone_feat);
std::vector<torch::Tensor> get_iou_features(const std::map<std::string, torch::Tensor>& backbone_feat);
std::pair<torch::Tensor, torch::Tensor> get_sample_location(const torch::Tensor& sample_coords_xyxy);
torch::Tensor get_centered_sample_pos();
torch::Tensor classify_target(const torch::Tensor& test_x_clf_feat);
struct LocalizationResult {
torch::Tensor translation_vec_yx; // y, x displacement
int64_t scale_idx;
torch::Tensor scores_peak_map; // The score map from the peak scale
std::string flag; // "normal", "not_found", "uncertain"
};
LocalizationResult localize_target(const torch::Tensor& scores_raw,
const torch::Tensor& sample_pos_yx,
const torch::Tensor& sample_scales);
LocalizationResult localize_advanced(const torch::Tensor& scores_scaled,
const torch::Tensor& sample_pos_yx,
const torch::Tensor& sample_scales);
void update_state(const torch::Tensor& new_pos_yx);
torch::Tensor get_iounet_box(const torch::Tensor& pos_yx, const torch::Tensor& sz_hw,
const torch::Tensor& sample_pos_yx, double sample_scale);
void refine_target_box(const std::map<std::string, torch::Tensor>& backbone_feat,
const torch::Tensor& sample_pos_yx,
double sample_scale,
int64_t scale_idx,
bool update_scale_flag);
std::pair<torch::Tensor, torch::Tensor> optimize_boxes_default(
const std::vector<torch::Tensor>& iou_features,
const torch::Tensor& init_boxes_xywh); // proposals_xywh
// Image processing / patch sampling helpers
std::pair<torch::Tensor, torch::Tensor> sample_patch_multiscale_affine(
const torch::Tensor& im_chw_float,
const torch::Tensor& pos_yx,
const torch::Tensor& scales, // 1D tensor of scales
const torch::IntArrayRef& output_sz_hw,
const std::string& border_mode = "replicate",
std::optional<double> max_scale_change = std::nullopt);
std::pair<torch::Tensor, torch::Tensor> sample_patch_transformed_affine(
const torch::Tensor& im_chw_float,
const torch::Tensor& pos_yx,
double scale,
const torch::IntArrayRef&aug_expansion_sz_hw, // Size of patch to extract before transform
const std::vector<torch::Tensor>& affine_matrices, // One 2x3 affine matrix per transform
const torch::IntArrayRef& out_sz_hw // Final output size after transform
);
// Augmentation helpers
// ...
};
} // namespace cimp