-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.js
More file actions
114 lines (95 loc) · 3.21 KB
/
data.js
File metadata and controls
114 lines (95 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
const tf = require('@tensorflow/tfjs');
const orm = require('./orm');
const CLASSES = ['rejected', 'correct'];
const LEARNING_RATE = 0.001;
var rawMetrics = null;
async function load(){
console.log("DB load...");
const Metric = await orm.getCollection();
const q = await Metric.find().limit(100000); // TODO: improve and randomize sampling, maybe load by batches
rawMetrics = q.map(({filename, ...keepAttrs}) => keepAttrs);
return rawMetrics;
}
async function getModel() {
if (rawMetrics == null) {
load();
}
const model = tf.sequential();
model.add(tf.layers.dense(
{
units: 15,
inputShape: [Object.keys(rawMetrics[0]).length - 1], // how many features we have minus label
activation: 'sigmoid' // for analog activation
}
));
model.add(tf.layers.dense(
{
units: 2,
activation: 'softmax'
}
));
const optimizer = tf.train.adam(LEARNING_RATE);
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy', 'categoricalCrossentropy', 'precision', 'recall'],
});
return model;
}
async function getData(testSplit) {
if (rawMetrics == null) {
await load();
}
return tf.tidy(() => {
const xTrains = [];
const yTrains = [];
const xTests = [];
const yTests = [];
const dataByClass = [];
const targetsByClass = [];
for (let i = 0; i < CLASSES.length; ++i) {
dataByClass.push([]);
targetsByClass.push([]);
}
for (const row of rawMetrics) {
example = Object.values(row);
const target = example[example.length - 1] ? 1 : 0;
const data = example.slice(0, example.length - 1);
dataByClass[target].push(data);
targetsByClass[target].push(target);
}
for (let i = 0; i < CLASSES.length; ++i) {
const [xTrain, yTrain, xTest, yTest] =
convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
xTrains.push(xTrain);
yTrains.push(yTrain);
xTests.push(xTest);
yTests.push(yTest);
}
const concatAxis = 0;
return [
tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
];
});
}
function convertToTensors(data, targets, testSplit) {
const numExamples = data.length;
if (numExamples !== targets.length) {
throw new Error('data and split have different numbers of examples');
}
const numTestExamples = Math.round(numExamples * testSplit);
const numTrainExamples = numExamples - numTestExamples;
const xDims = data[0].length;
// Create a 2D `tf.Tensor` to hold the feature data.
const xs = tf.tensor2d(data, [numExamples, xDims]);
// Create a 1D `tf.Tensor` to hold the labels using one-hot encoding
const ys = tf.oneHot(tf.tensor1d(targets).toInt(), CLASSES.length);
// Split the data into training and test sets, using `slice`.
const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
const yTrain = ys.slice([0, 0], [numTrainExamples, CLASSES.length]);
const yTest = ys.slice([0, 0], [numTestExamples, CLASSES.length]);
return [xTrain, yTrain, xTest, yTest];
}
module.exports = { getData, getModel, CLASSES };