from pytracking.features.net_wrappers import NetWithBackbone


class TrackerParams:
    """Class for tracker parameters."""

    image_sample_size = 18 * 16
    search_area_scale = 5

    # Learning parameters
    sample_memory_size = 50
    learning_rate = 0.01
    train_skipping = 20

    # Net optimization params
    net_opt_iter = 10


    # Init augmentation parameters
    use_augmentation = True
    augmentation = {'fliplr': True,
                           'rotate': [10, -10, 45, -45],
                           'blur': [(3, 1), (1, 3), (2, 2)],
                           'relativeshift': [(0.6, 0.6), (-0.6, 0.6), (0.6, -0.6), (-0.6, -0.6)],
                           'dropout': (2, 0.2)}

    augmentation_expansion_factor = 2
    random_shift_factor = 1 / 3

    # Advanced localization parameters
    target_not_found_threshold = 0.25
    target_neighborhood_scale = 2.2
    update_scale_when_uncertain = True

    # IoUnet parameters
    iounet_k = 3
    box_jitter_pos = 0.1
    box_jitter_sz = 0.5
    maximal_aspect_ratio = 6
    box_refinement_iter = 5
    box_refinement_step_length = 1
    box_refinement_step_decay = 1

    net = NetWithBackbone(net_path='dimp50.pth')

    vot_anno_conversion_type = 'preserve_area'

    def get(self, name: str, *default):
        """Get a parameter value with the given name. If it does not exists, it return the default value given as a
        second argument or returns an error if no default value is given."""
        if len(default) > 1:
            raise ValueError('Can only give one default value.')

        if not default:
            return getattr(self, name)

        return getattr(self, name, default[0])

    def has(self, name: str):
        """Check if there exist a parameter with the given name."""
        return hasattr(self, name)