Skip to content

Commit 540c92f

Browse files
committed
added load_images_MNIST() to utils
1 parent 8f1bc7d commit 540c92f

File tree

3 files changed

+71
-7
lines changed

3 files changed

+71
-7
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
function [images, labels, db_image_nos] = load_images_MNIST(args)
2+
3+
arguments
4+
args.database = "mnist"
5+
args.n % number of images
6+
args.matlabnet % if a matlab neural network is supplied, only
7+
% the correctly classified images will be returned
8+
end
9+
10+
database = args.database;
11+
n = args.n;
12+
matlabnet = args.matlabnet;
13+
14+
images = {};
15+
labels = {};
16+
db_image_nos = {};
17+
if strcmp(database, "mnist")
18+
if n > 10000
19+
error('Maximum 10000 mnist images available.')
20+
end
21+
% Load data (no download necessary)
22+
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
23+
'nndatasets','DigitDataset');
24+
% Images
25+
imds = imageDatastore(digitDatasetPath, ...
26+
'IncludeSubfolders',true,'LabelSource','foldernames');
27+
28+
numClasses = 10;
29+
if mod(n, numClasses) ~= 0
30+
error(['For MNIST, to have balanced dataset, number of images must be a multiple of ' num2str(numClasses)])
31+
end
32+
NPerClass = n/numClasses;
33+
34+
db_img_no = 1;
35+
no_of_images_chosen_from_class = 0;
36+
class_no = 1;
37+
while class_no <= numClasses
38+
% Load one image in dataset
39+
[img, fileInfo] = readimage(imds, db_img_no);
40+
img = single(img); % change precision
41+
label = single(fileInfo.Label);
42+
43+
append_image = 1;
44+
if ~isempty(matlabnet)
45+
[~, pred] = max(predict(matlabnet, img));
46+
if label ~= pred
47+
append_image = 0;
48+
end
49+
end
50+
51+
if append_image
52+
images{end + 1} = img;
53+
labels{end + 1} = label;
54+
db_image_nos{end + 1} = db_img_no;
55+
no_of_images_chosen_from_class = no_of_images_chosen_from_class + 1;
56+
if no_of_images_chosen_from_class == NPerClass
57+
class_no = class_no + 1;
58+
no_of_images_chosen_from_class = 0;
59+
db_img_no = (class_no - 1)*1000;
60+
end
61+
end
62+
db_img_no = db_img_no + 1;
63+
end
64+
65+
else
66+
error(['Unsupported database ' database])
67+
end
68+
end
69+

code/nnv/examples/Tutorial/NN/MNIST/weightPerturb/single_layer/conv_expt_any_layer.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
otherwise
6161
error("Not supported yet.")
6262
end
63-
[images, labels] = load_images(database = database, ...
63+
[images, labels] = load_images_MNIST(database = database, ...
6464
n = expt.data.n_images, ...
6565
matlabnet = matlabnet);
6666

code/nnv/tests/weightPerturb/test_conv_layer_perturbation.m

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,8 @@
2424
reachOptions.numCores = 1;
2525
% reachOptions.device = 'gpu';
2626
reachOptions.device = 'cpu';
27-
% reachOptions.delete_old_sets = 1;
28-
reachOptions.free_mem_frac_for_LPs_in_verify_specification = 0.1;
29-
% reachOptions.dis_opt = 'display';
30-
% reachOptions.disp_intersection_result = 1;
31-
% reachOptions.debug = 1;
3227

33-
[images, labels] = load_images(database = "mnist", ...
28+
[images, labels] = load_images_MNIST(database = "mnist", ...
3429
n = 10, ...
3530
matlabnet = matlabnet);
3631
img = images{1};

0 commit comments

Comments
 (0)