Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,12 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors):
continue

if k == "param":
if self.config.store_param_remainders and self.config.bf16:
v = v.to(torch.int16)
self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
else:
if v.dtype != torch.float32:
v = v.to(torch.float32)
self.optimizer.set_scaled_state(sharded_model_param, k, v)
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
Expand Down Expand Up @@ -963,6 +967,16 @@ def get_parameter_state_dp_reshardable(self):
state[gbuf_idx] = dtype_state
return state

def _get_dtype_by_key(self, key):
if key == "param":
return torch.float32
elif key == "exp_avg":
return self.config.exp_avg_dtype
elif key == "exp_avg_sq":
return self.config.exp_avg_sq_dtype
else:
raise ValueError(f"Invalid key: {key}")

def get_parameter_state_dp_zero(
self,
use_gloo_comm: bool = True,
Expand Down Expand Up @@ -1019,7 +1033,7 @@ def get_parameter_state_dp_zero(
if data_parallel_rank == 0 or return_on_all_ranks:
world_tensors = {
key: torch.zeros(
(buffer_numel_unpadded,), dtype=torch.float32, device="cpu"
(buffer_numel_unpadded,), dtype=self._get_dtype_by_key(key), device="cpu"
)
for key in ("param", "exp_avg", "exp_avg_sq")
}
Expand All @@ -1042,7 +1056,7 @@ def get_parameter_state_dp_zero(
assert gbuf_world_numel_unpadded <= gbuf_world_numel

local_shards = {
key: torch.zeros((gbuf_local_numel,), dtype=torch.float32, device="cpu")
key: torch.zeros((gbuf_local_numel,), dtype=self._get_dtype_by_key(key), device="cpu")
for key in ("param", "exp_avg", "exp_avg_sq")
}

Expand All @@ -1066,7 +1080,7 @@ def get_parameter_state_dp_zero(
device = "cpu" if use_gloo_comm else torch.cuda.current_device()
recv_tensors = [
torch.zeros(
(gbuf_local_numel,), dtype=torch.float32, device=device
(gbuf_local_numel,), dtype=self._get_dtype_by_key(key), device=device
)
for _ in range(data_parallel_world_size)
]
Expand Down Expand Up @@ -1885,7 +1899,7 @@ def load_parameter_state_from_dp_zero_legacy(self, state_dict):

# Contiguous local shards (received from DP rank 0).
recv_tensor = torch.zeros(
(gbuf_local_numel,), dtype=torch.float32, device="cpu"
(gbuf_local_numel,), dtype=self._get_dtype_by_key(key), device="cpu"
)

# Scatter tensor list.
Expand Down Expand Up @@ -1999,7 +2013,7 @@ def load_parameter_state_from_dp_zero(self, state_dict, *, update_legacy_format=

# Contiguous local shards (received from DP rank 0).
recv_tensor = torch.zeros(
(gbuf_local_numel,), dtype=torch.float32, device="cpu"
(gbuf_local_numel,), dtype=self._get_dtype_by_key(key), device="cpu"
)

# Scatter tensor list.
Expand Down