diff --git a/AUTHORS b/AUTHORS index 1595c46..b5f563f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -22,4 +22,5 @@ Contributors: * Lily Wang * Josh Vermaas * Irfan Alibay -* Zhiyi Wu \ No newline at end of file +* Zhiyi Wu +* Olivier Languin-Cattoën diff --git a/CHANGELOG b/CHANGELOG index 318bab8..04c0480 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -13,14 +13,26 @@ The rules for this file: * accompany each entry with github issue/PR number (Issue #xyz) ------------------------------------------------------------------------------ -??/??/???? IAlibay - * 1.0.3 +??/??/???? IAlibay, ollyfutur + * 1.1.0 Changes * Python 3.13 and 3.14 are now supported (PR #140) * Support for Python 3.9 and 3.10 is now dropped as per SPEC0 (PR #140) + Enhancements + + * `Grid` now accepts binary operations with any operand that can be + broadcasted to the grid's shape according to `numpy` broadcasting rules + (PR #142) + + Fixes + + * Attempting binary operations on grids with different edges now raises an + exception (PR #142) + + 10/21/2023 IAlibay, orbeckst, lilyminium * 1.0.2 diff --git a/gridData/core.py b/gridData/core.py index e66a822..3395d62 100644 --- a/gridData/core.py +++ b/gridData/core.py @@ -711,7 +711,7 @@ def check_compatible(self, other): `other` is compatible if - 1) `other` is a scalar + 1) `other` is a scalar or an array-like broadcastable to the grid 2) `other` is a grid defined on the same edges In order to make `other` compatible, resample it on the same @@ -732,13 +732,22 @@ def check_compatible(self, other): -------- :meth:`resample` """ - - if not (numpy.isreal(other) or self == other): + if isinstance(other, Grid): + is_compatible = all( + numpy.allclose(other_edge, self_edge) + for other_edge, self_edge in zip(other.edges, self.edges) + ) + else: + try: + is_compatible = numpy.broadcast(self.grid, other).shape == self.grid.shape + except ValueError: + is_compatible = False + if not is_compatible: raise TypeError( "The argument cannot be arithmetically combined with the grid. " - "It must be a scalar or a grid with identical edges. " - "Use Grid.resample(other.edges) to make a new grid that is " - "compatible with other.") + "It must be broadcastable to the grid's shape or a `Grid` with identical edges. " + "Use `Grid.resample(other.edges)` to make a new grid that is " + "compatible with `other`.") return True def _interpolationFunctionFactory(self, spline_order=None, cval=None): diff --git a/gridData/tests/test_grid.py b/gridData/tests/test_grid.py index 8465be4..bc6c9c5 100644 --- a/gridData/tests/test_grid.py +++ b/gridData/tests/test_grid.py @@ -107,12 +107,18 @@ def test_power(self, data): def test_compatibility_type(self, data): assert data['grid'].check_compatible(data['grid']) assert data['grid'].check_compatible(3) - g = Grid(data['griddata'], origin=data['origin'] - 1, delta=data['delta']) + g = Grid(data['griddata'], origin=data['origin'], delta=data['delta']) assert data['grid'].check_compatible(g) + assert data['grid'].check_compatible(g.grid) def test_wrong_compatibile_type(self, data): + g = Grid(data['griddata'], origin=data['origin'] + 1, delta=data['delta']) + with pytest.raises(TypeError): + data['grid'].check_compatible(g) + + arr = np.zeros(data['griddata'].shape[-1] + 1) # Not broadcastable with pytest.raises(TypeError): - data['grid'].check_compatible("foo") + data['grid'].check_compatible(arr) def test_non_orthonormal_boxes(self, data): delta = np.eye(3)