Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 268 additions & 28 deletions jax/_src/scipy/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,95 @@ def _round_half_away_from_zero(a: Array) -> Array:
return a if dtypes.issubdtype(a.dtype, np.integer) else lax.round(a)


def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
index = _round_half_away_from_zero(coordinate).astype(np.int32)
def _round_half_to_posinf(a: Array) -> Array:
return a if dtypes.issubdtype(a.dtype, np.integer) else lax.floor(a + 0.5)


def _filter_index_and_weight(coordinate: Array, even: bool = False) -> tuple[Array, Array]:
lower = jnp.floor(coordinate + 0.5 if even else coordinate)
lower_dist = coordinate - lower
# (index, dist to lower knot)
return (lower.astype(np.int32), lower_dist)


def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
index = _round_half_to_posinf(coordinate).astype(np.int32)
weight = coordinate.dtype.type(1)
return [(index, weight)]


def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
lower = jnp.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = lower.astype(np.int32)
return [(index, lower_weight), (index + 1, upper_weight)]
def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
(index, lower_dist) = _filter_index_and_weight(coordinate)
return [(index, 1 - lower_dist), (index + 1, lower_dist)]


def _quadratic_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
(index, t) = _filter_index_and_weight(coordinate, even=True)
# t from -0.5 to 0.5
return [
(index - 1, 0.5 * (0.5 - t)**2),
(index, 0.75 - t * t),
(index + 1, 0.5 * (t + 0.5)**2),
]


def _cubic_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
(index, t) = _filter_index_and_weight(coordinate)
t1 = 1 - t
return [
(index - 1, t1 * t1 * t1 / 6.),
(index, (4. + 3. * t * t * (t - 2.0)) / 6.),
(index + 1, (4. + 3. * t1 * t1 * (t1 - 2.0)) / 6.),
(index + 2, t * t * t / 6.),
]


def _quartic_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
(index, t) = _filter_index_and_weight(coordinate, even=True)
t_sq = t**2
y = t + 1
t1 = 1 - t
return [
(index - 2, (0.5 - t)**4 / 24.0),
(index - 1, y * (y * (y * (5.0 - y) / 6.0 - 1.25) + 5.0 / 24.0) + 55.0 / 96.0),
(index, t_sq * (t_sq * 0.25 - 0.625) + 115.0 / 192.0),
(index + 1, t1 * (t1 * (t1 * (5.0 - t1) / 6.0 - 1.25) + 5.0 / 24.0) + 55.0 / 96.0),
(index + 2, (t + 0.5)**4 / 24.0),
]
Comment on lines +101 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These formulas for quartic spline weights are quite complex. For future maintainability, it would be very helpful to add a comment referencing the source of these equations (e.g., a paper or a specific part of the SciPy implementation). This would also apply to _quintic_indices_and_weights.



def _quintic_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
(index, t) = _filter_index_and_weight(coordinate)
t1 = 1 - t
t_sq = t * t
t1_sq = t1 * t1
y = t + 1
y1 = t1 + 1
return [
(index - 2, t1 * t1_sq * t1_sq / 120.0),
(index - 1, y * (y * (y * (y * (y / 24.0 - 0.375) + 1.25) - 1.75) + 0.625) + 0.425),
(index, t_sq * (t_sq * (0.25 - t / 12.0) - 0.5) + 0.55),
(index + 1, t1_sq * (t1_sq * (0.25 - t1 / 12.0) - 0.5) + 0.55),
(index + 2, y1 * (y1 * (y1 * (y1 * (y1 / 24.0 - 0.375) + 1.25) - 1.75) + 0.625) + 0.425),
(index + 3, t * t_sq * t_sq / 120.0),
]


_INTERP_FNS: dict[int, Callable[[Array], list[tuple[Array, Array]]]] = {
0: _nearest_indices_and_weights,
1: _linear_indices_and_weights,
2: _quadratic_indices_and_weights,
3: _cubic_indices_and_weights,
4: _quartic_indices_and_weights,
5: _quintic_indices_and_weights,
}


