Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.29.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
22 changes: 20 additions & 2 deletions pkg/crypto/ciphersuite/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,31 @@ go test -tags=bench -bench=BenchmarkCBCEncrypt -benchmem
go test -tags=bench -bench=BenchmarkCBCDecrypt -benchmem
```

- All cipgers, with 1KB payloads only
- ChaCha20Poly1305 benchmarks only:

```bash
go test -tags=bench -bench=BenchmarkChaCha20Poly1305 -benchmem
```

- ChaCha20Poly1305 `Encrypt` benchmark only:

```bash
go test -tags=bench -bench=BenchmarkChaCha20Poly1305Encrypt -benchmem
```

- ChaCha20Poly1305 `Decrypt` benchmark only:

```bash
go test -tags=bench -bench=BenchmarkChaCha20Poly1305Decrypt -benchmem
```

- All ciphers, with 1KB payloads only

```bash
go test -tags=bench -bench=/1KB -benchmem
```

- All cipgers, with 16B payloads only
- All ciphers, with 16B payloads only

```bash
go test -tags=bench -bench=/16B -benchmem
Expand Down
104 changes: 103 additions & 1 deletion pkg/crypto/ciphersuite/bench_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,109 @@

package ciphersuite

import "testing"
import (
"testing"

"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
)

type testCipher interface {
Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error)
Decrypt(header recordlayer.Header, in []byte) ([]byte, error)
}

// benchmarkEncrypt benchmarks a cipher's encryption with various payload sizes.
func benchmarkEncrypt(b *testing.B, cipher testCipher) {
b.Helper()

payloadSizes := []int{16, 64, 128, 256, 512, 800, 1024, 1200, 1500, 4096, 8192}

for _, size := range payloadSizes {
b.Run(formatSize(b, size), func(b *testing.B) {
hdr := recordlayer.Header{
ContentType: protocol.ContentTypeApplicationData,
Version: protocol.Version1_2,
Epoch: 1,
SequenceNumber: 12345,
}

headerRaw, err := hdr.Marshal()
if err != nil {
b.Fatal(err)
}

payload := make([]byte, size)
raw := make([]byte, len(headerRaw)+len(payload))
copy(raw, headerRaw)
copy(raw[len(headerRaw):], payload)

pkt := &recordlayer.RecordLayer{Header: hdr}

b.ReportAllocs()
b.SetBytes(int64(size))
b.ResetTimer()

for i := 0; i < b.N; i++ {
rawCopy := make([]byte, len(raw))
copy(rawCopy, raw)

_, err := cipher.Encrypt(pkt, rawCopy)
if err != nil {
b.Fatal(err)
}
}
})
}
}

// benchmarkDecrypt benchmarks a cipher's decryption with various payload sizes.
func benchmarkDecrypt(b *testing.B, cipher testCipher) {
b.Helper()

payloadSizes := []int{16, 64, 256, 512, 1024, 1500}

for _, size := range payloadSizes {
b.Run(formatSize(b, size), func(b *testing.B) {
hdr := recordlayer.Header{
ContentType: protocol.ContentTypeApplicationData,
Version: protocol.Version1_2,
Epoch: 1,
SequenceNumber: 12345,
}

headerRaw, err := hdr.Marshal()
if err != nil {
b.Fatal(err)
}

payload := make([]byte, size)
raw := make([]byte, len(headerRaw)+len(payload))
copy(raw, headerRaw)
copy(raw[len(headerRaw):], payload)

pkt := &recordlayer.RecordLayer{Header: hdr}
encrypted, err := cipher.Encrypt(pkt, raw)
if err != nil {
b.Fatal(err)
}

b.ReportAllocs()
b.SetBytes(int64(size))
b.ResetTimer()

for i := 0; i < b.N; i++ {
encCopy := make([]byte, len(encrypted))
copy(encCopy, encrypted)

_, err := cipher.Decrypt(hdr, encCopy)
if err != nil {
b.Fatal(err)
}
}
})
}
}

func formatSize(b *testing.B, size int) string {
b.Helper()
Expand Down
87 changes: 2 additions & 85 deletions pkg/crypto/ciphersuite/cbc_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ package ciphersuite
import (
"crypto/sha256"
"testing"

"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
)

// BenchmarkCBCEncrypt benchmarks CBC encryption with various payload sizes.
Expand All @@ -26,45 +23,7 @@ func BenchmarkCBCEncrypt(b *testing.B) {
b.Fatal(err)
}

payloadSizes := []int{16, 64, 256, 512, 1024, 1500}

// nolint:dupl
for _, size := range payloadSizes {
b.Run(formatSize(b, size), func(b *testing.B) {
hdr := recordlayer.Header{
ContentType: protocol.ContentTypeApplicationData,
Version: protocol.Version1_2,
Epoch: 1,
SequenceNumber: 12345,
}

headerRaw, err := hdr.Marshal()
if err != nil {
b.Fatal(err)
}

payload := make([]byte, size)
raw := make([]byte, len(headerRaw)+len(payload))
copy(raw, headerRaw)
copy(raw[len(headerRaw):], payload)

pkt := &recordlayer.RecordLayer{Header: hdr}

b.ReportAllocs()
b.SetBytes(int64(size))
b.ResetTimer()

for i := 0; i < b.N; i++ {
rawCopy := make([]byte, len(raw))
copy(rawCopy, raw)

_, err := cbcCipher.Encrypt(pkt, rawCopy)
if err != nil {
b.Fatal(err)
}
}
})
}
benchmarkEncrypt(b, cbcCipher)
}

