|
| 1 | +## Speculative Decoding |
| 2 | +The basic idea here is to have a smaller/faster model that acts as a "draft" |
| 3 | +to predict the next few tokens, which are then verified by a larger model which |
| 4 | +is the target model that we actually want to use. |
| 5 | + |
| 6 | +Lets say we specify that we want the draft model to predict 5 tokens. It does |
| 7 | +so in a normal autoregressive manner, and when it has predicted the 5 tokens |
| 8 | +those tokens are passed to the target model including the original prompt tokens. |
| 9 | + |
| 10 | +So we have an initial prompt just like I normally would to start things off. This |
| 11 | +is passed to the draft model and it is set to predict 5 tokens, which it actually |
| 12 | +samples and all. |
| 13 | + |
| 14 | +These 5 tokens are then passed (including the tokens embeddings for the initial |
| 15 | +prompt) are then passed to the target model. The target model will process these |
| 16 | +as it would normally process a prompt, the additional 5 tokens at the end are |
| 17 | +just normal prompt tokens as far as the target model is concerned. |
| 18 | + |
| 19 | +And recall that the when a model processes a prompt it calculates the predictions |
| 20 | +for every position simultaneously in a single pass. |
| 21 | + |
| 22 | +```console |
| 23 | +Drafted tokens: [A, B, C, D, E] |
| 24 | +(5 cheap runs of the draft model) |
| 25 | + |
| 26 | +The target model processes: [Prompt, A, B, C, D, E] |
| 27 | +(1 expensive run of the target model) |
| 28 | +``` |
| 29 | +Because the target model calculates logits for every position in parallel: |
| 30 | +```console |
| 31 | +It sees Prompt -> Predicts validation for A |
| 32 | +It sees ...A -> Predicts validation for B |
| 33 | +It sees ...B -> Predicts validation for C |
| 34 | +It sees ...C -> Predicts validation for D |
| 35 | +It sees ...D -> Predicts validation for E |
| 36 | +It sees ...E -> Predicts the brand new token F |
| 37 | +``` |
| 38 | +It does all 6 of these calculations in that single forward pass. |
| 39 | + |
| 40 | +Now, lets say that the prediction for token C mismatches what the draft model |
| 41 | +predicted, lets say it predicted X. In that case we reject tokens C, D, and E |
| 42 | +and output X as the next predicted token from the target model: |
| 43 | +```console |
| 44 | +[A, B, X] |
| 45 | +``` |
0 commit comments