You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
23 lines
1.2 KiB
23 lines
1.2 KiB
import torch
|
|
|
|
# Inspect BN1 output
|
|
print('--- BN1 Output ---')
|
|
m = torch.jit.load('test/output/resnet/sample_0_bn1_output.pt', map_location='cpu')
|
|
for name, p in m.named_parameters():
|
|
print(f'{name}: shape={p.shape}, dtype={p.dtype}, device={p.device}, first5={p.flatten()[:5]}')
|
|
for name, b in m.named_buffers():
|
|
print(f'{name}: shape={b.shape}, dtype={b.dtype}, device={b.device}, first5={b.flatten()[:5]}')
|
|
|
|
# Inspect BN1 input
|
|
print('\n--- BN1 Input (Conv1 Output for BN1 Input) ---')
|
|
m_in = torch.jit.load('test/output/resnet/sample_0_debug_resnet_conv1_output_for_bn1_input.pt', map_location='cpu')
|
|
for name, p in m_in.named_parameters():
|
|
print(f'{name}: shape={p.shape}, dtype={p.dtype}, device={p.device}, first5={p.flatten()[:5]}')
|
|
for name, b in m_in.named_buffers():
|
|
print(f'{name}: shape={b.shape}, dtype={b.dtype}, device={b.device}, first5={b.flatten()[:5]}')
|
|
|
|
# Inspect BN1 parameters and stats from exported weights
|
|
print('\n--- BN1 Parameters and Stats from Exported Weights ---')
|
|
for param in ['weight', 'bias', 'running_mean', 'running_var']:
|
|
t = torch.load(f'exported_weights/backbone/bn1_{param}.pt', map_location='cpu')
|
|
print(f'bn1_{param}: shape={t.shape}, dtype={t.dtype}, device={t.device}, first5={t.flatten()[:5]}')
|