diff --git a/mem/buffer_pool.go b/mem/buffer_pool.go index e37afdd1981d..516e52793663 100644 --- a/mem/buffer_pool.go +++ b/mem/buffer_pool.go @@ -19,6 +19,9 @@ package mem import ( + "fmt" + "math/bits" + "slices" "sort" "sync" @@ -38,20 +41,28 @@ type BufferPool interface { Put(*[]byte) } -const goPageSize = 4 << 10 // 4KiB. N.B. this must be a power of 2. +const goPageSizeExponent = 12 +const goPageSize = 1 << goPageSizeExponent // 4KiB. N.B. this must be a power of 2. -var defaultBufferPoolSizes = []int{ - 256, - goPageSize, - 16 << 10, // 16KB (max HTTP/2 frame size used by gRPC) - 32 << 10, // 32KB (default buffer size for io.Copy) - 1 << 20, // 1MB +var defaultBufferPoolSizeExponents = []uint8{ + 8, + goPageSizeExponent, + 14, // 16KB (max HTTP/2 frame size used by gRPC) + 15, // 32KB (default buffer size for io.Copy) + 20, // 1MB } -var defaultBufferPool BufferPool +var ( + defaultBufferPool BufferPool + uintSize = bits.UintSize // use a variable for mocking during tests. +) func init() { - defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...) + var err error + defaultBufferPool, err = NewBinaryTieredBufferPool(defaultBufferPoolSizeExponents...) + if err != nil { + panic(fmt.Sprintf("Failed to create default buffer pool: %v", err)) + } internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) { defaultBufferPool = pool @@ -109,6 +120,131 @@ func (p *tieredBufferPool) getPool(size int) BufferPool { return p.sizedPools[poolIdx] } +type binaryTieredBufferPool struct { + // exponentToNextLargestPoolMap maps a power-of-two exponent (e.g., 12 for + // 4KB) to the index of the next largest sizedBufferPool. This is used by + // Get() to find the smallest pool that can satisfy a request for a given + // size. + exponentToNextLargestPoolMap []int + // exponentToPreviousLargestPoolMap maps a power-of-two exponent to the + // index of the previous largest sizedBufferPool. This is used by Put() + // to return a buffer to the most appropriate pool based on its capacity. + exponentToPreviousLargestPoolMap []int + sizedPools []*sizedBufferPool + fallbackPool simpleBufferPool + maxPoolCap int // Optimization: Cache max capacity +} + +// NewBinaryTieredBufferPool returns a BufferPool backed by multiple sub-pools. +// This structure enables O(1) lookup time for Get and Put operations. +// +// The arguments provided are the exponents for the buffer capacities (powers +// of 2), not the raw byte sizes. For example, to create a pool of 16KB buffers +// (2^14 bytes), pass 14 as the argument. +func NewBinaryTieredBufferPool(powerOfTwoExponents ...uint8) (BufferPool, error) { + slices.Sort(powerOfTwoExponents) + + // Determine the maximum exponent we need to support. This depends on the + // word size (32-bit vs 64-bit). + maxExponent := uintSize - 1 + indexOfNextLargestBit := slices.Repeat([]int{-1}, maxExponent+1) + indexOfPreviousLargestBit := slices.Repeat([]int{-1}, maxExponent+1) + + maxTier := 0 + pools := make([]*sizedBufferPool, 0, len(powerOfTwoExponents)) + + for i, exp := range powerOfTwoExponents { + // Allocating slices of size > 2^maxExponent isn't possible on + // maxExponent-bit machines. + if int(exp) > maxExponent { + return nil, fmt.Errorf("allocating slice of size 2^%d is not possible", exp) + } + tierSize := 1 << exp + pools = append(pools, newSizedBufferPool(tierSize)) + maxTier = max(maxTier, tierSize) + + // Map the exact power of 2 to this pool index. + indexOfNextLargestBit[exp] = i + indexOfPreviousLargestBit[exp] = i + } + + // Fill gaps for Get() (Next Largest) + // We iterate backwards. If current is empty, take the value from the right (larger). + for i := maxExponent - 1; i >= 0; i-- { + if indexOfNextLargestBit[i] == -1 { + indexOfNextLargestBit[i] = indexOfNextLargestBit[i+1] + } + } + + // Fill gaps for Put() (Previous Largest) + // We iterate forwards. If current is empty, take the value from the left (smaller). + for i := 1; i <= maxExponent; i++ { + if indexOfPreviousLargestBit[i] == -1 { + indexOfPreviousLargestBit[i] = indexOfPreviousLargestBit[i-1] + } + } + + return &binaryTieredBufferPool{ + exponentToNextLargestPoolMap: indexOfNextLargestBit, + exponentToPreviousLargestPoolMap: indexOfPreviousLargestBit, + sizedPools: pools, + maxPoolCap: maxTier, + }, nil +} + +func (b *binaryTieredBufferPool) Get(size int) *[]byte { + return b.poolForGet(size).Get(size) +} + +func (b *binaryTieredBufferPool) poolForGet(size int) BufferPool { + if size == 0 || size > b.maxPoolCap { + return &b.fallbackPool + } + + // Calculate the exponent of the smallest power of 2 >= size. + // We subtract 1 from size to handle exact powers of 2 correctly. + // + // Examples: + // size=16 (0b10000) -> size-1=15 (0b01111) -> bits.Len=4 -> Pool for 2^4 + // size=17 (0b10001) -> size-1=16 (0b10000) -> bits.Len=5 -> Pool for 2^5 + querySize := uint(size - 1) + poolIdx := b.exponentToNextLargestPoolMap[bits.Len(querySize)] + + return b.sizedPools[poolIdx] +} + +func (b *binaryTieredBufferPool) Put(buf *[]byte) { + b.poolForPut(cap(*buf)).Put(buf) +} + +func (b *binaryTieredBufferPool) poolForPut(bCap int) BufferPool { + if bCap == 0 { + return NopBufferPool{} + } + if bCap > b.maxPoolCap { + return &b.fallbackPool + } + // Find the pool with the largest capacity <= bCap. + // + // We calculate the exponent of the largest power of 2 <= bCap. + // bits.Len(x) returns the minimum number of bits required to represent x; + // i.e. the number of bits up to and including the most significant bit. + // Subtracting 1 gives the 0-based index of the most significant bit, + // which is the exponent of the largest power of 2 <= bCap. + // + // Examples: + // cap=16 (0b10000) -> Len=5 -> 5-1=4 -> 2^4 + // cap=15 (0b01111) -> Len=4 -> 4-1=3 -> 2^3 + largestPowerOfTwo := bits.Len(uint(bCap)) - 1 + poolIdx := b.exponentToPreviousLargestPoolMap[largestPowerOfTwo] + // The buffer is smaller than the smallest power of 2, discard it. + if poolIdx == -1 { + // Buffer is smaller than our smallest pool bucket. + return NopBufferPool{} + } + return b.sizedPools[poolIdx] +} + // sizedBufferPool is a BufferPool implementation that is optimized for specific // buffer sizes. For example, HTTP/2 frames within gRPC have a default max size // of 16kb and a sizedBufferPool can be configured to only return buffers with a diff --git a/mem/buffer_pool_internal_test.go b/mem/buffer_pool_internal_test.go new file mode 100644 index 000000000000..8faab7689faf --- /dev/null +++ b/mem/buffer_pool_internal_test.go @@ -0,0 +1,114 @@ +/* + * + * Copyright 2026 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package mem + +import ( + "testing" +) + +func TestNewBinaryTieredBufferPool_WordSize(t *testing.T) { + origUintSize := uintSize + defer func() { uintSize = origUintSize }() + + tests := []struct { + name string + wordSize int + exponents []uint8 + wantErr bool + }{ + { + name: "32-bit valid exponent", + wordSize: 32, + exponents: []uint8{31}, + wantErr: false, + }, + { + name: "32-bit invalid exponent", + wordSize: 32, + exponents: []uint8{32}, + wantErr: true, + }, + { + name: "64-bit valid exponent", + wordSize: 64, + exponents: []uint8{63}, + wantErr: false, + }, + { + name: "64-bit invalid exponent", + wordSize: 64, + exponents: []uint8{64}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uintSize = tt.wordSize + pool, err := NewBinaryTieredBufferPool(tt.exponents...) + if (err != nil) != tt.wantErr { + t.Errorf("NewBinaryTieredBufferPool() error = %t, wantErr %t", err, tt.wantErr) + return + } + if err == nil { + bp := pool.(*binaryTieredBufferPool) + if len(bp.exponentToNextLargestPoolMap) != tt.wordSize { + t.Errorf("exponentToNextLargestPoolMap length = %d, want %d", len(bp.exponentToNextLargestPoolMap), tt.wordSize) + } + if len(bp.exponentToPreviousLargestPoolMap) != tt.wordSize { + t.Errorf("exponentToPreviousLargestPoolMap length = %d, want %d", len(bp.exponentToPreviousLargestPoolMap), tt.wordSize) + } + } + }) + } +} + +// BenchmarkTieredPool benchmarks the performance of the tiered buffer pool +// implementations, specifically focusing on the overhead of selecting the +// correct bucket for a given size. +func BenchmarkTieredPool(b *testing.B) { + defaultBufferPoolSizes := make([]int, len(defaultBufferPoolSizeExponents)) + for i, exp := range defaultBufferPoolSizeExponents { + defaultBufferPoolSizes[i] = 1 << exp + } + b.Run("pool=Tiered", func(b *testing.B) { + p := NewTieredBufferPool(defaultBufferPoolSizes...).(*tieredBufferPool) + for b.Loop() { + for size := range 1 << 19 { + // One for get, one for put. + _ = p.getPool(size) + _ = p.getPool(size) + } + } + }) + + b.Run("pool=BinaryTiered", func(b *testing.B) { + pool, err := NewBinaryTieredBufferPool(defaultBufferPoolSizeExponents...) + if err != nil { + b.Fatalf("Failed to create buffer pool: %v", err) + } + p := pool.(*binaryTieredBufferPool) + for b.Loop() { + for size := range 1 << 19 { + _ = p.poolForGet(size) + _ = p.poolForPut(size) + } + } + }) +} diff --git a/mem/buffer_pool_test.go b/mem/buffer_pool_test.go index 6086805674ba..7d5024e97e30 100644 --- a/mem/buffer_pool_test.go +++ b/mem/buffer_pool_test.go @@ -20,6 +20,7 @@ package mem_test import ( "bytes" + "fmt" "testing" "unsafe" @@ -105,3 +106,43 @@ func (s) TestBufferPoolIgnoresShortBuffers(t *testing.T) { // pool, it could cause a panic. pool.Get(10) } + +func TestBinaryBufferPool(t *testing.T) { + poolSizes := []uint8{0, 2, 3, 4} + + testCases := []struct { + requestSize int + wantCapacity int + }{ + {requestSize: 0, wantCapacity: 0}, + {requestSize: 1, wantCapacity: 1}, + {requestSize: 2, wantCapacity: 4}, + {requestSize: 3, wantCapacity: 4}, + {requestSize: 4, wantCapacity: 4}, + {requestSize: 5, wantCapacity: 8}, + {requestSize: 6, wantCapacity: 8}, + {requestSize: 7, wantCapacity: 8}, + {requestSize: 8, wantCapacity: 8}, + {requestSize: 9, wantCapacity: 16}, + {requestSize: 15, wantCapacity: 16}, + {requestSize: 16, wantCapacity: 16}, + {requestSize: 17, wantCapacity: 4096}, // fallback pool returns sizes in multiples of 4096. + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("requestSize=%d", tc.requestSize), func(t *testing.T) { + pool, err := mem.NewBinaryTieredBufferPool(poolSizes...) + if err != nil { + t.Fatalf("Failed to create buffer pool: %v", err) + } + buf := pool.Get(tc.requestSize) + if cap(*buf) != tc.wantCapacity { + t.Errorf("Get(%d) returned buffer with capacity: %d, want %d", tc.requestSize, cap(*buf), tc.wantCapacity) + } + if len(*buf) != tc.requestSize { + t.Errorf("Get(%d) returned buffer with length: %d, want %d", tc.requestSize, len(*buf), tc.requestSize) + } + pool.Put(buf) + }) + } +}