@@ -31,27 +31,35 @@ def test_getitem(shape, data):
3131 obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
3232 x = xp .asarray (obj , dtype = dtype )
3333 note (f"{ x = } " )
34- key = data .draw (xps .indices (shape = shape ), label = "key" )
34+ key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
3535
3636 out = x [key ]
3737
3838 ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
3939 _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
4040 if Ellipsis in _key :
41- start_a = _key .index (Ellipsis )
42- stop_a = start_a + (len (shape ) - (len (_key ) - 1 ))
43- slices = tuple (slice (None , None ) for _ in range (start_a , stop_a ))
44- _key = _key [:start_a ] + slices + _key [start_a + 1 :]
41+ nonexpanding_key = tuple (i for i in _key if i is not None )
42+ start_a = nonexpanding_key .index (Ellipsis )
43+ stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
44+ slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
45+ start_pos = _key .index (Ellipsis )
46+ _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
4547 axes_indices = []
4648 out_shape = []
47- for a , i in enumerate (_key ):
48- if isinstance (i , int ):
49- axes_indices .append ([i ])
49+ a = 0
50+ for i in _key :
51+ if i is None :
52+ out_shape .append (1 )
5053 else :
51- side = shape [a ]
52- indices = range (side )[i ]
53- axes_indices .append (indices )
54- out_shape .append (len (indices ))
54+ if isinstance (i , int ):
55+ axes_indices .append ([i ])
56+ else :
57+ assert isinstance (i , slice ) # sanity check
58+ side = shape [a ]
59+ indices = range (side )[i ]
60+ axes_indices .append (indices )
61+ out_shape .append (len (indices ))
62+ a += 1
5563 out_shape = tuple (out_shape )
5664 ph .assert_shape ("__getitem__" , out .shape , out_shape )
5765 assume (all (len (indices ) > 0 for indices in axes_indices ))
@@ -104,8 +112,6 @@ def test_setitem(shape, data):
104112 )
105113
106114
107- # TODO: make mask tests optional
108-
109115@pytest .mark .data_dependent_shapes
110116@given (hh .shapes (), st .data ())
111117def test_getitem_masking (shape , data ):
0 commit comments