Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 145 additions & 9 deletions mem/buffer_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
package mem

import (
"fmt"
"math/bits"
"slices"
"sort"
"sync"

Expand All @@ -38,20 +41,28 @@ type BufferPool interface {
Put(*[]byte)
}

const goPageSize = 4 << 10 // 4KiB. N.B. this must be a power of 2.
const goPageSizeExponent = 12
const goPageSize = 1 << goPageSizeExponent // 4KiB. N.B. this must be a power of 2.

var defaultBufferPoolSizes = []int{
256,
goPageSize,
16 << 10, // 16KB (max HTTP/2 frame size used by gRPC)
32 << 10, // 32KB (default buffer size for io.Copy)
1 << 20, // 1MB
var defaultBufferPoolSizeExponents = []uint8{
8,
goPageSizeExponent,
14, // 16KB (max HTTP/2 frame size used by gRPC)
15, // 32KB (default buffer size for io.Copy)
20, // 1MB
}

var defaultBufferPool BufferPool
var (
defaultBufferPool BufferPool
uintSize = bits.UintSize // use a variable for mocking during tests.
)

func init() {
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
var err error
defaultBufferPool, err = NewBinaryTieredBufferPool(defaultBufferPoolSizeExponents...)
if err != nil {
panic(fmt.Sprintf("Failed to create default buffer pool: %v", err))
}

internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) {
defaultBufferPool = pool
Expand Down Expand Up @@ -109,6 +120,131 @@ func (p *tieredBufferPool) getPool(size int) BufferPool {
return p.sizedPools[poolIdx]
}

type binaryTieredBufferPool struct {
// exponentToNextLargestPoolMap maps a power-of-two exponent (e.g., 12 for
// 4KB) to the index of the next largest sizedBufferPool. This is used by
// Get() to find the smallest pool that can satisfy a request for a given
// size.
exponentToNextLargestPoolMap []int
// exponentToPreviousLargestPoolMap maps a power-of-two exponent to the
// index of the previous largest sizedBufferPool. This is used by Put()
// to return a buffer to the most appropriate pool based on its capacity.
exponentToPreviousLargestPoolMap []int
sizedPools []*sizedBufferPool
fallbackPool simpleBufferPool
maxPoolCap int // Optimization: Cache max capacity
}

// NewBinaryTieredBufferPool returns a BufferPool backed by multiple sub-pools.
// This structure enables O(1) lookup time for Get and Put operations.
//
// The arguments provided are the exponents for the buffer capacities (powers
// of 2), not the raw byte sizes. For example, to create a pool of 16KB buffers
// (2^14 bytes), pass 14 as the argument.
func NewBinaryTieredBufferPool(powerOfTwoExponents ...uint8) (BufferPool, error) {
slices.Sort(powerOfTwoExponents)

// Determine the maximum exponent we need to support. This depends on the
// word size (32-bit vs 64-bit).
maxExponent := uintSize - 1
indexOfNextLargestBit := slices.Repeat([]int{-1}, maxExponent+1)
indexOfPreviousLargestBit := slices.Repeat([]int{-1}, maxExponent+1)

maxTier := 0
pools := make([]*sizedBufferPool, 0, len(powerOfTwoExponents))

for i, exp := range powerOfTwoExponents {
// Allocating slices of size > 2^maxExponent isn't possible on
// maxExponent-bit machines.
if int(exp) > maxExponent {
return nil, fmt.Errorf("allocating slice of size 2^%d is not possible", exp)
}
tierSize := 1 << exp
pools = append(pools, newSizedBufferPool(tierSize))
maxTier = max(maxTier, tierSize)

// Map the exact power of 2 to this pool index.
indexOfNextLargestBit[exp] = i
indexOfPreviousLargestBit[exp] = i
}

// Fill gaps for Get() (Next Largest)
// We iterate backwards. If current is empty, take the value from the right (larger).
for i := maxExponent - 1; i >= 0; i-- {
if indexOfNextLargestBit[i] == -1 {
indexOfNextLargestBit[i] = indexOfNextLargestBit[i+1]
}
}

// Fill gaps for Put() (Previous Largest)
// We iterate forwards. If current is empty, take the value from the left (smaller).
for i := 1; i <= maxExponent; i++ {
if indexOfPreviousLargestBit[i] == -1 {
indexOfPreviousLargestBit[i] = indexOfPreviousLargestBit[i-1]
}
}

return &binaryTieredBufferPool{
exponentToNextLargestPoolMap: indexOfNextLargestBit,
exponentToPreviousLargestPoolMap: indexOfPreviousLargestBit,
sizedPools: pools,
maxPoolCap: maxTier,
}, nil
}

func (b *binaryTieredBufferPool) Get(size int) *[]byte {
return b.poolForGet(size).Get(size)
}

func (b *binaryTieredBufferPool) poolForGet(size int) BufferPool {
if size == 0 || size > b.maxPoolCap {
return &b.fallbackPool
}

// Calculate the exponent of the smallest power of 2 >= size.
// We subtract 1 from size to handle exact powers of 2 correctly.
//
// Examples:
// size=16 (0b10000) -> size-1=15 (0b01111) -> bits.Len=4 -> Pool for 2^4
// size=17 (0b10001) -> size-1=16 (0b10000) -> bits.Len=5 -> Pool for 2^5
querySize := uint(size - 1)
poolIdx := b.exponentToNextLargestPoolMap[bits.Len(querySize)]

return b.sizedPools[poolIdx]
}

func (b *binaryTieredBufferPool) Put(buf *[]byte) {
b.poolForPut(cap(*buf)).Put(buf)
}

func (b *binaryTieredBufferPool) poolForPut(bCap int) BufferPool {
if bCap == 0 {
return NopBufferPool{}
}
if bCap > b.maxPoolCap {
return &b.fallbackPool
}
// Find the pool with the largest capacity <= bCap.
//
// We calculate the exponent of the largest power of 2 <= bCap.
// bits.Len(x) returns the minimum number of bits required to represent x;
// i.e. the number of bits up to and including the most significant bit.
// Subtracting 1 gives the 0-based index of the most significant bit,
// which is the exponent of the largest power of 2 <= bCap.
//
// Examples:
// cap=16 (0b10000) -> Len=5 -> 5-1=4 -> 2^4
// cap=15 (0b01111) -> Len=4 -> 4-1=3 -> 2^3
largestPowerOfTwo := bits.Len(uint(bCap)) - 1
poolIdx := b.exponentToPreviousLargestPoolMap[largestPowerOfTwo]
// The buffer is smaller than the smallest power of 2, discard it.
if poolIdx == -1 {
// Buffer is smaller than our smallest pool bucket.
return NopBufferPool{}
}
return b.sizedPools[poolIdx]
}

// sizedBufferPool is a BufferPool implementation that is optimized for specific
// buffer sizes. For example, HTTP/2 frames within gRPC have a default max size
// of 16kb and a sizedBufferPool can be configured to only return buffers with a
Expand Down
114 changes: 114 additions & 0 deletions mem/buffer_pool_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
*
* Copyright 2026 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package mem

import (
"testing"
)

func TestNewBinaryTieredBufferPool_WordSize(t *testing.T) {
origUintSize := uintSize
defer func() { uintSize = origUintSize }()

tests := []struct {
name string
wordSize int
exponents []uint8
wantErr bool
}{
{
name: "32-bit valid exponent",
wordSize: 32,
exponents: []uint8{31},
wantErr: false,
},
{
name: "32-bit invalid exponent",
wordSize: 32,
exponents: []uint8{32},
wantErr: true,
},
{
name: "64-bit valid exponent",
wordSize: 64,
exponents: []uint8{63},
wantErr: false,
},
{
name: "64-bit invalid exponent",
wordSize: 64,
exponents: []uint8{64},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uintSize = tt.wordSize
pool, err := NewBinaryTieredBufferPool(tt.exponents...)
if (err != nil) != tt.wantErr {
t.Errorf("NewBinaryTieredBufferPool() error = %t, wantErr %t", err, tt.wantErr)
return
}
if err == nil {
bp := pool.(*binaryTieredBufferPool)
if len(bp.exponentToNextLargestPoolMap) != tt.wordSize {
t.Errorf("exponentToNextLargestPoolMap length = %d, want %d", len(bp.exponentToNextLargestPoolMap), tt.wordSize)
}
if len(bp.exponentToPreviousLargestPoolMap) != tt.wordSize {
t.Errorf("exponentToPreviousLargestPoolMap length = %d, want %d", len(bp.exponentToPreviousLargestPoolMap), tt.wordSize)
}
}
})
}
}

// BenchmarkTieredPool benchmarks the performance of the tiered buffer pool
// implementations, specifically focusing on the overhead of selecting the
// correct bucket for a given size.
func BenchmarkTieredPool(b *testing.B) {
defaultBufferPoolSizes := make([]int, len(defaultBufferPoolSizeExponents))
for i, exp := range defaultBufferPoolSizeExponents {
defaultBufferPoolSizes[i] = 1 << exp
}
b.Run("pool=Tiered", func(b *testing.B) {
p := NewTieredBufferPool(defaultBufferPoolSizes...).(*tieredBufferPool)
for b.Loop() {
for size := range 1 << 19 {
// One for get, one for put.
_ = p.getPool(size)
_ = p.getPool(size)
}
}
})

b.Run("pool=BinaryTiered", func(b *testing.B) {
pool, err := NewBinaryTieredBufferPool(defaultBufferPoolSizeExponents...)
if err != nil {
b.Fatalf("Failed to create buffer pool: %v", err)
}
p := pool.(*binaryTieredBufferPool)
for b.Loop() {
for size := range 1 << 19 {
_ = p.poolForGet(size)
_ = p.poolForPut(size)
}
}
})
}
41 changes: 41 additions & 0 deletions mem/buffer_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package mem_test

import (
"bytes"
"fmt"
"testing"
"unsafe"

Expand Down Expand Up @@ -105,3 +106,43 @@ func (s) TestBufferPoolIgnoresShortBuffers(t *testing.T) {
// pool, it could cause a panic.
pool.Get(10)
}

func TestBinaryBufferPool(t *testing.T) {
poolSizes := []uint8{0, 2, 3, 4}

testCases := []struct {
requestSize int
wantCapacity int
}{
{requestSize: 0, wantCapacity: 0},
{requestSize: 1, wantCapacity: 1},
{requestSize: 2, wantCapacity: 4},
{requestSize: 3, wantCapacity: 4},
{requestSize: 4, wantCapacity: 4},
{requestSize: 5, wantCapacity: 8},
{requestSize: 6, wantCapacity: 8},
{requestSize: 7, wantCapacity: 8},
{requestSize: 8, wantCapacity: 8},
{requestSize: 9, wantCapacity: 16},
{requestSize: 15, wantCapacity: 16},
{requestSize: 16, wantCapacity: 16},
{requestSize: 17, wantCapacity: 4096}, // fallback pool returns sizes in multiples of 4096.
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("requestSize=%d", tc.requestSize), func(t *testing.T) {
pool, err := mem.NewBinaryTieredBufferPool(poolSizes...)
if err != nil {
t.Fatalf("Failed to create buffer pool: %v", err)
}
buf := pool.Get(tc.requestSize)
if cap(*buf) != tc.wantCapacity {
t.Errorf("Get(%d) returned buffer with capacity: %d, want %d", tc.requestSize, cap(*buf), tc.wantCapacity)
}
if len(*buf) != tc.requestSize {
t.Errorf("Get(%d) returned buffer with length: %d, want %d", tc.requestSize, len(*buf), tc.requestSize)
}
pool.Put(buf)
})
}
}