Skip to content
15 changes: 14 additions & 1 deletion ext/AdaptiveArrayPoolsCUDAExt/AdaptiveArrayPoolsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,20 @@ Loaded automatically when `using CUDA` with AdaptiveArrayPools.
module AdaptiveArrayPoolsCUDAExt

using AdaptiveArrayPools
using AdaptiveArrayPools: AbstractTypedPool, AbstractArrayPool, CACHE_WAYS
using AdaptiveArrayPools: AbstractTypedPool, AbstractArrayPool
using Preferences: @load_preference, @set_preferences!

# N-way view cache configuration (CUDA only — CPU ≥1.11 uses slot-first _claim_slot!).
# GPU view/reshape allocates ~80 bytes on CPU heap, so caching still matters.
const CACHE_WAYS = let
ways = @load_preference("cache_ways", 4)::Int
if ways < 1 || ways > 16
@warn "CACHE_WAYS=$ways out of range [1,16], using default 4"
4
else
ways
end
end
using CUDA

# Type definitions
Expand Down
22 changes: 4 additions & 18 deletions ext/AdaptiveArrayPoolsCUDAExt/acquire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# - Could track recent N sizes to make smarter decisions (avoid shrink if sizes fluctuate)
# ==============================================================================

using AdaptiveArrayPools: get_view!, get_nd_view!, get_nd_array!, allocate_vector, safe_prod,
using AdaptiveArrayPools: get_view!, get_array!, allocate_vector, safe_prod,
_record_type_touch!, _fixed_slot_bit, _checkpoint_typed_pool!,
_MODE_BITS_MASK

Expand Down Expand Up @@ -138,30 +138,16 @@ See module header for "lazy shrink" optimization notes.
end

# ==============================================================================
# CUDA-Specific get_nd_view! - Delegates to unified get_view!
# CUDA-Specific get_array! - Delegates to unified get_view!
# ==============================================================================

"""
get_nd_view!(tp::CuTypedPool{T}, dims::NTuple{N,Int}) -> CuArray{T,N}

Delegates to `get_view!(tp, dims)` for unified caching.
This override exists for API compatibility with the base package.
"""
@inline function AdaptiveArrayPools.get_nd_view!(tp::CuTypedPool{T}, dims::NTuple{N, Int}) where {T, N}
return get_view!(tp, dims)
end

# ==============================================================================
# CUDA-Specific get_nd_array! - Delegates to unified get_view!
# ==============================================================================

"""
get_nd_array!(tp::CuTypedPool{T}, dims::NTuple{N,Int}) -> CuArray{T,N}
get_array!(tp::CuTypedPool{T}, dims::NTuple{N,Int}) -> CuArray{T,N}

Delegates to `get_view!(tp, dims)` for unified caching.
Used by `unsafe_acquire!` - same zero-allocation behavior as `acquire!`.
"""
@inline function AdaptiveArrayPools.get_nd_array!(tp::CuTypedPool{T}, dims::NTuple{N, Int}) where {T, N}
@inline function AdaptiveArrayPools.get_array!(tp::CuTypedPool{T}, dims::NTuple{N, Int}) where {T, N}
return get_view!(tp, dims)
end

Expand Down
12 changes: 1 addition & 11 deletions ext/AdaptiveArrayPoolsCUDAExt/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ==============================================================================
# Key dispatch points for GPU-specific allocation and type routing.

using AdaptiveArrayPools: allocate_vector, wrap_array, get_typed_pool!
using AdaptiveArrayPools: allocate_vector, get_typed_pool!

# ==============================================================================
# Allocation Dispatch (single GPU-specific method needed!)
Expand All @@ -13,16 +13,6 @@ using AdaptiveArrayPools: allocate_vector, wrap_array, get_typed_pool!
::AbstractTypedPool{T, CuVector{T}}, n::Int
) where {T} = CuVector{T}(undef, n)

# ==============================================================================
# Array Wrapping Dispatch
# ==============================================================================

# GPU uses reshape which returns CuArray{T,N} via GPUArrays derive()
# (NOT ReshapedArray like CPU - this is simpler for GPU kernels)
@inline AdaptiveArrayPools.wrap_array(
::AbstractTypedPool{T, CuVector{T}}, flat_view, dims::NTuple{N, Int}
) where {T, N} = reshape(flat_view, dims)