@functools.partial(api.jit, static_argnums=(2, 3, 4))
def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
order: int, mode: str, cval: ArrayLike) -> Array:
@functools.partial(api.jit, static_argnums=(3, 4, 5, 6))
def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike], n_pad: int,
dtype: type[np.dtype], order: int, mode: str, cval: ArrayLike) -> Array:
input_arr = jnp.asarray(input)
coordinate_arrs = [jnp.asarray(c) for c in coordinates]
coordinate_arrs = [jnp.asarray(c) + n_pad for c in coordinates]
cval = jnp.asarray(cval, input_arr.dtype)

if len(coordinates) != input_arr.ndim:
Expand All @@ -91,13 +161,11 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
else:
is_valid = lambda index, size: True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
interp_fun = _INTERP_FNS.get(int(order))
if interp_fun is None:
raise NotImplementedError(
'jax.scipy.ndimage.map_coordinates currently requires order<=1')
'jax.scipy.ndimage.map_coordinates does not yet support order {}. '
'Currently supported orders are {}.'.format(int(order), set(_INTERP_FNS)))

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
Expand All @@ -120,15 +188,15 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
contribution = jnp.where(all_valid, input_arr[indices], cval)
outputs.append(_nonempty_prod(weights) * contribution) # type: ignore
result = _nonempty_sum(outputs)
if dtypes.issubdtype(input_arr.dtype, np.integer):
if dtypes.issubdtype(dtype, np.integer):
result = _round_half_away_from_zero(result)
return result.astype(input_arr.dtype)
return result.astype(dtype)


