diff --git a/src_files/models/utils/factory.py b/src_files/models/utils/factory.py index 12a2f34..8c2aea5 100644 --- a/src_files/models/utils/factory.py +++ b/src_files/models/utils/factory.py @@ -44,11 +44,11 @@ def create_model(args,load_head=False): model_path = "./tresnet_l.pth" print('done') state = torch.load(model_path, map_location='cpu') + if 'model' in state: + key = 'model' + else: + key = 'state_dict' if not load_head: - if 'model' in state: - key = 'model' - else: - key = 'state_dict' filtered_dict = {k: v for k, v in state[key].items() if (k in model.state_dict() and 'head.fc' not in k)} model.load_state_dict(filtered_dict, strict=False)