2020@inline function reduce_group (op, val:: T , neutral, shuffle:: Val{true} , :: Val{maxthreads} ) where {T, maxthreads}
2121 # shared mem for partial sums
2222 assume (threads_per_simdgroup () == 32 )
23- shared = MtlThreadGroupArray (T, 32 )
23+ shared = KI . localmemory (T, 32 )
2424
2525 wid = simdgroup_index_in_threadgroup ()
2626 lane = thread_index_in_simdgroup ()
3434 end
3535
3636 # wait for all partial reductions
37- threadgroup_barrier (MemoryFlagThreadGroup )
37+ KI . barrier ( )
3838
3939 # read from shared memory only if that warp existed
40- val = if thread_index_in_threadgroup () <= fld1 (threads_per_threadgroup (). x, 32 )
40+ val = if KI . get_local_id () . x <= fld1 (KI . get_local_size (). x, 32 )
4141 @inbounds shared[lane]
4242 else
4343 neutral
5252
5353# Reduce a value across a group, using local memory for communication
5454@inline function reduce_group (op, val:: T , neutral, shuffle:: Val{false} , :: Val{maxthreads} ) where {T, maxthreads}
55- threads = threads_per_threadgroup (). x
56- thread = thread_position_in_threadgroup (). x
55+ threads = KI . get_local_size (). x
56+ thread = KI . get_local_id (). x
5757
5858 # local mem for a complete reduction
59- shared = MtlThreadGroupArray (T, (maxthreads,))
59+ shared = KI . localmemory (T, (maxthreads,))
6060 @inbounds shared[thread] = val
6161
6262 # perform a reduction
6363 d = 1
6464 while d < threads
65- threadgroup_barrier (MemoryFlagThreadGroup )
65+ KI . barrier ( )
6666 index = 2 * d * (thread- 1 ) + 1
6767 @inbounds if index <= threads
6868 other_val = if index + d <= threads
@@ -94,9 +94,9 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
9494 :: Val{Rother} , :: Val{Rlen} , :: Val{grain} , shuffle, R, As... ) where {Rreduce, Rother, Rlen, grain}
9595 # decompose the 1D hardware indices into separate ones for reduction (across items
9696 # and possibly groups if it doesn't fit) and other elements (remaining groups)
97- localIdx_reduce = thread_position_in_threadgroup (). x
98- localDim_reduce = threads_per_threadgroup (). x * grain
99- groupIdx_reduce, groupIdx_other = fldmod1 (threadgroup_position_in_grid (). x, Rlen)
97+ localIdx_reduce = KI . get_local_id (). x
98+ localDim_reduce = KI . get_local_size (). x * grain
99+ groupIdx_reduce, groupIdx_other = fldmod1 (KI . get_group_id (). x, Rlen)
100100
101101 # group-based indexing into the values outside of the reduction dimension
102102 # (that means we can safely synchronize items within this group)
@@ -141,7 +141,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
141141end
142142
143143function serial_mapreduce_kernel (f, op, neutral, :: Val{Rreduce} , :: Val{Rother} , R, As) where {Rreduce, Rother}
144- grid_idx = thread_position_in_grid (). x
144+ grid_idx = KI . get_global_id (). x
145145
146146 @inbounds if grid_idx <= length (Rother)
147147 Iother = Rother[grid_idx]
@@ -166,11 +166,12 @@ end
166166
167167# # COV_EXCL_STOP
168168
169- serial_mapreduce_threshold (dev) = dev . maxThreadsPerThreadgroup . width * num_gpu_cores ( )
169+ serial_mapreduce_threshold (dev) = KI . max_work_group_size ( MetalBackend ()) * KI . multiprocessor_count ( MetalBackend () )
170170
171171function GPUArrays. mapreducedim! (f:: F , op:: OP , R:: WrappedMtlArray{T} ,
172172 A:: Union{AbstractArray,Broadcast.Broadcasted} ;
173173 init= nothing ) where {F, OP, T}
174+ backend = MetalBackend ()
174175 Base. check_reducedims (R, A)
175176 length (A) == 0 && return R # isempty(::Broadcasted) iterates
176177
@@ -195,10 +196,10 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
195196
196197 # If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197198 if length (Rother) >= serial_mapreduce_threshold (device (R))
198- kernel = @metal launch= false serial_mapreduce_kernel (f, op, init, Val (Rreduce), Val (Rother), R, A)
199- threads = min (length (Rother), kernel. pipeline. maxTotalThreadsPerThreadgroup)
199+ kernel = KI . @kikernel backend launch = false serial_mapreduce_kernel (f, op, init, Val (Rreduce), Val (Rother), R, A)
200+ threads = min (length (Rother), kernel. kern . pipeline. maxTotalThreadsPerThreadgroup)
200201 groups = cld (length (Rother), threads)
201- kernel (f, op, init, Val (Rreduce), Val (Rother), R, A; threads, groups)
202+ kernel (f, op, init, Val (Rreduce), Val (Rother), R, A; workgroupsize = threads, numworkgroups = groups)
202203 return R
203204 end
204205
@@ -223,18 +224,18 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
223224 # we might not be able to launch all those threads to reduce each slice in one go.
224225 # that's why each threads also loops across their inputs, processing multiple values
225226 # so that we can span the entire reduction dimension using a single item group.
226- kernel = @metal launch= false partial_mapreduce_device (f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
227+ kernel = KI . @kikernel backend launch = false partial_mapreduce_device (f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
227228 Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A)
228229
229230 # how many threads do we want?
230231 #
231232 # threads in a group work together to reduce values across the reduction dimensions;
232233 # we want as many as possible to improve algorithm efficiency and execution occupancy.
233234 function compute_threads (kern)
234- max_threads = kern . pipeline . maxTotalThreadsPerThreadgroup
235- wanted_threads = shuffle ? nextwarp (kern. pipeline, length (Rreduce)) : length (Rreduce)
235+ max_threads = KI . kernel_max_work_group_size (backend, kern)
236+ wanted_threads = shuffle ? nextwarp (kern. kern . pipeline, length (Rreduce)) : length (Rreduce)
236237 if wanted_threads > max_threads
237- shuffle ? prevwarp (kern. pipeline, max_threads) : max_threads
238+ shuffle ? prevwarp (kern. kern . pipeline, max_threads) : max_threads
238239 else
239240 wanted_threads
240241 end
@@ -259,15 +260,15 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
259260 # we can cover the dimensions to reduce using a single group
260261 kernel (f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
261262 Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A;
262- threads, groups)
263+ workgroupsize = threads, numworkgroups = groups)
263264 else
264265 # temporary empty array whose type will match the final partial array
265266 partial = similar (R, ntuple (_ -> 0 , Val (ndims (R)+ 1 )))
266267
267268 # NOTE: we can't use the previously-compiled kernel, or its launch configuration,
268269 # since the type of `partial` might not match the original output container
269270 # (e.g. if that was a view).
270- partial_kernel = @metal launch= false partial_mapreduce_device (
271+ partial_kernel = KI . @kikernel backend launch = false partial_mapreduce_device (
271272 f, op, init, Val (threads), Val (Rreduce),
272273 Val (Rother), Val (UInt64 (length (Rother))),
273274 Val (grain), Val (shuffle), partial, A)
@@ -286,7 +287,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
286287 partial_kernel (f, op, init, Val (threads), Val (Rreduce),
287288 Val (Rother), Val (UInt64 (length (Rother))),
288289 Val (grain), Val (shuffle), partial, A;
289- groups = partial_groups, threads = partial_threads)
290+ numworkgroups = partial_groups, workgroupsize = partial_threads)
290291
291292 GPUArrays. mapreducedim! (identity, op, R, partial; init= init)
292293 end
0 commit comments