Skip to content

Commit bffc801

Browse files
Merge branch 'main' into helenn-rl-training-cudagraphs-functional
2 parents 9ff653a + 9d71cb1 commit bffc801

27 files changed

+1238
-208
lines changed

.gitlab/stages/01.build.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ test:build_image:
121121
KUBERNETES_SERVICE_MEMORY_LIMIT: 90Gi
122122
SHARED_PATH: /builds/$CI_PROJECT_PATH/shared
123123
script:
124+
- apk add skopeo
124125
- |
125126
set -x
126127
@@ -132,6 +133,11 @@ test:build_image:
132133
${IMAGE}:${CI_PIPELINE_ID}-arm64
133134
134135
docker manifest push ${IMAGE}:${CI_PIPELINE_ID}
136+
137+
if [[ "$CI_COMMIT_BRANCH" == "ci-rebuild-mcore-nemo-image" || "$CI_COMMIT_BRANCH" == "main" || "$CI_COMMIT_BRANCH" == "dev" ]]; then
138+
skopeo copy --all docker://${IMAGE}:${CI_PIPELINE_ID} docker://${IMAGE}:${CI_COMMIT_BRANCH}
139+
fi
140+
135141
- echo "MCORE_MR_COMMIT=$CI_COMMIT_SHA" | tee -a build.env
136142
- echo "MCORE_BACKWARDS_COMMIT=$MCORE_BACKWARDS_COMMIT" | tee -a build.env
137143
- cat build.env

mamba_builders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@
88
from megatron.training.arguments import core_transformer_config_from_args
99
from megatron.core.models.mamba.mamba_layer_specs import mamba_inference_stack_spec
1010

11+
1112
def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None):
1213
print_rank_0('building MAMBA model ...')
1314
if config is None:
1415
config = core_transformer_config_from_args(args, TransformerConfig)
1516
assert args.use_legacy_models is False, "Mamba only supported in Mcore!"
1617

1718
if config.transformer_impl == "inference_optimized":
18-
mamba_stack_spec = mamba_inference_stack_spec
19-
assert not config.inference_fuse_tp_communication, "inference_fuse_tp_communication is not supported for Mamba"
19+
mamba_stack_spec = mamba_inference_stack_spec
20+
assert (
21+
not config.inference_fuse_tp_communication
22+
), "inference_fuse_tp_communication is not supported for Mamba"
2023
elif args.spec is not None:
2124
mamba_stack_spec = import_module(args.spec)
2225
else:
@@ -39,6 +42,7 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, p
3942
rotary_percent=args.rotary_percent,
4043
rotary_base=args.rotary_base,
4144
pg_collection=pg_collection,
45+
vp_stage=vp_stage,
4246
)
4347

4448
for l in range(model.decoder.num_layers_per_pipeline_rank):

