Skip to content

Commit add98bc

Browse files
authored
fix: _get_physical_gpu_id requires device_manager (#53)
1 parent a53b342 commit add98bc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def run():
6868
world_size = int(os.getenv("WORLD_SIZE"))
6969
ctx = get_context("spawn")
7070
queue = ctx.Queue()
71-
_device_uuid = _get_physical_gpu_id(rank)
7271
ps = ParameterServer(auto_pg=True)
72+
_device_uuid = _get_physical_gpu_id(ps.device_manager, rank)
7373
named_tensors = dict(gen_test_tensors(rank))
7474
checkpoint_name = "test"
7575
proc = ctx.Process(target=checker_proc, args=(rank, _device_uuid, named_tensors, queue))

0 commit comments

Comments
 (0)