Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ dynamic = ["version"]
readme = "README.md"
dependencies = [
"torch>=2.7",
"opentelemetry-exporter-otlp-proto-http>=1.37.0",
"opentelemetry-sdk>=1.37.0",
"opentelemetry-api>=1.37.0",
"opentelemetry-exporter-otlp-proto-http>=1.39.0",
"opentelemetry-sdk>=1.39.0",
"opentelemetry-api>=1.39.0",
]

[project.urls]
Expand Down
4 changes: 4 additions & 0 deletions torchft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torchft.optim import OptimizerWrapper as Optimizer
from torchft.otel import setup_logger
from torchft.process_group import (
ProcessGroupAccelerator,
ProcessGroupBabyAccelerator,
ProcessGroupBabyNCCL,
ProcessGroupBabyXCCL,
ProcessGroupGloo,
Expand All @@ -26,6 +28,8 @@
"DistributedSampler",
"Manager",
"Optimizer",
"ProcessGroupAccelerator",
"ProcessGroupBabyAccelerator",
"ProcessGroupNCCL",
"ProcessGroupXCCL",
"ProcessGroupBabyNCCL",
Expand Down
6 changes: 3 additions & 3 deletions torchft/_test/diloco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchft.manager_integ_test import MyModel, Runner
from torchft.process_group import (
FakeProcessGroupWrapper,
ProcessGroupBabyNCCL,
ProcessGroupBabyAccelerator,
ProcessGroupGloo,
)

Expand Down Expand Up @@ -151,8 +151,8 @@ def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]:
return outer_optimizers

