|
6 | 6 | import torch |
7 | 7 | from torch import Tensor |
8 | 8 |
|
9 | | -from megatron.core import parallel_state, tensor_parallel |
| 9 | +from megatron.core import tensor_parallel |
10 | 10 | from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk |
11 | 11 | from megatron.core.dist_checkpointing.mapping import ShardedStateDict |
12 | 12 | from megatron.core.inference.contexts import BaseInferenceContext |
|
26 | 26 | from megatron.core.tensor_parallel import gather_from_sequence_parallel_region |
27 | 27 | from megatron.core.transformer.enums import CudaGraphScope, ModelType |
28 | 28 | from megatron.core.transformer.multi_token_prediction import ( |
29 | | - MTPLossAutoScaler, |
30 | | - MTPLossLoggingHelper, |
31 | 29 | MultiTokenPredictionBlock, |
32 | | - roll_tensor, |
33 | | - tie_word_embeddings_state_dict, |
| 30 | + mtp_on_this_rank, |
| 31 | + process_mtp_loss, |
34 | 32 | ) |
35 | 33 | from megatron.core.transformer.spec_utils import ModuleSpec |
36 | 34 | from megatron.core.transformer.transformer_block import TransformerBlock |
@@ -144,7 +142,9 @@ def __init__( |
144 | 142 | self.rotary_base = rotary_base |
145 | 143 | self.rotary_scaling = rope_scaling |
146 | 144 | self.mtp_block_spec = mtp_block_spec |
147 | | - self.mtp_process = mtp_block_spec is not None |
| 145 | + self.mtp_process = mtp_block_spec is not None and mtp_on_this_rank( |
| 146 | + self.config, ignore_virtual=False, vp_stage=vp_stage |
| 147 | + ) |
148 | 148 |
|
149 | 149 | if self.pre_process or self.mtp_process: |
150 | 150 | self.embedding = LanguageModelEmbedding( |
@@ -609,56 +609,19 @@ def _postprocess( |
609 | 609 | return hidden_states |
610 | 610 |
|
611 | 611 | if self.config.mtp_num_layers is not None: |
612 | | - mtp_labels = labels.clone() |
613 | | - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) |
614 | | - hidden_states = hidden_states_list[0] |
615 | | - if loss_mask is None: |
616 | | - # if loss_mask is not provided, use all ones as loss_mask |
617 | | - loss_mask = torch.ones_like(mtp_labels) |
618 | | - for mtp_layer_number in range(self.config.mtp_num_layers): |
619 | | - # output |
620 | | - mtp_logits, _ = self.output_layer( |
621 | | - hidden_states_list[mtp_layer_number + 1], |
622 | | - weight=output_weight, |
623 | | - runtime_gather_output=runtime_gather_output, |
624 | | - ) |
625 | | - # Calc loss for the current Multi-Token Prediction (MTP) layers. |
626 | | - mtp_labels, _ = roll_tensor( |
627 | | - mtp_labels, |
628 | | - shifts=-1, |
629 | | - dims=-1, |
630 | | - cp_group=self.cp_group, |
631 | | - packed_seq_params=packed_seq_params, |
632 | | - ) |
633 | | - loss_mask, num_tokens = roll_tensor( |
634 | | - loss_mask, |
635 | | - shifts=-1, |
636 | | - dims=-1, |
637 | | - cp_group=self.cp_group, |
638 | | - packed_seq_params=packed_seq_params, |
639 | | - ) |
640 | | - mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) |
641 | | - mtp_loss = loss_mask * mtp_loss |
642 | | - if self.training: |
643 | | - # TODO(shifangx): remove the use of parallel_state here |
644 | | - # after moving loss logging to loss_func in pretrain_gpt.py |
645 | | - MTPLossLoggingHelper.save_loss_to_tracker( |
646 | | - torch.sum(mtp_loss) / num_tokens, |
647 | | - mtp_layer_number, |
648 | | - self.config.mtp_num_layers, |
649 | | - avg_group=parallel_state.get_data_parallel_group( |
650 | | - with_context_parallel=True |
651 | | - ), |
652 | | - ) |
653 | | - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers |
654 | | - if self.config.calculate_per_token_loss: |
655 | | - hidden_states = MTPLossAutoScaler.apply( |
656 | | - hidden_states, mtp_loss_scale * mtp_loss |
657 | | - ) |
658 | | - else: |
659 | | - hidden_states = MTPLossAutoScaler.apply( |
660 | | - hidden_states, mtp_loss_scale * mtp_loss / num_tokens |
661 | | - ) |
| 612 | + hidden_states = process_mtp_loss( |
| 613 | + hidden_states=hidden_states, |
| 614 | + labels=labels, |
| 615 | + loss_mask=loss_mask, |
| 616 | + output_layer=self.output_layer, |
| 617 | + output_weight=output_weight, |
| 618 | + runtime_gather_output=runtime_gather_output, |
| 619 | + is_training=self.training, |
| 620 | + compute_language_model_loss=self.compute_language_model_loss, |
| 621 | + config=self.config, |
| 622 | + cp_group=self.pg_collection.cp, |
| 623 | + packed_seq_params=packed_seq_params, |
| 624 | + ) |
662 | 625 | sequence_parallel_override = False |
663 | 626 |
|
664 | 627 | if in_inference_mode and inference_context.config.materialize_only_last_token_logits: |
@@ -715,27 +678,6 @@ def _postprocess( |
715 | 678 |
|
716 | 679 | return loss |
717 | 680 |
|
718 | | - def shared_embedding_or_output_weight(self) -> Tensor: |
719 | | - """Gets the embedding weight or output logit weights when share input embedding and |
720 | | - output weights set to True or when use Multi-Token Prediction (MTP) feature. |
721 | | -
|
722 | | - Returns: |
723 | | - Tensor: During pre processing or MTP process it returns the input embeddings weight. |
724 | | - Otherwise, during post processing it returns the final output layers weight. |
725 | | - """ |
726 | | - if self.pre_process or self.mtp_process: |
727 | | - # Multi-Token Prediction (MTP) need both embedding layer and output layer. |
728 | | - # So there will be both embedding layer and output layer in the mtp process stage. |
729 | | - # In this case, if share_embeddings_and_output_weights is True, the shared weights |
730 | | - # will be stored in embedding layer, and output layer will not have any weight. |
731 | | - assert hasattr( |
732 | | - self, 'embedding' |
733 | | - ), f"embedding is needed in this pipeline stage, but it is not initialized." |
734 | | - return self.embedding.word_embeddings.weight |
735 | | - elif self.post_process: |
736 | | - return self.output_layer.weight |
737 | | - return None |
738 | | - |
739 | 681 | def build_schedule_plan( |
740 | 682 | self, |
741 | 683 | input_ids: Tensor, |
@@ -826,20 +768,4 @@ def sharded_state_dict( |
826 | 768 | output_extra_state and output_extra_state.data |
827 | 769 | ), f'Expected output layer extra state to be empty, got: {output_extra_state}' |
828 | 770 |
|
829 | | - # Multi-Token Prediction (MTP) need embedding layer in mtp process stage. |
830 | | - # If MTP is not placed in the pre processing stage, we need to maintain a copy of |
831 | | - # embedding layer in the mtp process stage and tie it to the embedding in the pre |
832 | | - # processing stage. |
833 | | - # Now MTP loss is computed in post processing stage, so the output_layer is not needed. |
834 | | - if self.mtp_process and not self.pre_process: |
835 | | - emb_weight_key = f'{prefix}embedding.word_embeddings.weight' |
836 | | - emb_weight = self.embedding.word_embeddings.weight |
837 | | - tie_word_embeddings_state_dict( |
838 | | - sharded_state_dict, |
839 | | - emb_weight, |
840 | | - emb_weight_key, |
841 | | - tp_group=self.tp_group, |
842 | | - dp_cp_group=metadata['dp_cp_group'], |
843 | | - ) |
844 | | - |
845 | 771 | return sharded_state_dict |
0 commit comments