Skip to content

Commit 47bf6bc

Browse files
committed
sarvam mla support
1 parent 9f257cd commit 47bf6bc

File tree

6 files changed

+320
-55
lines changed

6 files changed

+320
-55
lines changed

src/megatron/bridge/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@
196196
from megatron.bridge.models.t5_provider import T5ModelProvider
197197
from megatron.bridge.models.sarvam import (
198198
SarvamMoEBridge,
199+
SarvamMLABridge,
199200
)
200201

201202

@@ -349,4 +350,5 @@
349350
"NemotronNano12Bv2Provider",
350351
"NemotronNano12Bv2VLModelProvider",
351352
"SarvamMoEBridge",
353+
"SarvamMLABridge",
352354
]

src/megatron/bridge/models/sarvam/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
from megatron.bridge.models.sarvam.sarvam_moe_bridge import SarvamMoEBridge
16+
from megatron.bridge.models.sarvam.sarvam_mla_bridge import SarvamMLABridge
1617

1718

1819
__all__ = [
1920
"SarvamMoEBridge",
21+
"SarvamMLABridge",
2022
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
16+
17+
18+
def get_common_config(hf_pretrained: PreTrainedCausalLM) -> dict:
19+
"""
20+
Returns a dictionary of common configurations for the Sarvam family of models.
21+
"""
22+
hf_config = hf_pretrained.config
23+
24+
config = {}
25+
26+
config["num_layers"] = hf_config.num_hidden_layers
27+
config["hidden_size"] = hf_config.hidden_size
28+
config["ffn_hidden_size"] = hf_config.intermediate_size
29+
config["moe_ffn_hidden_size"] = hf_config.moe_intermediate_size
30+
config["num_attention_heads"] = hf_config.num_attention_heads
31+
config["num_moe_experts"] = hf_config.num_experts
32+
config["moe_router_topk"] = hf_config.num_experts_per_tok
33+
config["moe_shared_expert_intermediate_size"] = (
34+
hf_config.num_shared_experts * hf_config.moe_intermediate_size
35+
)
36+
config["moe_layer_freq"] = [0] * hf_config.first_k_dense_replace + [1] * (
37+
hf_config.num_hidden_layers - hf_config.first_k_dense_replace
38+
)
39+
config["vocab_size"] = hf_config.vocab_size
40+
config["seq_length"] = hf_config.max_position_embeddings
41+
config["generation_config"] = getattr(hf_pretrained, "generation_config", None)
42+
config["rotary_base"] = hf_config.rope_theta
43+
44+
return config
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
17+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
18+
from megatron.bridge.models.conversion.param_mapping import (
19+
AutoMapping,
20+
GatedMLPMapping,
21+
)
22+
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
23+
from megatron.bridge.models.sarvam.common import get_common_config
24+
from megatron.bridge.models.sarvam.sarvam_provider import SarvamMLAModelProvider
25+
from megatron.core.models.gpt.gpt_model import GPTModel
26+
27+
28+
@MegatronModelBridge.register_bridge(source="SarvamMLAForCausalLM", target=GPTModel)
29+
class SarvamMLABridge(MegatronModelBridge):
30+
"""
31+
Megatron Hub Bridge for Sarvam MLA Causal LM.
32+
33+
This bridge handles the conversion between HuggingFace SarvamMLAForCausalLM
34+
and Megatron-Core GPTModel formats. Sarvam MLA models use multi-latent attention
35+
architecture.
36+
"""
37+
38+
def provider_bridge(
39+
self, hf_pretrained: PreTrainedCausalLM
40+
) -> SarvamMLAModelProvider:
41+
hf_config = hf_pretrained.config
42+
config = get_common_config(hf_pretrained)
43+
44+
config["fp16"] = (
45+
self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16
46+
)
47+
config["bf16"] = (
48+
self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16
49+
)
50+
config["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32)
51+
config["kv_channels"] = hf_config.hidden_size // hf_config.num_attention_heads
52+
53+
# MLA
54+
config["kv_lora_rank"] = hf_config.kv_lora_rank
55+
config["qk_head_dim"] = hf_config.qk_nope_head_dim
56+
config["qk_pos_emb_head_dim"] = hf_config.qk_rope_head_dim
57+
config["v_head_dim"] = hf_config.v_head_dim
58+
59+
if hasattr(hf_config, "rope_scaling") and hf_config.rope_scaling is not None:
60+
config["rotary_scaling_factor"] = hf_config.rope_scaling["factor"]
61+
config["mscale"] = hf_config.rope_scaling["mscale"]
62+
config["mscale_all_dim"] = hf_config.rope_scaling["mscale_all_dim"]
63+
else:
64+
config["rotary_scaling_factor"] = 1.0
65+
config["mscale"] = 1.0
66+
config["mscale_all_dim"] = 1.0
67+
68+
return SarvamMLAModelProvider(**config)
69+
70+
def mapping_registry(self) -> MegatronMappingRegistry:
71+
72+
param_mappings = {
73+
# Embed
74+
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
75+
76+
# Attention
77+
"decoder.layers.*.input_layernorm.weight": "model.layers.*.input_layernorm.weight",
78+
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
79+
# In sarvam, HF weight `model.layers.*.post_attention_layernorm.weight` is mapped to the following mcore weights depending on the layer type:
80+
# (a) `decoder.layers.*.pre_mlp_layernorm.weight`, if the layer is MoE
81+
# (b) `decoder.layers.*.mlp.linear_fc1.layer_norm_weight`, if the layer is dense
82+
"decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight",
83+
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
84+
"decoder.layers.*.self_attention.linear_q_proj.weight": "model.layers.*.self_attn.q_proj.weight",
85+
"decoder.layers.*.self_attention.linear_kv_down_proj.weight": "model.layers.*.self_attn.kv_a_proj_with_mqa.weight",
86+
"decoder.layers.*.self_attention.linear_kv_up_proj.weight": "model.layers.*.self_attn.kv_b_proj.weight",
87+
"decoder.layers.*.self_attention.linear_kv_up_proj.layer_norm_weight": "model.layers.*.self_attn.kv_a_layernorm.weight",
88+
# Mcore local spec
89+
"decoder.layers.*.self_attention.kv_layernorm.weight": "model.layers.*.self_attn.kv_a_layernorm.weight",
90+
91+
# Dense MLP
92+
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
93+
94+
# Moe
95+
"decoder.layers.*.mlp.experts.linear_fc2.weight*": "model.layers.*.mlp.experts.*.down_proj.weight",
96+
"decoder.layers.*.mlp.shared_experts.linear_fc2.weight": "model.layers.*.mlp.shared_experts.down_proj.weight",
97+
"decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias",
98+
"decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight",
99+
100+
# LM Head
101+
"decoder.final_layernorm.weight": "model.norm.weight",
102+
"output_layer.weight": "lm_head.weight",
103+
}
104+
105+
mapping_list = []
106+
for megatron_param, hf_param in param_mappings.items():
107+
mapping_list.append(
108+
AutoMapping(hf_param=hf_param, megatron_param=megatron_param)
109+
)
110+
111+
mapping_list.extend(
112+
[
113+
GatedMLPMapping(
114+
megatron_param="decoder.layers.*.mlp.linear_fc1.weight",
115+
gate="model.layers.*.mlp.gate_proj.weight",
116+
up="model.layers.*.mlp.up_proj.weight",
117+
),
118+
GatedMLPMapping(
119+
megatron_param="decoder.layers.*.mlp.experts.linear_fc1.weight*",
120+
gate="model.layers.*.mlp.experts.*.gate_proj.weight",
121+
up="model.layers.*.mlp.experts.*.up_proj.weight",
122+
),
123+
GatedMLPMapping(
124+
megatron_param="decoder.layers.*.mlp.shared_experts.linear_fc1.weight",
125+
gate="model.layers.*.mlp.shared_experts.gate_proj.weight",
126+
up="model.layers.*.mlp.shared_experts.up_proj.weight",
127+
),
128+
]
129+
)
130+
131+
return MegatronMappingRegistry(*mapping_list)

src/megatron/bridge/models/sarvam/sarvam_moe_bridge.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
# limitations under the License.
1414

1515
import torch
16-
from megatron.core.models.gpt.gpt_model import GPTModel
17-
1816
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
1917
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
2018
from megatron.bridge.models.conversion.param_mapping import (
2119
AutoMapping,
22-
GatedMLPMapping,
2320
ConcatenatedQKVMapping,
21+
GatedMLPMapping,
2422
)
2523
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
24+
from megatron.bridge.models.sarvam.common import get_common_config
2625
from megatron.bridge.models.sarvam.sarvam_provider import SarvamMoEModelProvider
26+
from megatron.core.models.gpt.gpt_model import GPTModel
2727

2828

2929
@MegatronModelBridge.register_bridge(source="SarvamMoEForCausalLM", target=GPTModel)
@@ -36,59 +36,51 @@ class SarvamMoEBridge(MegatronModelBridge):
3636
architecture with QKV layernorm.
3737
"""
3838

39-
def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> SarvamMoEModelProvider:
39+
def provider_bridge(
40+
self, hf_pretrained: PreTrainedCausalLM
41+
) -> SarvamMoEModelProvider:
4042
hf_config = hf_pretrained.config
43+
config = get_common_config(hf_pretrained)
4144

42-
provider = SarvamMoEModelProvider(
43-
num_layers=hf_config.num_hidden_layers,
44-
hidden_size=hf_config.hidden_size,
45-
ffn_hidden_size=hf_config.intermediate_size,
46-
moe_ffn_hidden_size=hf_config.moe_intermediate_size, # Maps to moe_intermediate_size in HF
47-
num_attention_heads=hf_config.num_attention_heads,
48-
kv_channels=hf_config.head_dim,
49-
num_query_groups=hf_config.num_key_value_heads,
50-
num_moe_experts=hf_config.num_experts,
51-
moe_router_topk=hf_config.num_experts_per_tok, # Maps to num_experts_per_tok in HF
52-
moe_shared_expert_intermediate_size=hf_config.num_shared_experts * hf_config.moe_intermediate_size,
53-
moe_router_enable_expert_bias=hf_config.moe_router_enable_expert_bias,
54-
moe_layer_freq=[0] * hf_config.first_k_dense_replace + [1] * (hf_config.num_hidden_layers - hf_config.first_k_dense_replace),
55-
vocab_size=hf_config.vocab_size,
56-
seq_length=hf_config.max_position_embeddings,
57-
generation_config=hf_pretrained.generation_config,
58-
rotary_base=hf_config.rope_theta,
59-
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
60-
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
61-
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
45+
config["fp16"] = (
46+
self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16
47+
)
48+
config["bf16"] = (
49+
self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16
6250
)
51+
config["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32)
6352

64-
return provider
53+
# GQA
54+
config["num_query_groups"] = hf_config.num_key_value_heads
55+
config["kv_channels"] = hf_config.head_dim
56+
57+
return SarvamMoEModelProvider(**config)
6558

6659
def mapping_registry(self) -> MegatronMappingRegistry:
67-
# Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format
68-
# First create simple 1:1 parameter mappings using a dictionary for readability
6960

70-
# Dictionary maps Megatron parameter names -> HF parameter names
7161
param_mappings = {
7262
# Embed
7363
"embedding.word_embeddings.weight": "model.word_embeddings.weight",
7464

7565
# Attention
7666
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
77-
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
67+
# In sarvam, HF weight `model.layers.*.post_attention_layernorm.weight` is mapped to the following mcore weights depending on the layer type:
68+
# (a) `decoder.layers.*.pre_mlp_layernorm.weight`, if the layer is MoE
69+
# (b) `decoder.layers.*.mlp.linear_fc1.layer_norm_weight`, if the layer is dense
7870
"decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight",
79-
"decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.expert_bias",
80-
"decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight",
81-
71+
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
8272
"decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.attention.query_layernorm.weight",
8373
"decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.attention.key_layernorm.weight",
8474
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.attention.dense.weight",
85-
75+
8676
# Dense MLP
8777
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
78+
79+
# MoE
80+
"decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.expert_bias",
81+
"decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight",
8882
"decoder.layers.*.mlp.experts.linear_fc2.weight*": "model.layers.*.mlp.experts.*.down_proj.weight",
8983
"decoder.layers.*.mlp.shared_experts.linear_fc2.weight": "model.layers.*.mlp.shared_experts.down_proj.weight",
90-
91-
"final_layernorm.weight": "final_layernorm.weight",
9284

9385
# LM Head
9486
"decoder.final_layernorm.weight": "model.norm.weight",
@@ -97,7 +89,9 @@ def mapping_registry(self) -> MegatronMappingRegistry:
9789

9890
mapping_list = []
9991
for megatron_param, hf_param in param_mappings.items():
100-
mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param))
92+
mapping_list.append(
93+
AutoMapping(hf_param=hf_param, megatron_param=megatron_param)
94+
)
10195

10296
mapping_list.extend(
10397
[
@@ -120,7 +114,6 @@ def mapping_registry(self) -> MegatronMappingRegistry:
120114
gate="model.layers.*.mlp.shared_experts.gate_proj.weight",
121115
up="model.layers.*.mlp.shared_experts.up_proj.weight",
122116
),
123-
124117
]
125118
)
126119

0 commit comments

Comments
 (0)