Skip to content

Commit d07d45d

Browse files
committed
deploy: 955a9c0
1 parent 90d6e02 commit d07d45d

30 files changed

+5651
-5672
lines changed

.doctrees/environment.pickle

-1.13 KB
Binary file not shown.
-50 Bytes
Binary file not shown.
-910 Bytes
Binary file not shown.
-122 Bytes
Binary file not shown.

.doctrees/nbsphinx/examples/hello-mnist.ipynb

Lines changed: 13 additions & 13 deletions
Large diffs are not rendered by default.

.doctrees/nbsphinx/examples/hello-world.ipynb

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -80,30 +80,33 @@
8080
},
8181
{
8282
"cell_type": "markdown",
83-
"id": "a7af8a3d-77dc-4007-a6de-03ab617bf3fa",
83+
"id": "d7d5fc30",
8484
"metadata": {},
8585
"source": [
86-
"Next, we choose our error measure. Error measures allow us to both compute the loss of the model, and also to compute the derivative of the loss with respect to model outputs. For simplicity we will just use squared error."
86+
"Next, we set up a loss function and create a jitted function for both evaluating the loss and also returning its gradient."
8787
]
8888
},
8989
{
9090
"cell_type": "code",
9191
"execution_count": 3,
92-
"id": "a7ea38a1-2684-437f-88fc-cb1f2a44133c",
92+
"id": "8b719f14",
9393
"metadata": {},
9494
"outputs": [],
9595
"source": [
96-
"from modula.error import SquareError\n",
96+
"def mse(w, inputs, targets):\n",
97+
" outputs = mlp(inputs, w)\n",
98+
" loss = ((outputs-targets) ** 2).mean()\n",
99+
" return loss\n",
97100
"\n",
98-
"error = SquareError()"
101+
"mse_and_grad = jax.jit(jax.value_and_grad(mse))"
99102
]
100103
},
101104
{
102105
"cell_type": "markdown",
103106
"id": "1c4b8252-b3f0-4d16-9b48-9d8d582c1abe",
104107
"metadata": {},
105108
"source": [
106-
"Finally we are ready to train our model. The method `mlp.backward` takes as input the weights, activations and the gradient of the error. It returns the gradient of the loss with respect to both the model weights and the inputs. The method `mlp.dualize` takes in the gradient of the weights and solves for the vector of unit modular norm that maximizes the linearized improvement in loss."
109+
"Finally we are ready to train our model. We will apply the method `mlp.dualize` to the gradient of the loss to solve for the vector of unit modular norm that maximizes the linearized improvement in loss."
107110
]
108111
},
109112
{
@@ -116,16 +119,16 @@
116119
"name": "stdout",
117120
"output_type": "stream",
118121
"text": [
119-
"Step 0 \t Loss 0.976274\n",
120-
"Step 100 \t Loss 0.001989\n",
121-
"Step 200 \t Loss 0.001537\n",
122-
"Step 300 \t Loss 0.001194\n",
123-
"Step 400 \t Loss 0.000885\n",
124-
"Step 500 \t Loss 0.000627\n",
125-
"Step 600 \t Loss 0.000420\n",
126-
"Step 700 \t Loss 0.000255\n",
127-
"Step 800 \t Loss 0.000134\n",
128-
"Step 900 \t Loss 0.000053\n"
122+
"Step 0 \t Loss 0.976154\n",
123+
"Step 100 \t Loss 0.001773\n",
124+
"Step 200 \t Loss 0.001371\n",
125+
"Step 300 \t Loss 0.001002\n",
126+
"Step 400 \t Loss 0.000696\n",
127+
"Step 500 \t Loss 0.000453\n",
128+
"Step 600 \t Loss 0.000282\n",
129+
"Step 700 \t Loss 0.000152\n",
130+
"Step 800 \t Loss 0.000061\n",
131+
"Step 900 \t Loss 0.000011\n"
129132
]
130133
}
131134
],
@@ -137,17 +140,9 @@
137140
"w = mlp.initialize(key)\n",
138141
"\n",
139142
"for step in range(steps):\n",
140-
" # compute outputs and activations\n",
141-
" outputs, activations = mlp(inputs, w)\n",
142-
" \n",
143-
" # compute loss\n",
144-
" loss = error(outputs, targets)\n",
145-
" \n",
146-
" # compute error gradient\n",
147-
" error_grad = error.grad(outputs, targets)\n",
148-
" \n",
149-
" # compute gradient of weights\n",
150-
" grad_w, _ = mlp.backward(w, activations, error_grad)\n",
143+
"\n",
144+
" # compute loss and gradient of weights\n",
145+
" loss, grad_w = mse_and_grad(w, inputs, targets)\n",
151146
" \n",
152147
" # dualize gradient\n",
153148
" d_w = mlp.dualize(grad_w)\n",
@@ -179,7 +174,7 @@
179174
"name": "python",
180175
"nbconvert_exporter": "python",
181176
"pygments_lexer": "ipython3",
182-
"version": "3.10.16"
177+
"version": "3.12.8"
183178
}
184179
},
185180
"nbformat": 4,

.doctrees/nbsphinx/examples/weight-erasure.ipynb

Lines changed: 15 additions & 17 deletions
Large diffs are not rendered by default.
955 Bytes
Loading
355 Bytes
Loading
198 Bytes
Loading

0 commit comments

Comments
 (0)