You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

23 lines
1.2 KiB

import torch
# Inspect BN1 output
print('--- BN1 Output ---')
m = torch.jit.load('test/output/resnet/sample_0_bn1_output.pt', map_location='cpu')
for name, p in m.named_parameters():
print(f'{name}: shape={p.shape}, dtype={p.dtype}, device={p.device}, first5={p.flatten()[:5]}')
for name, b in m.named_buffers():
print(f'{name}: shape={b.shape}, dtype={b.dtype}, device={b.device}, first5={b.flatten()[:5]}')
# Inspect BN1 input
print('\n--- BN1 Input (Conv1 Output for BN1 Input) ---')
m_in = torch.jit.load('test/output/resnet/sample_0_debug_resnet_conv1_output_for_bn1_input.pt', map_location='cpu')
for name, p in m_in.named_parameters():
print(f'{name}: shape={p.shape}, dtype={p.dtype}, device={p.device}, first5={p.flatten()[:5]}')
for name, b in m_in.named_buffers():
print(f'{name}: shape={b.shape}, dtype={b.dtype}, device={b.device}, first5={b.flatten()[:5]}')
# Inspect BN1 parameters and stats from exported weights
print('\n--- BN1 Parameters and Stats from Exported Weights ---')
for param in ['weight', 'bias', 'running_mean', 'running_var']:
t = torch.load(f'exported_weights/backbone/bn1_{param}.pt', map_location='cpu')
print(f'bn1_{param}: shape={t.shape}, dtype={t.dtype}, device={t.device}, first5={t.flatten()[:5]}')