|
117 | 117 | "output_type": "stream", |
118 | 118 | "text": [ |
119 | 119 | "Step 0 \t Loss 0.976274\n", |
120 | | - "Step 100 \t Loss 0.001985\n", |
121 | | - "Step 200 \t Loss 0.001541\n", |
122 | | - "Step 300 \t Loss 0.001189\n", |
123 | | - "Step 400 \t Loss 0.000884\n", |
124 | | - "Step 500 \t Loss 0.000625\n", |
125 | | - "Step 600 \t Loss 0.000413\n", |
126 | | - "Step 700 \t Loss 0.000251\n", |
127 | | - "Step 800 \t Loss 0.000130\n", |
128 | | - "Step 900 \t Loss 0.000049\n" |
| 120 | + "Step 100 \t Loss 0.001989\n", |
| 121 | + "Step 200 \t Loss 0.001537\n", |
| 122 | + "Step 300 \t Loss 0.001194\n", |
| 123 | + "Step 400 \t Loss 0.000885\n", |
| 124 | + "Step 500 \t Loss 0.000627\n", |
| 125 | + "Step 600 \t Loss 0.000420\n", |
| 126 | + "Step 700 \t Loss 0.000255\n", |
| 127 | + "Step 800 \t Loss 0.000134\n", |
| 128 | + "Step 900 \t Loss 0.000053\n" |
129 | 129 | ] |
130 | 130 | } |
131 | 131 | ], |
|
135 | 135 | "\n", |
136 | 136 | "key = jax.random.PRNGKey(0)\n", |
137 | 137 | "w = mlp.initialize(key)\n", |
138 | | - "w = mlp.project(w)\n", |
139 | 138 | "\n", |
140 | 139 | "for step in range(steps):\n", |
141 | 140 | " # compute outputs and activations\n", |
|
0 commit comments