-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.js
More file actions
26 lines (22 loc) · 872 Bytes
/
test.js
File metadata and controls
26 lines (22 loc) · 872 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
const tf = require('@tensorflow/tfjs-node-gpu');
const { GenerateGraph, updateGraph } = require('./index.js');
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
model.add(tf.layers.dense({ units: 2, inputShape: [2] }));
model.add(tf.layers.dense({ units: 2, inputShape: [3] }));
model.add(tf.layers.dense({ units: 1, inputShape: [2] }));
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd', metrics: ['accuracy'] });
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
GenerateGraph(model);
model.fit(xs, ys, {
epochs: 100,
callbacks: {
onEpochEnd: async (epoch, logs) => {
console.log(`Epoch: ${epoch} Loss: ${logs.loss * 100} Accuracy: ${logs.acc}`);
updateGraph(epoch, logs);
}
}
}).then(() => {
console.log('Training complete');
});