Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion checkpoint_engine/pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def _pin(t: torch.Tensor):
torch.cuda.set_device(device_index)
cudart = torch.cuda.cudart()
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
assert r == 0, f"pin memory error, error code: {r}"
if r != 0:
error_msg = cudart.cudaGetErrorString(r)
raise RuntimeError(f"pin memory error, error code: {r}, error message: {error_msg}")

# TODO: should only support /dev/shm? but we found files in disk also work?
size = os.stat(file_path).st_size
Expand Down Expand Up @@ -254,6 +256,12 @@ def _pin(t: torch.Tensor):
# Remove the file after successfully loading. This will avoid doubling the memory usage.
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
os.remove(file_path)
if not metas:
# TODO: should we still return this buffer?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah skipping this buffer seems to be a better idea, to prevent future bugs. (not a hard requirement

assert buffer.nbytes == 0, f"buffer nbytes {buffer.nbytes} should be 0"
logger.warning(f"[rank{rank}] no metas found in {file_path}, skip pin memory")
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=[], manually_pinned=False)

_pin(buffer)
logger.info(
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
Expand Down
6 changes: 5 additions & 1 deletion checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,11 @@ def _unpin(t: torch.Tensor):
)
cudart = torch.cuda.cudart()
r = cudart.cudaHostUnregister(t.data_ptr())
assert r == 0, f"unpin memory error, error code: {r}"
if r != 0:
error_msg = cudart.cudaGetErrorString(r)
raise RuntimeError(
f"unpin memory error, error code: {r}, error message: {error_msg}"
)

# if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
try:
Expand Down