Skip to content

Commit 9613670

Browse files
committed
feat: fire network error when network disconnects during request
1 parent 3d4728c commit 9613670

File tree

4 files changed

+86
-8
lines changed

4 files changed

+86
-8
lines changed

requestmanager/requestmanager.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"github.com/hannahhoward/go-pubsub"
8+
"golang.org/x/xerrors"
79
"sync/atomic"
810

911
blocks "github.com/ipfs/go-block-format"
@@ -70,6 +72,7 @@ type RequestManager struct {
7072
peerHandler PeerHandler
7173
rc *responseCollector
7274
asyncLoader AsyncLoader
75+
disconnectNotif *pubsub.PubSub
7376
// dont touch out side of run loop
7477
nextRequestID graphsync.RequestID
7578
inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
@@ -111,6 +114,7 @@ func New(ctx context.Context,
111114
ctx: ctx,
112115
cancel: cancel,
113116
asyncLoader: asyncLoader,
117+
disconnectNotif: pubsub.New(disconnectDispatcher),
114118
rc: newResponseCollector(ctx),
115119
messages: make(chan requestManagerMessage, 16),
116120
inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus),
@@ -128,6 +132,7 @@ func (rm *RequestManager) SetDelegate(peerHandler PeerHandler) {
128132

129133
type inProgressRequest struct {
130134
requestID graphsync.RequestID
135+
request gsmsg.GraphSyncRequest
131136
incoming chan graphsync.ResponseProgress
132137
incomingError chan error
133138
}
@@ -166,14 +171,46 @@ func (rm *RequestManager) SendRequest(ctx context.Context,
166171
case receivedInProgressRequest = <-inProgressRequestChan:
167172
}
168173

174+
// If the connection to the peer is disconnected, fire an error
175+
unsub := rm.listenForDisconnect(p, func(neterr error) {
176+
rm.networkErrorListeners.NotifyNetworkErrorListeners(p, receivedInProgressRequest.request, neterr)
177+
})
178+
169179
return rm.rc.collectResponses(ctx,
170180
receivedInProgressRequest.incoming,
171181
receivedInProgressRequest.incomingError,
172182
func() {
173183
rm.cancelRequest(receivedInProgressRequest.requestID,
174184
receivedInProgressRequest.incoming,
175185
receivedInProgressRequest.incomingError)
176-
})
186+
},
187+
// Once the request has completed, stop listening for disconnect events
188+
unsub,
189+
)
190+
}
191+
192+
// Dispatch the Disconnect event to subscribers
193+
func disconnectDispatcher(p pubsub.Event, subscriberFn pubsub.SubscriberFn) error {
194+
listener := subscriberFn.(func(peer.ID))
195+
listener(p.(peer.ID))
196+
return nil
197+
}
198+
199+
// Listen for the Disconnect event for the given peer
200+
func (rm *RequestManager) listenForDisconnect(p peer.ID, onDisconnect func(neterr error)) func() {
201+
// Subscribe to Disconnect notifications
202+
return rm.disconnectNotif.Subscribe(func(evtPeer peer.ID) {
203+
// If the peer is the one we're interested in, call the listener
204+
if evtPeer == p {
205+
onDisconnect(xerrors.Errorf("disconnected from peer %s", p))
206+
}
207+
})
208+
}
209+
210+
// Disconnected is called when a peer disconnects
211+
func (rm *RequestManager) Disconnected(p peer.ID) {
212+
// Notify any listeners that a peer has disconnected
213+
rm.disconnectNotif.Publish(p)
177214
}
178215

