Skip to content

Commit cb7dee9

Browse files
adaptive dimension
1 parent 3839bde commit cb7dee9

File tree

13 files changed

+2898
-58
lines changed

13 files changed

+2898
-58
lines changed

code/nnv/engine/nn/NN.m

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,32 @@
260260

261261

262262
function outputSet = reachProb_ImageStar(obj, IS, reachOptions)
263-
264-
263+
264+
265265
pe = pyenv;
266266
py_dir = pe.Executable;
267267

268268

269+
if isa(IS, 'ImageStar')
270+
if isempty(IS.im_lb)
271+
[LB, UB] = getRanges(IS);
272+
else
273+
LB = IS.im_lb;
274+
UB = IS.im_ub;
275+
end
276+
277+
elseif isa(IS, 'Star')
278+
if isempty(IS.state_lb)
279+
[LB, UB] = getBox(IS);
280+
else
281+
LB = IS.state_lb;
282+
UB = IS.state_ub;
283+
end
284+
285+
else
286+
error('The input must be a Star or Image_Star object.');
287+
end
288+
269289
if isfield(reachOptions, 'coverage')
270290
coverage = reachOptions.coverage;
271291
else
@@ -276,19 +296,16 @@
276296
else
277297
confidence = 0.99;
278298
end
279-
if isempty(IS.im_lb)
280-
error('We assume the input ImageStar is a box, and also contains the feature, im_lb. In case your input is not a box but contains im_lb, then your reachset will be more conservative as we assume a box with lower bound im_lb and upper bound im_ub.')
281-
end
282299

283300
if isfield(reachOptions, 'device')
284-
run_device = reachOptions.device;
301+
train_device = reachOptions.device;
285302
else
286-
run_device = 'cpu';
303+
train_device = 'gpu';
287304
end
288305
if isfield(reachOptions, 'epochs')
289-
epochs = reachOptions.epochs;
306+
train_epochs = reachOptions.epochs;
290307
else
291-
epochs = 50;
308+
train_epochs = 50;
292309
end
293310

294311
if isfield(reachOptions, 'train_lr')
@@ -299,19 +316,21 @@
299316

300317
[N_dir , N , Ns] = CP_specification(coverage, confidence, numel(IS.im_lb) , train_device, 'single');
301318

302-
303-
SizeIn = size(IS.im_lb);
304-
SizeOut = size(evaluate(obj, IS.im_lb));
319+
320+
SizeIn = size(LB);
321+
SizeOut = size(evaluate(obj, LB));
305322
height = SizeIn(1);
306323
width = SizeIn(2);
307-
if length(SizeOut) == 2
308-
SizeOut = [ SizeOut , 1];
324+
325+
326+
if isfield(reachOptions, 'indices')
327+
indices = reachOptions.indices;
328+
else
329+
[J,I] = ndgrid(1:width,1:height);
330+
indices = [I(:), J(:)];
309331
end
310332

311-
[J,I] = ndgrid(1:width,1:height);
312-
indices = [I(:), J(:)];
313333

314-
315334
if isfield(reachOptions, 'mode')
316335
train_mode = reachOptions.mode;
317336
else
@@ -343,7 +362,7 @@
343362
params.guarantee = coverage;
344363
params.py_dir = py_dir;
345364

346-
The_class = ProbReach_ImageStar(obj, IS, indices, SizeOut, train_mode, params);
365+
The_class = ProbReach_ImageStar(obj, LB, UB, indices, SizeOut, train_mode, params);
347366

348367
outputSet = The_class.ProbReach();
349368

code/nnv/engine/nn/Prob_reach/ProbReach_ImageStar.m

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,35 @@
1313
end
1414

1515
methods
16-
function obj = ProbReach_ImageStar(model,IS,indices,output_dim,mode,params)
16+
function obj = ProbReach_ImageStar(model,LB, UB,indices,SizeOut,mode,params)
1717
obj.model = model;
18-
if isempty(IS.im_lb)
19-
error('We assume the input ImageStar is a box, and also contains the feature, im_lb. In case your input is not a box but contains im_lb, then your reachset will be more conservative as we assume a box with lower bound im_lb and upper bound im_ub.')
20-
end
21-
obj.LB = IS.im_lb;
22-
obj.de = IS.im_ub-IS.im_lb;
18+
obj.LB = LB;
19+
obj.de = UB-LB;
2320
obj.indices = indices;
24-
SizeIn = size(IS.im_lb);
25-
if length(SizeIn) == 2 %% n_channel = 1
21+
22+
SizeIn = size(LB);
23+
lenIn = length(SizeIn);
24+
if lenIn == 1
25+
obj.original_dim = [SizeIn , 1, 1];
26+
elseif lenIn == 2
2627
obj.original_dim = [SizeIn , 1];
27-
else
28+
elseif lenIn == 3
2829
obj.original_dim = SizeIn;
30+
else
31+
obj.original_dim = SizeIn(1:3);
2932
end
30-
obj.output_dim = output_dim;
33+
34+
lenOut = length(SizeOut);
35+
if lenOut == 1
36+
obj.output_dim = [SizeOut , 1, 1];
37+
elseif lenOut == 2
38+
obj.output_dim = [SizeOut , 1];
39+
elseif lenOut == 3
40+
obj.output_dim = SizeOut;
41+
else
42+
obj.output_dim = SizeOut(1:3);
43+
end
44+
3145
obj.mode = mode;
3246

3347
thisFile = mfilename('fullpath');
@@ -70,7 +84,7 @@
7084
out = obj.model.predict(x);
7185

7286
case 'dlnetwork'
73-
dlX = dlarray(X, 'SSC'); % 'SSC' = Spatial, Spatial, Channel
87+
dlX = dlarray(x);
7488
out = obj.model.predict(dlX);
7589

