Skip to content

Commit f1ffae3

Browse files
committed
include project in initalize
1 parent 4762d67 commit f1ffae3

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

examples/hello-mnist.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
{
180180
"data": {
181181
"application/vnd.jupyter.widget-view+json": {
182-
"model_id": "8a432072d1bb4a39b5438417ac15d5f5",
182+
"model_id": "d26bcf36d33e4b108018d97c00b975fc",
183183
"version_major": 2,
184184
"version_minor": 0
185185
},
@@ -202,8 +202,8 @@
202202
"\n",
203203
"error = SquareError()\n",
204204
"\n",
205-
"w = mlp.initialize(jax.random.PRNGKey(0))\n",
206-
"w = mlp.project(w)\n",
205+
"key = jax.random.PRNGKey(0)\n",
206+
"w = mlp.initialize(key)\n",
207207
" \n",
208208
"progress_bar = tqdm(range(steps), desc=f\"Loss: {0:.4f}\")\n",
209209
"for step in progress_bar:\n",
@@ -248,7 +248,7 @@
248248
"output_type": "stream",
249249
"text": [
250250
"Accuracy on shown samples: 5/5\n",
251-
"Overall test accuracy: 97.58%\n"
251+
"Overall test accuracy: 97.52%\n"
252252
]
253253
}
254254
],

examples/hello-world.ipynb

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,15 @@
117117
"output_type": "stream",
118118
"text": [
119119
"Step 0 \t Loss 0.976274\n",
120-
"Step 100 \t Loss 0.001985\n",
121-
"Step 200 \t Loss 0.001541\n",
122-
"Step 300 \t Loss 0.001189\n",
123-
"Step 400 \t Loss 0.000884\n",
124-
"Step 500 \t Loss 0.000625\n",
125-
"Step 600 \t Loss 0.000413\n",
126-
"Step 700 \t Loss 0.000251\n",
127-
"Step 800 \t Loss 0.000130\n",
128-
"Step 900 \t Loss 0.000049\n"
120+
"Step 100 \t Loss 0.001989\n",
121+
"Step 200 \t Loss 0.001537\n",
122+
"Step 300 \t Loss 0.001194\n",
123+
"Step 400 \t Loss 0.000885\n",
124+
"Step 500 \t Loss 0.000627\n",
125+
"Step 600 \t Loss 0.000420\n",
126+
"Step 700 \t Loss 0.000255\n",
127+
"Step 800 \t Loss 0.000134\n",
128+
"Step 900 \t Loss 0.000053\n"
129129
]
130130
}
131131
],
@@ -135,7 +135,6 @@
135135
"\n",
136136
"key = jax.random.PRNGKey(0)\n",
137137
"w = mlp.initialize(key)\n",
138-
"w = mlp.project(w)\n",
139138
"\n",
140139
"for step in range(steps):\n",
141140
" # compute outputs and activations\n",

modula/atom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def backward(self, w, acts, grad_output):
3232

3333
def initialize(self, key):
3434
weight = jax.random.normal(key, shape=(self.fanout, self.fanin))
35-
return [weight]
35+
return self.project([weight])
3636

3737
def project(self, w):
3838
weight = w[0]

0 commit comments

Comments
 (0)