Skip to content

Commit 12860c4

Browse files
committed
Adding grayscale classification
1 parent c9f5401 commit 12860c4

File tree

11 files changed

+268
-8
lines changed

11 files changed

+268
-8
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ nvm use v22.1.0
7979
8080
node scripts/classify-images.js
8181
node scripts/custom-model.js
82+
node scripts/grayscale-custom-model.js
8283
node scripts/transfer-learning.js
8384
```
8485

@@ -128,5 +129,12 @@ Thanks go to the fantastic bakers out there making hyper-realistic cakes that in
128129
5. https://www.thesugardreams.net
129130
6. https://juliemcallistercakes.com
130131
7. https://www.thesweetstopofrva.com
132+
8. https://apriljulian.com/portfolio
133+
9. https://cakeyall.com/
134+
10. https://www.layerscakestudio.com/
135+
11. https://www.apinchofspirit.com/gallery
136+
12. https://www.mycakebuds.com
137+
13. https://mannandcobakeshop.com/
138+
14. https://www.thesculptedslice.com/
131139

132140
Thanks also to [Unsplash](https://unsplash.com/) and their amazing contributors who allowed me to find non-cake images through the [JavaScript wrapper unsplash-js](https://www.npmjs.com/package/unsplash-js) of the [Unsplash API](https://unsplash.com/documentation).

cake-game/src/components/classifier-table-row/ClassifierTableRow.jsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ function ClassifierTableRow(props) {
3939
<th className="classification">{ formatClassificationCollections(props.result.models?.coco_ssd_predictions, 'class') }</th>
4040
<th className="classification">{ formatClassificationString(props.result.models?.my_transfer_model_classifier?.category) }</th>
4141
<th className="classification">{ formatClassificationString(props.result.models?.my_model_classifier.category) }</th>
42+
<th className="classification">{ formatClassificationString(props.result.models?.my_grayscale_model_classifier?.category) }</th>
4243
</tr>
4344
</>
4445
);

cake-game/src/functions/game_results.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ export async function handler(event, context) {
1717
return document._source;
1818
});
1919

20+
// console.log(`RESULTS: ${results}`)
21+
2022
return generateResponse(200, results);
2123
} catch (e) {
2224
console.log(e);

cake-game/src/routes/end/End.jsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ function End() {
7373
<th>COCO-SSD</th>
7474
<th>MobileNet Transfer Classifier</th>
7575
<th>Carly Model</th>
76+
<th>Carly Model (Grayscale)</th>
7677
</tr>
7778
</thead>
7879
<tbody>

model-classification-app/scripts/cake-image-urls.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

model-classification-app/scripts/cake-sites.json

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,41 @@
3434
{
3535
"url": "https://www.thesweetstopofrva.com/weddingcakes",
3636
"category": "cake"
37+
},
38+
{
39+
"url": "https://apriljulian.com/portfolio",
40+
"category": "cake"
41+
},
42+
{
43+
"url": "https://cakeyall.com/pages/info",
44+
"category": "cake"
45+
},
46+
{
47+
"url": "https://www.layerscakestudio.com/",
48+
"category": "cake"
49+
},
50+
{
51+
"url": "https://www.apinchofspirit.com/gallery",
52+
"category": "cake"
53+
},
54+
{
55+
"url": "https://www.mycakebuds.com/products",
56+
"category": "cake"
57+
},
58+
{
59+
"url": "https://mannandcobakeshop.com/pages/custom-cake",
60+
"category": "cake"
61+
},
62+
{
63+
"url": "https://www.thesculptedslice.com/gallery-home/celebration-cakes",
64+
"category": "cake"
65+
},
66+
{
67+
"url": "https://www.thesculptedslice.com/gallery-home/sculpted-cakes",
68+
"category": "cake"
69+
},
70+
{
71+
"url": "https://www.thesculptedslice.com/gallery-home/wedding-cakes",
72+
"category": "cake"
3773
}
3874
]

model-classification-app/scripts/classify-images.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ async function run() {
4040
// Reset index (uncomment if regenerating)
4141
//await clearIndex();
4242

43-
/*await getCakeImages();
44-
const cakeImageUrls = loadCakeImageUrls();
45-
await classifyImages('cake', cakeImageUrls);*/
43+
//await getCakeImages();
44+
//const cakeImageUrls = loadCakeImageUrls();
45+
//await classifyImages('cake', cakeImageUrls);
4646

4747
const objectImageUrls = await getUnsplashImageUrls();
4848
await classifyImages("not cake", objectImageUrls);
@@ -65,7 +65,7 @@ async function getCakeImages() {
6565
try {
6666
await page.goto(site.url);
6767
const currentPageImages = await page
68-
.locator("img")
68+
.getByRole("img")
6969
.evaluateAll((images) => {
7070
return images
7171
.map((image) => {

model-classification-app/scripts/elasticsearch-util.js

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,29 @@ async function updateDocumentWithClassification(documentId, category, prediction
8585
}
8686
}
8787

88+
async function updateDocumentWithGrayscaleClassification(documentId, category, predictions) {
89+
const myGrayscaleModelClassifier = {
90+
category: category,
91+
predictions: predictions
92+
};
93+
try {
94+
const response = await esClient.update(
95+
{
96+
index: index,
97+
id: documentId,
98+
script: {
99+
lang: 'painless',
100+
source: `ctx._source.my_grayscale_model_classifier = params.classification`,
101+
params: { classification: myGrayscaleModelClassifier }
102+
}
103+
}
104+
);
105+
console.log(response);
106+
} catch(e) {
107+
console.log(e);
108+
}
109+
}
110+
88111
async function updateDocumentWithTransferClassification(documentId, category, predictions) {
89112
const myModelClassifier = {
90113
category: category,
@@ -107,4 +130,4 @@ async function updateDocumentWithTransferClassification(documentId, category, pr
107130
}
108131
}
109132

110-
module.exports = { esClient, clearIndex, addClassifiersToIndex, getAllImages, getFirstNImagesByCategory, updateDocumentWithClassification, updateDocumentWithTransferClassification };
133+
module.exports = { esClient, clearIndex, addClassifiersToIndex, getAllImages, getFirstNImagesByCategory, updateDocumentWithClassification, updateDocumentWithGrayscaleClassification, updateDocumentWithTransferClassification };
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
const tf = require("@tensorflow/tfjs-node");
2+
3+
const {
4+
getAllImages,
5+
getFirstNImagesByCategory,
6+
updateDocumentWithGrayscaleClassification,
7+
} = require("./elasticsearch-util");
8+
const { getGrayscaleImageTensor, getGrayscaleTensorsForImageSet, IMAGE_HEIGHT, IMAGE_WIDTH } = require("./tf-util");
9+
10+
const CLASS_NAMES = ["cake", "not cake"];
11+
12+
// Build custom model
13+
run();
14+
15+
async function run() {
16+
// Get a subset of the cake images
17+
const cakesResponse = await getFirstNImagesByCategory(CLASS_NAMES[0], 50);
18+
const cakeTensors = await getGrayscaleTensorsForImageSet(cakesResponse);
19+
20+
// Get a subset of the unsplash images for not cake images
21+
const notCakesResponse = await getFirstNImagesByCategory(CLASS_NAMES[1], 50);
22+
const notCakeTensors = await getGrayscaleTensorsForImageSet(notCakesResponse);
23+
24+
const images = cakeTensors.concat(notCakeTensors);
25+
const labels = Array.from({ length: cakeTensors.length })
26+
.fill([1, 0])
27+
.concat(Array.from({ length: notCakeTensors.length }).fill([0, 1]));
28+
29+
tf.util.shuffleCombo(images, labels);
30+
const singleImageTensor = tf.stack(images);
31+
const labelsTensor = tf.tensor2d(labels);
32+
33+
const model = createModel();
34+
35+
const BATCH_SIZE = 32;
36+
const NUM_EPOCHS = 10;
37+
38+
await model.fit(singleImageTensor, labelsTensor, {
39+
batchSize: BATCH_SIZE, // Number of samples to work through before updating the internal model parameters
40+
epochs: NUM_EPOCHS, // Number of passes through the dataset
41+
shuffle: true, // Shuffle data before each pass
42+
});
43+
44+
// Classify images
45+
await classifyAllImages(model);
46+
47+
// Optional saving of model
48+
const MODEL_DIR = "./model";
49+
50+
await model.save(`file://${MODEL_DIR}`);
51+
52+
// Tidy up
53+
singleImageTensor.dispose();
54+
labelsTensor.dispose();
55+
tf.dispose(cakeTensors);
56+
tf.dispose(notCakeTensors);
57+
58+
console.log('Classification complete!');
59+
}
60+
61+
/* Functional implementation */
62+
// Convolutional Neural Network (CNN) example
63+
function createModel() {
64+
const model = tf.sequential();
65+
66+
/* Creates a 2d convolution layer.
67+
* Concept from computer vision where a filter (or kernel or matrix) is applied and moves
68+
through the image by the specified strides to identify features of interest in the image
69+
See https://www.kaggle.com/discussions/general/463431
70+
*/
71+
model.add(
72+
tf.layers.conv2d({
73+
inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, 1], // 1 = Grayscale
74+
filters: 16, // dimensions of the output space
75+
kernelSize: 3, // 3x3 matrix
76+
activation: "relu", //f(x)=max(0,x)
77+
})
78+
);
79+
80+
/* Max pooling reduces the dimensionality of images by reducing the number of pixels in the output from the
81+
* previous convolutional layer.
82+
* Used to reduce computational load going forward and reduce overfitting
83+
* See https://deeplizard.com/learn/video/ZjM_XQa5s6s
84+
*/
85+
model.add(
86+
tf.layers.maxPooling2d({
87+
poolSize: 2,
88+
strides: 2,
89+
})
90+
);
91+
92+
model.add(
93+
tf.layers.conv2d({
94+
filters: 32,
95+
kernelSize: 3,
96+
activation: "relu",
97+
})
98+
);
99+
100+
model.add(
101+
tf.layers.maxPooling2d({
102+
poolSize: 2,
103+
strides: 2,
104+
})
105+
);
106+
107+
// Flattens the inputs to 1D, making the outputs 2D
108+
model.add(tf.layers.flatten());
109+
110+
/* Dense Layer is simple layer of neurons in which each neuron receives input from all the neurons of previous layer,
111+
* thus called as dense. Dense Layer is used to classify image based on output from convolutional layers.
112+
see https://towardsdatascience.com/introduction-to-convolutional-neural-network-cnn-de*/
113+
model.add(
114+
tf.layers.dense({
115+
units: 64,
116+
activation: "relu",
117+
})
118+
);
119+
120+
model.add(
121+
tf.layers.dense({
122+
units: CLASS_NAMES.length,
123+
activation: "softmax", // turns a vector of K real values into a vector of K real values that sum to 1
124+
})
125+
);
126+
127+
model.compile({
128+
optimizer: tf.train.adam(), // Stochastic Optimization method
129+
loss: "binaryCrossentropy",
130+
metrics: ["accuracy"],
131+
});
132+
133+
return model;
134+
}
135+
136+
async function classifyAllImages(model) {
137+
const imagesResponse = await getAllImages();
138+
const images = imagesResponse.hits.hits.flatMap((result) => {
139+
return { id: result._id, url: result._source.image_url };
140+
});
141+
142+
for (image of images) {
143+
console.log(image.url);
144+
const tensor = await getGrayscaleImageTensor(image.url);
145+
const results = await model.predict(tensor.expandDims()).data();
146+
147+
const predictions = Array.from(results)
148+
.map(function (p, i) {
149+
return {
150+
probability: p,
151+
className: CLASS_NAMES[i], // we are selecting the value from the obj
152+
};
153+
})
154+
.sort(function (a, b) {
155+
return b.probability - a.probability;
156+
})
157+
.slice(0, 2);
158+
159+
console.log(predictions);
160+
updateDocumentWithGrayscaleClassification(
161+
image.id,
162+
predictions[0].className,
163+
predictions
164+
);
165+
166+
tensor.dispose();
167+
}
168+
}

