Skip to content

Commit 3e6ce4a

Browse files
committed
write text
1 parent e9d49da commit 3e6ce4a

File tree

2 files changed

+65
-22
lines changed

2 files changed

+65
-22
lines changed

docs/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ To build these docs locally do:
88
```bash
99
cd docs
1010
pip install -r requirements.txt
11+
conda install -c conda-forge pandoc
1112
make livedirhtml
1213
```

examples/hello-world.ipynb

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,44 @@
88
"# Hello, World!"
99
]
1010
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "847730fa-390b-4b0a-8600-55fb76f9cc38",
14+
"metadata": {},
15+
"source": [
16+
"On this page, we will build a simple training loop to fit an MLP to some randomly generated data. We start by sampling some data. Modula uses JAX to handle array computations, so we use JAX to sample the data. JAX requires us to explicitly pass in the state of the random number generator."
17+
]
18+
},
1119
{
1220
"cell_type": "code",
13-
"execution_count": 5,
21+
"execution_count": 1,
1422
"id": "5a7a804b-06ec-4773-864c-db8a3b01c3e1",
1523
"metadata": {},
1624
"outputs": [],
1725
"source": [
1826
"import jax\n",
1927
"import jax.numpy as jnp\n",
2028
"\n",
21-
"input_dim = 28 * 28\n",
29+
"input_dim = 784\n",
2230
"output_dim = 10\n",
2331
"batch_size = 128\n",
2432
"\n",
25-
"# Generate random training data\n",
2633
"key = jax.random.PRNGKey(0)\n",
2734
"inputs = jax.random.normal(key, (input_dim, batch_size))\n",
2835
"targets = jax.random.normal(key, (output_dim, batch_size))"
2936
]
3037
},
38+
{
39+
"cell_type": "markdown",
40+
"id": "3809ea7f-cd49-4b2f-98a9-0bcd420fbcac",
41+
"metadata": {},
42+
"source": [
43+
"Next, we will build our neural network. We import the basic Linear and ReLU modules. And we compose them by using the `@` operator. Calling `mlp.jit()` tries to make all the internal module methods more efficient using [just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) from JAX."
44+
]
45+
},
3146
{
3247
"cell_type": "code",
33-
"execution_count": 4,
48+
"execution_count": 2,
3449
"id": "a7a14a1b-1428-4432-8e89-6b7cfed3d765",
3550
"metadata": {},
3651
"outputs": [
@@ -53,43 +68,70 @@
5368
"width = 256\n",
5469
"\n",
5570
"mlp = Linear(output_dim, width)\n",
56-
"mlp @= ReLU() @ Linear(width, width) \n",
57-
"mlp @= ReLU() @ Linear(width, input_dim)\n",
71+
"mlp @= ReLU() \n",
72+
"mlp @= Linear(width, width) \n",
73+
"mlp @= ReLU() \n",
74+
"mlp @= Linear(width, input_dim)\n",
5875
"\n",
5976
"print(mlp)\n",
6077
"\n",
6178
"mlp.jit()"
6279
]
6380
},
81+
{
82+
"cell_type": "markdown",
83+
"id": "a7af8a3d-77dc-4007-a6de-03ab617bf3fa",
84+
"metadata": {},
85+
"source": [
86+
"Next, we choose our error measure. Error measures allow us to both compute the loss of the model, and also to compute the derivative of the loss with respect to model outputs. For simplicity we will just use squared error."
87+
]
88+
},
6489
{
6590
"cell_type": "code",
66-
"execution_count": 12,
91+
"execution_count": 3,
92+
"id": "a7ea38a1-2684-437f-88fc-cb1f2a44133c",
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"from modula.error import SquareError\n",
97+
"\n",
98+
"error = SquareError()"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"id": "1c4b8252-b3f0-4d16-9b48-9d8d582c1abe",
104+
"metadata": {},
105+
"source": [
106+
"Finally we are ready to train our model. The method `mlp.backward` takes as input the weights, activations and the gradient of the error. It returns the gradient of the loss with respect to both the model weights and the inputs. The method `mlp.dualize` takes in the gradient of the weights and solves for the vector of unit modular norm that maximizes the linearized improvement in loss."
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": 4,
67112
"id": "080bbf4f-0b73-4d6a-a3d5-f64a2875da9c",
68113
"metadata": {},
69114
"outputs": [
70115
{
71116
"name": "stdout",
72117
"output_type": "stream",
73118
"text": [
74-
"Step 0, Loss 0.9790326952934265\n",
75-
"Step 100, Loss 0.0018738203216344118\n",
76-
"Step 200, Loss 0.0014391584554687142\n",
77-
"Step 300, Loss 0.0010814154520630836\n",
78-
"Step 400, Loss 0.0008106177556328475\n",
79-
"Step 500, Loss 0.0005738214822486043\n",
80-
"Step 600, Loss 0.0003808117180597037\n",
81-
"Step 700, Loss 0.00022766715846955776\n",
82-
"Step 800, Loss 0.00011454012565081939\n",
83-
"Step 900, Loss 3.979807297582738e-05\n"
119+
"Step 0 \t Loss 0.976274\n",
120+
"Step 100 \t Loss 0.001985\n",
121+
"Step 200 \t Loss 0.001541\n",
122+
"Step 300 \t Loss 0.001189\n",
123+
"Step 400 \t Loss 0.000884\n",
124+
"Step 500 \t Loss 0.000625\n",
125+
"Step 600 \t Loss 0.000413\n",
126+
"Step 700 \t Loss 0.000251\n",
127+
"Step 800 \t Loss 0.000130\n",
128+
"Step 900 \t Loss 0.000049\n"
84129
]
85130
}
86131
],
87132
"source": [
88-
"from modula.error import SquareError\n",
89-
"\n",
90133
"steps = 1000\n",
91134
"learning_rate = 0.1\n",
92-
"error = SquareError()\n",
93135
"\n",
94136
"key = jax.random.PRNGKey(0)\n",
95137
"w = mlp.initialize(key)\n",
@@ -118,7 +160,7 @@
118160
" w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]\n",
119161
"\n",
120162
" if step % 100 == 0:\n",
121-
" print(f\"Step {step}, Loss {loss}\")\n"
163+
" print(f\"Step {step:3d} \\t Loss {loss:.6f}\")\n"
122164
]
123165
}
124166
],
@@ -138,7 +180,7 @@
138180
"name": "python",
139181
"nbconvert_exporter": "python",
140182
"pygments_lexer": "ipython3",
141-
"version": "3.12.8"
183+
"version": "3.10.16"
142184
}
143185
},
144186
"nbformat": 4,

0 commit comments

Comments
 (0)