#include "dimp_tracker.h" #include // For debugging output namespace cimp { DiMPTracker::DiMPTracker(const DiMPTrackerParams& params, const std::string& resnet_weights_dir, const std::string& classifier_weights_dir, const std::string& bbregressor_weights_dir, torch::Device device) : params_(params), device_(device), resnet_model_(cimp::resnet::resnet50(resnet_weights_dir, {"layer2", "layer3", "layer4"}, device)), classifier_model_(classifier_weights_dir, device), bbregressor_model_(bbregressor_weights_dir, device) { // Move models to the specified device (should be handled by individual model constructors or their to() methods if they have one) // However, explicit to() calls ensure they are on the correct device if constructors don't guarantee it. this->resnet_model_->to(this->device_); // this->classifier_model_.to(this->device_); // Classifier constructor handles device // this->bbregressor_model_.to(this->device_); // BBRegressor constructor handles device // Set models to evaluation mode this->resnet_model_->eval(); // this->classifier_model_.eval(); // Classifier does not have an explicit eval() method this->bbregressor_model_.eval(); // Initialize some state variables from params this->img_sample_sz_ = torch::tensor(params_.image_sample_size.vec(), torch::kInt64).to(torch::kFloat32); // Convert to float tensor for calculations this->img_support_sz_ = this->img_sample_sz_.clone(); std::cout << "DiMPTracker initialized." << std::endl; std::cout << " Device: " << (this->device_.is_cuda() ? "CUDA" : "CPU") << std::endl; if (this->device_.is_cuda()) { std::cout << " CUDA Device Index: " << this->device_.index() << std::endl; } std::cout << " ResNet, Classifier, and BBRegressor models constructed and set to eval mode." << std::endl; } // --- Placeholder for initialize() --- void DiMPTracker::initialize(const torch::Tensor& image_tensor_hwc_uchar, const torch::Tensor& initial_bbox_xywh) { std::cout << "DiMPTracker::initialize() called (placeholder)." << std::endl; // TODO: Implement full initialization logic // 1. Convert image_tensor_hwc_uchar to CHW float tensor, normalize // 2. Set initial pos_, target_sz_, image_sz_, target_scale_, base_target_sz_ // 3. Call generate_init_samples // 4. Call init_classifier_internal // 5. Call init_iou_net_internal // Example: Convert image (assuming HWC uchar input) auto image_chw_float = convert_image_to_tensor_chw_float(image_tensor_hwc_uchar); this->image_sz_ = torch::tensor({image_chw_float.size(1), image_chw_float.size(2)}, torch::kFloat32).to(device_); // H, W // Example: Set initial state (ensure tensors are on device_) this->pos_ = torch::tensor({initial_bbox_xywh[1].item() + (initial_bbox_xywh[3].item() - 1.0f) / 2.0f, initial_bbox_xywh[0].item() + (initial_bbox_xywh[2].item() - 1.0f) / 2.0f}, torch::kFloat32).to(device_); // y_center, x_center this->target_sz_ = torch::tensor({initial_bbox_xywh[3].item(), initial_bbox_xywh[2].item()}, torch::kFloat32).to(device_); // height, width double search_area = torch::prod(this->target_sz_ * params_.search_area_scale).item(); this->target_scale_ = std::sqrt(search_area) / torch::prod(this->img_sample_sz_).sqrt().item(); this->base_target_sz_ = this->target_sz_ / this->target_scale_; this->init_sample_pos_ = this->pos_.round(); this->init_sample_scale_ = this->target_scale_; // TODO: Call generate_init_samples, init_classifier_internal, init_iou_net_internal } // --- Placeholder for track() --- torch::Tensor DiMPTracker::track(const torch::Tensor& image_tensor_hwc_uchar) { std::cout << "DiMPTracker::track() called (placeholder)." << std::endl; // TODO: Implement full tracking logic // Return a dummy bounding box for now [x,y,w,h] return torch::tensor({0.0, 0.0, 0.0, 0.0}, torch::kFloat32); } // --- Helper Method Implementations (Placeholders or Basic Forms) --- torch::Tensor DiMPTracker::convert_image_to_tensor_chw_float(const torch::Tensor& image_hwc_uchar) { // Assuming image_hwc_uchar is HWC uint8 on CPU or CUDA auto img_float = image_hwc_uchar.to(torch::kFloat32); img_float = img_float.permute({2, 0, 1}); // HWC to CHW // Normalize: (img / 255.0 - mean) / std // These are standard ImageNet mean/std torch::Tensor mean = torch::tensor({0.485, 0.456, 0.406}, device_).reshape({3, 1, 1}); torch::Tensor std_dev = torch::tensor({0.229, 0.224, 0.225}, device_).reshape({3, 1, 1}); img_float = img_float.div(255.0); img_float = img_float.sub_(mean).div_(std_dev); return img_float.contiguous(); } // ... Other private method placeholders would go here ... } // namespace cimp