179216
func (rm *RequestManager) emptyResponse() (chan graphsync.ResponseProgress, chan error) {
@@ -311,17 +348,19 @@ type terminateRequestMessage struct {
311348
requestID graphsync.RequestID
312349
}
313350

314-
func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (chan graphsync.ResponseProgress, chan error) {
351+
func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (gsmsg.GraphSyncRequest, chan graphsync.ResponseProgress, chan error) {
315352
request, hooksResult, err := rm.validateRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions)
316353
if err != nil {
317-
return rm.singleErrorResponse(err)
354+
rp, err := rm.singleErrorResponse(err)
355+
return request, rp, err
318356
}
319357
doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs)
320358
var doNotSendCids *cid.Set
321359
if has {
322360
doNotSendCids, err = cidset.DecodeCidSet(doNotSendCidsData)
323361
if err != nil {
324-
return rm.singleErrorResponse(err)
362+
rp, err := rm.singleErrorResponse(err)
363+
return request, rp, err
325364
}
326365
} else {
327366
doNotSendCids = cid.NewSet()
@@ -355,14 +394,14 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re
355394
ResumeMessages: resumeMessages,
356395
PauseMessages: pauseMessages,
357396
})
358-
return incoming, incomingError
397+
return request, incoming, incomingError
359398
}
360399

361400
func (nrm *newRequestMessage) handle(rm *RequestManager) {
362401
var ipr inProgressRequest
363402
ipr.requestID = rm.nextRequestID
364403
rm.nextRequestID++
365-
ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm)
404+
ipr.request, ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm)
366405

367406
select {
368407
case nrm.inProgressRequestChan <- ipr:

requestmanager/requestmanager_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,42 @@ func TestRequestReturnsMissingBlocks(t *testing.T) {
352352
require.NotEqual(t, len(errs), 0, "did not send errors")
353353
}
354354

355+
func TestDisconnectNotification(t *testing.T) {
356+
ctx := context.Background()
357+
td := newTestData(ctx, t)
358+
requestCtx, cancel := context.WithTimeout(ctx, time.Second)
359+
defer cancel()
360+
peers := testutil.GeneratePeers(2)
361+
362+
// Listen for network errors
363+
networkErrors := make(chan peer.ID, 1)
364+
td.networkErrorListeners.Register(func(p peer.ID, request graphsync.RequestData, err error) {
365+
networkErrors <- p
366+
})
367+
368+
// Send a request to the target peer
369+
targetPeer := peers[0]
370+
td.requestManager.SendRequest(requestCtx, targetPeer, td.blockChain.TipLink, td.blockChain.Selector())
371+
372+
// Disconnect a random peer, should not fire any events
373+
randomPeer := peers[1]
374+
td.requestManager.Disconnected(randomPeer)
375+
select {
376+
case <-networkErrors:
377+
t.Fatal("should not fire network error when unrelated peer disconnects")
378+
default:
379+
}
380+
381+
// Disconnect the target peer, should fire a network error
382+
td.requestManager.Disconnected(targetPeer)
383+
select {
384+
case p:= <-networkErrors:
385+
require.Equal(t, p, targetPeer)
386+
default:
387+
t.Fatal("should fire network error when peer disconnects")
388+
}
389+
}
390+
355391
func TestEncodingExtensions(t *testing.T) {
356392
ctx := context.Background()
357393
td := newTestData(ctx, t)

requestmanager/responsecollector.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@ func (rc *responseCollector) collectResponses(
1818
requestCtx context.Context,
1919
incomingResponses <-chan graphsync.ResponseProgress,
2020
incomingErrors <-chan error,
21-
cancelRequest func()) (<-chan graphsync.ResponseProgress, <-chan error) {
21+
cancelRequest func(),
22+
onComplete func(),
23+
) (<-chan graphsync.ResponseProgress, <-chan error) {
2224

2325
returnedResponses := make(chan graphsync.ResponseProgress)
2426
returnedErrors := make(chan error)
2527

2628
go func() {
2729
var receivedResponses []graphsync.ResponseProgress
2830
defer close(returnedResponses)
31+
defer onComplete()
2932
outgoingResponses := func() chan<- graphsync.ResponseProgress {
3033
if len(receivedResponses) == 0 {
3134
return nil

requestmanager/responsecollector_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestBufferingResponseProgress(t *testing.T) {
2626
cancelRequest := func() {}
2727

2828
outgoingResponses, outgoingErrors := rc.collectResponses(
29-
requestCtx, incomingResponses, incomingErrors, cancelRequest)
29+
requestCtx, incomingResponses, incomingErrors, cancelRequest, func(){})
3030

3131
blockStore := make(map[ipld.Link][]byte)
3232
persistence := testutil.NewTestStore(blockStore)

0 commit comments

Comments
 (0)