-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathCNNTest.m
More file actions
108 lines (94 loc) · 3.08 KB
/
CNNTest.m
File metadata and controls
108 lines (94 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
function [DataSet, SalImg, PrecEER] = CNNTest(varargin)
% experiment and data paths
opts.expDir = [];
opts.GPUID = [];
opts.TestDir = [];
opts.TestClassName = [];
opts.imdb = [];
opts = vl_argparse(opts, varargin) ;
ModelList = dir([opts.expDir '/net-epoch*.mat']);
NumEpoches = numel(ModelList) - 1;
opts.ModelPath = cell(1, NumEpoches);
for i = 1:NumEpoches
opts.ModelPath{i} = [opts.expDir '/net-epoch-' num2str(i) '.mat'];
end
averageImage = load(strrep(opts.ModelPath{1}, '-1.mat','-0.mat'), 'AvgImage');
averageImage = averageImage.AvgImage(1, 1, :);
inputVar = 'input';
imdb = load([opts.TestDir '/TrainTestSplit.mat'], 'TestSet');
AllClassName = {imdb.TestSet.ClassName};
ClassID = strcmp(AllClassName,opts.TestClassName);
imdb = imdb.TestSet(ClassID);
NumImages = numel(imdb.ImageName);
ImageList = cell(NumImages, 1);
MaskList = cell(NumImages, 1);
DataSet = struct('Image', ImageList, 'Mask', ImageList, ...
'ImageLabel', ImageList, 'ImageName', ImageList);
for imId = 1:NumImages
rgbPath = [opts.TestDir '/' imdb.ImageName{imId}];
labelsPath = [opts.TestDir '/' imdb.MaskName{imId}];
Image = imread(rgbPath);
Mask = imread(labelsPath);
ImageList{imId} = bsxfun(@minus, single(Image), averageImage);
MaskList{imId} = vec(single(Mask));
DataSet(imId).Image = Image;
DataSet(imId).Mask = Mask;
DataSet(imId).ImageLabel = imdb.ImageLabel(imId);
[~,ImageName,~] = fileparts(imdb.ImageName{imId});
DataSet(imId).ImageName = ImageName;
end
PrecEER = zeros(numel(opts.ModelPath), 1);
SalImg = cell(NumImages, 1);
TempSalImg = cell(NumImages, 1);
for EpochID = 1:numel(opts.ModelPath)
if isempty(opts.ModelPath{EpochID})
continue
end
disp(['EpochID: ' num2str(EpochID) '/' num2str(numel(opts.ModelPath))])
tic
SalImgList = cell(NumImages, 1);
[net, predVar] = LoadNet(opts.ModelPath{EpochID}, opts.GPUID);
for imId = 1:NumImages
Image = ImageList{imId};
sz = [size(Image,1), size(Image,2)] ;
sz_ = round(sz / 32)*32 ;
Image_ = imresize(Image, sz_) ;
if ~isempty(opts.GPUID)
Image_ = gpuArray(Image_) ;
end
net.eval({inputVar, Image_});
scores = gather(net.vars(predVar).value);
SalMap = imresize(scores, sz);
SalImgList{imId, 1} = vec(SalMap);
TempSalImg{imId} = SalMap;
net.reset();
end
[Precision, Recall, EER, TempPrecEER] = EvalResult(SalImgList, MaskList);
toc
if TempPrecEER > max(PrecEER)
SalImg = TempSalImg;
end
PrecEER(EpochID) = TempPrecEER;
end
end
function [net, predVar] = LoadNet(ModelPath, GPUID)
net = load(ModelPath) ;
net = dagnn.DagNN.loadobj(net.net);
LayerName = {net.layers.name}';
Temp = strfind(LayerName, 'Classifier');
RemoveLayerName = LayerName(cellfun(@NonEmpty, Temp));
RemoveLayerName{end+1} = 'objective';
net.removeLayer(RemoveLayerName);
predVar = net.getVarIndex('ObjSalPred') ;
net.vars(predVar).precious = true;
net.mode = 'test' ;
if ~isempty(GPUID)
net.move('gpu') ;
end
end
function V = NonEmpty(V)
V = ~isempty(V);
end
function V = vec(V)
V = V(:);
end