You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
5.0 KiB
97 lines
5.0 KiB
#include "dimp_tracker.h"
|
|
#include <iostream> // For debugging output
|
|
|
|
namespace cimp {
|
|
|
|
DiMPTracker::DiMPTracker(const DiMPTrackerParams& params,
|
|
const std::string& resnet_weights_dir,
|
|
const std::string& classifier_weights_dir,
|
|
const std::string& bbregressor_weights_dir,
|
|
torch::Device device)
|
|
: params_(params),
|
|
device_(device),
|
|
resnet_model_(cimp::resnet::resnet50(resnet_weights_dir, {"layer2", "layer3", "layer4"}, device)),
|
|
classifier_model_(classifier_weights_dir, device),
|
|
bbregressor_model_(bbregressor_weights_dir, device)
|
|
{
|
|
// Move models to the specified device (should be handled by individual model constructors or their to() methods if they have one)
|
|
// However, explicit to() calls ensure they are on the correct device if constructors don't guarantee it.
|
|
this->resnet_model_->to(this->device_);
|
|
// this->classifier_model_.to(this->device_); // Classifier constructor handles device
|
|
// this->bbregressor_model_.to(this->device_); // BBRegressor constructor handles device
|
|
|
|
// Set models to evaluation mode
|
|
this->resnet_model_->eval();
|
|
// this->classifier_model_.eval(); // Classifier does not have an explicit eval() method
|
|
this->bbregressor_model_.eval();
|
|
|
|
// Initialize some state variables from params
|
|
this->img_sample_sz_ = torch::tensor(params_.image_sample_size.vec(), torch::kInt64).to(torch::kFloat32); // Convert to float tensor for calculations
|
|
this->img_support_sz_ = this->img_sample_sz_.clone();
|
|
|
|
std::cout << "DiMPTracker initialized." << std::endl;
|
|
std::cout << " Device: " << (this->device_.is_cuda() ? "CUDA" : "CPU") << std::endl;
|
|
if (this->device_.is_cuda()) {
|
|
std::cout << " CUDA Device Index: " << this->device_.index() << std::endl;
|
|
}
|
|
std::cout << " ResNet, Classifier, and BBRegressor models constructed and set to eval mode." << std::endl;
|
|
}
|
|
|
|
// --- Placeholder for initialize() ---
|
|
void DiMPTracker::initialize(const torch::Tensor& image_tensor_hwc_uchar, const torch::Tensor& initial_bbox_xywh) {
|
|
std::cout << "DiMPTracker::initialize() called (placeholder)." << std::endl;
|
|
// TODO: Implement full initialization logic
|
|
// 1. Convert image_tensor_hwc_uchar to CHW float tensor, normalize
|
|
// 2. Set initial pos_, target_sz_, image_sz_, target_scale_, base_target_sz_
|
|
// 3. Call generate_init_samples
|
|
// 4. Call init_classifier_internal
|
|
// 5. Call init_iou_net_internal
|
|
|
|
// Example: Convert image (assuming HWC uchar input)
|
|
auto image_chw_float = convert_image_to_tensor_chw_float(image_tensor_hwc_uchar);
|
|
this->image_sz_ = torch::tensor({image_chw_float.size(1), image_chw_float.size(2)}, torch::kFloat32).to(device_); // H, W
|
|
|
|
// Example: Set initial state (ensure tensors are on device_)
|
|
this->pos_ = torch::tensor({initial_bbox_xywh[1].item<float>() + (initial_bbox_xywh[3].item<float>() - 1.0f) / 2.0f,
|
|
initial_bbox_xywh[0].item<float>() + (initial_bbox_xywh[2].item<float>() - 1.0f) / 2.0f},
|
|
torch::kFloat32).to(device_); // y_center, x_center
|
|
this->target_sz_ = torch::tensor({initial_bbox_xywh[3].item<float>(), initial_bbox_xywh[2].item<float>()},
|
|
torch::kFloat32).to(device_); // height, width
|
|
|
|
double search_area = torch::prod(this->target_sz_ * params_.search_area_scale).item<double>();
|
|
this->target_scale_ = std::sqrt(search_area) / torch::prod(this->img_sample_sz_).sqrt().item<double>();
|
|
this->base_target_sz_ = this->target_sz_ / this->target_scale_;
|
|
|
|
this->init_sample_pos_ = this->pos_.round();
|
|
this->init_sample_scale_ = this->target_scale_;
|
|
|
|
// TODO: Call generate_init_samples, init_classifier_internal, init_iou_net_internal
|
|
}
|
|
|
|
// --- Placeholder for track() ---
|
|
torch::Tensor DiMPTracker::track(const torch::Tensor& image_tensor_hwc_uchar) {
|
|
std::cout << "DiMPTracker::track() called (placeholder)." << std::endl;
|
|
// TODO: Implement full tracking logic
|
|
// Return a dummy bounding box for now [x,y,w,h]
|
|
return torch::tensor({0.0, 0.0, 0.0, 0.0}, torch::kFloat32);
|
|
}
|
|
|
|
// --- Helper Method Implementations (Placeholders or Basic Forms) ---
|
|
torch::Tensor DiMPTracker::convert_image_to_tensor_chw_float(const torch::Tensor& image_hwc_uchar) {
|
|
// Assuming image_hwc_uchar is HWC uint8 on CPU or CUDA
|
|
auto img_float = image_hwc_uchar.to(torch::kFloat32);
|
|
img_float = img_float.permute({2, 0, 1}); // HWC to CHW
|
|
|
|
// Normalize: (img / 255.0 - mean) / std
|
|
// These are standard ImageNet mean/std
|
|
torch::Tensor mean = torch::tensor({0.485, 0.456, 0.406}, device_).reshape({3, 1, 1});
|
|
torch::Tensor std_dev = torch::tensor({0.229, 0.224, 0.225}, device_).reshape({3, 1, 1});
|
|
|
|
img_float = img_float.div(255.0);
|
|
img_float = img_float.sub_(mean).div_(std_dev);
|
|
return img_float.contiguous();
|
|
}
|
|
|
|
// ... Other private method placeholders would go here ...
|
|
|
|
} // namespace cimp
|