|
80 | 80 | }, |
81 | 81 | { |
82 | 82 | "cell_type": "markdown", |
83 | | - "id": "a7af8a3d-77dc-4007-a6de-03ab617bf3fa", |
| 83 | + "id": "d7d5fc30", |
84 | 84 | "metadata": {}, |
85 | 85 | "source": [ |
86 | | - "Next, we choose our error measure. Error measures allow us to both compute the loss of the model, and also to compute the derivative of the loss with respect to model outputs. For simplicity we will just use squared error." |
| 86 | + "Next, we set up a loss function and create a jitted function for both evaluating the loss and also returning its gradient." |
87 | 87 | ] |
88 | 88 | }, |
89 | 89 | { |
90 | 90 | "cell_type": "code", |
91 | 91 | "execution_count": 3, |
92 | | - "id": "a7ea38a1-2684-437f-88fc-cb1f2a44133c", |
| 92 | + "id": "8b719f14", |
93 | 93 | "metadata": {}, |
94 | 94 | "outputs": [], |
95 | 95 | "source": [ |
96 | | - "from modula.error import SquareError\n", |
| 96 | + "def mse(w, inputs, targets):\n", |
| 97 | + " outputs = mlp(inputs, w)\n", |
| 98 | + " loss = ((outputs-targets) ** 2).mean()\n", |
| 99 | + " return loss\n", |
97 | 100 | "\n", |
98 | | - "error = SquareError()" |
| 101 | + "mse_and_grad = jax.jit(jax.value_and_grad(mse))" |
99 | 102 | ] |
100 | 103 | }, |
101 | 104 | { |
102 | 105 | "cell_type": "markdown", |
103 | 106 | "id": "1c4b8252-b3f0-4d16-9b48-9d8d582c1abe", |
104 | 107 | "metadata": {}, |
105 | 108 | "source": [ |
106 | | - "Finally we are ready to train our model. The method `mlp.backward` takes as input the weights, activations and the gradient of the error. It returns the gradient of the loss with respect to both the model weights and the inputs. The method `mlp.dualize` takes in the gradient of the weights and solves for the vector of unit modular norm that maximizes the linearized improvement in loss." |
| 109 | + "Finally we are ready to train our model. We will apply the method `mlp.dualize` to the gradient of the loss to solve for the vector of unit modular norm that maximizes the linearized improvement in loss." |
107 | 110 | ] |
108 | 111 | }, |
109 | 112 | { |
|
116 | 119 | "name": "stdout", |
117 | 120 | "output_type": "stream", |
118 | 121 | "text": [ |
119 | | - "Step 0 \t Loss 0.976274\n", |
120 | | - "Step 100 \t Loss 0.001989\n", |
121 | | - "Step 200 \t Loss 0.001537\n", |
122 | | - "Step 300 \t Loss 0.001194\n", |
123 | | - "Step 400 \t Loss 0.000885\n", |
124 | | - "Step 500 \t Loss 0.000627\n", |
125 | | - "Step 600 \t Loss 0.000420\n", |
126 | | - "Step 700 \t Loss 0.000255\n", |
127 | | - "Step 800 \t Loss 0.000134\n", |
128 | | - "Step 900 \t Loss 0.000053\n" |
| 122 | + "Step 0 \t Loss 0.976154\n", |
| 123 | + "Step 100 \t Loss 0.001773\n", |
| 124 | + "Step 200 \t Loss 0.001371\n", |
| 125 | + "Step 300 \t Loss 0.001002\n", |
| 126 | + "Step 400 \t Loss 0.000696\n", |
| 127 | + "Step 500 \t Loss 0.000453\n", |
| 128 | + "Step 600 \t Loss 0.000282\n", |
| 129 | + "Step 700 \t Loss 0.000152\n", |
| 130 | + "Step 800 \t Loss 0.000061\n", |
| 131 | + "Step 900 \t Loss 0.000011\n" |
129 | 132 | ] |
130 | 133 | } |
131 | 134 | ], |
|
137 | 140 | "w = mlp.initialize(key)\n", |
138 | 141 | "\n", |
139 | 142 | "for step in range(steps):\n", |
140 | | - " # compute outputs and activations\n", |
141 | | - " outputs, activations = mlp(inputs, w)\n", |
142 | | - " \n", |
143 | | - " # compute loss\n", |
144 | | - " loss = error(outputs, targets)\n", |
145 | | - " \n", |
146 | | - " # compute error gradient\n", |
147 | | - " error_grad = error.grad(outputs, targets)\n", |
148 | | - " \n", |
149 | | - " # compute gradient of weights\n", |
150 | | - " grad_w, _ = mlp.backward(w, activations, error_grad)\n", |
| 143 | + "\n", |
| 144 | + " # compute loss and gradient of weights\n", |
| 145 | + " loss, grad_w = mse_and_grad(w, inputs, targets)\n", |
151 | 146 | " \n", |
152 | 147 | " # dualize gradient\n", |
153 | 148 | " d_w = mlp.dualize(grad_w)\n", |
|
179 | 174 | "name": "python", |
180 | 175 | "nbconvert_exporter": "python", |
181 | 176 | "pygments_lexer": "ipython3", |
182 | | - "version": "3.10.16" |
| 177 | + "version": "3.12.8" |
183 | 178 | } |
184 | 179 | }, |
185 | 180 | "nbformat": 4, |
|
0 commit comments