Skip to content

Conversation

@hexane360
Copy link

This is a draft PR implementing higher-order interpolations in jax.scipy.ndimage.

The IIR spline prefilters are implemented with an affine associative scan.
This might not be mergeable immediately; it seems like the spline filters have some stability problems with single-precision floats, especially with orders above 3. Therefore, the tests for spline_filter and map_coordinates with order > 3 have significantly reduced precision. Feedback or contributions are welcome.

Closes: #3928

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hexane360, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a substantial upgrade to jax.scipy.ndimage by enabling higher-order interpolations (up to order 5) within the map_coordinates function. This is achieved through the integration of new spline pre-filtering capabilities, which leverage an affine associative scan for efficient computation. The enhancement allows for more sophisticated and accurate image and data manipulation, although some precision considerations for higher-order single-precision floats are noted.

Highlights

  • Extended Interpolation Orders: The map_coordinates function now supports interpolation orders from 0 to 5, significantly expanding its capabilities beyond the previous limit of order 1.
  • Spline Prefiltering Implementation: New functions spline_filter and spline_filter1d have been added to perform multidimensional and one-dimensional spline pre-filtering, respectively, which are crucial for higher-order interpolations.
  • Affine Associative Scan for Prefilters: The IIR spline prefilters are implemented using an affine associative scan, a technique that enables efficient and parallel computation of these filters.
  • Precision Adjustments for Higher Orders: Tests for spline_filter and map_coordinates with orders greater than 3 have reduced precision due to observed stability problems with single-precision floats.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements higher-order interpolations for jax.scipy.ndimage.map_coordinates, which is a great enhancement. The implementation uses an affine associative scan for the IIR spline prefilters, which is a solid approach for performance on accelerators. The code is well-structured, with new interpolation orders and spline filtering logic cleanly integrated. The tests are comprehensive, covering the new orders and prefiltering options, and correctly handling differences with SciPy's boundary behavior.

My feedback includes a few suggestions for improving code clarity and maintainability, such as adding comments to explain complex formulas and magic numbers, moving an import to the top level, and simplifying some code patterns. Overall, this is a high-quality contribution.

Comment on lines +101 to +112
def _quartic_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
(index, t) = _filter_index_and_weight(coordinate, even=True)
t_sq = t**2
y = t + 1
t1 = 1 - t
return [
(index - 2, (0.5 - t)**4 / 24.0),
(index - 1, y * (y * (y * (5.0 - y) / 6.0 - 1.25) + 5.0 / 24.0) + 55.0 / 96.0),
(index, t_sq * (t_sq * 0.25 - 0.625) + 115.0 / 192.0),
(index + 1, t1 * (t1 * (t1 * (5.0 - t1) / 6.0 - 1.25) + 5.0 / 24.0) + 55.0 / 96.0),
(index + 2, (t + 0.5)**4 / 24.0),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These formulas for quartic spline weights are quite complex. For future maintainability, it would be very helpful to add a comment referencing the source of these equations (e.g., a paper or a specific part of the SciPy implementation). This would also apply to _quintic_indices_and_weights.

n_pad = 0
if order > 1 and prefilter:
if mode in ('nearest', 'constant'):
n_pad = 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The padding size n_pad = 12 seems to be a magic number. Could you add a comment explaining how this value was chosen? It seems related to the maximum spline order supported, but making this explicit would improve clarity.

idx = jnp.arange(0, arr.size - 1, dtype=arr.dtype)
z_n = z**(arr.dtype.type(arr.size) - 1)
return (
jnp.sum(z**idx * (arr[:-1] + z_n * arr[:0:-1]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slice arr[:0:-1] is a bit obscure. Using arr[1:][::-1] is equivalent and more explicit about reversing the array starting from the second element. This would improve readability.

Suggested change
jnp.sum(z**idx * (arr[:-1] + z_n * arr[:0:-1]))
jnp.sum(z**idx * (arr[:-1] + z_n * arr[1:][::-1]))

def _init_wrap_causal(arr: Array, z: float) -> Array:
idx = jnp.arange(1, arr.size, dtype=arr.dtype)
return (
arr[0] + jnp.sum(z**idx * arr[:0:-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slice arr[:0:-1] is a bit obscure. Using arr[1:][::-1] is equivalent and more explicit about reversing the array starting from the second element. This would improve readability.

Suggested change
arr[0] + jnp.sum(z**idx * arr[:0:-1])
arr[0] + jnp.sum(z**idx * arr[1:][::-1])

def _spline_filter1d(
input: Array, order: int, axis: int, mode: str = 'mirror',
) -> Array:
from jax._src.lax.control_flow.loops import associative_scan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import of associative_scan is inside the _spline_filter1d function. According to PEP 8, imports should usually be at the top of the file. Please move this import to the top-level imports of the file.

Comment on lines +323 to +325
(causal_fn, anticausal_fn) = _SPLINE_BOUNDARY_FNS.get(mode, (None, None))
if causal_fn is None or anticausal_fn is None:
raise ValueError("Boundary mode '{}' not supported for pre-filtering".format(mode))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This get followed by a check for the default value can be made slightly more direct by checking for key existence first.

Suggested change
(causal_fn, anticausal_fn) = _SPLINE_BOUNDARY_FNS.get(mode, (None, None))
if causal_fn is None or anticausal_fn is None:
raise ValueError("Boundary mode '{}' not supported for pre-filtering".format(mode))
if mode not in _SPLINE_BOUNDARY_FNS:
raise ValueError("Boundary mode '{}' not supported for pre-filtering".format(mode))
causal_fn, anticausal_fn = _SPLINE_BOUNDARY_FNS[mode]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement higher order interpolation in jax's map_coordinates

1 participant