Skip to content

Commit 691bf92

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Created NamedSharding arrays when jax_pmap_shmap_merge=True in jax.device_put_replicated and jax.device_put_sharded.
PiperOrigin-RevId: 845530346
1 parent 9a391d7 commit 691bf92

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

jax/_src/api.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from jax._src.lib import xla_client as xc
6969
from jax._src.lib import pmap_lib
7070
from jax._src.sharding import Sharding
71-
from jax._src.mesh import get_concrete_mesh, get_abstract_mesh
71+
from jax._src.mesh import get_concrete_mesh, get_abstract_mesh, Mesh
7272
from jax._src.sharding_impls import (PmapSharding, PartitionSpec as P,
7373
NamedSharding)
7474
from jax._src.layout import Format
@@ -2800,8 +2800,12 @@ def _device_put_sharded(*xs):
28002800
raise ValueError("the shards passed to device_put_sharded must have "
28012801
f"consistent shape and dtype, but got {a1} and {a2}.")
28022802
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
2803-
sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape)
2804-
sharding = PmapSharding(np.array(devices), sharding_spec)
2803+
if config.pmap_shmap_merge.value:
2804+
mesh = Mesh(np.array(devices), ('_device_put_sharded',))
2805+
sharding = NamedSharding(mesh, P('_device_put_sharded'))
2806+
else:
2807+
sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape)
2808+
sharding = PmapSharding(np.array(devices), sharding_spec)
28052809
if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended):
28062810
return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices)
28072811
if config.pmap_no_rank_reduction.value:
@@ -2856,15 +2860,19 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
28562860
def _device_put_replicated(x):
28572861
aval = core.unmapped_aval(len(devices), 0, core.get_aval(x))
28582862
assert isinstance(aval, ShapedArray)
2859-
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
28602863
if config.pmap_no_rank_reduction.value:
28612864
if isinstance(x, (np.ndarray, basearray.Array)):
28622865
buf = device_put(x[None], devices[0])
28632866
else:
28642867
buf = device_put(x, devices[0])[None]
28652868
else:
28662869
buf = device_put(x, devices[0])
2867-
sharding = PmapSharding(np.array(devices), sharding_spec)
2870+
if config.pmap_shmap_merge.value:
2871+
mesh = Mesh(np.array(devices), ('_device_put_replicated',))
2872+
sharding = NamedSharding(mesh, P('_device_put_replicated'))
2873+
else:
2874+
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
2875+
sharding = PmapSharding(np.array(devices), sharding_spec)
28682876
if dtypes.issubdtype(aval.dtype, dtypes.extended):
28692877
return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices)
28702878
return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices)

tests/pmap_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2951,7 +2951,10 @@ def test_device_put_sharded(self):
29512951
x = [np.arange(i, i + 4) for i in range(n_devices)]
29522952
y = jax.device_put_sharded(x, devices)
29532953
self.assertIsInstance(y, array.ArrayImpl)
2954-
self.assertIsInstance(y.sharding, jax.sharding.PmapSharding)
2954+
if config.pmap_shmap_merge.value:
2955+
self.assertIsInstance(y.sharding, jax.NamedSharding)
2956+
else:
2957+
self.assertIsInstance(y.sharding, jax.sharding.PmapSharding)
29552958
for s in y.addressable_shards:
29562959
self.assertArraysEqual(s.data, y[s.index])
29572960
self.assertEqual(s.replica_id, 0)

0 commit comments

Comments
 (0)