Skip to content

Commit c04352a

Browse files
authored
Merge pull request #243 from mldiego/master
Some experimenting with GPU, fix few issues
2 parents 2b13d82 + 5ae592c commit c04352a

File tree

366 files changed

+613
-38
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

366 files changed

+613
-38
lines changed

code/nnv/engine/nn/layers/FullyConnectedLayer.m

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,8 @@
284284

285285
n = in_image.numPred;
286286
V(1, 1, :, in_image.numPred + 1) = zeros(obj.OutputSize, 1, 'like', in_image.V);
287-
for i=1:n+1
288-
I = in_image.V(:,:,:,i);
289-
I = reshape(I,N,1); % flatten input
290-
if i==1
291-
V(1, 1,:,i) = obj.Weights*I + obj.Bias;
292-
else
293-
V(1, 1,:,i) = obj.Weights*I;
294-
end
295-
end
287+
V(1, 1, :, :) = obj.Weights*reshape(in_image.V, N, n + 1);
288+
V(1, 1, :, 1) = reshape(V(1, 1, :, 1), obj.OutputSize, 1) + obj.Bias;
296289
% output set
297290
image = ImageStar(V, in_image.C, in_image.d, in_image.pred_lb, in_image.pred_ub);
298291
else % reach Star set

code/nnv/engine/nn/layers/ImageInputLayer.m

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@
116116
image = in_image;
117117
elseif strcmp(obj.Normalization, 'zerocenter')
118118
image = in_image.affineMap([], -obj.Mean);
119+
elseif strcmp(obj.Normalization, 'zscore')
120+
image = in_image.affineMap([], -obj.Mean);
121+
layer_std = obj.StandardDeviation;
122+
for nc = 1:image.numChannel
123+
image.V(:,:,nc,:) = image.V(:,:,nc,:)/layer_std(nc);
124+
end
119125
else
120126
error('The normalization method is not supported yet.')
121127
end

code/nnv/engine/set/Box.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
% Speeding up implementation
4141
gens = diag(vec); % generate matrix
4242
if numel(gens) > 1
43-
gens(:,all(gens(gens==0))) = []; % delete colums with no info
43+
gens(:,all(gens==0)) = []; % delete colums with no info
4444
end
4545
obj.generators = gens;
4646
catch

code/nnv/engine/set/VolumeStar.m

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,16 +554,16 @@
554554
S.V = gather(S.V);
555555
S.C = gather(S.C);
556556
S.d = gather(S.d);
557-
S.predicate_lb = gather(S.predicate_lb);
558-
S.predicate_ub = gather(S.predicate_ub);
557+
S.pred_lb = gather(S.pred_lb);
558+
S.pred_ub = gather(S.pred_ub);
559559
S.vol_lb = gather(S.vol_lb);
560560
S.vol_ub = gather(S.vol_ub);
561561
elseif strcmp(deviceTarget, 'gpu')
562562
S.V = gpuArray(S.V);
563563
S.C = gpuArray(S.C);
564564
S.d = gpuArray(S.d);
565-
S.predicate_lb = gpuArray(S.predicate_lb);
566-
S.predicate_ub = gpuArray(S.predicate_ub);
565+
S.pred_lb = gpuArray(S.pred_lb);
566+
S.pred_ub = gpuArray(S.pred_ub);
567567
S.vol_lb = gpuArray(S.vol_lb);
568568
S.vol_ub = gpuArray(S.vol_ub);
569569
else

code/nnv/engine/utils/lpsolver.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
% Define solver parameters
5050
params = struct; % for now, leave default options/params
5151
params.OutputFlag = 0; % no display
52-
params.OptimalityTol = 1e-09;
53-
params.FeasibilityTol = 1e-09;
52+
% params.OptimalityTol = 1e-09;
53+
% params.FeasibilityTol = 1e-09;
5454
result = gurobi(model, params);
5555
fval = result.objval; % get fval value from results
5656
% get exitflag and match those of linprog for easier parsing

code/nnv/engine/utils/verify_specification.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
if ~isa(Set, "Star")
3333
Set = Set.toStar;
3434
end
35+
if isa(Set.V, 'gpuArray')
36+
Set = Set.changeDevice('cpu');
37+
end
3538
S = Set.intersectHalfSpace(property.G, property.g); % compute intersection with unsafe/not robust region
3639
if isempty(S)
3740
result = 1; % no intersection with unsafe region = safe (unsat)
@@ -57,6 +60,7 @@
5760
continue; % does nothing, just need an statement, wanted to make this clear
5861
else
5962
result = 2; % unknown if approx, sat if exact
63+
return;
6064
end
6165
end
6266
cp = cp+1;
4.44 KB
4.38 KB
2.69 KB
5.08 KB

0 commit comments

Comments
 (0)