Skip to content

Commit bf44580

Browse files
Merge branch 'main' into feat/nnx-transformer-dp-sgd
2 parents 008a96d + 8975bfc commit bf44580

35 files changed

+682
-1009
lines changed

.github/workflows/ci.yml

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,110 @@ on:
77
branches: ["main"]
88

99
jobs:
10-
build-and-test:
11-
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
12-
runs-on: "${{ matrix.os }}"
10+
lint:
11+
name: "Lint: Python ${{ matrix.python-version }} on ${{ matrix.os }}"
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
matrix:
15+
python-version: ["3.11"]
16+
os: [ubuntu-latest]
17+
steps:
18+
- uses: "actions/checkout@v6"
19+
- uses: "actions/setup-python@v6"
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
cache: "pip"
23+
- name: Install Linting Tools
24+
run: |
25+
python -m pip install --upgrade pip
26+
pip install -e ".[lint,examples,test]"
27+
- name: Run flake
28+
run: |
29+
flake8 jax_privacy tests examples
30+
- name: Run pydocstyle
31+
run: |
32+
pydocstyle --convention=google --add-ignore=D101,D102,D103,D105,D202,D402 jax_privacy/
33+
- name: Run pylint
34+
shell: bash
35+
run: |
36+
pylint jax_privacy || pylint-exit -efail -wfail -cfail -rfail $?
37+
pylint examples || pylint-exit -efail -wfail -cfail -rfail $?
38+
pylint tests -d W0101,W0212,C0114 || pylint-exit -efail -wfail -cfail -rfail $?
39+
- name: Run pytype
40+
run: |
41+
pytype jax_privacy -k
42+
pytype tests -k
43+
pytype examples -k
1344
45+
test:
46+
name: "Unit Test: Python ${{ matrix.python-version }} on ${{ matrix.os }}"
47+
runs-on: ${{ matrix.os }}
1448
strategy:
1549
matrix:
1650
python-version: ["3.11"]
1751
os: [ubuntu-latest]
52+
steps:
53+
- uses: "actions/checkout@v6"
54+
- uses: "actions/setup-python@v6"
55+
with:
56+
python-version: ${{ matrix.python-version }}
57+
cache: "pip"
58+
- name: Install Test Dependencies
59+
run: |
60+
python -m pip install --upgrade pip
61+
pip install -e ".[test,examples]"
62+
- name: Run Doctests
63+
run: |
64+
pytest --doctest-modules jax_privacy
65+
- name: Run Standard Tests
66+
run: |
67+
pytest -n auto tests/ -k "not matrix_factorization and not distributed_noise_generation_test and not sharding_utils_test"
68+
pytest -n auto tests/ -k "distributed_noise_generation_test"
69+
pytest -n auto tests/ -k "sharding_utils_test"
1870
71+
matrix-tests:
72+
name: "Matrix Tests: Python ${{ matrix.python-version }} on ${{ matrix.os }}"
73+
needs: lint
74+
runs-on: "${{ matrix.os }}"
75+
strategy:
76+
matrix:
77+
python-version: ["3.11"]
78+
os: [ubuntu-latest]
1979
steps:
2080
- uses: "actions/checkout@v6"
2181
- uses: "actions/setup-python@v6"
2282
with:
23-
python-version: "${{ matrix.python-version }}"
24-
- name: Install example requirements
25-
run: pip install -r examples/requirements.txt
26-
- name: Run CI tests
27-
run: bash test.sh
83+
python-version: "${{ matrix.python-version }}"
84+
cache: "pip"
85+
- name: Install Dependencies
86+
run: |
87+
python -m pip install --upgrade pip
88+
pip install -e ".[test]"
89+
- name: Run Heavy Tests
90+
run: |
91+
export HYPOTHESIS_PROFILE=dpftrl_default
92+
pytest -n auto tests/ -k "matrix_factorization" --ignore=tests/matrix_factorization/buffered_toeplitz_test.py
2893
shell: bash
94+
95+
docs:
96+
name: "Docs: Python ${{ matrix.python-version }} on ${{ matrix.os }}"
97+
needs: lint
98+
runs-on: "${{ matrix.os }}"
99+
strategy:
100+
matrix:
101+
python-version: ["3.11"]
102+
os: [ubuntu-latest]
103+
steps:
104+
- uses: "actions/checkout@v6"
105+
- uses: "actions/setup-python@v6"
106+
with:
107+
python-version: "${{ matrix.python-version }}"
108+
cache: "pip"
109+
- name: Install Docs Dependencies
110+
run: |
111+
python -m pip install --upgrade pip
112+
pip install -e ".[docs]"
113+
- name: Build Sphinx
114+
run: |
115+
cd docs
116+
sphinx-build -W -b html . _build/html

