@@ -33,8 +33,8 @@ class StarryPyServer:
3333 """
3434 def __init__ (self , reader , writer , config , factory ):
3535 logger .debug ("Initializing connection." )
36- self ._reader = reader # read packets from client
37- self ._writer = writer # writes packets to client
36+ self ._reader = ZstdFrameReader ( reader , Direction . TO_SERVER ) # read packets from client
37+ self ._writer = ZstdFrameWriter ( writer ) # writes packets to client
3838 self ._client_reader = None # read packets from server (acting as client)
3939 self ._client_writer = None # write packets to server
4040 self .factory = factory
@@ -48,17 +48,13 @@ def __init__(self, reader, writer, config, factory):
4848 self ._client_read_future = None
4949 self ._server_write_future = None
5050 self ._client_write_future = None
51- self ._expect_server_loop_death = False
5251 logger .info ("Received connection from {}" .format (self .client_ip ))
5352
5453 def start_zstd (self ):
55- self ._reader = ZstdFrameReader (self ._reader , Direction .TO_SERVER )
56- self ._client_reader = ZstdFrameReader (self ._client_reader , Direction .TO_CLIENT )
57- self ._writer = ZstdFrameWriter (self ._writer , skip_packets = 1 )
58- self ._client_writer = ZstdFrameWriter (self ._client_writer )
59- self ._expect_server_loop_death = True
60- self ._server_loop_future .cancel ()
61- self ._server_loop_future = asyncio .create_task (self .server_loop ())
54+ self ._reader .enable_zstd ()
55+ self ._client_reader .enable_zstd ()
56+ self ._writer .enable_zstd (skip_packets = 1 ) # skip this packet
57+ self ._client_writer .enable_zstd ()
6258 logger .info ("Switched to zstd" )
6359
6460
@@ -95,12 +91,8 @@ async def server_loop(self):
9591 "{}: {}" .format (err .__class__ .__name__ , err ))
9692 logger .error ("Error details and traceback: {}" .format (traceback .format_exc ()))
9793 finally :
98- if not self ._expect_server_loop_death :
99- logger .info ("Server loop ended." )
100- self .die ()
101- else :
102- logger .info ("Restarting server loop for switch to zstd." )
103- self ._expect_server_loop_death = False
94+ logger .info ("Server loop ended." )
95+ self .die ()
10496
10597 async def client_loop (self ):
10698 """
@@ -109,9 +101,11 @@ async def client_loop(self):
109101
110102 :return:
111103 """
112- (self ._client_reader , self ._client_writer ) = \
113- await asyncio .open_connection (self .config ['upstream_host' ],
104+ (reader , writer ) = await asyncio .open_connection (self .config ['upstream_host' ],
114105 self .config ['upstream_port' ])
106+
107+ self ._client_reader = ZstdFrameReader (reader , Direction .TO_CLIENT )
108+ self ._client_writer = ZstdFrameWriter (writer )
115109
116110 try :
117111 while True :
0 commit comments