def map_coordinates(
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int,
mode: str = 'constant', cval: ArrayLike = 0.0,
):
mode: str = 'constant', cval: ArrayLike = 0.0, prefilter: bool = True,
) -> Array:
"""
Map the input array to new coordinates using interpolation.

Expand All @@ -141,11 +209,8 @@ def map_coordinates(
input: N-dimensional input array from which values are interpolated.
coordinates: length-N sequence of arrays specifying the coordinates
at which to evaluate the interpolated values
order: The order of interpolation. JAX supports the following:

* 0: Nearest-neighbor
* 1: Linear

order: The order of interpolation. JAX supports orders 0-5, where 0 is nearest-neighbor
interpolation, 1 is linear interpolation, 3 is cubic interpolation, etc.
mode: Points outside the boundaries of the input are filled according to the given mode.
JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. Note the
``'wrap'`` mode in JAX behaves as ``'grid-wrap'`` mode in SciPy, and ``'constant'``
Expand All @@ -156,6 +221,8 @@ def map_coordinates(
ones, for backwards compatibility reasons. Default is 'constant'.
cval: Value used for points outside the boundaries of the input if ``mode='constant'``
Default is 0.0.
prefilter: Determines if the array is prefiltered with :func:`spline_prefilter` before
use. The default is `True`. Only has an effect for ``order > 1``.

Returns:
The interpolated values at the specified coordinates.
Expand All @@ -177,4 +244,177 @@ def map_coordinates(
This function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""
return _map_coordinates(input, coordinates, order, mode, cval)

input = jnp.asarray(input)
dtype = input.dtype

n_pad = 0
if order > 1 and prefilter:
if mode in ('nearest', 'constant'):
n_pad = 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The padding size n_pad = 12 seems to be a magic number. Could you add a comment explaining how this value was chosen? It seems related to the maximum spline order supported, but making this explicit would improve clarity.

if mode == 'nearest':
input = jnp.pad(input, n_pad, 'edge')
else:
input = jnp.pad(input, n_pad, 'constant', constant_values=cval)
mode = 'mirror'
input = spline_filter(input.astype(float), order, mode)

return _map_coordinates(input, coordinates, n_pad, dtype, order, mode, cval)


def _init_mirror_causal(arr: Array, z: float) -> Array:
idx = jnp.arange(0, arr.size - 1, dtype=arr.dtype)
z_n = z**(arr.dtype.type(arr.size) - 1)
return (
jnp.sum(z**idx * (arr[:-1] + z_n * arr[:0:-1]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slice arr[:0:-1] is a bit obscure. Using arr[1:][::-1] is equivalent and more explicit about reversing the array starting from the second element. This would improve readability.

Suggested change
jnp.sum(z**idx * (arr[:-1] + z_n * arr[:0:-1]))
jnp.sum(z**idx * (arr[:-1] + z_n * arr[1:][::-1]))

) / (1 - z_n**2)

def _init_mirror_anticausal(arr: Array, z: float) -> Array:
return z / (z**2 - 1) * (z * arr[-2] + arr[-1])

def _init_wrap_causal(arr: Array, z: float) -> Array:
idx = jnp.arange(1, arr.size, dtype=arr.dtype)
return (
arr[0] + jnp.sum(z**idx * arr[:0:-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slice arr[:0:-1] is a bit obscure. Using arr[1:][::-1] is equivalent and more explicit about reversing the array starting from the second element. This would improve readability.

Suggested change
arr[0] + jnp.sum(z**idx * arr[:0:-1])
arr[0] + jnp.sum(z**idx * arr[1:][::-1])

) / (1 - z**arr.size)

def _init_wrap_anticausal(arr: Array, z: float) -> Array:
idx = jnp.arange(1, arr.size, dtype=arr.dtype)
return (
arr[-1] + jnp.sum(z**idx * arr[:-1])
) * z / (z**arr.size - 1)

def _init_reflect_causal(arr: Array, z: float) -> Array:
idx = jnp.arange(arr.size, dtype=arr.dtype)
z_n = z**arr.dtype.type(arr.size)
return arr[0] + z / (1 - z_n**2) * jnp.sum(z**idx * (arr + z_n * arr[::-1]))

def _init_reflect_anticausal(arr: Array, z: float) -> Array:
return z / (z - 1) * arr[-1]

_SPLINE_BOUNDARY_FNS: dict[str, tuple[Callable[[Array, float], Array], Callable[[Array, float], Array]]] = {
'reflect': (_init_reflect_causal, _init_reflect_anticausal),
'wrap': (_init_wrap_causal, _init_wrap_anticausal),
'mirror': (_init_mirror_causal, _init_mirror_anticausal),
# closest b.c. to nearest
'nearest': (_init_reflect_causal, _init_reflect_anticausal),
# default to mirror boundary
'constant': (_init_mirror_causal, _init_mirror_anticausal),
}

_SPLINE_FILTER_POLES: dict[int, list[float]] = {
2: [-0.171572875253809902396622551580603843],
3: [-0.267949192431122706472553658494127633],
4: [-0.361341225900220177092212841325675255, -0.013725429297339121360331226939128204],
5: [-0.430575347099973791851434783493520110, -0.043096288203264653822712376822550182],
}


@functools.partial(api.jit, static_argnums=(1, 2, 3))
def _spline_filter1d(
input: Array, order: int, axis: int, mode: str = 'mirror',
) -> Array:
from jax._src.lax.control_flow.loops import associative_scan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import of associative_scan is inside the _spline_filter1d function. According to PEP 8, imports should usually be at the top of the file. Please move this import to the top-level imports of the file.


poles = _SPLINE_FILTER_POLES.get(order)
if poles is None:
raise ValueError("Spline order '{}' not supported for pre-filtering".format(order))

(causal_fn, anticausal_fn) = _SPLINE_BOUNDARY_FNS.get(mode, (None, None))
if causal_fn is None or anticausal_fn is None:
raise ValueError("Boundary mode '{}' not supported for pre-filtering".format(mode))
Comment on lines +323 to +325
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This get followed by a check for the default value can be made slightly more direct by checking for key existence first.

Suggested change
(causal_fn, anticausal_fn) = _SPLINE_BOUNDARY_FNS.get(mode, (None, None))
if causal_fn is None or anticausal_fn is None:
raise ValueError("Boundary mode '{}' not supported for pre-filtering".format(mode))
if mode not in _SPLINE_BOUNDARY_FNS:
raise ValueError("Boundary mode '{}' not supported for pre-filtering".format(mode))
causal_fn, anticausal_fn = _SPLINE_BOUNDARY_FNS[mode]


gain = functools.reduce(operator.mul, (
(1.0 - z) * (1.0 - 1.0 / z) for z in poles
))
arr = input.astype(float) * gain

# compose an affine transform (y = k*x + b)
# t1 @ t0 => y = (k0*k1)*x + (b0 + k0*b1)
def compose_affine(t1: tuple[Array, Array], t0: tuple[Array, Array]) -> tuple[Array, Array]:
return (t0[0] * t1[0], t0[1] + t0[0]*t1[1])

for z in poles:
# causal
init = jnp.apply_along_axis(lambda arr: jnp.array([causal_fn(arr, z)]), axis, arr)
arr_rest = lax.slicing.slice_in_dim(arr, 1, None, axis=axis)
K, B = associative_scan(compose_affine, (jnp.full_like(arr_rest, z), arr_rest), axis=axis)
arr = lax.concatenate([init, K * init + B], axis)

# anticausal
init = jnp.apply_along_axis(lambda arr: jnp.array([anticausal_fn(arr, z)]), axis, arr)
arr_rest = lax.slicing.slice_in_dim(arr, None, -1, axis=axis)
K, B = associative_scan(compose_affine, (jnp.full_like(arr_rest, z), -z * arr_rest), axis=axis, reverse=True)
arr = lax.concatenate([K * init + B, init], axis)

if dtypes.issubdtype(input.dtype, np.integer):
arr = _round_half_away_from_zero(arr)
return arr.astype(input.dtype)


def spline_filter(
input: ArrayLike,
order: int = 3,
mode: str = 'mirror',
) -> Array:
"""
Applies a multidimensional spline pre-filter.