.pylintrc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ unsafe-load-any-extension=no
3535
extension-pkg-whitelist=
3636

3737

38+
[STRING]
39+
40+
# This flag controls whether inconsistent-quotes generates a warning when the
41+
# character used as a quote delimiter is used inconsistently within a module.
42+
check-quote-consistency=yes
43+
3844
[MESSAGES CONTROL]
3945

4046
# Only show warnings with the listed confidence levels. Leave empty to show
@@ -128,6 +134,7 @@ disable=apply-builtin,
128134
too-many-branches,
129135
too-many-instance-attributes,
130136
too-many-locals,
137+
too-many-positional-arguments,
131138
too-many-public-methods,
132139
too-many-return-statements,
133140
too-many-statements,
@@ -142,6 +149,7 @@ disable=apply-builtin,
142149
xrange-builtin,
143150
wrong-import-order,
144151
zip-builtin-not-iterating,
152+
g-bad-todo,
145153

146154

147155
[REPORTS]

CONTRIBUTING.md

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,27 @@
33
We welcome small patches related to bug fixes, but we do not plan to accept
44
major changes to this repository at this time.
55

6+
## Coordination & Claiming Issues
7+
8+
To avoid duplicate effort and ensure your work can be merged, please follow
9+
these steps:
10+
11+
* **Check for existing PRs and Issues:** Before starting work, search the
12+
issue tracker and active PRs to see if the feature/bug is already being
13+
addressed.
14+
15+
* **Claim the issue:** If an issue exists, comment on it expressing your
16+
intent to work on it (e.g., "I'd like to work on this"). A maintainer will
17+
then assign it to you.
18+
19+
* **Wait for Assignment:** Do not start a large-scale implementation until a
20+
maintainer has acknowledged your comment. This prevents multiple people from
21+
working on the same fix simultaneously.
22+
23+
* **Stale Assignments:** If an assigned issue hasn't seen progress or
24+
communication for 14 days, the assignment may be cleared to allow others to
25+
contribute.
26+
627
## Contributor License Agreement
728

829
Contributions to this project must be accompanied by a Contributor License
@@ -25,45 +46,47 @@ information on using pull requests.
2546
## Style guide
2647

