@@ -18,7 +18,7 @@ use tokio::{spawn, task::JoinHandle, time::Duration};
1818use crate :: {
1919 archive:: Archive ,
2020 state:: State ,
21- types:: { AuthorizedTransaction , Hash , Tip , Version , hash, schema} ,
21+ types:: { AuthorizedTransaction , Hash , Network , Tip , Version , hash, schema} ,
2222} ;
2323
2424mod channel_pool;
@@ -140,6 +140,7 @@ where
140140#[ derive( Clone ) ]
141141pub struct Connection {
142142 pub ( in crate :: net) inner : quinn:: Connection ,
143+ pub network : Network ,
143144}
144145
145146impl Connection {
@@ -154,14 +155,25 @@ impl Connection {
154155 self . inner . remote_address ( )
155156 }
156157
157- pub async fn new (
158+ pub fn new ( connection : quinn:: Connection , network : Network ) -> Self {
159+ Self {
160+ inner : connection,
161+ network,
162+ }
163+ }
164+
165+ pub async fn from_connecting (
158166 connecting : quinn:: Connecting ,
167+ network : Network ,
159168 ) -> Result < Self , quinn:: ConnectionError > {
160169 let addr = connecting. remote_address ( ) ;
161170 tracing:: trace!( %addr, "connecting to peer" ) ;
162171 let connection = connecting. await ?;
163172 tracing:: info!( %addr, "connected successfully to peer" ) ;
164- Ok ( Self { inner : connection } )
173+ Ok ( Self {
174+ inner : connection,
175+ network,
176+ } )
165177 }
166178
167179 async fn receive_request (
@@ -170,6 +182,15 @@ impl Connection {
170182 {
171183 let ( tx, mut rx) = self . inner . accept_bi ( ) . await ?;
172184 tracing:: trace!( recv_id = %rx. id( ) , "Receiving request" ) ;
185+ let mut magic_bytes = [ 0u8 ; message:: MAGIC_BYTES_LEN ] ;
186+ rx. read_exact ( & mut magic_bytes)
187+ . await
188+ . map_err ( error:: connection:: Receive :: ReadMagic ) ?;
189+ if magic_bytes != message:: magic_bytes ( self . network ) {
190+ return Err (
191+ error:: connection:: Receive :: BadMagic ( magic_bytes) . into ( )
192+ ) ;
193+ }
173194 let msg_bytes = rx. read_to_end ( Connection :: READ_REQUEST_LIMIT ) . await ?;
174195 let msg: RequestMessage = bincode:: deserialize ( & msg_bytes) ?;
175196 tracing:: trace!(
@@ -191,8 +212,9 @@ impl Connection {
191212 "Sending heartbeat"
192213 ) ;
193214 let message = RequestMessageRef :: from ( heartbeat) ;
194- let message = bincode:: serialize ( & message) ?;
195- send. write_all ( & message) . await . map_err ( |err| {
215+ let mut message_buf = message:: magic_bytes ( self . network ) . to_vec ( ) ;
216+ bincode:: serialize_into :: < & mut Vec < _ > , _ > ( & mut message_buf, & message) ?;
217+ send. write_all ( & message_buf) . await . map_err ( |err| {
196218 error:: connection:: Send :: Write {
197219 stream_id : send. id ( ) ,
198220 source : err,
@@ -203,10 +225,20 @@ impl Connection {
203225 }
204226
205227 async fn receive_response (
228+ network : Network ,
206229 mut recv : RecvStream ,
207230 read_response_limit : NonZeroUsize ,
208231 ) -> ResponseResult {
209232 tracing:: trace!( recv_id = %recv. id( ) , "Receiving response" ) ;
233+ let mut magic_bytes = [ 0u8 ; message:: MAGIC_BYTES_LEN ] ;
234+ recv. read_exact ( & mut magic_bytes)
235+ . await
236+ . map_err ( error:: connection:: Receive :: ReadMagic ) ?;
237+ if magic_bytes != message:: magic_bytes ( network) {
238+ return Err (
239+ error:: connection:: Receive :: BadMagic ( magic_bytes) . into ( )
240+ ) ;
241+ }
210242 let response_bytes =
211243 recv. read_to_end ( read_response_limit. get ( ) ) . await ?;
212244 let response: ResponseMessage = bincode:: deserialize ( & response_bytes) ?;
@@ -230,40 +262,52 @@ impl Connection {
230262 "Sending request"
231263 ) ;
232264 let message = RequestMessageRef :: from ( request) ;
233- let message = bincode:: serialize ( & message) ?;
234- send. write_all ( & message) . await . map_err ( |err| {
265+ let mut message_buf = message:: magic_bytes ( self . network ) . to_vec ( ) ;
266+ bincode:: serialize_into :: < & mut Vec < _ > , _ > ( & mut message_buf, & message) ?;
267+ send. write_all ( & message_buf) . await . map_err ( |err| {
235268 error:: connection:: Send :: Write {
236269 stream_id : send. id ( ) ,
237270 source : err,
238271 }
239272 } ) ?;
240273 send. finish ( ) ?;
241- Ok ( Self :: receive_response ( recv, read_response_limit) . await )
274+ Ok (
275+ Self :: receive_response ( self . network , recv, read_response_limit)
276+ . await ,
277+ )
242278 }
243279
280+ // Send a pre-serialized response, where the response does not include
281+ // magic bytes
244282 async fn send_serialized_response (
283+ network : Network ,
245284 mut response_tx : SendStream ,
246285 serialized_response : & [ u8 ] ,
247286 ) -> Result < ( ) , error:: connection:: SendResponse > {
248287 tracing:: trace!(
249288 send_id = %response_tx. id( ) ,
250289 "Sending response"
251290 ) ;
252- response_tx
253- . write_all ( serialized_response)
254- . await
255- . map_err ( |err| {
256- {
257- error:: connection:: Send :: Write {
258- stream_id : response_tx. id ( ) ,
259- source : err,
260- }
291+ async {
292+ response_tx
293+ . write_all ( & message:: magic_bytes ( network) )
294+ . await ?;
295+ response_tx. write_all ( serialized_response) . await
296+ }
297+ . await
298+ . map_err ( |err| {
299+ {
300+ error:: connection:: Send :: Write {
301+ stream_id : response_tx. id ( ) ,
302+ source : err,
261303 }
262- . into ( )
263- } )
304+ }
305+ . into ( )
306+ } )
264307 }
265308
266309 async fn send_response (
310+ network : Network ,
267311 mut response_tx : SendStream ,
268312 response : ResponseMessage ,
269313 ) -> Result < ( ) , error:: connection:: SendResponse > {
@@ -272,8 +316,9 @@ impl Connection {
272316 send_id = %response_tx. id( ) ,
273317 "Sending response"
274318 ) ;
275- let response_bytes = bincode:: serialize ( & response) ?;
276- response_tx. write_all ( & response_bytes) . await . map_err ( |err| {
319+ let mut message_buf = message:: magic_bytes ( network) . to_vec ( ) ;
320+ bincode:: serialize_into :: < & mut Vec < _ > , _ > ( & mut message_buf, & response) ?;
321+ response_tx. write_all ( & message_buf) . await . map_err ( |err| {
277322 {
278323 error:: connection:: Send :: Write {
279324 stream_id : response_tx. id ( ) ,
@@ -285,15 +330,10 @@ impl Connection {
285330 }
286331}
287332
288- impl From < quinn:: Connection > for Connection {
289- fn from ( inner : quinn:: Connection ) -> Self {
290- Self { inner }
291- }
292- }
293-
294333pub struct ConnectionContext {
295334 pub env : sneed:: Env ,
296335 pub archive : Archive ,
336+ pub network : Network ,
297337 pub state : State ,
298338}
299339
@@ -427,7 +467,8 @@ pub fn connect(
427467 let status_repr = status_repr. clone ( ) ;
428468 let info_tx = info_tx. clone ( ) ;
429469 move || async move {
430- let connection = Connection :: new ( connecting) . await ?;
470+ let connection =
471+ Connection :: from_connecting ( connecting, ctxt. network ) . await ?;
431472 status_repr. store (
432473 PeerConnectionStatus :: Connected . as_repr ( ) ,
433474 atomic:: Ordering :: SeqCst ,
0 commit comments