|
| 1 | +%% Fairness Verification of Adult Classification Model (NN) |
| 2 | +% Comparison for the models used in Fairify |
| 3 | + |
| 4 | +% Suppress warnings |
| 5 | +warning('off', 'nnet_cnn_onnx:onnx:WarnAPIDeprecation'); |
| 6 | +warning('off', 'nnet_cnn_onnx:onnx:FillingInClassNames'); |
| 7 | + |
| 8 | +%% Load data into NNV |
| 9 | +warning('on', 'verbose') |
| 10 | + |
| 11 | +%% Setup |
| 12 | +clear; clc; |
| 13 | +modelDir = './adult_onnx'; % Directory containing ONNX models |
| 14 | +onnxFiles = dir(fullfile(modelDir, '*.onnx')); % List all .onnx files |
| 15 | + |
| 16 | +load("adult_fairify2_data.mat", 'X', 'y'); % Load data once |
| 17 | + |
| 18 | + |
| 19 | +%% Loop through each model |
| 20 | +for k = 1:length(2) |
| 21 | + % onnx_model_path = fullfile(onnxFiles(k).folder, onnxFiles(k).name); |
| 22 | + onnx_model_path = fullfile("adult_my_models2/model_0.onnx"); |
| 23 | + % onnx_model_path = fullfile("adult_onnx/AC-1.onnx"); |
| 24 | + |
| 25 | + % Load the ONNX file as DAGNetwork |
| 26 | + netONNX = importONNXNetwork(onnx_model_path, 'OutputLayerType', 'classification', 'InputDataFormats', {'BC'}); |
| 27 | + |
| 28 | + % analyzeNetwork(netONNX) |
| 29 | + |
| 30 | + % Convert the DAGNetwork to NNV format |
| 31 | + net = matlab2nnv(netONNX); |
| 32 | + |
| 33 | + % Jimmy Rigged Fix: manually edit ouput size |
| 34 | + net.OutputSize = 2; |
| 35 | + |
| 36 | + % disp(net) |
| 37 | + |
| 38 | + X_test_loaded = permute(X, [2, 1]); |
| 39 | + y_test_loaded = y+1; % update labels |
| 40 | + |
| 41 | + % Normalize features in X_test_loaded |
| 42 | + min_values = min(X_test_loaded, [], 2); |
| 43 | + max_values = max(X_test_loaded, [], 2); |
| 44 | + |
| 45 | + % Ensure no division by zero for constant features |
| 46 | + variableFeatures = max_values - min_values > 0; |
| 47 | + min_values(~variableFeatures) = 0; % Avoids changing constant features |
| 48 | + max_values(~variableFeatures) = 1; % Avoids division by zero |
| 49 | + |
| 50 | + % Normalizing X_test_loaded |
| 51 | + X_test_loaded = (X_test_loaded - min_values) ./ (max_values - min_values); |
| 52 | + |
| 53 | + % % Print normalized values for a few samples |
| 54 | + % disp('First few normalized inputs in MATLAB:'); |
| 55 | + % disp(X_test_loaded(:, 1:5)); |
| 56 | + % |
| 57 | + % % Print model outputs for a few samples |
| 58 | + % disp('First few model outputs in MATLAB:'); |
| 59 | + % for i = 1:5 |
| 60 | + % im = X_test_loaded(:, i); |
| 61 | + % predictedLabels = net.evaluate(im); |
| 62 | + % disp(predictedLabels); |
| 63 | + % end |
| 64 | + |
| 65 | + % Count total observations |
| 66 | + total_obs = size(X_test_loaded, 2); |
| 67 | + % disp(['There are total ', num2str(total_obs), ' observations']); |
| 68 | + |
| 69 | + % % |
| 70 | + % % Test accuracy --> verify matches with python |
| 71 | + % % |
| 72 | + % total_corr = 0; |
| 73 | + % for i=1:total_obs |
| 74 | + % im = X_test_loaded(:, i); |
| 75 | + % predictedLabels = net.evaluate(im); |
| 76 | + % [~, Pred] = min(predictedLabels); |
| 77 | + % disp(Pred) |
| 78 | + % TrueLabel = y_test_loaded(i); |
| 79 | + % disp(TrueLabel) |
| 80 | + % if Pred == TrueLabel |
| 81 | + % total_corr = total_corr + 1; |
| 82 | + % end |
| 83 | + % end |
| 84 | + % disp(['Test Accuracy: ', num2str(total_corr/total_obs)]); |
| 85 | + |
| 86 | + % Number of observations we want to test |
| 87 | + numObs = 100; |
| 88 | + |
| 89 | + %% Verification |
| 90 | + |
| 91 | + % to save results (robustness and time) |
| 92 | + results = zeros(numObs,2); |
| 93 | + |
| 94 | + % First, we define the reachability options |
| 95 | + reachOptions = struct; % initialize |
| 96 | + reachOptions.reachMethod = 'exact-star'; |
| 97 | + reachOptions.relaxFactor = 0.5; |
| 98 | + |
| 99 | + nR = 50; % ---> just chosen arbitrarily |
| 100 | + |
| 101 | + % ADJUST epsilon value here |
| 102 | + % epsilon = [0.01]; |
| 103 | + epsilon = [0.0,0.001,0.01]; |
| 104 | + % epsilon = [0.00001]; |
| 105 | + |
| 106 | + % |
| 107 | + % Set up results |
| 108 | + % |
| 109 | + nE = 3; %% will need to update later |
| 110 | + res = zeros(numObs,nE); % robust result |
| 111 | + time = zeros(numObs,nE); % computation time |
| 112 | + met = repmat("exact", [numObs, nE]); % method used to compute result |
| 113 | + |
| 114 | + |
| 115 | + % Randomly select observations |
| 116 | + rng(500); % Set a seed for reproducibility |
| 117 | + rand_indices = randsample(total_obs, numObs); |
| 118 | + |
| 119 | + for e=1:length(epsilon) |
| 120 | + % Reset the timeout flag |
| 121 | + assignin('base', 'timeoutOccurred', false); |
| 122 | + |
| 123 | + % Create and configure the timer |
| 124 | + verificationTimer = timer; |
| 125 | + verificationTimer.StartDelay = 600; % Set timer for 10 minutes |
| 126 | + verificationTimer.TimerFcn = @(myTimerObj, thisEvent) ... |
| 127 | + assignin('base', 'timeoutOccurred', true); |
| 128 | + start(verificationTimer); % Start the timer |
| 129 | + |
| 130 | + ce_count = 0; |
| 131 | + exact_count = 0; |
| 132 | + ap_count = 0; |
| 133 | + |
| 134 | + for i=1:numObs |
| 135 | + idx = rand_indices(i); |
| 136 | + IS = perturbation(X_test_loaded(:, idx), epsilon(e), min_values, max_values); |
| 137 | + |
| 138 | + |
| 139 | + t = tic; % Start timing the verification for each sample |
| 140 | + |
| 141 | + temp = net.verify_robustness(IS, reachOptions, y_test_loaded(idx)); |
| 142 | + % disp(string(i)+" Exact: "+string(temp)) |
| 143 | + met(i,e) = 'exact'; |
| 144 | + res(i,e) = temp; % robust result |
| 145 | + % end |
| 146 | + |
| 147 | + time(i,e) = toc(t); % store computation time |
| 148 | + |
| 149 | + % Check for timeout flag |
| 150 | + if evalin('base', 'timeoutOccurred') |
| 151 | + disp(['Timeout reached for epsilon = ', num2str(epsilon(e)), ': stopping verification for this epsilon.']); |
| 152 | + res(i+1:end,e) = 2; % Mark remaining as unknown |
| 153 | + break; % Exit the inner loop after timeout |
| 154 | + end |
| 155 | + end |
| 156 | + |
| 157 | + % Summary results, stopping, and deleting the timer should be outside the inner loop |
| 158 | + stop(verificationTimer); |
| 159 | + delete(verificationTimer); |
| 160 | + |
| 161 | + % Get summary results |
| 162 | + N = numObs; |
| 163 | + rob = sum(res(:,e)==1); |
| 164 | + not_rob = sum(res(:,e) == 0); |
| 165 | + unk = sum(res(:,e) == 2); |
| 166 | + totalTime = sum(time(:,e)); |
| 167 | + avgTime = totalTime/N; |
| 168 | + |
| 169 | + % Print results to screen |
| 170 | + % fprintf('Model: %s\n', onnxFiles(k).name); |
| 171 | + disp("======= ROBUSTNESS RESULTS e: "+string(epsilon(e))+" ==========") |
| 172 | + disp(" "); |
| 173 | + disp("Number of fair samples = "+string(rob)+ ", equivalent to " + string(100*rob/N) + "% of the samples."); |
| 174 | + disp("Number of non-fair samples = " +string(not_rob)+ ", equivalent to " + string(100*not_rob/N) + "% of the samples.") |
| 175 | + disp("Number of unknown samples = "+string(unk)+ ", equivalent to " + string(100*unk/N) + "% of the samples."); |
| 176 | + disp(" "); |
| 177 | + disp("It took a total of "+string(totalTime) + " seconds to compute the verification results, an average of "+string(avgTime)+" seconds per sample"); |
| 178 | + end |
| 179 | +end |
| 180 | + |
| 181 | + |
| 182 | +%% Helper Function |
| 183 | +% Adjusted for fairness check -> only apply perturbation to desired feature. |
| 184 | +function IS = perturbation(x, epsilon, min_values, max_values) |
| 185 | + % Applies perturbations on selected features of input sample x |
| 186 | + % Return an ImageStar (IS) and random images from initial set |
| 187 | + SampleSize = size(x); |
| 188 | + |
| 189 | + disturbance = zeros(SampleSize, "like", x); |
| 190 | + sensitive_rows = [9]; |
| 191 | + nonsensitive_rows = [1,10,11,12]; |
| 192 | + |
| 193 | + % Flip the sensitive attribute |
| 194 | + if x(sensitive_rows) == 1 |
| 195 | + x(sensitive_rows) = 0; |
| 196 | + else |
| 197 | + x(sensitive_rows) = 1; |
| 198 | + end |
| 199 | + |
| 200 | + % Apply epsilon perturbation to non-sensitive numerical features |
| 201 | + for i = 1:length(nonsensitive_rows) |
| 202 | + if nonsensitive_rows(i) <= size(x, 1) |
| 203 | + disturbance(nonsensitive_rows(i), :) = epsilon; |
| 204 | + else |
| 205 | + error('The input data does not have enough rows.'); |
| 206 | + end |
| 207 | + end |
| 208 | + |
| 209 | + % Calculate disturbed lower and upper bounds considering min and max values |
| 210 | + lb = max(x - disturbance, min_values); |
| 211 | + ub = min(x + disturbance, max_values); |
| 212 | + IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models) |
| 213 | +end |
| 214 | + |
0 commit comments