feat: migrate pipeline to nnx#2885
feat: migrate pipeline to nnx#2885mesakhcienet wants to merge 18 commits intoAI-Hypercomputer:mainfrom
Conversation
6875da8 to
f34b1a3
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
12a3907 to
2c16599
Compare
64dc147 to
9e4518e
Compare
631a73e to
ac97a1d
Compare
1849f0b to
669dc01
Compare
bvandermoon
left a comment
There was a problem hiding this comment.
@gobbleturk what testing do you recommend for migrating pipeline parallelism to NNX? I'll send over an internal doc @hsuan-lun-chiang, @mesakhcienet, and others put together that shows the tests they have already run
@NuojCheng any thoughts here? |
NuojCheng
left a comment
There was a problem hiding this comment.
Some additional train compile test for pipeline NNX migration:
- Train compile test 1: https://paste.googleplex.com/5960957017849856
- Train compile test 2: https://paste.googleplex.com/5749974483730432
- Train compile test 3: https://paste.googleplex.com/5201745681711104
If the train compile tests above can pass without getting OOM + current tests in pipeline_parallelism_test.py can all pass, then I think it is good to go! Please ping me if the PR is ready for review.
|
There are also some linen usage in
I don't see them get updated in this PR but I think they probably should be updated? Another thing is the usage of function in maxtext/src/maxtext/utils/pipeline_utils.py Lines 151 to 162 in 77f5334 |
As far as I know, the current objective is to migrate the Linen pipeline to NNX while preserving the current Linen version. Please advise if any additional progress is required at this time. Thanks! |
Shouldn't we have a nnx version of functions in pipeline_utils.py as well? |
Are we able to bridge the NNX version back to Linen at a higher layer? If so, then I think we could get rid of the old Linen code that is no longer used |
| new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out | ||
| return new_carry, nnx.state(layer) | ||
|
|
||
| final_carry, scanned_state = jax.lax.scan(layer_fn, inputs, (params, state)) |
There was a problem hiding this comment.
you can alternatively use nnx.scan here which already does the moveaxis for you.
There was a problem hiding this comment.
it also avoids the use of split and merge
There was a problem hiding this comment.
thank you for your review,
as far as I know, we should avoid the nnx.scan and nnx.remat as mentioned in the internal docs chat here
@bvandermoon Option 1: If we use Option 2: Delay the full migration until Please let me know which of these two solutions you prefer, thank you. @NuojCheng The NNX pipeline classes (NNXPipeline, NNXCircularPipeline) already handle these internally with JAX-native equivalents:
So no NNX versions of those functions are needed — the NNX path bypasses them entirely. maybe you have some suggestions or any part that i am wrong? Please let me know. Thank you. |
| scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) | ||
| scanned_state = nnx.State.merge(scanned_params, scanned_other) | ||
|
|
||
| self.scanned_layers = nnx.merge(graphdef, scanned_state) |
There was a problem hiding this comment.
| self.scanned_layers = nnx.merge(graphdef, scanned_state) | |
| nnx.update(self.scanned_layers, scanned_state) |
Thank you @mesakhcienet. Let's go with option 1 please. That way we can continue running unit tests along the way, and we don't need to worry about the Linen/NNX versions diverging before the migration is fully done |
src/maxtext/layers/pipeline.py
Outdated
| if bsw_pps is not None: | ||
|
|
||
| @jax.shard_map(mesh=self.mesh, in_specs=((bsw_pps, bsw_pps), P("stage")), out_specs=bsw_pps, check_vma=True) | ||
| @jax.shard_map(mesh=self.mesh, in_specs=((bsw_pps, bsw_pps), P("stage")), out_specs=bsw_pps, check_vma=False) |
There was a problem hiding this comment.
please set check_vma=True in pipeline code
This reverts commit 6d6bea2.
refactor: replace Linen pipeline classes with to_linen_class wrappers Remove PipelineBaseLinen, Pipeline, CircularPipeline (~740 lines). Add Pipeline = to_linen_class(NNXPipeline) and CircularPipeline = to_linen_class(NNXCircularPipeline). Update create_pipeline to accept stage_factory instead of layers. test: update pipeline test to use NNX stage factory refactor: switch Linen decoder pipeline path to NNX stage factory Replace get_pipeline_stage_module with _get_nnx_decoder_block_classes and _build_nnx_pipeline_stage. Delete SequentialBlockDecoderLayers. Pipeline setup now uses stage_factory callable (rngs -> NNX module).
…teration run_one_iteration now: 1. Fetches params from BSW (params-only, matching shard_map specs) 2. Gathers metrics/mutables directly for current repeat 3. Merges into full state for forward pass 4. Scatter-updates only non-params back (params static in scan) Fixes ValueError: pytree structure error in shard_map where out_specs had None (leaf) at RNG paths but BSW had RngCount (pytree node). fix(pipeline): pass params-only to weight_prefetching in scan_body BSW prefetching only needs parameters. Non-param state (metrics, mutables) is now passed separately to run_one_iteration for direct gathering. fix(pipeline): clean up run_one_iteration consistency - Filter to non-params in num_pipeline_repeats == 1 path for consistency - Remove redundant get_microbatch_and_repeat_ids call
Description
implement nnx-based pipeline.
This PR extends PR#2831
Main changes:
NNXPipeline, which is a nnx-based pipeline class.Tests
we run the pipeline process with command below:
MODEL_NAME=llama2-7b python -m MaxText.train src/maxtext/configs/base.yml \ run_name=pipeline_test_${MODEL_NAME}_nnx \ base_output_directory=/dev/shm/pipeline_test_nnx \ model_name=${MODEL_NAME}\ dataset_type=synthetic \ steps=15 \ debug_sharding=true \ per_device_batch_size=2 \ max_target_length=32 \ ici_pipeline_parallelism=2 \ num_pipeline_microbatches=4 \ num_layers_per_pipeline_stage=2 \ enable_checkpointing=false \ enable_nnx=true \ pure_nnx_decoder=true \ scan_layers_per_stage=false \ async_checkpointing=false > nnx-porting-log/pipeline/custom_${MODEL_NAME}.log 2>&1Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.