7690
case 'NN'
@@ -101,7 +115,7 @@
101115

102116
%%%%%%%%%%%%%%
103117
parfor i=1:N
104-
% disp(i)
118+
disp(i)
105119
Rand = rand(n_channel*N_perturbed,1);
106120
Rand_matrix = obj.mat_generator_no_third(Rand);
107121
d_at = zeros(height,width,n_channel);
-8.7 MB
Binary file not shown.
Binary file not shown.

code/nnv/engine/nn/Prob_reach/Trainer_ReLU.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def estimate_lipschitz(x, y, num_samples=1000):
3939
y = torch.tensor(mat_data['dYV'].T, dtype=torch.float32) # Shape [10000, 10]
4040
dims = mat_data['dims'].flatten().astype(int).tolist()
4141
epochs = int(mat_data['epochs'].flatten()[0])
42+
lr = mat_data['lr'].flatten()[0]
4243

4344
# Estimate λ before training
4445
lam = max( 10.0 , 5*estimate_lipschitz(x, y) )

code/nnv/engine/utils/Prob_reach.m

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,26 @@
33
pe = pyenv;
44
py_dir = pe.Executable;
55

6+
if isa(In_ImS, 'ImageStar')
7+
if isempty(In_ImS.im_lb)
8+
[LB, UB] = getRanges(In_ImS);
9+
else
10+
LB = In_ImS.im_lb;
11+
UB = In_ImS.im_ub;
12+
end
13+
14+
elseif isa(In_ImS, 'Star')
15+
if isempty(In_ImS.state_lb)
16+
[LB, UB] = getBox(In_ImS);
17+
else
18+
LB = In_ImS.state_lb;
19+
UB = In_ImS.state_ub;
20+
end
21+
22+
else
23+
error('The input must be a Star or Image_Star object.');
24+
end
25+
626

727
if isfield(reachOptions, 'coverage')
828
coverage = reachOptions.coverage;
@@ -14,15 +34,26 @@
1434
else
1535
confidence = 0.99;
1636
end
17-
if isempty(In_ImS.im_lb)
18-
error('We assume the input ImageStar is a box, and also contains the feature, im_lb. In case your input is not a box but contains im_lb, then your reachset will be more conservative as we assume a box with lower bound im_lb and upper bound im_ub.')
19-
end
37+
2038

2139
if isfield(reachOptions, 'train_device')
2240
train_device = reachOptions.train_device;
2341
else
2442
train_device = 'gpu';
2543
end
44+
45+
if isfield(reachOptions, 'train_mode')
46+
train_mode = reachOptions.train_mode;
47+
else
48+
train_mode = 'Linear';
49+
end
50+
51+
if isfield(reachOptions, 'surrogate_dim')
52+
surrogate_dim = reachOptions.surrogate_dim;
53+
else
54+
surrogate_dim = [-1, -1];
55+
end
56+
2657
if isfield(reachOptions, 'train_epochs')
2758
train_epochs = reachOptions.train_epochs;
2859
else
@@ -32,18 +63,17 @@
3263
if isfield(reachOptions, 'train_lr')
3364
train_lr = reachOptions.train_lr;
3465
else
35-
train_lr = 0.01;
66+
train_lr = 0.0001; %%% The prefrence for lr is 0.0001 in 'Linear' and 0.01 in 'relu' mode.
3667
end
3768

69+
70+
3871
[N_dir , N , Ns] = CP_specification(coverage, confidence, numel(In_ImS.im_lb) , train_device, 'single');
3972

40-
SizeIn = size(In_ImS.im_lb);
41-
SizeOut = size(forward(Net, In_ImS.im_lb));
73+
SizeIn = size(LB);
74+
SizeOut = size(forward(Net, LB));
4275
height = SizeIn(1);
4376
width = SizeIn(2);
44-
if length(SizeOut) == 2
45-
SizeOut = [ SizeOut , 1];
46-
end
4777

4878

4979
if isfield(reachOptions, 'indices')
@@ -53,17 +83,6 @@
5383
indices = [I(:), J(:)];
5484
end
5585

56-
if isfield(reachOptions, 'train_mode')
57-
train_mode = reachOptions.train_mode;
58-
else
59-
train_mode = 'Linear';
60-
end
61-
62-
if isfield(reachOptions, 'surrogate_dim')
63-
surrogate_dim = reachOptions.surrogate_dim;
64-
else
65-
surrogate_dim = [-1, -1];
66-
end
6786

6887
if isfield(reachOptions, 'threshold_normal')
6988
threshold_normal = reachOptions.threshold_normal;
@@ -84,15 +103,12 @@
84103
params.py_dir = py_dir;
85104

86105

87-
obj = ProbReach_ImageStar(Net,In_ImS,indices,SizeOut,train_mode, params);
106+
obj = ProbReach_ImageStar(Net,LB, UB,indices, SizeOut, train_mode, params);
88107
Out_ImS = obj.ProbReach();
89108

90109
end
91110

92111

93-
94-
95-
96112
function out = forward(model, x)
97113

98114
model_source = class(model);
@@ -106,7 +122,7 @@
106122
out = model.predict(x);
107123

108124
case 'dlnetwork'
109-
dlX = dlarray(X, 'SSC'); % 'SSC' = Spatial, Spatial, Channel
125+
dlX = dlarray(x);
110126
out = model.predict(dlX);
111127

112128
case 'NN'
@@ -115,4 +131,4 @@
115131
otherwise
116132
error("Unknown model source: " + model_source + ". We only cover NN, SeriesNetwork, dlnetwork and DAGNetwork.");
117133
end
118-
end
134+
end

0 commit comments

Comments
 (0)