Skip to content

Commit c441044

Browse files
authored
Merge branch 'master' into HeyangQin/enable_hpz_nograd
2 parents 3f89662 + 57d629a commit c441044

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,11 @@ def wrapped_fn(*args, **kwargs) -> Tensor:
250250

251251
def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:
252252

253-
def new_tensor(cls, *args) -> Tensor:
253+
def new_tensor(cls, *args, **kwargs) -> Tensor:
254254
device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
255-
tensor = _orig_torch_empty(0, device=device).new_empty(*args)
255+
if not args:
256+
args = (0, )
257+
tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs)
256258
if tensor.is_floating_point():
257259
tensor = tensor.to(dtype)
258260

0 commit comments

Comments
 (0)