Skip to content

Commit 6600b0a

Browse files
authored
Merge pull request #157 from evonzee/opensb-support-norestart
OpenSB Support - register filters always, use flags to control compression loop
2 parents a61e384 + 89bf87b commit 6600b0a

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

server.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class StarryPyServer:
3333
"""
3434
def __init__(self, reader, writer, config, factory):
3535
logger.debug("Initializing connection.")
36-
self._reader = reader # read packets from client
37-
self._writer = writer # writes packets to client
36+
self._reader = ZstdFrameReader(reader, Direction.TO_SERVER) # read packets from client
37+
self._writer = ZstdFrameWriter(writer) # writes packets to client
3838
self._client_reader = None # read packets from server (acting as client)
3939
self._client_writer = None # write packets to server
4040
self.factory = factory
@@ -48,17 +48,13 @@ def __init__(self, reader, writer, config, factory):
4848
self._client_read_future = None
4949
self._server_write_future = None
5050
self._client_write_future = None
51-
self._expect_server_loop_death = False
5251
logger.info("Received connection from {}".format(self.client_ip))
5352

5453
def start_zstd(self):
55-
self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER)
56-
self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT)
57-
self._writer = ZstdFrameWriter(self._writer, skip_packets=1)
58-
self._client_writer = ZstdFrameWriter(self._client_writer)
59-
self._expect_server_loop_death = True
60-
self._server_loop_future.cancel()
61-
self._server_loop_future = asyncio.create_task(self.server_loop())
54+
self._reader.enable_zstd()
55+
self._client_reader.enable_zstd()
56+
self._writer.enable_zstd(skip_packets=1) # skip this packet
57+
self._client_writer.enable_zstd()
6258
logger.info("Switched to zstd")
6359

6460

@@ -95,12 +91,8 @@ async def server_loop(self):
9591
"{}: {}".format(err.__class__.__name__, err))
9692
logger.error("Error details and traceback: {}".format(traceback.format_exc()))
9793
finally:
98-
if not self._expect_server_loop_death:
99-
logger.info("Server loop ended.")
100-
self.die()
101-
else:
102-
logger.info("Restarting server loop for switch to zstd.")
103-
self._expect_server_loop_death = False
94+
logger.info("Server loop ended.")
95+
self.die()
10496

10597
async def client_loop(self):
10698
"""
@@ -109,9 +101,11 @@ async def client_loop(self):
109101
110102
:return:
111103
"""
112-
(self._client_reader, self._client_writer) = \
113-
await asyncio.open_connection(self.config['upstream_host'],
104+
(reader, writer) = await asyncio.open_connection(self.config['upstream_host'],
114105
self.config['upstream_port'])
106+
107+
self._client_reader = ZstdFrameReader(reader, Direction.TO_CLIENT)
108+
self._client_writer = ZstdFrameWriter(writer)
115109

116110
try:
117111
while True:

zstd_reader.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ def __init__(self, reader: asyncio.StreamReader, direction: Direction):
1111
self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer)
1212
self.raw_reader = reader
1313
self.direction = direction
14+
self.zstd_enabled = False
15+
16+
def enable_zstd(self):
17+
self.zstd_enabled = True
1418

1519
async def readexactly(self, count):
1620
# print(f"Reading exactly {count} bytes")
@@ -31,11 +35,14 @@ async def read_from_network(self, target_count):
3135
# print(f"Read {len(chunk)} bytes from network")
3236
if not chunk:
3337
raise asyncio.CancelledError("Connection closed")
34-
try:
35-
self.decompressor.write(chunk)
36-
except zstd.ZstdError:
37-
print("Zstd error, dropping connection")
38-
raise asyncio.CancelledError("Error in compressed data stream!")
38+
if not self.zstd_enabled:
39+
self.outputbuffer.write(chunk)
40+
else:
41+
try:
42+
self.decompressor.write(chunk)
43+
except zstd.ZstdError:
44+
print("Zstd error, dropping connection")
45+
raise asyncio.CancelledError("Error in compressed data stream!")
3946

4047
class NonSeekableMemoryStream(io.RawIOBase):
4148
def __init__(self):

zstd_writer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33
import zstandard as zstd
44

55
class ZstdFrameWriter:
6-
def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0):
6+
def __init__(self, raw_writer: asyncio.StreamWriter):
77
self.compressor = zstd.ZstdCompressor()
88
self.raw_writer = raw_writer
9+
self.skip_packets = 0
10+
self.zstd_enabled = False
11+
12+
def enable_zstd(self, skip_packets=0):
13+
self.zstd_enabled = True
914
self.skip_packets = skip_packets
1015

1116
async def drain(self):
@@ -16,6 +21,10 @@ def close(self):
1621
self.compressor = None
1722

1823
def write(self, data):
24+
25+
if not self.zstd_enabled:
26+
self.raw_writer.write(data)
27+
return
1928

2029
if self.skip_packets > 0:
2130
self.skip_packets -= 1

0 commit comments

Comments
 (0)