megatron/core/models/common/language_module/language_module.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from megatron.core.process_groups_config import ProcessGroupCollection
2424
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope
2525
from megatron.core.transformer.module import MegatronModule
26+
from megatron.core.transformer.multi_token_prediction import tie_word_embeddings_state_dict
2627
from megatron.core.transformer.transformer_config import TransformerConfig
2728
from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group
2829
from megatron.core.utils import (
@@ -255,12 +256,20 @@ def setup_embeddings_and_output_layer(self) -> None:
255256
LanguageModule.embedding_warning_printed = True
256257

257258
def shared_embedding_or_output_weight(self) -> Tensor:
258-
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True.
259+
"""Gets the embedding weight or output logit weights when share embedding and output weights set to True
260+
or when use Multi-Token Prediction (MTP).
259261
260262
Returns:
261-
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
263+
Tensor: During pre processing or MTP process it returns the input embeddings weight while during post processing it returns the final output layers weight
262264
"""
263-
if self.pre_process:
265+
if self.pre_process or getattr(self, 'mtp_process', False):
266+
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
267+
# So there will be both embedding layer and output layer in the mtp process stage.
268+
# When share_embeddings_and_output_weights is True, the embedding weight is the
269+
# canonical shared weight and is passed to the output layer during forward.
270+
assert hasattr(
271+
self, 'embedding'
272+
), f"embedding is needed in this pipeline stage, but it is not initialized."
264273
return self.embedding.word_embeddings.weight
265274
elif self.post_process:
266275
return self.output_layer.weight
@@ -293,6 +302,21 @@ def sharded_state_dict(
293302
output_layer_weight_key = f'{prefix}output_layer.weight'
294303
output_layer_bias_key = f'{prefix}output_layer.bias'
295304

305+
# Multi-Token Prediction (MTP) needs embedding layer in mtp process stage.
306+
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
307+
# embedding layer in the mtp process stage and tie it to the embedding in the pre
308+
# processing stage.
309+
# Note: MTP loss is computed at post_process stage, so the output_layer on mtp_process
310+
# rank doesn't need special tying - it's not used for loss computation.
311+
if getattr(self, 'mtp_process', False) and not self.pre_process:
312+
emb_weight = self.embedding.word_embeddings.weight
313+
tie_word_embeddings_state_dict(
314+
sharded_state_dict,
315+
emb_weight,
316+
first_stage_word_emb_key,
317+
tp_group=self.tp_group,
318+
dp_cp_group=metadata['dp_cp_group'],
319+
)
296320
if self.share_embeddings_and_output_weights:
297321
self.tie_embeddings_and_output_weights_state_dict(
298322
sharded_state_dict, output_layer_weight_key, first_stage_word_emb_key, metadata

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
123123
# get flags for latter use
124124
is_mtp = isinstance(self.layer, MultiTokenPredictionLayer)
125125
is_moe = (
126-
isinstance(self.layer.transformer_layer.mlp, MoELayer)
126+
isinstance(self.layer.mtp_model_layer.mlp, MoELayer)
127127
if is_mtp
128128
else isinstance(self.layer.mlp, MoELayer)
129129
)

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,9 @@ def build_mtp_layer_callables(layer):
613613
multi-token prediction layer nodes (attention, MLP, etc.)
614614
"""
615615

616-
forward_funcs, backward_dw = build_transformer_layer_callables(layer.transformer_layer)
616+
forward_funcs, backward_dw = build_transformer_layer_callables(layer.mtp_model_layer)
617617
attn_forward, dispatch_forward, mlp_forward, combine_forward, _ = forward_funcs
618-
is_moe = isinstance(layer.transformer_layer.mlp, MoELayer)
618+
is_moe = isinstance(layer.mtp_model_layer.mlp, MoELayer)
619619
assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now."
620620

621621
def submodule_mtp_attn_forward(node, hidden_states):

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def get_gpt_mtp_block_spec_for_backend(
704704
raise ValueError(f"Invalid spec: {spec}")
705705

706706
mtp_layer_spec = get_mtp_layer_spec_for_backend(
707-
transformer_layer_spec=transformer_layer_spec, backend=backend
707+
mtp_model_layer_spec=transformer_layer_spec, backend=backend
708708
)
709709
mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0
710710
mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers

megatron/core/models/gpt/gpt_model.py

Lines changed: 19 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch import Tensor
88

9-
from megatron.core import parallel_state, tensor_parallel
9+
from megatron.core import tensor_parallel
1010
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
1111
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
1212
from megatron.core.inference.contexts import BaseInferenceContext
@@ -26,11 +26,9 @@
2626
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
2727
from megatron.core.transformer.enums import CudaGraphScope, ModelType
2828
from megatron.core.transformer.multi_token_prediction import (
29-
MTPLossAutoScaler,
30-
MTPLossLoggingHelper,
3129
MultiTokenPredictionBlock,
32-
roll_tensor,
33-
tie_word_embeddings_state_dict,
30+
mtp_on_this_rank,
31+
process_mtp_loss,
3432
)
3533
from megatron.core.transformer.spec_utils import ModuleSpec
3634
from megatron.core.transformer.transformer_block import TransformerBlock
@@ -144,7 +142,9 @@ def __init__(
144142
self.rotary_base = rotary_base
145143
self.rotary_scaling = rope_scaling
146144
self.mtp_block_spec = mtp_block_spec
147-
self.mtp_process = mtp_block_spec is not None
145+
self.mtp_process = mtp_block_spec is not None and mtp_on_this_rank(
146+
self.config, ignore_virtual=False, vp_stage=vp_stage
147+
)
148148

149149
if self.pre_process or self.mtp_process:
150150
self.embedding = LanguageModelEmbedding(
@@ -609,56 +609,19 @@ def _postprocess(
609609
return hidden_states
610610

611611
if self.config.mtp_num_layers is not None:
612-
mtp_labels = labels.clone()
613-
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
614-
hidden_states = hidden_states_list[0]
615-
if loss_mask is None:
616-
# if loss_mask is not provided, use all ones as loss_mask
617-
loss_mask = torch.ones_like(mtp_labels)
618-
for mtp_layer_number in range(self.config.mtp_num_layers):
619-
# output
620-
mtp_logits, _ = self.output_layer(
621-
hidden_states_list[mtp_layer_number + 1],
622-
weight=output_weight,
623-
runtime_gather_output=runtime_gather_output,
624-
)
625-
# Calc loss for the current Multi-Token Prediction (MTP) layers.
626-
mtp_labels, _ = roll_tensor(
627-
mtp_labels,
628-
shifts=-1,
629-
dims=-1,
630-
cp_group=self.cp_group,
631-
packed_seq_params=packed_seq_params,
632-
)
633-
loss_mask, num_tokens = roll_tensor(
634-
loss_mask,
635-
shifts=-1,
636-
dims=-1,
637-
cp_group=self.cp_group,
638-
packed_seq_params=packed_seq_params,
639-
)
640-
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
641-
mtp_loss = loss_mask * mtp_loss
642-
if self.training:
643-
# TODO(shifangx): remove the use of parallel_state here
644-
# after moving loss logging to loss_func in pretrain_gpt.py
645-
MTPLossLoggingHelper.save_loss_to_tracker(
646-
torch.sum(mtp_loss) / num_tokens,
647-
mtp_layer_number,
648-
self.config.mtp_num_layers,
649-
avg_group=parallel_state.get_data_parallel_group(
650-
with_context_parallel=True
651-
),
652-
)
653-
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
654-
if self.config.calculate_per_token_loss:
655-
hidden_states = MTPLossAutoScaler.apply(
656-
hidden_states, mtp_loss_scale * mtp_loss
657-
)
658-
else:
659-
hidden_states = MTPLossAutoScaler.apply(
660-
hidden_states, mtp_loss_scale * mtp_loss / num_tokens
661-
)
612+
hidden_states = process_mtp_loss(
613+
hidden_states=hidden_states,
614+
labels=labels,
615+
loss_mask=loss_mask,
616+
output_layer=self.output_layer,
617+
output_weight=output_weight,
618+
runtime_gather_output=runtime_gather_output,
619+
is_training=self.training,
620+
compute_language_model_loss=self.compute_language_model_loss,
621+
config=self.config,
622+
cp_group=self.pg_collection.cp,
623+
packed_seq_params=packed_seq_params,
624+
)
662625
sequence_parallel_override = False
663626

664627
if in_inference_mode and inference_context.config.materialize_only_last_token_logits:
@@ -715,27 +678,6 @@ def _postprocess(
715678

716679
return loss
717680

718-
def shared_embedding_or_output_weight(self) -> Tensor:
719-
"""Gets the embedding weight or output logit weights when share input embedding and
720-
output weights set to True or when use Multi-Token Prediction (MTP) feature.
721-
722-
Returns:
723-
Tensor: During pre processing or MTP process it returns the input embeddings weight.
724-
Otherwise, during post processing it returns the final output layers weight.
725-
"""
726-
if self.pre_process or self.mtp_process:
727-
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
728-
# So there will be both embedding layer and output layer in the mtp process stage.
729-
# In this case, if share_embeddings_and_output_weights is True, the shared weights
730-
# will be stored in embedding layer, and output layer will not have any weight.
731-
assert hasattr(
732-
self, 'embedding'
733-
), f"embedding is needed in this pipeline stage, but it is not initialized."
734-
return self.embedding.word_embeddings.weight
735-
elif self.post_process:
736-
return self.output_layer.weight
737-
return None
738-
739681
def build_schedule_plan(
740682
self,
741683
input_ids: Tensor,
@@ -826,20 +768,4 @@ def sharded_state_dict(
826768
output_extra_state and output_extra_state.data
827769
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
828770

829-
# Multi-Token Prediction (MTP) need embedding layer in mtp process stage.
830-
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
831-
# embedding layer in the mtp process stage and tie it to the embedding in the pre
832-
# processing stage.
833-
# Now MTP loss is computed in post processing stage, so the output_layer is not needed.
834-
if self.mtp_process and not self.pre_process:
835-
emb_weight_key = f'{prefix}embedding.word_embeddings.weight'
836-
emb_weight = self.embedding.word_embeddings.weight
837-
tie_word_embeddings_state_dict(
838-
sharded_state_dict,
839-
emb_weight,
840-
emb_weight_key,
841-
tp_group=self.tp_group,
842-
dp_cp_group=metadata['dp_cp_group'],
843-
)
844-
845771
return sharded_state_dict

megatron/core/models/mamba/mamba_layer_specs.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22

33
from megatron.core.extensions.transformer_engine import (
4+
TEColumnParallelLinear,
45
TEDotProductAttention,
56
TELayerNormColumnParallelLinear,
67
TENorm,
@@ -19,20 +20,49 @@
1920
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
2021
from megatron.core.transformer.enums import AttnMaskType
2122
from megatron.core.transformer.mlp import MLP, MLPSubmodules
23+
from megatron.core.transformer.multi_token_prediction import (
24+
MultiTokenPredictionBlock,
25+
MultiTokenPredictionBlockSubmodules,
26+
MultiTokenPredictionLayer,
27+
MultiTokenPredictionLayerSubmodules,
28+
)
2229
from megatron.core.transformer.spec_utils import ModuleSpec
2330
from megatron.core.transformer.transformer_layer import (
2431
MoETransformerLayer,
2532
TransformerLayer,
2633
TransformerLayerSubmodules,
2734
)
2835

36+
# This should be private and should not be used outside of this file.
2937
moe = get_moe_module_spec(
3038
use_te=True,
3139
num_experts=8, # Can be any positive integer (must not be None).
3240
moe_grouped_gemm=True,
3341
moe_use_legacy_grouped_gemm=False,
3442
)
3543

44+
45+
# MTP block spec for Mamba - provides norms and projection only.
46+
# Inner layers are built by MultiTokenPredictionLayer using nested MambaStack
47+
_mamba_mtp_block_spec = ModuleSpec(
48+
module=MultiTokenPredictionBlock,
49+
submodules=MultiTokenPredictionBlockSubmodules(
50+
layer_specs=[
51+
ModuleSpec(
52+
module=MultiTokenPredictionLayer,
53+
submodules=MultiTokenPredictionLayerSubmodules(
54+
enorm=TENorm,
55+
hnorm=TENorm,
56+
eh_proj=TEColumnParallelLinear,
57+
mtp_model_layer=None, # Built via pattern + mamba_submodules
58+
layer_norm=TENorm,
59+
),
60+
)
61+
]
62+
),
63+
)
64+
65+
3666
mamba_stack_spec = ModuleSpec(
3767
module=MambaStack,
3868
submodules=MambaStackSubmodules(
@@ -87,9 +117,11 @@
87117
pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add
88118
),
89119
),
120+
mtp_block_spec=_mamba_mtp_block_spec,
90121
),
91122
)
92123

124+
93125
mamba_inference_stack_spec = ModuleSpec(
94126
module=MambaStack,
95127
submodules=MambaStackSubmodules(
@@ -147,5 +179,6 @@
147179
pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add
148180
),
149181
),
182+
mtp_block_spec=_mamba_mtp_block_spec,
150183
),
151184
)

0 commit comments

Comments
 (0)