diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 1189986c..0b5d14a1 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -130,6 +130,31 @@ def test_broadcast_arrays(shapes, data): raise +@pytest.mark.min_version("2025.12") +class TestBroadcastShapes: + + @given(shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes)) + def test_broadcast_shapes(self, shapes): + repro_snippet = ph.format_snippet(f"xp.broadcast_shapes(*shapes) with {shapes = }") + try: + out_shape = xp.broadcast_shapes(*shapes) + expected_shape = sh.broadcast_shapes(*shapes) + assert out_shape == expected_shape + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + + def test_empty(self): + assert xp.broadcast_shapes() == () + + @given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1, min_side=3)) + def test_error(self, shapes): + s1, s2 = shapes + s1 = s1[:-1] + (s1[-1] + 1,) + with pytest.raises(ValueError): + xp.broadcast_shapes(s1, s2) + + @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data()) def test_broadcast_to(x, data): shape = data.draw(