diff --git a/README.md b/README.md index 1347c16..be337db 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ Common initialization patterns have convenience functions: | `zeros!(pool, 10)` | `acquire!` + `fill!(0)` | | `ones!(pool, Float32, 3, 3)` | `acquire!` + `fill!(1)` | | `similar!(pool, A)` | `acquire!` matching `eltype(A)`, `size(A)` | +| `reshape!(pool, A, 3, 4)` | Reshape sharing memory, zero-alloc (1.11+) | These return views like `acquire!`. For raw `Array` types, use `unsafe_acquire!` or its convenience variants (`unsafe_zeros!`, `unsafe_ones!`, `unsafe_similar!`). See [API Reference](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/usage/api#convenience-functions). diff --git a/docs/src/basics/api-essentials.md b/docs/src/basics/api-essentials.md index 028bbae..5f793e3 100644 --- a/docs/src/basics/api-essentials.md +++ b/docs/src/basics/api-essentials.md @@ -62,6 +62,26 @@ Match existing array properties: end ``` +### Reshaping with `reshape!` + +Reshape an existing array using the pool's wrapper cache. The result shares memory with the original — mutations are visible in both: + +```julia +@with_pool pool function process_grid(data, nx, ny) + M = reshape!(pool, data, nx, ny) # 1D → 2D, shares memory with data + col_sums = zeros!(pool, Float64, ny) + for j in 1:ny, i in 1:nx + col_sums[j] += M[i, j] + end + return sum(col_sums) +end +``` + +On Julia 1.11+, cross-dimensional reshapes are **zero-allocation** after warmup via `setfield!`-based wrapper reuse. On Julia 1.10, falls back to `Base.reshape`. + +!!! warning "DimensionMismatch" + `prod(dims)` must equal `length(A)`, otherwise a `DimensionMismatch` is thrown. + ### Custom Initialization with `fill!` For values other than 0 or 1, use Julia's built-in `fill!`: @@ -117,6 +137,7 @@ end | `zeros!(pool, [T,] dims...)` | View type | 0 bytes | Zero-initialized | | `ones!(pool, [T,] dims...)` | View type | 0 bytes | One-initialized | | `similar!(pool, A)` | View type | 0 bytes | Match existing array | +| `reshape!(pool, A, dims...)` | Reshaped array | 0 bytes (1.11+) | Reshape sharing memory | | `reset!(pool)` | `nothing` | - | Release all memory | | `pooling_enabled(pool)` | `Bool` | - | Check pool status | diff --git a/docs/src/basics/quick-start.md b/docs/src/basics/quick-start.md index 87727c9..d909da9 100644 --- a/docs/src/basics/quick-start.md +++ b/docs/src/basics/quick-start.md @@ -62,6 +62,7 @@ Common initialization patterns have shortcuts: | `zeros!(pool, 10)` | `acquire!` + `fill!(0)` | | `ones!(pool, Float32, 3, 3)` | `acquire!` + `fill!(1)` | | `similar!(pool, A)` | `acquire!` matching `eltype(A)`, `size(A)` | +| `reshape!(pool, A, 3, 4)` | Reshape sharing memory, zero-alloc (1.11+) | ```julia @with_pool pool function example(n) diff --git a/docs/src/reference/api.md b/docs/src/reference/api.md index 2b808ea..a6680a7 100644 --- a/docs/src/reference/api.md +++ b/docs/src/reference/api.md @@ -35,6 +35,7 @@ Default element type is `Float64` (CPU) or `Float32` (CUDA). | `trues!(pool, dims...)` | Bit-packed `BitVector` / `BitArray{N}` filled with `true`. | | `falses!(pool, dims...)` | Bit-packed `BitVector` / `BitArray{N}` filled with `false`. | | `similar!(pool, A)` | View matching `eltype(A)` and `size(A)`. | +| `reshape!(pool, A, dims...)` | Reshape `A` to `dims`, sharing memory. Zero-alloc on Julia 1.11+. | ### Types diff --git a/src/AdaptiveArrayPools.jl b/src/AdaptiveArrayPools.jl index c369ca5..2cb226d 100644 --- a/src/AdaptiveArrayPools.jl +++ b/src/AdaptiveArrayPools.jl @@ -5,7 +5,7 @@ using Printf # Public API export AdaptiveArrayPool, acquire!, unsafe_acquire!, pool_stats, get_task_local_pool export acquire_view!, acquire_array! # Explicit naming aliases -export zeros!, ones!, trues!, falses!, similar!, default_eltype # Convenience functions +export zeros!, ones!, trues!, falses!, similar!, reshape!, default_eltype # Convenience functions export unsafe_zeros!, unsafe_ones!, unsafe_similar! # Unsafe convenience functions export Bit # Sentinel type for BitArray (use with acquire!, trues!, falses!) export @with_pool, @maybe_with_pool diff --git a/src/acquire.jl b/src/acquire.jl index b9c2611..6d9dea1 100644 --- a/src/acquire.jl +++ b/src/acquire.jl @@ -41,6 +41,23 @@ negligible relative to the 100-200 ns cost of the full allocation path. return total end +# ============================================================================== +# Helper: Pool Growth Warning (cold path, kept out of hot loops) +# ============================================================================== + +@noinline function _warn_pool_growing(tp::AbstractTypedPool{T}, idx::Int) where {T} + total_bytes = sum(length, tp.vectors) * sizeof(T) + @warn "$(nameof(typeof(tp))){$T} growing large ($idx arrays, ~$(Base.format_bytes(total_bytes))). Missing rewind!()?" + return nothing +end + +@inline function _check_pool_growth(tp::AbstractTypedPool, idx::Int) + # Warn at every power of 2 from 512 onward (512, 1024, 2048, …) + if idx >= 512 && (idx & (idx - 1)) == 0 + _warn_pool_growing(tp, idx) + end +end + # ============================================================================== # Get 1D View (Internal - Zero-Allocation Cache) # ============================================================================== @@ -61,12 +78,7 @@ function get_view!(tp::AbstractTypedPool{T}, n::Int) where {T} new_view = view(tp.vectors[idx], 1:n) push!(tp.views, new_view) push!(tp.view_lengths, n) - - # Warn at powers of 2 (512, 1024, 2048, ...) - possible missing rewind!() - if idx >= 512 && (idx & (idx - 1)) == 0 - total_bytes = sum(length, tp.vectors) * sizeof(T) - @warn "$(nameof(typeof(tp))){$T} growing large ($idx arrays, ~$(Base.format_bytes(total_bytes))). Missing rewind!()?" - end + _check_pool_growth(tp, idx) return new_view end @@ -90,6 +102,88 @@ function get_view!(tp::AbstractTypedPool{T}, n::Int) where {T} return new_view end +# ============================================================================== +# Slot Claim (for reshape! — wrapper-only, no backing memory) +# ============================================================================== + +""" + _claim_slot!(tp::TypedPool{T}) -> Int + +Claim the next slot index by incrementing `n_active`. +Ensures the slot exists in vectors/views/view_lengths arrays. +The backing vector at this slot is unused — this is for wrapper-only caching +(e.g., `reshape!` uses the slot index for `nd_wrapper` storage only). +""" +@inline function _claim_slot!(tp::TypedPool{T}) where {T} + tp.n_active += 1 + idx = tp.n_active + if idx > length(tp.vectors) + push!(tp.vectors, Vector{T}(undef, 0)) + push!(tp.views, view(tp.vectors[idx], 1:0)) + push!(tp.view_lengths, 0) + _check_pool_growth(tp, idx) + end + return idx +end + +# ============================================================================== +# reshape! — Zero-Allocation Reshape (setfield!-based, Julia 1.11+) +# ============================================================================== + +""" + _reshape_impl!(pool::AdaptiveArrayPool, A::Array{T,M}, dims::NTuple{N,Int}) -> Array{T,N} + +Zero-allocation reshape using `setfield!`-based wrapper reuse (Julia 1.11+). + +- **Same dimensionality (M == N)**: `setfield!(A, :size, dims)` — no pool interaction +- **Different dimensionality (M ≠ N)**: Claims a pool slot via `_claim_slot!`, + reuses cached `Array{T,N}` wrapper with `setfield!(:ref, :size)` pointing to `A`'s memory. + Automatically reclaimed on `rewind!` via `n_active` restoration. +""" +@inline function _reshape_impl!(pool::AdaptiveArrayPool, A::Array{T,M}, dims::NTuple{N,Int}) where {T,M,N} + # Reject negative dimensions (match Base.reshape behavior) + for d in dims + d < 0 && throw(ArgumentError("invalid Array dimensions")) + end + + # Validate before claiming slot + total_len = safe_prod(dims) + length(A) == total_len || throw(DimensionMismatch( + "new dimensions $(dims) must be consistent with array length $(length(A))")) + + # 0-D reshape: rare edge case, delegate to Base (nd_wrappers is 1-indexed by N) + N == 0 && return reshape(A, dims) + + # Same dimensionality: just update size in-place, no pool interaction + if M == N + setfield!(A, :size, dims) + return A + end + + # Different dimensionality: claim slot + reuse cached N-D wrapper + tp = get_typed_pool!(pool, T) + slot = _claim_slot!(tp) + + # Look up cached wrapper (direct index, no hash) + wrappers = N <= length(tp.nd_wrappers) ? (@inbounds tp.nd_wrappers[N]) : nothing + if wrappers !== nothing && slot <= length(wrappers) + wrapper = @inbounds wrappers[slot] + if wrapper !== nothing + arr = wrapper::Array{T,N} + setfield!(arr, :ref, getfield(A, :ref)) + setfield!(arr, :size, dims) + return arr + end + end + + # Cache miss (first call per slot+N): create wrapper, cache forever + arr = Array{T,N}(undef, ntuple(_ -> 0, Val(N))) + setfield!(arr, :ref, getfield(A, :ref)) + setfield!(arr, :size, dims) + _store_nd_wrapper!(tp, N, slot, arr) + return arr +end + # ============================================================================== # Get N-D Array (setfield!-based Wrapper Reuse, Julia 1.11+) # ============================================================================== diff --git a/src/convenience.jl b/src/convenience.jl index 053f8cb..84312e5 100644 --- a/src/convenience.jl +++ b/src/convenience.jl @@ -309,6 +309,67 @@ end _acquire_impl!(pool, T, dims...) end +# ============================================================================== +# reshape! - Reshape arrays using pool's wrapper cache +# ============================================================================== + +""" + reshape!(pool, A, dims...) -> reshaped array + reshape!(pool, A, dims::Tuple) -> reshaped array + +Reshape array `A` to dimensions `dims` using the pool's wrapper cache. + +The returned array shares memory with `A` — mutations are visible in both. +The pool provides cached wrapper objects to reduce allocation on repeated calls. + +On Julia 1.11+: +- If `ndims(A) == length(dims)` (same dimensionality), `reshape!` mutates `A` + in-place by changing its size. This differs from `Base.reshape`, which always + returns a new wrapper. +- For cross-dimensional reshapes (`ndims(A) != length(dims)`), the returned + `Array` wrapper is taken from the pool's internal cache and may be reused + after `rewind!` or pool scope exit. + +As with all pool-backed objects, the reshaped result must not escape the +surrounding `@with_pool` scope. + +On Julia 1.10 and CUDA, falls back to `Base.reshape`. + +Throws `DimensionMismatch` if `prod(dims) != length(A)`. + +## Example +```julia +A = collect(1.0:12.0) +@with_pool pool begin + B = reshape!(pool, A, 3, 4) # 12-element vector → 3×4 matrix + B[1,1] = 999.0 # A[1] is now 999.0 +end +``` + +See also: [`acquire!`](@ref), [`similar!`](@ref) +""" +@inline function reshape!(pool::AbstractArrayPool, A::AbstractArray{T}, dims::Vararg{Int,N}) where {T,N} + _record_type_touch!(pool, T) + _reshape_impl!(pool, A, dims) +end + +@inline function reshape!(pool::AbstractArrayPool, A::AbstractArray{T}, dims::NTuple{N,Int}) where {T,N} + _record_type_touch!(pool, T) + _reshape_impl!(pool, A, dims) +end + +# Internal implementation (fallback: delegates to Base.reshape) +@inline function _reshape_impl!(::AbstractArrayPool, A::AbstractArray, dims::NTuple{N,Int}) where {N} + for d in dims + d < 0 && throw(ArgumentError("invalid Array dimensions")) + end + reshape(A, dims) +end + +# Vararg forwarding (macro transforms reshape!(pool, A, 3, 4) → _reshape_impl!(pool, A, 3, 4)) +@inline _reshape_impl!(pool::AbstractArrayPool, A::AbstractArray, dims::Vararg{Int,N}) where {N} = + _reshape_impl!(pool, A, dims) + # ============================================================================== # unsafe_zeros! - Acquire zero-initialized raw arrays from pool # ============================================================================== @@ -587,6 +648,10 @@ end @inline similar!(::DisabledPool{:cpu}, x::AbstractArray, dims::Vararg{Int,N}) where {N} = similar(x, dims...) @inline similar!(::DisabledPool{:cpu}, x::AbstractArray, ::Type{T}, dims::Vararg{Int,N}) where {T,N} = similar(x, T, dims...) +# --- reshape! for DisabledPool{:cpu} --- +@inline reshape!(::DisabledPool{:cpu}, A::AbstractArray, dims::Vararg{Int,N}) where {N} = reshape(A, dims...) +@inline reshape!(::DisabledPool{:cpu}, A::AbstractArray, dims::NTuple{N,Int}) where {N} = reshape(A, dims) + # --- unsafe_zeros! for DisabledPool{:cpu} --- @inline unsafe_zeros!(::DisabledPool{:cpu}, ::Type{T}, dims::Vararg{Int,N}) where {T,N} = zeros(T, dims...) @inline unsafe_zeros!(p::DisabledPool{:cpu}, dims::Vararg{Int,N}) where {N} = zeros(default_eltype(p), dims...) @@ -614,6 +679,7 @@ end @inline unsafe_zeros!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B) @inline unsafe_ones!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B) @inline unsafe_similar!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B) +@inline reshape!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B) # ============================================================================== # _impl! Delegators for DisabledPool @@ -650,6 +716,9 @@ end @inline _similar_impl!(p::DisabledPool, x::AbstractArray, dims::Vararg{Int,N}) where {N} = similar!(p, x, dims...) @inline _similar_impl!(p::DisabledPool, x::AbstractArray, ::Type{T}, dims::Vararg{Int,N}) where {T,N} = similar!(p, x, T, dims...) +# --- _reshape_impl! --- +@inline _reshape_impl!(p::DisabledPool, A::AbstractArray, dims::NTuple{N,Int}) where {N} = reshape!(p, A, dims) + # --- _unsafe_zeros_impl! --- @inline _unsafe_zeros_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int,N}) where {T,N} = unsafe_zeros!(p, T, dims...) @inline _unsafe_zeros_impl!(p::DisabledPool, dims::Vararg{Int,N}) where {N} = unsafe_zeros!(p, dims...) diff --git a/src/macros.jl b/src/macros.jl index 21a7599..752c13e 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -783,6 +783,12 @@ function _extract_acquire_types(expr, target_pool, types=Set{Any}()) push!(types, Expr(:call, :eltype, expr.args[3])) end end + # reshape! + elseif fn in (:reshape!,) || fn_name in (:reshape!,) + # reshape!(pool, A, dims...) — extract eltype(A) from second arg + if nargs >= 3 + push!(types, Expr(:call, :eltype, expr.args[3])) + end end end end @@ -1015,6 +1021,7 @@ const _SIMILAR_IMPL_REF = GlobalRef(@__MODULE__, :_similar_impl!) const _UNSAFE_ZEROS_IMPL_REF = GlobalRef(@__MODULE__, :_unsafe_zeros_impl!) const _UNSAFE_ONES_IMPL_REF = GlobalRef(@__MODULE__, :_unsafe_ones_impl!) const _UNSAFE_SIMILAR_IMPL_REF = GlobalRef(@__MODULE__, :_unsafe_similar_impl!) +const _RESHAPE_IMPL_REF = GlobalRef(@__MODULE__, :_reshape_impl!) function _transform_acquire_calls(expr, pool_name) if expr isa Expr @@ -1040,6 +1047,8 @@ function _transform_acquire_calls(expr, pool_name) expr = Expr(:call, _FALSES_IMPL_REF, expr.args[2:end]...) elseif fn == :similar! expr = Expr(:call, _SIMILAR_IMPL_REF, expr.args[2:end]...) + elseif fn == :reshape! + expr = Expr(:call, _RESHAPE_IMPL_REF, expr.args[2:end]...) elseif fn == :unsafe_zeros! expr = Expr(:call, _UNSAFE_ZEROS_IMPL_REF, expr.args[2:end]...) elseif fn == :unsafe_ones! @@ -1063,6 +1072,8 @@ function _transform_acquire_calls(expr, pool_name) expr = Expr(:call, _FALSES_IMPL_REF, expr.args[2:end]...) elseif qn == QuoteNode(:similar!) expr = Expr(:call, _SIMILAR_IMPL_REF, expr.args[2:end]...) + elseif qn == QuoteNode(:reshape!) + expr = Expr(:call, _RESHAPE_IMPL_REF, expr.args[2:end]...) elseif qn == QuoteNode(:unsafe_zeros!) expr = Expr(:call, _UNSAFE_ZEROS_IMPL_REF, expr.args[2:end]...) elseif qn == QuoteNode(:unsafe_ones!) diff --git a/test/runtests.jl b/test/runtests.jl index a0d0459..f017cf2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ else include("test_fixed_slots.jl") include("test_backend_macro_expansion.jl") include("test_convenience.jl") + include("test_reshape.jl") include("test_bitarray.jl") include("test_coverage.jl") include("test_allocation.jl") @@ -56,6 +57,7 @@ else include("test_fixed_slots.jl") include("test_backend_macro_expansion.jl") include("test_convenience.jl") + include("test_reshape.jl") include("test_bitarray.jl") include("test_coverage.jl") include("test_allocation.jl") diff --git a/test/test_reshape.jl b/test/test_reshape.jl new file mode 100644 index 0000000..1316d9b --- /dev/null +++ b/test/test_reshape.jl @@ -0,0 +1,453 @@ +# ============================================================================== +# Tests for reshape! — Pool-based zero-allocation array reshaping +# ============================================================================== + +@testset "reshape!" begin + + # ========================================================================== + # Basic reshape (cross-dimensional) + # ========================================================================== + + @testset "Basic reshape (cross-dim)" begin + pool = AdaptiveArrayPool() + + # 1D → 2D + A = collect(1.0:12.0) + checkpoint!(pool) + B = reshape!(pool, A, 3, 4) + @test size(B) == (3, 4) + @test eltype(B) == Float64 + @test B[1, 1] == 1.0 + @test B[3, 4] == 12.0 + rewind!(pool) + + # 2D → 1D + A2d = reshape(collect(1.0:12.0), 3, 4) + checkpoint!(pool) + C = reshape!(pool, A2d, 12) + @test size(C) == (12,) + @test C[1] == 1.0 + @test C[12] == 12.0 + rewind!(pool) + + # 1D → 3D + checkpoint!(pool) + D = reshape!(pool, A, 2, 3, 2) + @test size(D) == (2, 3, 2) + @test D[1, 1, 1] == 1.0 + @test D[2, 3, 2] == 12.0 + rewind!(pool) + end + + # ========================================================================== + # Same-dim reshape + # ========================================================================== + + @testset "Same-dim reshape" begin + pool = AdaptiveArrayPool() + A2d = reshape(collect(1.0:12.0), 3, 4) + + checkpoint!(pool) + E = reshape!(pool, A2d, 4, 3) + @test size(E) == (4, 3) + @test length(E) == 12 + rewind!(pool) + end + + # ========================================================================== + # Data preservation / memory sharing + # ========================================================================== + + @testset "Data preservation / memory sharing" begin + pool = AdaptiveArrayPool() + A = collect(1.0:12.0) + + checkpoint!(pool) + B = reshape!(pool, A, 3, 4) + + # Data identity + @test vec(B) == A + + # Mutation in B visible in A + B[1, 1] = 999.0 + @test A[1] == 999.0 + + # Mutation in A visible in B + A[12] = -1.0 + @test B[3, 4] == -1.0 + rewind!(pool) + end + + # ========================================================================== + # DimensionMismatch + # ========================================================================== + + @testset "DimensionMismatch" begin + pool = AdaptiveArrayPool() + A = collect(1.0:12.0) + @test_throws DimensionMismatch reshape!(pool, A, 5, 5) + @test_throws DimensionMismatch reshape!(pool, A, (7, 3)) + end + + # ========================================================================== + # Negative dimensions (must throw ArgumentError like Base.reshape) + # ========================================================================== + + @testset "Negative dimensions" begin + pool = AdaptiveArrayPool() + A = [1.0] + @test_throws ArgumentError reshape!(pool, A, -1, -1) + @test_throws ArgumentError reshape!(pool, A, (-1,)) + @test_throws ArgumentError reshape!(pool, collect(1.0:4.0), -2, -2) + end + + # ========================================================================== + # Zero-dimensional reshape (N == 0) + # ========================================================================== + + @testset "Zero-dimensional reshape" begin + pool = AdaptiveArrayPool() + A = [42.0] + checkpoint!(pool) + B = reshape!(pool, A, ()) + @test ndims(B) == 0 + @test B[] == 42.0 + rewind!(pool) + end + + # ========================================================================== + # Tuple and vararg syntax + # ========================================================================== + + @testset "Tuple and vararg syntax" begin + pool = AdaptiveArrayPool() + A = collect(1.0:12.0) + + checkpoint!(pool) + B_vararg = reshape!(pool, A, 3, 4) + @test size(B_vararg) == (3, 4) + rewind!(pool) + + checkpoint!(pool) + B_tuple = reshape!(pool, A, (3, 4)) + @test size(B_tuple) == (3, 4) + rewind!(pool) + end + + # ========================================================================== + # Multiple element types + # ========================================================================== + + @testset "Multiple element types" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + A_f64 = collect(1.0:6.0) + B_f64 = reshape!(pool, A_f64, 2, 3) + @test eltype(B_f64) == Float64 + @test size(B_f64) == (2, 3) + + A_i32 = Int32.(1:6) + B_i32 = reshape!(pool, A_i32, 3, 2) + @test eltype(B_i32) == Int32 + @test size(B_i32) == (3, 2) + + A_bool = Bool[true, false, true, false] + B_bool = reshape!(pool, A_bool, 2, 2) + @test eltype(B_bool) == Bool + @test size(B_bool) == (2, 2) + + rewind!(pool) + end + + # ========================================================================== + # checkpoint!/rewind! integration + # ========================================================================== + + @testset "checkpoint!/rewind! integration" begin + pool = AdaptiveArrayPool() + A = collect(1.0:12.0) + + tp = get_typed_pool!(pool, Float64) + n_before = tp.n_active + + checkpoint!(pool) + B = reshape!(pool, A, 3, 4) + @test size(B) == (3, 4) + + rewind!(pool) + @test tp.n_active == n_before # slot reclaimed + end + + # ========================================================================== + # @with_pool integration + # ========================================================================== + + @testset "@with_pool integration" begin + A = collect(1.0:12.0) + result = @with_pool pool begin + B = reshape!(pool, A, 3, 4) + sum(B) + end + @test result == sum(1.0:12.0) + end + + # ========================================================================== + # External arrays (not from pool) + # ========================================================================== + + @testset "External arrays" begin + pool = AdaptiveArrayPool() + A = rand(6) + + checkpoint!(pool) + B = reshape!(pool, A, 2, 3) + @test size(B) == (2, 3) + @test B[1, 1] == A[1] + rewind!(pool) + end + + # ========================================================================== + # Zero allocation (v1.11+ only) + # ========================================================================== + + @static if VERSION >= v"1.11-" + @testset "Zero allocation — cross-dim reshape" begin + function _test_reshape_cross_dim() + pool = AdaptiveArrayPool() + A = collect(1.0:12.0) + + # Warmup (compile + cache) + for _ in 1:3 + checkpoint!(pool) + B = reshape!(pool, A, 3, 4) + _ = sum(B) + rewind!(pool) + end + + alloc = @allocated begin + checkpoint!(pool) + B = reshape!(pool, A, 3, 4) + _ = sum(B) + rewind!(pool) + end + return alloc + end + + _test_reshape_cross_dim() # compile + _test_reshape_cross_dim() # compile again + alloc = _test_reshape_cross_dim() + println(" reshape! cross-dim: $alloc bytes") + @test alloc == 0 + end + + @testset "Zero allocation — same-dim reshape" begin + function _test_reshape_same_dim() + pool = AdaptiveArrayPool() + A = reshape(collect(1.0:12.0), 3, 4) + + for _ in 1:3 + checkpoint!(pool) + B = reshape!(pool, A, 4, 3) + _ = sum(B) + rewind!(pool) + end + + alloc = @allocated begin + checkpoint!(pool) + B = reshape!(pool, A, 4, 3) + _ = sum(B) + rewind!(pool) + end + return alloc + end + + _test_reshape_same_dim() + _test_reshape_same_dim() + alloc = _test_reshape_same_dim() + println(" reshape! same-dim: $alloc bytes") + @test alloc == 0 + end + + @testset "Zero allocation — multiple reshapes in sequence" begin + function _test_reshape_sequence() + pool = AdaptiveArrayPool() + A = collect(1.0:24.0) + + for _ in 1:3 + checkpoint!(pool) + B = reshape!(pool, A, 4, 6) + C = reshape!(pool, A, 2, 3, 4) + _ = sum(B) + sum(C) + rewind!(pool) + end + + alloc = @allocated begin + checkpoint!(pool) + B = reshape!(pool, A, 4, 6) + C = reshape!(pool, A, 2, 3, 4) + _ = sum(B) + sum(C) + rewind!(pool) + end + return alloc + end + + _test_reshape_sequence() + _test_reshape_sequence() + alloc = _test_reshape_sequence() + println(" reshape! sequence: $alloc bytes") + @test alloc == 0 + end + end + + # ========================================================================== + # DisabledPool fallback + # ========================================================================== + + @testset "DisabledPool fallback" begin + A = collect(1.0:12.0) + B = reshape!(DISABLED_CPU, A, 3, 4) + @test size(B) == (3, 4) + @test B[1, 1] == 1.0 + @test B[3, 4] == 12.0 + + # Tuple syntax + C = reshape!(DISABLED_CPU, A, (2, 6)) + @test size(C) == (2, 6) + end + + # ========================================================================== + # @with_pool function — realistic mixed operations + # ========================================================================== + + @testset "@with_pool function — mixed acquire + reshape" begin + src = collect(1.0:24.0) + + @with_pool pool function _test_reshape_mixed_ops(src) + # 1) acquire! a temp buffer, copy and scale + tmp = acquire!(pool, Float64, length(src)) + tmp .= src .* 2.0 + + # 2) reshape! external array → matrix + M = reshape!(pool, src, 4, 6) + + # 3) zeros! for column-sum accumulation + col_sums = zeros!(pool, Float64, 6) + for j in 1:6, i in 1:4 + col_sums[j] += M[i, j] + end + + # 4) memory sharing: mutation through M visible in src + old = src[1] + M[1, 1] = -999.0 + shared = (src[1] == -999.0) + src[1] = old # restore + + # 5) another reshape! of same data → 3D + M3 = reshape!(pool, src, 2, 3, 4) + + return ( + sum(tmp), # 2× sum of 1:24 + sum(col_sums), # sum of 1:24 + shared, # memory sharing + size(M), size(M3), # shapes + M3[1, 1, 1], M3[2, 3, 4], # values in 3D view + ) + end + + s_tmp, s_cols, mem_ok, sz2d, sz3d, v111, v234 = _test_reshape_mixed_ops(src) + @test s_tmp == sum(1.0:24.0) * 2.0 + @test s_cols ≈ sum(1.0:24.0) + @test mem_ok == true + @test sz2d == (4, 6) + @test sz3d == (2, 3, 4) + @test v111 == 1.0 + @test v234 == 24.0 + + # External data integrity: src must be unchanged after call + @test src == collect(1.0:24.0) + @test src isa Vector{Float64} + @test size(src) == (24,) + end + + # ========================================================================== + # Zero allocation — @with_pool function (v1.11+) + # ========================================================================== + + @static if VERSION >= v"1.11-" + @testset "Zero allocation — @with_pool function (acquire + reshape + zeros!)" begin + ext = collect(1.0:24.0) + + @with_pool pool function _test_reshape_func_alloc(data) + tmp = acquire!(pool, Float64, length(data)) + tmp .= data + M = reshape!(pool, data, 4, 6) + buf = zeros!(pool, Float64, 6) + for j in 1:6, i in 1:4 + buf[j] += M[i, j] + end + return sum(tmp) + sum(buf) + end + + # Warmup (compile + cache) + for _ in 1:4; _test_reshape_func_alloc(ext); end + + alloc = @allocated _test_reshape_func_alloc(ext) + println(" @with_pool function (acquire+reshape+zeros!): $alloc bytes") + @test alloc == 0 + end + end + + # ========================================================================== + # @maybe_with_pool — pooling vs no-pooling proves zero-alloc + # ========================================================================== + + @static if VERSION >= v"1.11-" + @testset "@maybe_with_pool — pooling vs no-pooling allocation" begin + ext = collect(1.0:12.0) + + @maybe_with_pool pool function _test_maybe_reshape_alloc(data) + M = reshape!(pool, data, 3, 4) + tmp = acquire!(pool, Float64, 12) + tmp .= data .* 2.0 + return sum(M) + sum(tmp) + end + + function _measure_maybe_reshape(data, enabled) + MAYBE_POOLING_ENABLED[] = enabled + for _ in 1:4; _test_maybe_reshape_alloc(data); end + return @allocated _test_maybe_reshape_alloc(data) + end + + expected = sum(1.0:12.0) * 3.0 + + old_state = MAYBE_POOLING_ENABLED[] + try + # Compile both paths + _measure_maybe_reshape(ext, true) + _measure_maybe_reshape(ext, false) + + # Measure + alloc_pooled = _measure_maybe_reshape(ext, true) + alloc_unpooled = _measure_maybe_reshape(ext, false) + + println(" @maybe_with_pool pooled: $alloc_pooled bytes") + println(" @maybe_with_pool unpooled: $alloc_unpooled bytes") + + # Pool: zero allocation + @test alloc_pooled == 0 + # No pool: must allocate (reshape wrapper + Vector) + @test alloc_unpooled > 0 + + # Both paths produce correct results + MAYBE_POOLING_ENABLED[] = true + @test _test_maybe_reshape_alloc(ext) ≈ expected + MAYBE_POOLING_ENABLED[] = false + @test _test_maybe_reshape_alloc(ext) ≈ expected + finally + MAYBE_POOLING_ENABLED[] = old_state + end + end + end + +end