66
77nemotron_h_aligned_custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ '<SPECIAL_10>System\n ' + message['content'].strip() + '\n ' }}{% elif message['role'] == 'user' %}{{ '<SPECIAL_11>User\n ' + message['content'].strip() + '\n ' + '<SPECIAL_11>Assistant\n ' }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() + '\n ' }}{% endif %}{% endfor %}"""
88nemotron_nano_v2_custom_template = """{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'system' %}{{ '<SPECIAL_10>System\n ' + content.replace('/think', '').replace('/no_think', '').strip() + '\n ' }}{% elif message['role'] == 'user' %}{{ '<SPECIAL_11>User\n ' + content.replace('/think', '').replace('/no_think', '').strip() + '\n ' }}{% elif message['role'] == 'assistant' %}{{ '<SPECIAL_11>Assistant\n ' + content.strip() + '\n <SPECIAL_12>\n ' }}{% endif %}{% endfor %}"""
9+ identity_template = """{% for message in messages %}{{ message['content'] }}{% endfor %}"""
910
1011from megatron .core .datasets .megatron_tokenizer import MegatronLegacyTokenizer
1112from megatron .training .datasets .sft_dataset import IGNORE_INDEX
@@ -59,6 +60,14 @@ def __init__(
5960 has_system_role = True ,
6061 )
6162 elif prompt_format == "identity" :
63+ self ._prompt_config = PromptConfig (
64+ assistant_prefix_len = 0 ,
65+ pad_token_id = tokenizer .convert_tokens_to_ids ("<unk>" ),
66+ custom_chat_template = dentity_template ,
67+ has_bos = False ,
68+ has_system_role = True ,
69+ )
70+ elif prompt_format == "default" :
6271 self ._prompt_config = PromptConfig (
6372 assistant_prefix_len = 0 ,
6473 pad_token_id = tokenizer .pad_token_id if tokenizer .pad_token_id is not None else tokenizer .eos_token_id ,
@@ -106,7 +115,9 @@ def tokenize_conversation(
106115
107116 target = tokens .copy ()
108117
109- if self ._prompt_format == "identity" :
118+ # When using the default prompt format, we do not replace any tokens with IGNORE_INDEX.
119+ # Instead, all token losses will be used for simplicity.
120+ if self ._prompt_format == "default" :
110121 return tokens , target
111122
112123 # Mask system and user tokens in the target.
@@ -116,7 +127,7 @@ def tokenize_conversation(
116127 if turn ["role" ].lower () == "assistant" and len (turn ["content" ]) == 0 :
117128 raise ValueError (f"empty assistant turn in conversation: { conversation } ." )
118129 if turn ["role" ].lower () == "assistant" :
119- assert conversation [turn_idx - 1 ]["role" ].lower () == "user"
130+ assert conversation [turn_idx - 1 ]["role" ].lower () in ( "user" , "tool" )
120131
121132 turn_tokens = self ._tokenizer .apply_chat_template (
122133 [turn ], tokenize = True , chat_template = self ._prompt_config .custom_chat_template
0 commit comments