Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 109 additions & 9 deletions megatron/core/dist_checkpointing/strategies/filesystem_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import pickle
import queue
import time as time_module
from functools import partial
from heapq import heappop, heappush
from itertools import chain
Expand Down Expand Up @@ -346,6 +347,8 @@ def write_preloaded_data(
results_queue: mp.SimpleQueue,
count_queue: mp.JoinableQueue,
use_fsync: bool,
max_item_retries: int = 3,
item_retry_delay: float = 10.0,
**kwargs,
) -> None:
"""
Expand All @@ -367,6 +370,84 @@ def write_preloaded_data(
use_msc = kwargs.get("use_msc", False)

local_results = []
local_output = None

def write_item_with_retry(
transform_list,
stream,
data,
write_item,
storage_key,
use_fsync,
use_msc,
max_item_retries,
item_retry_delay,
**extra_kwargs
):
"""
Wraps _write_item with retry logic

Args:
transform_list: List of storage writer transforms
stream: File stream to write to
data: Data to write (bytes or tensor)
write_item: WriteItem containing metadata
storage_key: Storage key for the item
use_fsync: Whether to call fsync after writing
use_msc: Whether using multistorageclient
max_item_retries: Maximum number of retry attempts for this item
item_retry_delay: Delay in seconds between retries
**extra_kwargs: Additional arguments for _write_item

Returns:
WriteResult from _write_item

Raises:
Exception: Re-raises the last exception if all retries fail
"""
last_exception = None
for attempt in range(max_item_retries):
try:
result = _write_item(
*transform_list, stream, data, write_item, storage_key, **extra_kwargs
)

# Perform fsync if requested and write was successful
if use_fsync:
try:
if use_msc:
stream.fsync()
else:
os.fsync(stream.fileno())
except Exception as fsync_err:
logger.warning(
f"fsync failed for item {write_item.index}: {type(fsync_err).__name__}: {str(fsync_err)}"
)
# Continue despite fsync failure, but log it

return result

except Exception as e:
last_exception = e
is_last_attempt = (attempt == max_item_retries - 1)

if is_last_attempt:
logger.error(
f"Failed to write item {write_item.index} after {max_item_retries} attempts. "
f"Last error: {type(e).__name__}: {str(e)}"
)
raise
else:
logger.warning(
f"Write item {write_item.index} failed on attempt {attempt + 1}/{max_item_retries}. "
f"Error: {type(e).__name__}: {str(e)}. Retrying in {item_retry_delay}s..."
)
time_module.sleep(item_retry_delay)

# Should not reach here, but just in case
if last_exception:
raise last_exception

try:
file_name, storage_key, (bytes_data, tensor_data) = write_bucket
extra_kwargs = {}
Expand All @@ -380,30 +461,49 @@ def write_preloaded_data(
open_file = msc.open
else:
open_file = open

# Reset results for each retry attempt
local_results = []

with open_file(file_name, "wb") as stream:
for write_item, data in bytes_data:
local_results.append(
_write_item(
*transform_list, stream, data, write_item, storage_key, **extra_kwargs
write_item_with_retry(
transform_list, stream, data, write_item, storage_key,
use_fsync, use_msc, max_item_retries, item_retry_delay, **extra_kwargs
)
)

for write_item, tensor in tensor_data:
assert tensor.is_cpu
local_results.append(
_write_item(
*transform_list, stream, tensor, write_item, storage_key, **extra_kwargs
write_item_with_retry(
transform_list, stream, tensor, write_item, storage_key,
use_fsync, use_msc, max_item_retries, item_retry_delay, **extra_kwargs
)
)

# Note: fsync is now handled inside write_item_with_retry for each item
# but we can still do a final fsync here if needed
if use_fsync:
if use_msc:
stream.fsync()
else:
os.fsync(stream.fileno())
try:
if use_msc:
stream.fsync()
else:
os.fsync(stream.fileno())
except Exception as fsync_err:
logger.warning(
f"fsync failed for file {file_name}: {type(fsync_err).__name__}: {str(fsync_err)}"
)
# Continue despite fsync failure, but log it

local_output = (local_proc_idx, local_results)
logger.debug(f"{local_proc_idx} completed successfully")

except Exception as e:
logger.debug(f"{local_proc_idx} failed")
logger.error(
f"{local_proc_idx} failed with {type(e).__name__}: {str(e)}"
)
local_output = (local_proc_idx, e) # type: ignore[assignment]

results_queue.put(local_output)
Expand Down
Loading