# ==============================================================================
# get_typed_pool! Dispatches for CuAdaptiveArrayPool
# ==============================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/AdaptiveArrayPools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ export Bit # Sentinel type for BitArray (use with acquire!, trues!, falses!)
export @with_pool, @maybe_with_pool
export USE_POOLING, MAYBE_POOLING_ENABLED, POOL_DEBUG
export checkpoint!, rewind!, reset!
export CACHE_WAYS, set_cache_ways! # N-way cache configuration
export get_task_local_cuda_pool, get_task_local_cuda_pools # CUDA (stubs, overridden by extension)

# Extension API (for GPU backends)
Expand All @@ -30,6 +29,7 @@ export DisabledPool, DISABLED_CPU, pooling_enabled # Disabled pool support
include("task_local_pool.jl")
include("macros.jl")
else
export CACHE_WAYS, set_cache_ways! # N-way cache configuration (legacy only)
include("legacy/types.jl")
include("utils.jl")
include("legacy/acquire.jl")
Expand Down
163 changes: 73 additions & 90 deletions src/acquire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@
@inline allocate_vector(::AbstractTypedPool{T, Vector{T}}, n::Int) where {T} =
Vector{T}(undef, n)

# Wrap flat view into N-D array (dispatch point for extensions)
@inline function wrap_array(
::AbstractTypedPool{T, Vector{T}},
flat_view, dims::NTuple{N, Int}
) where {T, N}
return unsafe_wrap(Array{T, N}, pointer(flat_view), dims)
end

# ==============================================================================
# Helper: Overflow-Safe Product
# ==============================================================================
Expand Down Expand Up @@ -61,73 +53,78 @@ end
end

# ==============================================================================
# Get 1D View (Internal - Zero-Allocation Cache)
# Slot Claim — Shared Primitive for All Acquisition Paths
# ==============================================================================

"""
get_view!(tp::AbstractTypedPool{T}, n::Int)
_claim_slot!(tp::TypedPool{T}, n::Int) -> Int

Get a 1D vector view of size `n` from the typed pool.
Returns cached view on hit (zero allocation), creates new on miss.
Claim the next slot, ensuring the backing vector exists and has capacity >= `n`.
Returns the slot index. This is the shared primitive for all acquisition paths
(`get_view!`, `get_array!`).
"""
function get_view!(tp::AbstractTypedPool{T}, n::Int) where {T}
@inline function _claim_slot!(tp::TypedPool{T}, n::Int) where {T}
tp.n_active += 1
idx = tp.n_active

# 1. Need to expand pool (new slot)
if idx > length(tp.vectors)
push!(tp.vectors, allocate_vector(tp, n))
new_view = view(tp.vectors[idx], 1:n)
push!(tp.views, new_view)
push!(tp.view_lengths, n)
_check_pool_growth(tp, idx)

return new_view
end

# 2. Cache hit: same size requested -> return cached view (ZERO ALLOC)
@inbounds cached_len = tp.view_lengths[idx]
if cached_len == n
return @inbounds tp.views[idx]
end

# 3. Cache miss: different size -> update cache
@inbounds vec = tp.vectors[idx]
if length(vec) < n
resize!(vec, n)
else
@inbounds vec = tp.vectors[idx]
if length(vec) < n
resize!(vec, n)
end
end

new_view = view(vec, 1:n)
@inbounds tp.views[idx] = new_view
@inbounds tp.view_lengths[idx] = n

return new_view
return idx
end

# ==============================================================================
# Slot Claim (for reshape! — wrapper-only, no backing memory)
# ==============================================================================

"""
_claim_slot!(tp::TypedPool{T}) -> Int

Claim the next slot index by incrementing `n_active`.
Ensures the slot exists in vectors/views/view_lengths arrays.
The backing vector at this slot is unused — this is for wrapper-only caching
(e.g., `reshape!` uses the slot index for `nd_wrapper` storage only).
Claim the next slot without provisioning memory (zero-length backing vector).
Used by `reshape!` which only needs the slot index for `nd_wrapper` caching —
the wrapper points to a different array's memory via `setfield!(:ref)`.
"""
@inline function _claim_slot!(tp::TypedPool{T}) where {T}
tp.n_active += 1
idx = tp.n_active
if idx > length(tp.vectors)
push!(tp.vectors, Vector{T}(undef, 0))
push!(tp.views, view(tp.vectors[idx], 1:0))
push!(tp.view_lengths, 0)
_check_pool_growth(tp, idx)
end
return idx
end

# ==============================================================================
# Get View (Internal — Always Fresh, SubArray is Stack-Allocated via SROA)
# ==============================================================================

