|
@ -123,78 +123,21 @@ torch::Tensor BottleneckImpl::forward(torch::Tensor x) { |
|
|
// conv1 -> bn1 -> relu
|
|
|
// conv1 -> bn1 -> relu
|
|
|
x = conv1->forward(x); |
|
|
x = conv1->forward(x); |
|
|
|
|
|
|
|
|
if (!this->is_training() && bn1) { |
|
|
|
|
|
const auto& bn_module = *bn1; |
|
|
|
|
|
torch::Tensor input_double = x.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); |
|
|
|
|
|
double eps_double = bn_module.options.eps(); |
|
|
|
|
|
|
|
|
|
|
|
auto c = x.size(1); |
|
|
|
|
|
running_mean_double = running_mean_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
running_var_double = running_var_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double)); |
|
|
|
|
|
if (weight_double.defined()) out_double = out_double * weight_double; |
|
|
|
|
|
if (bias_double.defined()) out_double = out_double + bias_double; |
|
|
|
|
|
x = out_double.to(original_dtype); |
|
|
|
|
|
} else if (bn1) { |
|
|
|
|
|
|
|
|
if (bn1) { |
|
|
x = bn1->forward(x); |
|
|
x = bn1->forward(x); |
|
|
} |
|
|
} |
|
|
x = relu->forward(x); |
|
|
x = relu->forward(x); |
|
|
|
|
|
|
|
|
// conv2 -> bn2 -> relu
|
|
|
// conv2 -> bn2 -> relu
|
|
|
x = conv2->forward(x); |
|
|
x = conv2->forward(x); |
|
|
if (!this->is_training() && bn2) { |
|
|
|
|
|
const auto& bn_module = *bn2; |
|
|
|
|
|
torch::Tensor input_double = x.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); |
|
|
|
|
|
double eps_double = bn_module.options.eps(); |
|
|
|
|
|
|
|
|
|
|
|
auto c = x.size(1); |
|
|
|
|
|
running_mean_double = running_mean_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
running_var_double = running_var_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double)); |
|
|
|
|
|
if (weight_double.defined()) out_double = out_double * weight_double; |
|
|
|
|
|
if (bias_double.defined()) out_double = out_double + bias_double; |
|
|
|
|
|
x = out_double.to(original_dtype); |
|
|
|
|
|
} else if (bn2) { |
|
|
|
|
|
|
|
|
if (bn2) { |
|
|
x = bn2->forward(x); |
|
|
x = bn2->forward(x); |
|
|
} |
|
|
} |
|
|
x = relu->forward(x); |
|
|
x = relu->forward(x); |
|
|
|
|
|
|
|
|
// conv3 -> bn3
|
|
|
// conv3 -> bn3
|
|
|
x = conv3->forward(x); |
|
|
x = conv3->forward(x); |
|
|
if (!this->is_training() && bn3) { |
|
|
|
|
|
const auto& bn_module = *bn3; |
|
|
|
|
|
torch::Tensor input_double = x.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); |
|
|
|
|
|
double eps_double = bn_module.options.eps(); |
|
|
|
|
|
|
|
|
|
|
|
auto c = x.size(1); |
|
|
|
|
|
running_mean_double = running_mean_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
running_var_double = running_var_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double)); |
|
|
|
|
|
if (weight_double.defined()) out_double = out_double * weight_double; |
|
|
|
|
|
if (bias_double.defined()) out_double = out_double + bias_double; |
|
|
|
|
|
x = out_double.to(original_dtype); |
|
|
|
|
|
} else if (bn3) { |
|
|
|
|
|
|
|
|
if (bn3) { |
|
|
x = bn3->forward(x); |
|
|
x = bn3->forward(x); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -302,39 +245,7 @@ std::map<std::string, torch::Tensor> ResNetImpl::forward(torch::Tensor x) { |
|
|
torch::ScalarType original_dtype_resnet_bn1 = x.scalar_type(); |
|
|
torch::ScalarType original_dtype_resnet_bn1 = x.scalar_type(); |
|
|
|
|
|
|
|
|
// Apply bn1
|
|
|
// Apply bn1
|
|
|
if (!this->is_training() && bn1) { |
|
|
|
|
|
const auto& bn_module = *bn1; |
|
|
|
|
|
torch::Tensor input_double = x.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); |
|
|
|
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); |
|
|
|
|
|
double eps_double = bn_module.options.eps(); |
|
|
|
|
|
|
|
|
|
|
|
auto c = x.size(1); |
|
|
|
|
|
torch::Tensor reshaped_running_mean = running_mean_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
torch::Tensor reshaped_running_var = running_var_double.reshape({1, c, 1, 1}); |
|
|
|
|
|
torch::Tensor reshaped_weight = weight_double.defined() ? weight_double.reshape({1, c, 1, 1}) : torch::Tensor(); |
|
|
|
|
|
torch::Tensor reshaped_bias = bias_double.defined() ? bias_double.reshape({1, c, 1, 1}) : torch::Tensor(); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor centered_x = input_double - reshaped_running_mean; |
|
|
|
|
|
if (should_output("bn1_centered_x")) outputs["bn1_centered_x"] = centered_x.clone(); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor variance_plus_eps = reshaped_running_var + eps_double; |
|
|
|
|
|
if (should_output("bn1_variance_plus_eps")) outputs["bn1_variance_plus_eps"] = variance_plus_eps.clone(); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor inv_std = torch::rsqrt(variance_plus_eps); // Using rsqrt for potential match
|
|
|
|
|
|
if (should_output("bn1_inv_std")) outputs["bn1_inv_std"] = inv_std.clone(); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor normalized_x = centered_x * inv_std; |
|
|
|
|
|
if (should_output("bn1_normalized_x")) outputs["bn1_normalized_x"] = normalized_x.clone(); |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor out_double = normalized_x; |
|
|
|
|
|
if (reshaped_weight.defined()) out_double = out_double * reshaped_weight; |
|
|
|
|
|
if (reshaped_bias.defined()) out_double = out_double + reshaped_bias; |
|
|
|
|
|
|
|
|
|
|
|
x = out_double.to(original_dtype_resnet_bn1); |
|
|
|
|
|
} else if (bn1) { // Training mode or if manual is disabled
|
|
|
|
|
|
|
|
|
if (bn1) { |
|
|
x = bn1->forward(x); |
|
|
x = bn1->forward(x); |
|
|
} |
|
|
} |
|
|
// End apply bn1
|
|
|
// End apply bn1
|
|
|