import torch # Inspect BN1 output print('--- BN1 Output ---') m = torch.jit.load('test/output/resnet/sample_0_bn1_output.pt', map_location='cpu') for name, p in m.named_parameters(): print(f'{name}: shape={p.shape}, dtype={p.dtype}, device={p.device}, first5={p.flatten()[:5]}') for name, b in m.named_buffers(): print(f'{name}: shape={b.shape}, dtype={b.dtype}, device={b.device}, first5={b.flatten()[:5]}') # Inspect BN1 input print('\n--- BN1 Input (Conv1 Output for BN1 Input) ---') m_in = torch.jit.load('test/output/resnet/sample_0_debug_resnet_conv1_output_for_bn1_input.pt', map_location='cpu') for name, p in m_in.named_parameters(): print(f'{name}: shape={p.shape}, dtype={p.dtype}, device={p.device}, first5={p.flatten()[:5]}') for name, b in m_in.named_buffers(): print(f'{name}: shape={b.shape}, dtype={b.dtype}, device={b.device}, first5={b.flatten()[:5]}') # Inspect BN1 parameters and stats from exported weights print('\n--- BN1 Parameters and Stats from Exported Weights ---') for param in ['weight', 'bias', 'running_mean', 'running_var']: t = torch.load(f'exported_weights/backbone/bn1_{param}.pt', map_location='cpu') print(f'bn1_{param}: shape={t.shape}, dtype={t.dtype}, device={t.device}, first5={t.flatten()[:5]}')