Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120#126
Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120#126debanganghosh08 wants to merge 17 commits intogoogle-deepmind:mainfrom
Conversation
7cbfbb1 to
944df7c
Compare
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
add timeout to prevent indefinite blocking
There was a problem hiding this comment.
That's a good catch brother. i have now added a timeout and is definitely best practice to avoid hangs in CI/CD. I've updated download_data to include a 10-second timeout. I'm also moving the flax dependency into a proper requirements file as you suggested.
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.
There was a problem hiding this comment.
so try adding in the requirements txt which is located in the docs folder
There was a problem hiding this comment.
The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.
There was a problem hiding this comment.
Or you can add it to the "dev" requirements in pyproject.toml
| from absl import app | ||
| from absl import flags | ||
| import flax.linen as nn | ||
| import flax.linen as nn # pytype: disable=import-error |
There was a problem hiding this comment.
No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.
ryan112358
left a comment
There was a problem hiding this comment.
Looks great ,very clean - nice work! Left some comments
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.
| x: Input batch (single example or microbatch). | ||
| y: Target batch (single example or microbatch). | ||
| graphdef: The static graph definition of the NNX model. | ||
| other: Non-trainable state (e.g., RNG counts). |
There was a problem hiding this comment.
What else other than the rng counts is captured here? Is it possible to call this argument prng and have it typed as a jax.Array, then somehow wire it through to flax? I ask because when you call clipped_grad, if the loss function contains a prng key it needs special handling.
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
Give this a descriptive name like model
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
You might need to pass prng_argnum here as well to ensure the random key is handled appropriately. But it might require slight refactoring of your loss function
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
Usually we want to keep this to the default (True), unless we're doing user-level DP. If you set this to True (or remove it), can you remove the line that adds an extra batch axis in pure_loss_fn?
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
grad_fn already aggregates gradients across the batch dimension, so I think this is a bug
| # Aggregate gradients (mean across batch) | ||
| mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads) | ||
|
|
||
| # Add Privacy Noise |
There was a problem hiding this comment.
I'll leave it up to your discretion, but I think these inline comments can be removed.
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
In an ideal world this would use poisson sampling / jax_privacy.batch_selection. It's fine to leave a TODO for now and add it in a follow-up
examples/dp_sgd_transformer_nnx.py
Outdated
There was a problem hiding this comment.
The stddev should be grad_fn.sensitiivty() * noise_multiplier. can you add NOISE_MULTIPLIER to the list of constants above?
1d03537 to
9eac33d
Compare
|
Hi @ryan112358 , I've pushed an update addressing all your feedback. Here is a summary of the changes I made:
✅ Verification: The script was verified for 10 steps locally, achieving a stable loss and passing a 10.00/10 pylint check. Remind me if new changes are required! |
|
#128 might fix the ci failures easy to debug |
That's an Good approach for moving current CICD to modular DAG architecture. It is good for improving DX. |
|
@debanganghosh08 , since now the new ci pipeline and new dependency flow has been introduced, so there will ci failures from now on. As you have added the one lib in examples/req...txt it will not considered from now on. Kindly first pull the lastest changes from upstream main, then delete the examples/req..txt file and add the deps to the pyproject.toml, you can see there is optional tab and a space for [examples], kindly add it there. Now a central optional deps are managed at the root pyproject.toml file |
…alse) per maintainer review
b6d6d66 to
d5a7943
Compare
Thanks for the heads-up and the clear guidance on the new dependency flow, @amyssnippet! I've just pushed an update aligning with the new modular CI. I pulled the latest upstream changes, migrated flax to the [project.optional-dependencies] section in pyproject.toml, and cleaned up the temporary requirements file. Everything should be in sync now! |
amyssnippet
left a comment
There was a problem hiding this comment.
i guess check the files changed tab, there are still some files visible, kindly fix them all, i already left comments
.github/workflows/ci.yml
Outdated
There was a problem hiding this comment.
this block of ci should not be here, it is unusual, it is not required
pyproject.toml
Outdated
There was a problem hiding this comment.
i have already created arrays to manage all optional dependencies, check it here https://github.com/google-deepmind/jax_privacy/blob/main/pyproject.toml
i have made deps in the prev task with ci, make sure you pulled the changes properly. including this file
examples/requirements.txt
Outdated
There was a problem hiding this comment.
i guess its still available here, which is not required
|
Hello @ryan112358 and @Neerajpathak07 , I’ve updated the implementation for both the NNX Transformer (#126) and ULS Transformer (#107) examples to align with the architectural suggestions provided. I performed a side-by-side experimental benchmark to evaluate the impact of moving from a manual loop to the library's internal execution_plan abstraction. Key Refactors Implemented: Standardized Orchestration: Switched to execution_plan.BandMFExecutionPlanConfig to wire the privatizer and clipped_grad. This ensures the noise addition and sampling strategies are mathematically synchronized with the library's core mechanisms. ULS Integration: In the User-Level example, I successfully wrapped the plan.batch_selection_strategy within our UserSelectionStrategy, maintaining the required intra-user averaging while utilizing the standard batch_iterator. Production Standards: Adopted the main(argv: Sequence[str]) entry point and migrated hyperparameters to centralized constants for better readability. Both files now achieve a 10.00/10 Pylint score. Benchmarking Observation: During local 10-step runs, I noted a significant initialization overhead (~45s). This is due to the Toeplitz.optimize_banded_toeplitz step required by the BandMF strategy. While this increases the 'wall-clock' time for short CI checks, it is a fixed cost that will be fully amortized during production-scale training runs. @Neerajpathak07, thanks for pointing out the BandMF configuration, it makes the examples much more idiomatic. @ryan112358, do you agree that the increased alignment with the core library's 'Plan' API is worth the trade-off in script simplicity for these examples? |
bff72ee to
9019d85
Compare
9d5e143 to
008a96d
Compare
bf44580 to
1b84638
Compare
This PR introduces a comprehensive example of training a Transformer model with Differential Privacy using the new Flax NNX API. While JAX Privacy provides robust support for Linen and Haiku, this addition provides a template for users moving toward the functional-object paradigm of NNX.
Key Technical Implementations:
✔️ Exhaustive State Partitioning: Utilizes nnx.split(model, nnx.Param, ...) to strictly separate trainable parameters from non-trainable state (RNG counts, etc.), ensuring the JAX tracer maintains leaf parity across functional boundaries.
✔️ Rank-Normalized Loss: Implements a rank-injection strategy within the pure loss function to account for vmap dimension-stripping. By forcing a singleton batch dimension during the forward pass, the model correctly generates 4D causal masks required by the attention mechanism.
✔️ Privacy-Safe State Reconstruction: Uses an internal nnx.merge pattern to ensure that mutations to RNG states during training remain local to the functional trace, preventing TraceContextError regressions.
✅ Verification: The script was validated on the Tiny Shakespeare dataset for 20 steps, achieving stable convergence under DP constraints (Default: CLIP_NORM=1.0).
Screenshot of output attached 👇
