1313# limitations under the License.
1414
1515import torch
16- from megatron .core .models .gpt .gpt_model import GPTModel
17-
1816from megatron .bridge .models .conversion .mapping_registry import MegatronMappingRegistry
1917from megatron .bridge .models .conversion .model_bridge import MegatronModelBridge
2018from megatron .bridge .models .conversion .param_mapping import (
2119 AutoMapping ,
22- GatedMLPMapping ,
2320 ConcatenatedQKVMapping ,
21+ GatedMLPMapping ,
2422)
2523from megatron .bridge .models .hf_pretrained .causal_lm import PreTrainedCausalLM
24+ from megatron .bridge .models .sarvam .common import get_common_config
2625from megatron .bridge .models .sarvam .sarvam_provider import SarvamMoEModelProvider
26+ from megatron .core .models .gpt .gpt_model import GPTModel
2727
2828
2929@MegatronModelBridge .register_bridge (source = "SarvamMoEForCausalLM" , target = GPTModel )
@@ -36,59 +36,51 @@ class SarvamMoEBridge(MegatronModelBridge):
3636 architecture with QKV layernorm.
3737 """
3838
39- def provider_bridge (self , hf_pretrained : PreTrainedCausalLM ) -> SarvamMoEModelProvider :
39+ def provider_bridge (
40+ self , hf_pretrained : PreTrainedCausalLM
41+ ) -> SarvamMoEModelProvider :
4042 hf_config = hf_pretrained .config
43+ config = get_common_config (hf_pretrained )
4144
42- provider = SarvamMoEModelProvider (
43- num_layers = hf_config .num_hidden_layers ,
44- hidden_size = hf_config .hidden_size ,
45- ffn_hidden_size = hf_config .intermediate_size ,
46- moe_ffn_hidden_size = hf_config .moe_intermediate_size , # Maps to moe_intermediate_size in HF
47- num_attention_heads = hf_config .num_attention_heads ,
48- kv_channels = hf_config .head_dim ,
49- num_query_groups = hf_config .num_key_value_heads ,
50- num_moe_experts = hf_config .num_experts ,
51- moe_router_topk = hf_config .num_experts_per_tok , # Maps to num_experts_per_tok in HF
52- moe_shared_expert_intermediate_size = hf_config .num_shared_experts * hf_config .moe_intermediate_size ,
53- moe_router_enable_expert_bias = hf_config .moe_router_enable_expert_bias ,
54- moe_layer_freq = [0 ] * hf_config .first_k_dense_replace + [1 ] * (hf_config .num_hidden_layers - hf_config .first_k_dense_replace ),
55- vocab_size = hf_config .vocab_size ,
56- seq_length = hf_config .max_position_embeddings ,
57- generation_config = hf_pretrained .generation_config ,
58- rotary_base = hf_config .rope_theta ,
59- fp16 = (self .dtype_from_hf (hf_config , default = torch .float32 ) == torch .float16 ),
60- bf16 = (self .dtype_from_hf (hf_config , default = torch .float32 ) == torch .bfloat16 ),
61- params_dtype = self .dtype_from_hf (hf_config , default = torch .float32 ),
45+ config ["fp16" ] = (
46+ self .dtype_from_hf (hf_config , default = torch .float32 ) == torch .float16
47+ )
48+ config ["bf16" ] = (
49+ self .dtype_from_hf (hf_config , default = torch .float32 ) == torch .bfloat16
6250 )
51+ config ["params_dtype" ] = self .dtype_from_hf (hf_config , default = torch .float32 )
6352
64- return provider
53+ # GQA
54+ config ["num_query_groups" ] = hf_config .num_key_value_heads
55+ config ["kv_channels" ] = hf_config .head_dim
56+
57+ return SarvamMoEModelProvider (** config )
6558
6659 def mapping_registry (self ) -> MegatronMappingRegistry :
67- # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format
68- # First create simple 1:1 parameter mappings using a dictionary for readability
6960
70- # Dictionary maps Megatron parameter names -> HF parameter names
7161 param_mappings = {
7262 # Embed
7363 "embedding.word_embeddings.weight" : "model.word_embeddings.weight" ,
7464
7565 # Attention
7666 "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight" : "model.layers.*.input_layernorm.weight" ,
77- "decoder.layers.*.mlp.linear_fc1.layer_norm_weight" : "model.layers.*.post_attention_layernorm.weight" ,
67+ # In sarvam, HF weight `model.layers.*.post_attention_layernorm.weight` is mapped to the following mcore weights depending on the layer type:
68+ # (a) `decoder.layers.*.pre_mlp_layernorm.weight`, if the layer is MoE
69+ # (b) `decoder.layers.*.mlp.linear_fc1.layer_norm_weight`, if the layer is dense
7870 "decoder.layers.*.pre_mlp_layernorm.weight" : "model.layers.*.post_attention_layernorm.weight" ,
79- "decoder.layers.*.mlp.router.expert_bias" : "model.layers.*.mlp.gate.expert_bias" ,
80- "decoder.layers.*.mlp.router.weight" : "model.layers.*.mlp.gate.weight" ,
81-
71+ "decoder.layers.*.mlp.linear_fc1.layer_norm_weight" : "model.layers.*.post_attention_layernorm.weight" ,
8272 "decoder.layers.*.self_attention.q_layernorm.weight" : "model.layers.*.attention.query_layernorm.weight" ,
8373 "decoder.layers.*.self_attention.k_layernorm.weight" : "model.layers.*.attention.key_layernorm.weight" ,
8474 "decoder.layers.*.self_attention.linear_proj.weight" : "model.layers.*.attention.dense.weight" ,
85-
75+
8676 # Dense MLP
8777 "decoder.layers.*.mlp.linear_fc2.weight" : "model.layers.*.mlp.down_proj.weight" ,
78+
79+ # MoE
80+ "decoder.layers.*.mlp.router.expert_bias" : "model.layers.*.mlp.gate.expert_bias" ,
81+ "decoder.layers.*.mlp.router.weight" : "model.layers.*.mlp.gate.weight" ,
8882 "decoder.layers.*.mlp.experts.linear_fc2.weight*" : "model.layers.*.mlp.experts.*.down_proj.weight" ,
8983 "decoder.layers.*.mlp.shared_experts.linear_fc2.weight" : "model.layers.*.mlp.shared_experts.down_proj.weight" ,
90-
91- "final_layernorm.weight" : "final_layernorm.weight" ,
9284
9385 # LM Head
9486 "decoder.final_layernorm.weight" : "model.norm.weight" ,
@@ -97,7 +89,9 @@ def mapping_registry(self) -> MegatronMappingRegistry:
9789
9890 mapping_list = []
9991 for megatron_param , hf_param in param_mappings .items ():
100- mapping_list .append (AutoMapping (hf_param = hf_param , megatron_param = megatron_param ))
92+ mapping_list .append (
93+ AutoMapping (hf_param = hf_param , megatron_param = megatron_param )
94+ )
10195
10296 mapping_list .extend (
10397 [
@@ -120,7 +114,6 @@ def mapping_registry(self) -> MegatronMappingRegistry:
120114 gate = "model.layers.*.mlp.shared_experts.gate_proj.weight" ,
121115 up = "model.layers.*.mlp.shared_experts.up_proj.weight" ,
122116 ),
123-
124117 ]
125118 )
126119
0 commit comments