-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimportance_generation.py
More file actions
61 lines (55 loc) · 2.5 KB
/
importance_generation.py
File metadata and controls
61 lines (55 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import os
import torch
from utils.common import *
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Importance Assessment')
parser.add_argument('--dataset',type=str,default='cifar10',choices=('cifar10','imagenet','DUTS'),help='dataset')
parser.add_argument('--data_dir',type=str,default='./data',help='path to dataset')
parser.add_argument('--batch_size',type=int,default=128,help='batch size')
parser.add_argument('--pretrain_dir',type=str,default='checkpoints/googlenet.pt',help='load the model from the specified checkpoint')
parser.add_argument('--limit',type=int,default=5,help='The num of batch to get importence score.')
parser.add_argument(
'--net',
type=str,
default='googlenet',
choices=('resnet_50','vgg_16_bn','resnet_56',
'resnet_110','densenet_40','googlenet','u2netp'),
help='net type')
args = parser.parse_args()
net = get_network(args)
if args.pretrain_dir:
# Load checkpoint.
print('==> Resuming from checkpoint..')
if args.net=='u2netp':
from collections import OrderedDict
pretrained_dict = torch.load(args.pretrain_dir, map_location='cpu')
new_state_dirct = OrderedDict()
model_dict = net.state_dict()
for k,v in pretrained_dict.items():
new_state_dirct[k] = v
pretrained_dict_current = {k: v for k, v in new_state_dirct.items() if k in model_dict}
model_dict.update(pretrained_dict_current)
net.load_state_dict(model_dict)
else:
if args.net=='vgg_16_bn' or args.net=='resnet_56':
checkpoint = torch.load(args.pretrain_dir, map_location='cuda:0')
else:
checkpoint = torch.load(args.pretrain_dir)
if args.net=='resnet_50':
net.load_state_dict(checkpoint)
elif args.net=='densenet_40' or args.net=='resnet_110':
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
new_state_dict[k.replace('module.', '')] = v
net.load_state_dict(new_state_dict)
else:
net.load_state_dict(checkpoint['state_dict'])
print('Completed! ')
else:
print('please speicify a pretrain model ')
raise NotImplementedError
# print(net)
imp_score(net, args)
# tmp(net, args)