def setup_pg(self) -> FakeProcessGroupWrapper:
if self.device.type == "cuda":
return FakeProcessGroupWrapper(ProcessGroupBabyNCCL())
if self.device.type == torch.accelerator.current_accelerator().type:
return FakeProcessGroupWrapper(ProcessGroupBabyAccelerator())
else:
return FakeProcessGroupWrapper(
ProcessGroupGloo(timeout=timedelta(seconds=10))
Expand Down
34 changes: 18 additions & 16 deletions torchft/_test/managed_work_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from torchft.manager import _ManagedWork, Manager


DEVICE = torch.accelerator.current_accelerator()

class SimpleWork(Work):
"""A simple implementation of torch.distributed.Work for testing."""

Expand All @@ -42,16 +44,16 @@ class TestManagedWork(unittest.TestCase):
@parameterized.parameterized.expand(
[
("cpu", torch.device("cpu")),
("cuda", torch.device("cuda:0")),
(DEVICE.type, DEVICE),
]
)
def test_callbacks_execute_after_wait(
self, name: str, device: torch.device
) -> None:
"""Test that callbacks are only executed after wait() is called."""
# Skip if CUDA is requested but not available
if device.type == "cuda" and not torch.cuda.is_available():
self.skipTest("CUDA not available")
# Skip if accelerator is requested but not available
if device.type == DEVICE.type and not torch.accelerator.is_available():
self.skipTest("accelerator not available")

# Create a tensor to work with
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
Expand Down Expand Up @@ -99,16 +101,16 @@ def callback(fut: Future[object]) -> List[torch.Tensor]:
@parameterized.parameterized.expand(
[
("cpu", torch.device("cpu")),
("cuda", torch.device("cuda:0")),
(DEVICE.type, DEVICE),
]
)
def test_multiple_callbacks_execute_in_order(
self, name: str, device: torch.device
) -> None:
"""Test that multiple callbacks are executed in the order they were added."""
# Skip if CUDA is requested but not available
if device.type == "cuda" and not torch.cuda.is_available():
self.skipTest("CUDA not available")
# Skip if accelerator is requested but not available
if device.type == DEVICE.type and not torch.accelerator.is_available():
self.skipTest("accelerator not available")

# Create a tensor to work with
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
Expand Down Expand Up @@ -169,14 +171,14 @@ def callback3(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
@parameterized.parameterized.expand(
[
("cpu", torch.device("cpu")),
("cuda", torch.device("cuda:0")),
(DEVICE.type, DEVICE),
]
)
def test_future_then_api(self, name: str, device: torch.device) -> None:
"""Test that the future's then API works correctly with ManagedWork."""
# Skip if CUDA is requested but not available
if device.type == "cuda" and not torch.cuda.is_available():
self.skipTest("CUDA not available")
# Skip if accelerator is requested but not available
if device.type == DEVICE.type and not torch.accelerator.is_available():
self.skipTest("accelerator not available")

# Create a tensor to work with
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
Expand Down Expand Up @@ -224,7 +226,7 @@ def callback(fut: Future[object]) -> List[torch.Tensor]:
@parameterized.parameterized.expand(
[
("cpu", torch.device("cpu")),
("cuda", torch.device("cuda:0")),
(DEVICE.type, DEVICE),
]
)
def test_callbacks_changing_return_types(
Expand All @@ -237,9 +239,9 @@ def test_callbacks_changing_return_types(
2. Using Future.value() instead of nonlocal
3. Verifying tensors are modified in-place for both approaches
"""
# Skip if CUDA is requested but not available
if device.type == "cuda" and not torch.cuda.is_available():
self.skipTest("CUDA not available")
# Skip if accelerator is requested but not available
if device.type == DEVICE.type and not torch.accelerator.is_available():
self.skipTest("accelerator not available")

# Create tensors to work with
tensor1: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
Expand Down
9 changes: 6 additions & 3 deletions torchft/checkpointing/http_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchft.checkpointing._serialization import _streaming_load, _streaming_save
from torchft.checkpointing.transport import CheckpointTransport
from torchft.http import _IPv6HTTPServer
from torchft.utils import get_stream_context

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,8 +57,10 @@ def __init__(self, timeout: timedelta, num_chunks: int) -> None:
self._timeout = timeout
self._state_dict: Optional[T] = None
self._num_chunks = num_chunks
self._stream: Optional[torch.cuda.Stream] = (
torch.cuda.Stream() if torch.cuda.is_available() else None
self._stream: Optional[torch.Stream] = (
torch.Stream(torch.accelerator.current_accelerator())
if torch.accelerator.is_available()
else None
)

# staged checkpoint information
Expand Down Expand Up @@ -223,7 +226,7 @@ def send_checkpoint(
values, spec = tree_flatten(state_dict)

with (
torch.cuda.stream(self._stream)
get_stream_context(self._stream)
if self._stream is not None
else nullcontext()
):
Expand Down
12 changes: 6 additions & 6 deletions torchft/checkpointing/http_transport_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
run_multi_recovery_test,
)

DEVICE = torch.accelerator.current_accelerator()

class TestHTTPTransport(TestCase):
@parameterized.expand(
Expand All @@ -33,8 +34,8 @@ def test_checkpoint_server(self, name: str, num_chunks: int) -> None:
expected: Dict[str, object] = {
"state": "dict",
"tensor": torch.rand(5, 2),
"cuda": torch.rand(
2, 3, device="cuda" if torch.cuda.is_available() else "cpu"
"accelerator": torch.rand(
2, 3, device=DEVICE.type if torch.accelerator.is_available() else "cpu"
),
}
state_dict_fn = MagicMock()
Expand Down Expand Up @@ -123,17 +124,16 @@ def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
run_multi_recovery_test(self, init, device=device)

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
@skipUnless(torch.cuda.is_available(), "CUDA is not available")
def test_multi_http_transport_cuda(self) -> None:
device = torch.device("cuda")
@skipUnless(torch.accelerator.is_available(), "accelerator is not available")
def test_multi_http_transport_accelerator(self) -> None:

def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
return HTTPTransport(
timeout=timedelta(seconds=10),
num_chunks=0,
)

run_multi_recovery_test(self, init, device=device)
run_multi_recovery_test(self, init, device=DEVICE)

def test_benchmark(self) -> None:
bench_main(
Expand Down
13 changes: 9 additions & 4 deletions torchft/checkpointing/pg_transport_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.distributed as dist

from torchft.checkpointing.pg_transport import _timeit, PGTransport
from torchft.process_group import ProcessGroupBabyNCCL
from torchft.process_group import ProcessGroupBabyAccelerator

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,12 +47,17 @@ def main(argv: list[str]) -> None:
store_addr: str = f"localhost:{store.port}"

def run(rank: int) -> None:
torch.cuda.set_device(rank)
if torch.accelerator.is_available():
torch.accelerator.set_device_index(rank)

device = torch.device(DEVICE)
device = torch.device(
DEVICE
if DEVICE != "accelerator"
else torch.accelerator.current_accelerator().type
)

with _timeit("init_pg"):
pg = ProcessGroupBabyNCCL(timeout=timeout)
pg = ProcessGroupBabyAccelerator(timeout=timeout)
pg.configure(store_addr=store_addr, replica_id="0", rank=rank, world_size=2)

t = torch.zeros(10, device=device, dtype=torch.float32)
Expand Down
22 changes: 11 additions & 11 deletions torchft/checkpointing/pg_transport_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
make_state_dict,
run_multi_recovery_test,
)
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
from torchft.process_group import ProcessGroupBabyAccelerator, ProcessGroupGloo


class PGTransportTest(TestCase):
Expand Down Expand Up @@ -45,18 +45,18 @@ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
run_multi_recovery_test(self, init, device=device)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
def test_pg_transport_baby_nccl(self) -> None:
@skipUnless(torch.accelerator.device_count() >= 3, "need three accelerator")
def test_pg_transport_baby_accelerator(self) -> None:
store: TCPStore = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
device: torch.device = torch.device("cuda")
device: torch.device = torch.accelerator.current_accelerator()
timeout: timedelta = timedelta(seconds=10)

def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
torch.cuda.set_device(rank)
torch.accelerator.set_device_index(rank)

pg = ProcessGroupBabyNCCL(timeout=timeout)
pg = ProcessGroupBabyAccelerator(timeout=timeout)
pg.configure(
store_addr=f"localhost:{store.port}/prefix",
replica_id="0",
Expand All @@ -69,21 +69,21 @@ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
run_multi_recovery_test(self, init, device=device)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
def test_pg_transport_baby_nccl_inplace(self) -> None:
@skipUnless(torch.accelerator.device_count() >= 3, "need three accelerator")
def test_pg_transport_baby_accelerator_inplace(self) -> None:
store: TCPStore = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
device: torch.device = torch.device("cuda")
device: torch.device = torch.accelerator.current_accelerator()
timeout: timedelta = timedelta(seconds=10)

def state_dict() -> dict[str, object]:
return make_state_dict(device)

def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
torch.cuda.set_device(rank)
torch.accelerator.set_device(rank)

pg = ProcessGroupBabyNCCL(timeout=timeout)
pg = ProcessGroupBabyAccelerator(timeout=timeout)
pg.configure(
store_addr=f"localhost:{store.port}/prefix",
replica_id="0",
Expand Down
Loading