-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathevaluate_localization.py
More file actions
62 lines (54 loc) · 3.35 KB
/
evaluate_localization.py
File metadata and controls
62 lines (54 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import argparse
import os
import numpy as np
from attribution_evaluation.evaluation import localization
from attribution_evaluation.models import settings
import torchvision
def main(args):
torch.manual_seed(args.seed)
np.random.seed(args.seed)
test_data_dict = torch.load(os.path.join(args.dataset_path, 'test.pt'))
scale = test_data_dict["scale"]
img_dims = test_data_dict["input_dims"][1:]
attributions = torch.load(os.path.join(args.attributions_path, 'attributions_' +
args.model + '_' + args.setting + '_' + os.path.basename(args.dataset_path) + '_' + args.exp + '_' + args.config + '_' + str(args.layer) + args.save_suffix + '.pt'))
localization_scores = localization.get_localization_score(
attributions, only_corners=settings.eval_only_corners(args.setting), img_dims=img_dims, scale=scale)
print("Localization Scores:")
print("Number of data points:", len(localization_scores))
print("Mean:", "{:.4f}".format(localization_scores.mean()))
print("Standard Deviation:", "{:.4f}".format(localization_scores.std()))
print("Median:", "{:.4f}".format(localization_scores.median()))
print("Min:", "{:.4f}".format(localization_scores.min()))
print("Max:", "{:.4f}".format(localization_scores.max()))
fig, _ = localization.plot_localization_scores_single(
localization_scores.tolist(), args.model, args.setting, args.exp, args.config, args.layer, scale=scale)
full_save_path = os.path.join(args.save_path, "localization_" +
args.model + "_" + args.setting + "_" + os.path.basename(args.dataset_path) + '_' + args.exp + '_' + args.config + '_' + str(args.layer) + args.save_suffix + ".png")
print("Saving box plot at", full_save_path)
fig.savefig(full_save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Computes the localization scores for a set of attributions and plots them on a box plot.")
parser.add_argument('--dataset_path', type=str, required=True,
help="Path of directory containing the dataset")
parser.add_argument('--seed', type=int, default=1, help="Random seed value")
parser.add_argument('--model', type=str, required=True,
choices=["vgg11", "resnet18"], help="Model to evaluate on")
parser.add_argument('--setting', type=str, required=True,
choices=["GridPG", "DiFull", "DiPart"], help="Setting to evaluate on")
parser.add_argument('--layer', type=str, required=True, choices=["Input", "Middle", "Final"],
help="Layer to evaluate on")
parser.add_argument('--attributions_path', type=str, required=True,
help="Path of directory from which to load attributions")
parser.add_argument('--save_path', type=str, required=True,
help="Path of directory in which to save plot")
parser.add_argument('--save_suffix', type=str, default='',
help="Suffix to add to the output file name")
parser.add_argument('--exp', type=str, required=True,
help="Attribution method to evaluate")
parser.add_argument('--config', type=str, required=True,
help="Configuration of the attribution method to be used")
args = parser.parse_args()
main(args)