2748
Code in this library generally follows the
28-
[Google Style Guide](https://google.github.io/styleguide/pyguide.html).
29-
We aim to keep APIs, names, and design patterns in line with the broader JAX
30-
ecosystem as much as possible, with immutability and functional purity being a
31-
key guiding principles we adhere to across our library. Below are some more
32-
detailed conventions depending on what is being contributed.
33-
34-
1. **Public Facing Functions**: Public facing functions are those that are
35-
exposed to the **users** of JAX Privacy (usually surfaced via \_\_init\_\_.py).
36-
Public facing functions and classes should **always** have full docstrings, type
37-
annotations, and example usages in the form of
38-
[doctests](https://docs.python.org/3/library/doctest.html). Doctests
39-
provides useful documentation that stays up-to-date with code changes, and
40-
is a useful litmus test on the simplicity and usability of the API surface.
41-
42-
1. **Internal functions used across files**: For maintainability of the core
43-
library, it is sometimes beneficial to define a function in one file and have it
44-
be used by another file within the jax_privacy package. These functions are not
45-
intended to be consumed by JAX Privacy users (although may be encountered by
46-
developers / contributors). These functions should generally have
47-
descriptive names, type annotations. Internal functions should have a one-line
48-
docstring explaining what they do. A full docstring is encouraged if the
49-
function has non-obvious side effects, complex arguments,
50-
or implements a multi-step algorithm that isn't clear from the code alone.
51-
52-
1. **File-local private functions**: These function should always have a
53-
leading "_". This signals to developers that the function is not part of the
54-
public API and is subject to change without notice. These functions should have
55-
1-line docstrings; type annotations are optional and context-dependent. Example
56-
usages are not needed as they can be found in the corresponding _test.py file.
57-
58-
1. **Nested functions**: Functions defined within other functions should
59-
generally be as simple as possible; we prefer to keep the boilerplate minimal
60-
on these (no docstrings + type annotations), inline comments can be used, but
61-
should be used sparingly.
49+
[Google Style Guide](https://google.github.io/styleguide/pyguide.html). We aim
50+
to keep APIs, names, and design patterns in line with the broader JAX ecosystem
51+
as much as possible, with immutability and functional purity being a key guiding
52+
principles we adhere to across our library. Below are some more detailed
53+
conventions depending on what is being contributed.
54+
55+
1. **Public Facing Functions**: Public facing functions are those that are
56+
exposed to the **users** of JAX Privacy (usually surfaced via
57+
\_\_init\_\_.py). Public facing functions and classes should **always** have
58+
full docstrings, type annotations, and example usages in the form of
59+
[doctests](https://docs.python.org/3/library/doctest.html). Doctests
60+
provides useful documentation that stays up-to-date with code changes, and
61+
is a useful litmus test on the simplicity and usability of the API surface.
62+
63+
1. **Internal functions used across files**: For maintainability of the core
64+
library, it is sometimes beneficial to define a function in one file and
65+
have it be used by another file within the jax_privacy package. These
66+
functions are not intended to be consumed by JAX Privacy users (although may
67+
be encountered by developers / contributors). These functions should
68+
generally have descriptive names, type annotations. Internal functions
69+
should have a one-line docstring explaining what they do. A full docstring
70+
is encouraged if the function has non-obvious side effects, complex
71+
arguments, or implements a multi-step algorithm that isn't clear from the
72+
code alone.
73+
74+
1. **File-local private functions**: These function should always have a
75+
leading "_". This signals to developers that the function is not part of the
76+
public API and is subject to change without notice. These functions should
77+
have 1-line docstrings; type annotations are optional and context-dependent.
78+
Example usages are not needed as they can be found in the corresponding
79+
_test.py file.
80+
81+
1. **Nested functions**: Functions defined within other functions should
82+
generally be as simple as possible; we prefer to keep the boilerplate
83+
minimal on these (no docstrings + type annotations), inline comments can be
84+
used, but should be used sparingly.
6285

6386
## Linting and testing
6487

65-
We use `flake8`, `pylint` and `pytype` for linting and type checking. Please
66-
run the following commands locally before submitting a pull request:
88+
We use `flake8`, `pylint` and `pytype` for linting and type checking. Please run
89+
the following commands locally before submitting a pull request:
6790

6891
```bash
6992
$ flake8 jax_privacy/**.py

MANIFEST.in

Lines changed: 0 additions & 1 deletion
This file was deleted.

conftest.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

docs/core_library.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,3 @@ Experimental Modules
2525

2626
experimental.execution_plan
2727
experimental.compilation_utils
28-
29-
30-
Other References
31-
----------------
32-
.. autosummary::
33-
:toctree: _autosummary_output
34-
:nosignatures:
35-
36-
experimental.microbatching

examples/distributed_noise_generation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
from absl import flags
3737
import jax
3838
import jax.numpy as jnp
39-
# pylint: disable=g-importing-member,no-name-in-module
40-
from jax.experimental.shard import reshard
4139
from jax_privacy import noise_addition
4240
from jax_privacy.matrix_factorization import toeplitz
4341

@@ -112,15 +110,17 @@ def bert_model_params(hidden_size: int) -> Any:
112110
is_leaf=lambda x: isinstance(x, tuple),
113111
)
114112

115-
return reshard(model_params, jax.sharding.PartitionSpec())
113+
return jax.sharding.reshard(model_params, jax.sharding.PartitionSpec())
116114

117115

118116
def toy_model_params(hidden_size: int) -> jax.Array:
119117
"""Returns model parameters for a toy model."""
120118
# This is a toy example where the model is just a 2D array of size (H, H^2).
121119
leaf_shape = (hidden_size, hidden_size**2)
122120

123-
return reshard(jnp.zeros(leaf_shape), jax.sharding.PartitionSpec('x', 'y'))
121+
return jax.sharding.reshard(
122+
jnp.zeros(leaf_shape), jax.sharding.PartitionSpec('x', 'y')
123+
)
124124

125125

126126
def generate_noise(
@@ -162,10 +162,10 @@ def run(pytree_like_model_params):
162162
t0 = time.time()
163163
compiled_run = run.lower(model_params).compile()
164164
t1 = time.time()
165-
print(f'[BandMF] Compilation time: {t1 - t0:.3f} seconds')
165+
print(f'[BandMF] Compilation time: {t1-t0:.3f} seconds')
166166
state, noisy_grad = jax.block_until_ready(compiled_run(model_params))
167167
t2 = time.time()
168-
print(f'[BandMF] Per-step run time: {(t2 - t1) / steps:.3f} seconds')
168+
print(f'[BandMF] Per-step run time: {(t2-t1)/steps:.3f} seconds')
169169

170170
return state, noisy_grad
171171

examples/keras_api_example.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
import os
2323
from absl import app
2424

25-
os.environ["KERAS_BACKEND"] = "jax" # pylint: disable=wrong-import-position
26-
from jax_privacy.keras import keras_api # pylint: disable=g-import-not-at-top
25+
os.environ["KERAS_BACKEND"] = "jax"
26+
# pylint: disable=g-import-not-at-top,wrong-import-position
27+
from jax_privacy import keras_api
2728
import keras
2829
from keras import layers
2930
import numpy as np

0 commit comments

Comments
 (0)