@@ -264,57 +264,131 @@ function vml_prefix(t::DataType)
264264 error (" unknown type $t " )
265265end
266266
267+ if isdefined (Base, :_checkcontiguous )
268+ alldense (@nospecialize (x)) = Base. _checkcontiguous (Bool, x)
269+ else
270+ alldense (x) = x isa DenseArray
271+ alldense (x:: Base.ReshapedArray ) = alldense (parent (x))
272+ alldense (x:: Base.FastContiguousSubArray ) = alldense (parent (x))
273+ alldense (x:: Base.ReinterpretArray ) = alldense (parent (x))
274+ end
275+ alldense (x, y, z... ) = alldense (x) && alldense (y, z... )
276+
277+ if isdefined (Base, :merge_adjacent_dim )
278+ const merge_adjacent_dim = Base. merge_adjacent_dim
279+ else
280+ merge_adjacent_dim (:: Dims{0} , :: Dims{0} ) = 1 , 1 , 0
281+ merge_adjacent_dim (apsz:: Dims{1} , apst:: Dims{1} ) = apsz[1 ], apst[1 ], 1
282+ function merge_adjacent_dim (apsz:: Dims{N} , apst:: Dims{N} , n:: Int = 1 ) where {N}
283+ sz, st = apsz[n], apst[n]
284+ while n < N
285+ szₙ, stₙ = apsz[n+ 1 ], apst[n+ 1 ]
286+ if sz == 1
287+ sz, st = szₙ, stₙ
288+ elseif stₙ == st * sz || szₙ == 1
289+ sz *= szₙ
290+ else
291+ break
292+ end
293+ n += 1
294+ end
295+ return sz, st, n
296+ end
297+ end
298+
299+ getstrides (x... ) = map (stride1, x)
300+ function stride1 (x:: AbstractArray )
301+ alldense (x) && return 1
302+ ndims (x) == 1 && return stride (x, 1 )
303+ szs:: Dims = size (x)
304+ sts:: Dims = strides (x)
305+ _, st, n = merge_adjacent_dim (szs, sts)
306+ n === ndims (x) && return st
307+ throw (ArgumentError (" only support vector like inputs" ))
308+ end
309+
267310function def_unary_op (tin, tout, jlname, jlname!, mklname;
268311 vmltype = tin)
269- mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (vmltype))$mklname " ))
312+ mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (vmltype))$(mklname) I" ))
313+ mklfndense = Base. Meta. quot (Symbol (" $(vml_prefix (vmltype))$mklname " ))
270314 exports = Symbol[]
271315 (@isdefined jlname) || push! (exports, jlname)
272316 (@isdefined jlname!) || push! (exports, jlname!)
273317 @eval begin
274- function ($ jlname!)(out:: Array {$tout} , A:: Array {$tin} )
318+ function ($ jlname!)(out:: AbstractArray {$tout} , A:: AbstractArray {$tin} )
275319 size (out) == size (A) || throw (DimensionMismatch ())
276- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out)
320+ if alldense (out, A) || ((sts = getstrides (out, A)) == (1 , 1 ))
321+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out)
322+ else
323+ stᵒ, stᴬ = sts
324+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, out, stᵒ)
325+ end
277326 vml_check_error ()
278327 return out
279328 end
280329 $ (if tin == tout
281330 quote
282- function $ (jlname!)(A:: Array{$tin} )
283- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, A)
331+ function $ (jlname!)(A:: AbstractArray{$tin} )
332+ if alldense (A) || ((sts = getstrides (A)) == (1 ,))
333+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, A)
334+ else
335+ (stᴬ,) = sts
336+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, A, stᴬ)
337+ end
284338 vml_check_error ()
285339 return A
286340 end
287341 end
288342 end )
289- function ($ jlname)(A:: Array{$tin} )
290- out = similar (A, $ tout)
291- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out)
292- vml_check_error ()
293- return out
294- end
343+ ($ jlname)(A:: AbstractArray{$tin} ) = $ (jlname!)(similar (A, $ tout), A)
295344 $ (isempty (exports) ? nothing : Expr (:export , exports... ))
296345 end
297346end
298347
299348function def_binary_op (tin, tout, jlname, jlname!, mklname, broadcast)
300- mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$mklname " ))
349+ mklfndense = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$mklname " ))
350+ mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$(mklname) I" ))
301351 exports = Symbol[]
302352 (@isdefined jlname) || push! (exports, jlname)
303353 (@isdefined jlname!) || push! (exports, jlname!)
304354 @eval begin
305355 $ (isempty (exports) ? nothing : Expr (:export , exports... ))
306- function ($ jlname!)(out:: Array{$tout} , A:: Array{$tin} , B:: Array{$tin} )
307- size (out) == size (A) == size (B) || throw (DimensionMismatch (" Input arrays and output array need to have the same size" ))
308- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, B, out)
356+ function ($ jlname!)(out:: AbstractArray{$tout} , A:: AbstractArray{$tin} , B:: AbstractArray{$tin} )
357+ size (A) == size (B) || throw (DimensionMismatch (" Input arrays need to have the same size" ))
358+ size (out) == size (A) || throw (DimensionMismatch (" Output array need to have the same size with input" ))
359+ if alldense (out, A, B) || ((sts = getstrides (out, A, B)) == (1 , 1 , 1 ))
360+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, B, out)
361+ else
362+ stᵒ, stᴬ, stᴮ = sts
363+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, B, stᴮ, out, stᵒ)
364+ end
309365 vml_check_error ()
310366 return out
311367 end
312- function ($ jlname)(A:: Array{$tout} , B:: Array{$tin} )
313- size (A) == size (B) || throw (DimensionMismatch (" Input arrays need to have the same size" ))
314- out = similar (A)
315- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, B, out)
368+ ($ jlname)(A:: AbstractArray{$tin} , B:: AbstractArray{$tin} ) = ($ jlname!)(similar (A, $ tout), A, B)
369+ end
370+ end
371+
372+ function def_one2two_op (tin, tout, jlname, jlname!, mklname)
373+ mklfndense = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$mklname " ))
374+ mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$(mklname) I" ))
375+ exports = Symbol[]
376+ (@isdefined jlname) || push! (exports, jlname)
377+ (@isdefined jlname!) || push! (exports, jlname!)
378+ @eval begin
379+ $ (isempty (exports) ? nothing : Expr (:export , exports... ))
380+ function ($ jlname!)(out1:: AbstractArray{$tout} , out2:: AbstractArray{$tout} , A:: AbstractArray{$tin} )
381+ size (out1) == size (out2) || throw (DimensionMismatch (" Output arrays need to have the same size" ))
382+ size (A) == size (out2) || throw (DimensionMismatch (" Output array need to have the same size with input" ))
383+ if alldense (out1, out2, A) || ((sts = getstrides (out1, out2, A)) == (1 , 1 , 1 ))
384+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out1, out2)
385+ else
386+ st¹, st², stᴬ = sts
387+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, out1, st¹, out2, st²)
388+ end
316389 vml_check_error ()
317- return out
390+ return out1, out2
318391 end
392+ ($ jlname)(A:: AbstractArray{$tin} ) = ($ jlname!)(similar (A, $ tout), similar (A, $ tout), A)
319393 end
320394end
0 commit comments