"""
get_view!(tp::TypedPool{T}, n::Int) -> SubArray{T,1}
get_view!(tp::TypedPool{T}, dims::NTuple{N,Int}) -> ReshapedArray{T,N}

Get a pooled view from the typed pool.
- **1D**: Returns a fresh `SubArray` (stack-allocated via SROA in compiled code).
- **N-D**: Returns a `ReshapedArray` wrapping a 1D view (zero creation cost).

Always creates fresh views — caching is unnecessary since both `SubArray` and
`ReshapedArray` are small structs that SROA can stack-allocate.

Dispatches on `TypedPool{T}` (not `AbstractTypedPool`) because `_claim_slot!`
is only defined for `TypedPool{T}`. Other subtypes override `get_view!` directly
(e.g., `CuTypedPool`) or use a separate path (e.g., `BitTypedPool` → `get_bitarray!`).
"""
@inline function get_view!(tp::TypedPool{T}, n::Int) where {T}
idx = _claim_slot!(tp, n)
return @inbounds view(tp.vectors[idx], 1:n)
end

@inline function get_view!(tp::TypedPool{T}, dims::NTuple{N, Int}) where {T, N}
total_len = safe_prod(dims)
slot = _claim_slot!(tp, total_len)
return @inbounds reshape(view(tp.vectors[slot], 1:total_len), dims)
end

# ==============================================================================
# reshape! — Zero-Allocation Reshape (setfield!-based, Julia 1.11+)
# ==============================================================================
Expand Down Expand Up @@ -156,7 +153,7 @@ Zero-allocation reshape using `setfield!`-based wrapper reuse (Julia 1.11+).
)
)

# 0-D reshape: rare edge case, delegate to Base (nd_wrappers is 1-indexed by N)
# 0-D reshape: rare edge case, delegate to Base (arr_wrappers is 1-indexed by N)
N == 0 && return reshape(A, dims)

# Same dimensionality: just update size in-place, no pool interaction
Expand All @@ -170,7 +167,7 @@ Zero-allocation reshape using `setfield!`-based wrapper reuse (Julia 1.11+).
slot = _claim_slot!(tp)

# Look up cached wrapper (direct index, no hash)
wrappers = N <= length(tp.nd_wrappers) ? (@inbounds tp.nd_wrappers[N]) : nothing
wrappers = N <= length(tp.arr_wrappers) ? (@inbounds tp.arr_wrappers[N]) : nothing
if wrappers !== nothing && slot <= length(wrappers)
wrapper = @inbounds wrappers[slot]
if wrapper !== nothing
Expand All @@ -185,7 +182,7 @@ Zero-allocation reshape using `setfield!`-based wrapper reuse (Julia 1.11+).
arr = Array{T, N}(undef, ntuple(_ -> 0, Val(N)))
setfield!(arr, :ref, getfield(A, :ref))
setfield!(arr, :size, dims)
_store_nd_wrapper!(tp, N, slot, arr)
_store_arr_wrapper!(tp, N, slot, arr)
return arr
end

Expand All @@ -198,23 +195,23 @@ end
# unlimited dimension patterns per slot, 0-alloc after warmup for any dims with same N.

"""
_store_nd_wrapper!(tp::AbstractTypedPool, N::Int, slot::Int, wrapper)
_store_arr_wrapper!(tp::AbstractTypedPool, N::Int, slot::Int, wrapper)

Store a cached N-D wrapper for the given slot. Creates the per-N Vector if needed.
"""
function _store_nd_wrapper!(tp::AbstractTypedPool, N::Int, slot::Int, wrapper)
# Grow nd_wrappers vector so index N is valid
if N > length(tp.nd_wrappers)
old_len = length(tp.nd_wrappers)
resize!(tp.nd_wrappers, N)
function _store_arr_wrapper!(tp::AbstractTypedPool, N::Int, slot::Int, wrapper)
# Grow arr_wrappers vector so index N is valid
if N > length(tp.arr_wrappers)
old_len = length(tp.arr_wrappers)
resize!(tp.arr_wrappers, N)
for i in (old_len + 1):N
@inbounds tp.nd_wrappers[i] = nothing
@inbounds tp.arr_wrappers[i] = nothing
end
end
wrappers = @inbounds tp.nd_wrappers[N]
wrappers = @inbounds tp.arr_wrappers[N]
if wrappers === nothing
wrappers = Vector{Any}(nothing, slot)
@inbounds tp.nd_wrappers[N] = wrappers
@inbounds tp.arr_wrappers[N] = wrappers
elseif slot > length(wrappers)
old_len = length(wrappers)
resize!(wrappers, slot)
Expand All @@ -227,26 +224,21 @@ function _store_nd_wrapper!(tp::AbstractTypedPool, N::Int, slot::Int, wrapper)
end