JAX implementation of :func:`scipy.ndimage.spline_filter`.

Given an input array, this function pre-calculates the B-spline coefficients
for an interpolation with the given order and boundary conditions. These
coefficients can then be consumed by interpolation functions with ``prefilter=False``.

Args:
input: N-dimensional input array for which prefiltering is performed
order: The order of the spline. Supported orders are 2-5.
mode: Boundary mode to use. See :func:`map_coordinates` for more details.
Modes 'nearest' and 'constant' cannot be used, as they have no analytic
solution for the prefilter. Instead, pad the array by the filter size
prior to pre-filtering.

Returns:
An array of B-spline coefficients with the same shape and dtype as ``input``.
"""
arr = jnp.asarray(input)

for ax in range(arr.ndim):
arr = spline_filter1d(arr, order, ax, mode)
return arr


def spline_filter1d(
input: ArrayLike,
order: int = 3,
axis: int = -1,
mode: str = 'mirror',
) -> Array:
"""
Applies a one-dimensional spline pre-filter.

JAX implementation of :func:`scipy.ndimage.spline_filter1d`.

Given an input array, this function pre-calculates the B-spline coefficients
for an interpolation with the given order and boundary conditions along the given axis.
These coefficients can then be consumed by interpolation functions with ``prefilter=False``.

Args:
input: N-dimensional input array for which prefiltering is performed
order: The order of the spline. Supported orders are 2-5.
axis: Axis to apply the spline filter along.
mode: Boundary mode to use. See :func:`map_coordinates` for more details.
Modes 'nearest' and 'constant' cannot be used, as they have no analytic
solution for the prefilter. Instead, pad the array by the filter size
prior to pre-filtering.

Returns:
An array of B-spline coefficients with the same shape and dtype as ``input``.
"""
if mode in ('nearest', 'constant'):
raise ValueError("Boundary mode '{}' has no exact filter. "
"Instead, pad the array by the filter size "
"and use mode 'mirror'".format(mode))
input = jnp.asarray(input)
axis = util.canonicalize_axis(axis, input.ndim)
return _spline_filter1d(input, order, axis, mode)
2 changes: 2 additions & 0 deletions jax/scipy/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@

from jax._src.scipy.ndimage import (
map_coordinates as map_coordinates,
spline_filter as spline_filter,
spline_filter1d as spline_filter1d,
)
Loading