Skip to content

Commit 6ee5247

Browse files
committed
Code refactoring for panic: send on closed channel
1 parent 066d559 commit 6ee5247

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

pkg/core/track.go

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package core
33
import (
44
"encoding/json"
55
"errors"
6+
67
"github.com/pion/rtp"
78
)
89

@@ -70,9 +71,8 @@ type Sender struct {
7071
Packets int `json:"packets,omitempty"`
7172
Drops int `json:"drops,omitempty"`
7273

73-
buf chan *Packet
74-
done chan struct{}
75-
isClosed bool
74+
buf chan *Packet
75+
done chan struct{}
7676
}
7777

7878
func NewSender(media *Media, codec *Codec) *Sender {
@@ -99,11 +99,6 @@ func NewSender(media *Media, codec *Codec) *Sender {
9999
s.Input = func(packet *Packet) {
100100
// writing to nil chan - OK, writing to closed chan - panic
101101
s.mu.Lock()
102-
if s.isClosed {
103-
s.Drops++
104-
s.mu.Unlock()
105-
return
106-
}
107102
select {
108103
case s.buf <- packet:
109104
s.Bytes += len(packet.Payload)
@@ -145,6 +140,7 @@ func (s *Sender) Start() {
145140
s.done = make(chan struct{})
146141

147142
go func() {
143+
// for range on nil chan is OK
148144
for packet := range s.buf {
149145
s.Output(packet)
150146
}
@@ -153,7 +149,7 @@ func (s *Sender) Start() {
153149
}
154150

155151
func (s *Sender) Wait() {
156-
if done := s.done; s.done != nil {
152+
if done := s.done; done != nil {
157153
<-done
158154
}
159155
}
@@ -171,10 +167,9 @@ func (s *Sender) State() string {
171167
func (s *Sender) Close() {
172168
// close buffer if exists
173169
s.mu.Lock()
174-
if buf := s.buf; buf != nil && !s.isClosed {
175-
s.isClosed = true
176-
s.buf = nil
177-
defer close(buf)
170+
if s.buf != nil {
171+
close(s.buf) // exit from for range loop
172+
s.buf = nil // prevent writing to closed chan
178173
}
179174
s.mu.Unlock()
180175

pkg/core/track_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package core
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestSenser(t *testing.T) {
10+
recv := make(chan *Packet) // blocking receiver
11+
12+
sender := NewSender(nil, &Codec{})
13+
sender.Output = func(packet *Packet) {
14+
recv <- packet
15+
}
16+
require.Equal(t, "new", sender.State())
17+
18+
sender.Start()
19+
require.Equal(t, "connected", sender.State())
20+
21+
sender.Input(&Packet{})
22+
sender.Input(&Packet{})
23+
24+
require.Equal(t, 2, sender.Packets)
25+
require.Equal(t, 0, sender.Drops)
26+
27+
// important to read one before close
28+
// because goroutine in Start() can run with nil chan
29+
// it's OK in real life, but bad for test
30+
_, ok := <-recv
31+
require.True(t, ok)
32+
33+
sender.Close()
34+
require.Equal(t, "closed", sender.State())
35+
36+
sender.Input(&Packet{})
37+
38+
require.Equal(t, 2, sender.Packets)
39+
require.Equal(t, 1, sender.Drops)
40+
41+
// read 2nd
42+
_, ok = <-recv
43+
require.True(t, ok)
44+
45+
// read 3rd
46+
select {
47+
case <-recv:
48+
ok = true
49+
default:
50+
ok = false
51+
}
52+
require.False(t, ok)
53+
}

0 commit comments

Comments
 (0)