"""
get_nd_array!(tp::AbstractTypedPool{T,Vector{T}}, dims::NTuple{N,Int}) -> Array{T,N}
get_array!(tp::AbstractTypedPool{T,Vector{T}}, dims::NTuple{N,Int}) -> Array{T,N}

Get an N-dimensional `Array` from the pool with `setfield!`-based wrapper reuse.

Uses Julia 1.11+ `setfield!` to mutate cached `Array` wrappers in-place:
- Same N (dimensionality): `setfield!(arr, :size, dims)` — 0 allocation
- Backing memory: `setfield!(arr, :ref, ...)` — always updated, 0 allocation in compiled code
- First call per (slot, N): `unsafe_wrap` once, then cached forever

Unlike the N-way cache (Julia 1.10), this has no eviction limit — unlimited dimension
patterns per slot are supported with zero allocation after warmup.
Uses `_claim_slot!` directly for slot management (independent of view path).
Cache hit: `setfield!(arr, :ref/size)` — 0 allocation.
Cache miss: creates wrapper via `setfield!` pattern, then cached forever.
"""
@inline function get_nd_array!(tp::AbstractTypedPool{T, Vector{T}}, dims::NTuple{N, Int}) where {T, N}
@inline function get_array!(tp::AbstractTypedPool{T, Vector{T}}, dims::NTuple{N, Int}) where {T, N}
total_len = safe_prod(dims)
flat_view = get_view!(tp, total_len) # Increments n_active, ensures backing vec
slot = tp.n_active
slot = _claim_slot!(tp, total_len)
@inbounds vec = tp.vectors[slot]

# Look up cached wrapper for this dimensionality (direct index, no hash)
wrappers = N <= length(tp.nd_wrappers) ? (@inbounds tp.nd_wrappers[N]) : nothing
wrappers = N <= length(tp.arr_wrappers) ? (@inbounds tp.arr_wrappers[N]) : nothing
if wrappers !== nothing && slot <= length(wrappers)
wrapper = @inbounds wrappers[slot]
if wrapper !== nothing
Expand All @@ -261,23 +253,14 @@ patterns per slot are supported with zero allocation after warmup.
end
end

# Cache miss: first call for this (slot, N) — unsafe_wrap once
arr = wrap_array(tp, flat_view, dims)
_store_nd_wrapper!(tp, N, slot, arr)
# Cache miss: first call for this (slot, N) — create via setfield! pattern
arr = Array{T, N}(undef, ntuple(_ -> 0, Val(N)))
setfield!(arr, :ref, getfield(vec, :ref))
setfield!(arr, :size, dims)
_store_arr_wrapper!(tp, N, slot, arr)
return arr
end

"""
get_nd_view!(tp::AbstractTypedPool{T}, dims::NTuple{N,Int})

Get an N-dimensional view via `reshape` (zero creation cost).
"""
@inline function get_nd_view!(tp::AbstractTypedPool{T}, dims::NTuple{N, Int}) where {T, N}
total_len = safe_prod(dims)
flat_view = get_view!(tp, total_len) # 1D view (cached, 0 alloc)
return reshape(flat_view, dims) # ReshapedArray (0 creation cost)
end

# ==============================================================================
# Type Touch Recording (for selective rewind)
# ==============================================================================
Expand Down Expand Up @@ -349,7 +332,7 @@ end

@inline function _acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N}
tp = get_typed_pool!(pool, T)
return get_nd_view!(tp, dims)
return get_view!(tp, dims)
end

@inline function _acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
Expand All @@ -366,17 +349,17 @@ Internal implementation of unsafe_acquire!. Called directly by macro-transformed
"""
@inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{T}, n::Int) where {T}
tp = get_typed_pool!(pool, T)
return get_nd_array!(tp, (n,))
return get_array!(tp, (n,))
end

@inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N}
tp = get_typed_pool!(pool, T)
return get_nd_array!(tp, dims)
return get_array!(tp, dims)
end

@inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
tp = get_typed_pool!(pool, T)
return get_nd_array!(tp, dims)
return get_array!(tp, dims)
end

# Similar-style
Expand Down
Loading