|
68 | 68 | from jax._src.lib import xla_client as xc |
69 | 69 | from jax._src.lib import pmap_lib |
70 | 70 | from jax._src.sharding import Sharding |
71 | | -from jax._src.mesh import get_concrete_mesh, get_abstract_mesh |
| 71 | +from jax._src.mesh import get_concrete_mesh, get_abstract_mesh, Mesh |
72 | 72 | from jax._src.sharding_impls import (PmapSharding, PartitionSpec as P, |
73 | 73 | NamedSharding) |
74 | 74 | from jax._src.layout import Format |
@@ -2800,8 +2800,12 @@ def _device_put_sharded(*xs): |
2800 | 2800 | raise ValueError("the shards passed to device_put_sharded must have " |
2801 | 2801 | f"consistent shape and dtype, but got {a1} and {a2}.") |
2802 | 2802 | stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) |
2803 | | - sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) |
2804 | | - sharding = PmapSharding(np.array(devices), sharding_spec) |
| 2803 | + if config.pmap_shmap_merge.value: |
| 2804 | + mesh = Mesh(np.array(devices), ('_device_put_sharded',)) |
| 2805 | + sharding = NamedSharding(mesh, P('_device_put_sharded')) |
| 2806 | + else: |
| 2807 | + sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) |
| 2808 | + sharding = PmapSharding(np.array(devices), sharding_spec) |
2805 | 2809 | if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended): |
2806 | 2810 | return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices) |
2807 | 2811 | if config.pmap_no_rank_reduction.value: |
@@ -2856,15 +2860,19 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 |
2856 | 2860 | def _device_put_replicated(x): |
2857 | 2861 | aval = core.unmapped_aval(len(devices), 0, core.get_aval(x)) |
2858 | 2862 | assert isinstance(aval, ShapedArray) |
2859 | | - sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) |
2860 | 2863 | if config.pmap_no_rank_reduction.value: |
2861 | 2864 | if isinstance(x, (np.ndarray, basearray.Array)): |
2862 | 2865 | buf = device_put(x[None], devices[0]) |
2863 | 2866 | else: |
2864 | 2867 | buf = device_put(x, devices[0])[None] |
2865 | 2868 | else: |
2866 | 2869 | buf = device_put(x, devices[0]) |
2867 | | - sharding = PmapSharding(np.array(devices), sharding_spec) |
| 2870 | + if config.pmap_shmap_merge.value: |
| 2871 | + mesh = Mesh(np.array(devices), ('_device_put_replicated',)) |
| 2872 | + sharding = NamedSharding(mesh, P('_device_put_replicated')) |
| 2873 | + else: |
| 2874 | + sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) |
| 2875 | + sharding = PmapSharding(np.array(devices), sharding_spec) |
2868 | 2876 | if dtypes.issubdtype(aval.dtype, dtypes.extended): |
2869 | 2877 | return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) |
2870 | 2878 | return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) |
|
0 commit comments