model-classification-app/scripts/tf-util.js

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ async function getTensorFromImage(imageUrl) {
4242

4343
return tensors;
4444
}
45+
46+
async function getGrayscaleTensorsForImageSet(results) {
47+
let tensors = [];
48+
for (result of results.hits.hits) {
49+
const features = await getGrayscaleImageTensor(result._source.image_url);
50+
tensors.push(features);
51+
}
52+
53+
return tensors;
54+
}
4555

4656
async function getResizedImageTensor(imageUrl) {
4757
const decodedImage = await getTensorFromImage(imageUrl);
@@ -52,4 +62,15 @@ async function getTensorFromImage(imageUrl) {
5262
return resizedImage;
5363
}
5464

55-
module.exports = { getResizedImageTensor, getTensorFromImage, getTensorsForImageSet, IMAGE_HEIGHT, IMAGE_WIDTH };
65+
async function getGrayscaleImageTensor(imageUrl) {
66+
const decodedImage = await getTensorFromImage(imageUrl);
67+
const resizedImage = tf.image.resizeBilinear(decodedImage, [
68+
IMAGE_WIDTH,
69+
IMAGE_HEIGHT
70+
], true)
71+
const grayscaleImage = tf.image.rgbToGrayscale(resizedImage);
72+
73+
return grayscaleImage;
74+
}
75+
76+
module.exports = { getGrayscaleImageTensor, getResizedImageTensor, getTensorFromImage, getTensorsForImageSet, getGrayscaleTensorsForImageSet, IMAGE_HEIGHT, IMAGE_WIDTH };

0 commit comments

Comments
 (0)