Skip to content

Commit 97d75d3

Browse files
committed
feat(websockets): introduce post-generated SDK patch script
1 parent 3dff08b commit 97d75d3

File tree

3 files changed

+270
-0
lines changed

3 files changed

+270
-0
lines changed

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
.PHONY: patch-websockets
2+
3+
patch-websockets:
4+
poetry run python scripts/patch_websocket_transport.py
5+
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Post-generation script to patch auto-generated client files.
4+
5+
Replaces websockets imports with our wrapper module.
6+
"""
7+
8+
import re
9+
import sys
10+
from pathlib import Path
11+
12+
13+
# Pattern to find websockets sync client imports
14+
WEBSOCKETS_SYNC_IMPORT_PATTERN = r"import websockets\.sync\.client as websockets_sync_client"
15+
16+
# Pattern to find websockets async client imports
17+
WEBSOCKETS_ASYNC_IMPORT_PATTERN = r"from websockets\.legacy\.client import connect as websockets_client_connect|from websockets import connect as websockets_client_connect"
18+
19+
# Pattern to find websockets module imports
20+
WEBSOCKETS_MODULE_PATTERN = r"^import websockets$"
21+
22+
# Pattern to find websockets sync connection imports
23+
WEBSOCKETS_SYNC_CONNECTION_PATTERN = r"import websockets\.sync\.connection as websockets_sync_connection"
24+
25+
# Pattern to find WebSocketClientProtocol imports
26+
WEBSOCKETS_PROTOCOL_PATTERN = (
27+
r"from websockets\.legacy\.client import WebSocketClientProtocol|from websockets import WebSocketClientProtocol"
28+
)
29+
30+
31+
def has_patch_already(content: str) -> bool:
32+
"""Check if file already has our patch."""
33+
return (
34+
"from ...core.websocket_wrapper import" in content or "from deepgram.core.websocket_wrapper import" in content
35+
)
36+
37+
38+
def patch_file(file_path: Path) -> bool:
39+
"""Patch a single file if it needs patching."""
40+
try:
41+
content = file_path.read_text(encoding="utf-8")
42+
except Exception as e:
43+
print(f"Error reading {file_path}: {e}", file=sys.stderr)
44+
return False
45+
46+
# Skip if already patched
47+
if has_patch_already(content):
48+
return False
49+
50+
# Check if file has websockets imports
51+
has_sync_import = bool(re.search(WEBSOCKETS_SYNC_IMPORT_PATTERN, content))
52+
has_async_import = bool(re.search(WEBSOCKETS_ASYNC_IMPORT_PATTERN, content))
53+
has_module_import = bool(re.search(WEBSOCKETS_MODULE_PATTERN, content, re.MULTILINE))
54+
has_sync_connection = bool(re.search(WEBSOCKETS_SYNC_CONNECTION_PATTERN, content))
55+
has_protocol_import = bool(re.search(WEBSOCKETS_PROTOCOL_PATTERN, content))
56+
57+
if not (has_sync_import or has_async_import or has_module_import or has_sync_connection or has_protocol_import):
58+
return False
59+
60+
modified = False
61+
lines = content.split("\n")
62+
63+
# Determine relative import depth
64+
# Files are in listen/v1/, speak/v1/, agent/v1/, etc.
65+
# They use ...core.api_error (3 dots), which means 3 package levels up
66+
# From agent/v1/client.py: ... = deepgram package, so ...core = deepgram.core
67+
# So we always use 3 dots to match existing imports
68+
relative_import = "...core.websocket_wrapper"
69+
70+
# Patch websockets module import
71+
for i, line in enumerate(lines):
72+
if re.search(WEBSOCKETS_MODULE_PATTERN, line):
73+
# Replace: import websockets
74+
# With: from ...core.websocket_wrapper import websockets
75+
indent = re.match(r"(\s*)", line).group(1) if re.match(r"(\s*)", line) else ""
76+
lines[i] = f"{indent}from {relative_import} import websockets # noqa: E402"
77+
modified = True
78+
break
79+
80+
# Patch websockets.exceptions import
81+
for i, line in enumerate(lines):
82+
if "import websockets.exceptions" in line:
83+
# Replace: import websockets.exceptions
84+
# With: from ...core.websocket_wrapper import websockets
85+
# Note: websockets.exceptions is accessible via websockets.exceptions
86+
indent = re.match(r"(\s*)", line).group(1) if re.match(r"(\s*)", line) else ""
87+
# We'll import websockets which includes exceptions, so we can remove this line
88+
# or keep it for clarity - let's replace it
89+
lines[i] = f"{indent}from {relative_import} import websockets # noqa: E402"
90+
modified = True
91+
break
92+
93+
# Patch sync connection import
94+
for i, line in enumerate(lines):
95+
if re.search(WEBSOCKETS_SYNC_CONNECTION_PATTERN, line):
96+
# Replace: import websockets.sync.connection as websockets_sync_connection
97+
# With: from ...core.websocket_wrapper import websockets_sync_connection
98+
indent = re.match(r"(\s*)", line).group(1) if re.match(r"(\s*)", line) else ""
99+
lines[i] = f"{indent}from {relative_import} import websockets_sync_connection # noqa: E402"
100+
modified = True
101+
break
102+
103+
# Patch sync import - replace the import line
104+
for i, line in enumerate(lines):
105+
if re.search(WEBSOCKETS_SYNC_IMPORT_PATTERN, line):
106+
# Replace: import websockets.sync.client as websockets_sync_client
107+
# With: from ...core.websocket_wrapper import websockets_sync_client
108+
indent = re.match(r"(\s*)", line).group(1) if re.match(r"(\s*)", line) else ""
109+
lines[i] = f"{indent}from {relative_import} import websockets_sync_client # noqa: E402"
110+
modified = True
111+
break
112+
113+
# Patch async import - handle try/except block for connect
114+
try_block_start = None
115+
for i, line in enumerate(lines):
116+
if "try:" in line and i + 1 < len(lines) and re.search(WEBSOCKETS_ASYNC_IMPORT_PATTERN, lines[i + 1]):
117+
try_block_start = i
118+
break
119+
120+
if try_block_start is not None:
121+
# Replace the try branch import
122+
indent = (
123+
re.match(r"(\s*)", lines[try_block_start + 1]).group(1)
124+
if re.match(r"(\s*)", lines[try_block_start + 1])
125+
else ""
126+
)
127+
lines[try_block_start + 1] = f"{indent}from {relative_import} import websockets_client_connect # noqa: E402"
128+
129+
# Replace the except branch import if it exists
130+
if try_block_start + 2 < len(lines) and "except ImportError:" in lines[try_block_start + 2]:
131+
if try_block_start + 3 < len(lines):
132+
except_indent = (
133+
re.match(r"(\s*)", lines[try_block_start + 3]).group(1)
134+
if re.match(r"(\s*)", lines[try_block_start + 3])
135+
else indent
136+
)
137+
lines[try_block_start + 3] = (
138+
f"{except_indent}from {relative_import} import websockets_client_connect # noqa: E402"
139+
)
140+
modified = True
141+
142+
# Patch WebSocketClientProtocol import - handle try/except block
143+
protocol_try_block_start = None
144+
for i, line in enumerate(lines):
145+
if "try:" in line and i + 1 < len(lines) and re.search(WEBSOCKETS_PROTOCOL_PATTERN, lines[i + 1]):
146+
protocol_try_block_start = i
147+
break
148+
149+
if protocol_try_block_start is not None:
150+
# Replace the try branch import
151+
indent = (
152+
re.match(r"(\s*)", lines[protocol_try_block_start + 1]).group(1)
153+
if re.match(r"(\s*)", lines[protocol_try_block_start + 1])
154+
else ""
155+
)
156+
# Check if there are multiple imports on the same line or separate lines
157+
if "WebSocketClientProtocol" in lines[protocol_try_block_start + 1]:
158+
lines[protocol_try_block_start + 1] = (
159+
f"{indent}from {relative_import} import WebSocketClientProtocol # type: ignore # noqa: E402"
160+
)
161+
162+
# Replace the except branch import if it exists
163+
if protocol_try_block_start + 2 < len(lines) and "except ImportError:" in lines[protocol_try_block_start + 2]:
164+
if (
165+
protocol_try_block_start + 3 < len(lines)
166+
and "WebSocketClientProtocol" in lines[protocol_try_block_start + 3]
167+
):
168+
except_indent = (
169+
re.match(r"(\s*)", lines[protocol_try_block_start + 3]).group(1)
170+
if re.match(r"(\s*)", lines[protocol_try_block_start + 3])
171+
else indent
172+
)
173+
lines[protocol_try_block_start + 3] = (
174+
f"{except_indent}from {relative_import} import WebSocketClientProtocol # type: ignore # noqa: E402"
175+
)
176+
modified = True
177+
178+
if modified:
179+
try:
180+
file_path.write_text("\n".join(lines), encoding="utf-8")
181+
print(f"Patched: {file_path}")
182+
return True
183+
except Exception as e:
184+
print(f"Error writing {file_path}: {e}", file=sys.stderr)
185+
return False
186+
187+
return False
188+
189+
190+
def find_client_files(root_dir: Path) -> list[Path]:
191+
"""Find all client files that might need patching."""
192+
client_files = []
193+
194+
# Look for client.py files in listen, speak, agent directories
195+
for pattern in [
196+
"**/listen/**/client.py",
197+
"**/speak/**/client.py",
198+
"**/agent/**/client.py",
199+
"**/listen/**/raw_client.py",
200+
"**/speak/**/raw_client.py",
201+
"**/agent/**/raw_client.py",
202+
"**/listen/**/socket_client.py",
203+
"**/speak/**/socket_client.py",
204+
"**/agent/**/socket_client.py",
205+
]:
206+
client_files.extend(root_dir.glob(pattern))
207+
208+
return sorted(set(client_files))
209+
210+
211+
def main():
212+
"""Main entry point."""
213+
if len(sys.argv) > 1:
214+
root_dir = Path(sys.argv[1])
215+
else:
216+
root_dir = Path(__file__).parent.parent / "src" / "deepgram"
217+
218+
if not root_dir.exists():
219+
print(f"Error: Directory {root_dir} does not exist", file=sys.stderr)
220+
sys.exit(1)
221+
222+
client_files = find_client_files(root_dir)
223+
224+
if not client_files:
225+
print("No client files found to patch")
226+
return
227+
228+
patched_count = 0
229+
for file_path in client_files:
230+
if patch_file(file_path):
231+
patched_count += 1
232+
233+
print(f"\nPatched {patched_count} file(s)")
234+
235+
236+
if __name__ == "__main__":
237+
main()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
WebSocket wrapper that provides drop-in replacements for websockets modules.
3+
4+
This module simply wraps websockets.sync.client and websockets.client to allow
5+
for future transport customization without modifying auto-generated code.
6+
"""
7+
8+
# Import the real websockets modules
9+
import websockets
10+
import websockets.exceptions
11+
import websockets.sync.client as websockets_sync_client
12+
import websockets.sync.connection as websockets_sync_connection
13+
14+
try:
15+
from websockets.legacy.client import WebSocketClientProtocol
16+
from websockets.legacy.client import connect as websockets_client_connect
17+
except ImportError:
18+
from websockets import WebSocketClientProtocol
19+
from websockets import connect as websockets_client_connect
20+
21+
# Re-export everything that might be imported from this module
22+
__all__ = [
23+
"websockets",
24+
"websockets_sync_client",
25+
"websockets_sync_connection",
26+
"websockets_client_connect",
27+
"WebSocketClientProtocol",
28+
]

0 commit comments

Comments
 (0)