// BenchmarkCBCDecrypt benchmarks CBC decryption with various payload sizes.
Expand All @@ -80,47 +39,5 @@ func BenchmarkCBCDecrypt(b *testing.B) {
b.Fatal(err)
}

payloadSizes := []int{16, 64, 256, 512, 1024, 1500}

// nolint:dupl
for _, size := range payloadSizes {
b.Run(formatSize(b, size), func(b *testing.B) {
hdr := recordlayer.Header{
ContentType: protocol.ContentTypeApplicationData,
Version: protocol.Version1_2,
Epoch: 1,
SequenceNumber: 12345,
}

headerRaw, err := hdr.Marshal()
if err != nil {
b.Fatal(err)
}

payload := make([]byte, size)
raw := make([]byte, len(headerRaw)+len(payload))
copy(raw, headerRaw)
copy(raw[len(headerRaw):], payload)

pkt := &recordlayer.RecordLayer{Header: hdr}
encrypted, err := cbcCipher.Encrypt(pkt, raw)
if err != nil {
b.Fatal(err)
}

b.ReportAllocs()
b.SetBytes(int64(size))
b.ResetTimer()

for i := 0; i < b.N; i++ {
encCopy := make([]byte, len(encrypted))
copy(encCopy, encrypted)

_, err := cbcCipher.Decrypt(hdr, encCopy)
if err != nil {
b.Fatal(err)
}
}
})
}
benchmarkDecrypt(b, cbcCipher)
}
91 changes: 11 additions & 80 deletions pkg/crypto/ciphersuite/ccm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@ package ciphersuite

import (
"crypto/aes"
"crypto/rand"
"encoding/binary"
"fmt"
"sync"

"github.com/pion/dtls/v3/pkg/crypto/ccm"
"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
)

Expand All @@ -27,12 +22,7 @@ const (

// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets.
type CCM struct {
localCCM, remoteCCM ccm.CCM
localWriteIV, remoteWriteIV []byte
tagLen CCMTagLen

// buffer pool for nonces.
nonceBufferPool sync.Pool
aead *aead
}

// NewCCM creates a DTLS GCM Cipher.
Expand All @@ -56,82 +46,23 @@ func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV [
}

return &CCM{
localCCM: localCCM,
localWriteIV: localWriteIV,
remoteCCM: remoteCCM,
remoteWriteIV: remoteWriteIV,
tagLen: tagLen,

nonceBufferPool: sync.Pool{
New: func() any {
b := make([]byte, ccmNonceLength)
return &b // nolint:nlreturn
},
},
aead: newAEAD(
localCCM,
localWriteIV,
remoteCCM,
remoteWriteIV,
ccmNonceLength,
int(tagLen),
),
}, nil
}

// Encrypt encrypt a DTLS RecordLayer message.
func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
payload := raw[pkt.Header.Size():]
raw = raw[:pkt.Header.Size()]

noncePtr := c.nonceBufferPool.Get().(*[]byte) // nolint:forcetypeassert
nonce := *noncePtr
defer c.nonceBufferPool.Put(noncePtr)

copy(nonce[:4], c.localWriteIV[:4])
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}

var additionalData []byte
if pkt.Header.ContentType == protocol.ContentTypeConnectionID {
additionalData = generateAEADAdditionalDataCID(&pkt.Header, len(payload))
} else {
additionalData = generateAEADAdditionalData(&pkt.Header, len(payload))
}

finalSize := len(raw) + 8 + len(payload) + int(c.tagLen)
result := make([]byte, finalSize)
copy(result, raw)
copy(result[len(raw):], nonce[4:])

c.localCCM.Seal(result[len(raw)+8:len(raw)+8], nonce, payload, additionalData)

// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(result[pkt.Header.Size()-2:], uint16(len(result)-pkt.Header.Size())) //nolint:gosec //G115

return result, nil
return c.aead.encrypt(pkt, raw)
}

// Decrypt decrypts a DTLS RecordLayer message.
func (c *CCM) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) {
if err := header.Unmarshal(in); err != nil {
return nil, err
}
switch {
case header.ContentType == protocol.ContentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + header.Size()):
return nil, errNotEnoughRoomForNonce
}

nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[header.Size():header.Size()+8]...)
out := in[header.Size()+8:]

var additionalData []byte
if header.ContentType == protocol.ContentTypeConnectionID {
additionalData = generateAEADAdditionalDataCID(&header, len(out)-int(c.tagLen))
} else {
additionalData = generateAEADAdditionalData(&header, len(out)-int(c.tagLen))
}
var err error
out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint
}

return append(in[:header.Size()], out...), nil
return c.aead.decrypt(header, in)
}
Loading
Loading