Skip to content

Commit 38732df

Browse files
authored
[serving] refactor the api_llm_serving and add auto test for it (#434)
* [serving] refactor the api_llm_serving and add auto test for it * [pytest] ignore the dataflow main dir to avoid scaffold test
1 parent bd256eb commit 38732df

File tree

11 files changed

+573
-118
lines changed

11 files changed

+573
-118
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,19 @@ jobs:
6060
UV_SYSTEM_PYTHON: "1"
6161
run: |
6262
uv pip install --no-cache -r requirements.txt
63-
uv pip install -e .
63+
uv pip install -e .[test]
6464
6565
- name: Run tests on Windows
6666
if: startsWith(matrix.os, 'windows')
6767
env:
6868
PYTHONDONTWRITEBYTECODE: 1
6969
run: |
70-
pytest -m cpu test
70+
pytest -m "cpu or api"
7171
7272
- name: Run tests on Linux / macOS
7373
if: ${{ !startsWith(matrix.os, 'windows') }}
7474
env:
7575
PYTHONDONTWRITEBYTECODE: 1
7676
run: |
77-
pytest -m cpu test
77+
pytest -m "cpu or api"
78+
Lines changed: 148 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import json
2+
import warnings
23
import requests
4+
from requests.adapters import HTTPAdapter
35
import os
46
import logging
57
from ..logger import get_logger
@@ -20,16 +22,19 @@ def __init__(self,
2022
api_url: str = "https://api.openai.com/v1/chat/completions",
2123
key_name_of_api_key: str = "DF_API_KEY",
2224
model_name: str = "gpt-4o",
25+
temperature: float = 0.0,
2326
max_workers: int = 10,
2427
max_retries: int = 5,
25-
temperature = 0.0
28+
timeout: tuple[float, float] = (10.0, 120.0), # connect timeout, read timeout
2629
):
2730
# Get API key from environment variable or config
2831
self.api_url = api_url
2932
self.model_name = model_name
33+
self.temperature = temperature
3034
self.max_workers = max_workers
3135
self.max_retries = max_retries
32-
self.temperature = temperature
36+
self.timeout = timeout
37+
3338
self.logger = get_logger()
3439

3540
# config api_key in os.environ global, since safty issue.
@@ -38,7 +43,25 @@ def __init__(self,
3843
error_msg = f"Lack of `{key_name_of_api_key}` in environment variables. Please set `{key_name_of_api_key}` as your api-key to {api_url} before using APILLMServing_request."
3944
self.logger.error(error_msg)
4045
raise ValueError(error_msg)
46+
47+
48+
self.session = requests.Session()
49+
adapter = HTTPAdapter(
50+
pool_connections=self.max_workers,
51+
pool_maxsize=self.max_workers,
52+
max_retries=0, # 你已经有 _api_chat_id_retry 了,这里不要重复重试
53+
pool_block=True # 池满时阻塞,避免无限建连接导致资源抖动
54+
)
55+
self.session.mount("https://", adapter)
56+
self.session.mount("http://", adapter)
4157

58+
self.headers = {
59+
'Authorization': f"Bearer {self.api_key}",
60+
'Content-Type': 'application/json',
61+
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)'
62+
}
63+
64+
4265
def format_response(self, response: dict, is_embedding: bool = False) -> str:
4366
"""Format API response, supporting both embedding and chat completion modes"""
4467

@@ -62,37 +85,44 @@ def format_response(self, response: dict, is_embedding: bool = False) -> str:
6285
return f"<think>{reasoning_content}</think>\n<answer>{content}</answer>"
6386

6487
return content
88+
# deprecated
89+
# def api_chat(self, system_info: str, messages: str, model: str):
90+
# try:
91+
# payload = json.dumps({
92+
# "model": model,
93+
# "messages": [
94+
# {"role": "system", "content": system_info},
95+
# {"role": "user", "content": messages}
96+
# ],
97+
# "temperature": self.temperature
98+
# })
6599

100+
# headers = {
101+
# 'Authorization': f"Bearer {self.api_key}",
102+
# 'Content-Type': 'application/json',
103+
# 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)'
104+
# }
105+
# # Make a POST request to the API
106+
# response = requests.post(self.api_url, headers=headers, data=payload, timeout=60)
107+
# if response.status_code == 200:
108+
# response_data = response.json()
109+
# return self.format_response(response_data)
110+
# else:
111+
# logging.error(f"API request failed with status {response.status_code}: {response.text}")
112+
# return None
113+
# except Exception as e:
114+
# logging.error(f"API request error: {e}")
115+
# return None
66116

67-
def api_chat(self, system_info: str, messages: str, model: str):
68-
try:
69-
payload = json.dumps({
70-
"model": model,
71-
"messages": [
72-
{"role": "system", "content": system_info},
73-
{"role": "user", "content": messages}
74-
],
75-
"temperature": self.temperature
76-
})
77-
78-
headers = {
79-
'Authorization': f"Bearer {self.api_key}",
80-
'Content-Type': 'application/json',
81-
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)'
82-
}
83-
# Make a POST request to the API
84-
response = requests.post(self.api_url, headers=headers, data=payload, timeout=60)
85-
if response.status_code == 200:
86-
response_data = response.json()
87-
return self.format_response(response_data)
88-
else:
89-
logging.error(f"API request failed with status {response.status_code}: {response.text}")
90-
return None
91-
except Exception as e:
92-
logging.error(f"API request error: {e}")
93-
return None
94-
95-
def _api_chat_with_id(self, id, payload, model, is_embedding: bool = False, json_schema: dict = None):
117+
def _api_chat_with_id(
118+
self,
119+
id: int,
120+
payload,
121+
model: str,
122+
is_embedding: bool = False,
123+
json_schema: dict = None
124+
):
125+
start = time.time()
96126
try:
97127
if is_embedding:
98128
payload = json.dumps({
@@ -117,24 +147,38 @@ def _api_chat_with_id(self, id, payload, model, is_embedding: bool = False, json
117147
}
118148
}
119149
})
120-
121-
headers = {
122-
'Authorization': f"Bearer {self.api_key}",
123-
'Content-Type': 'application/json',
124-
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)'
125-
}
126150
# Make a POST request to the API
127-
response = requests.post(self.api_url, headers=headers, data=payload, timeout=1800)
151+
response = self.session.post(self.api_url, headers=self.headers, data=payload, timeout=self.timeout)
152+
cost = time.time() - start
128153
if response.status_code == 200:
129154
# logging.info(f"API request successful")
130155
response_data = response.json()
131156
# logging.info(f"API response: {response_data['choices'][0]['message']['content']}")
132157
return id,self.format_response(response_data, is_embedding)
133158
else:
134-
logging.error(f"API request failed with status {response.status_code}: {response.text}")
159+
# self.logger.exception(f"API request failed (id = {id}) with status {response.status_code}: {response.text}")
160+
self.logger.error(f"API request failed id={id} status={response.status_code} cost={cost:.2f}s body={response.text[:500]}")
135161
return id, None
162+
except requests.exceptions.Timeout as e:
163+
cost = time.time() - start
164+
warnings.warn(f"API timeout (id={id}) cost={cost:.2f}s: {e}", RuntimeWarning)
165+
return id, None
166+
167+
except requests.exceptions.ConnectionError as e:
168+
cost = time.time() - start
169+
msg = str(e).lower()
170+
171+
# requests/urllib3 有时会把 ReadTimeout 包装成 ConnectionError
172+
if "timed out" in msg or "read timed out" in msg:
173+
warnings.warn(f"API timeout (id={id}) cost={cost:.2f}s: {e}", RuntimeWarning)
174+
return id, None
175+
176+
self.logger.error(f"API connection error (id={id}) cost={cost:.2f}s: {e}")
177+
raise RuntimeError(f"Cannot connect to LLM server: {e}") from e
178+
136179
except Exception as e:
137-
logging.error(f"API request error: {e}")
180+
cost = time.time() - start
181+
self.logger.exception(f"API request error (id = {id}) cost={cost:.2f}s: {e}")
138182
return id, None
139183

140184
def _api_chat_id_retry(self, id, payload, model, is_embedding : bool = False, json_schema: dict = None):
@@ -145,81 +189,80 @@ def _api_chat_id_retry(self, id, payload, model, is_embedding : bool = False, js
145189
time.sleep(2**i)
146190
return id, None
147191

148-
def generate_from_input(self,
149-
user_inputs: list[str],
150-
system_prompt: str = "You are a helpful assistant",
151-
json_schema: dict = None,
152-
) -> list[str]:
153-
192+
def _run_threadpool(self, task_args_list: list[dict], desc: str) -> list:
193+
"""
194+
task_args_list: 每个元素都是 _api_chat_id_retry 的入参 dict
195+
e.g. {"id": 0, "payload": [...], "model": "...", "is_embedding": False, "json_schema": None}
196+
返回值按 id 回填到 responses
197+
"""
198+
responses = [None] * len(task_args_list)
154199

155-
responses = [None] * len(user_inputs)
156-
# -- end of subfunction api_chat_with_id --
157-
158-
# 使用 ThreadPoolExecutor 并行处理多个问题
159-
# logging.info(f"Generating {len(questions)} responses")
160200
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
161201
futures = [
162-
executor.submit(
163-
self._api_chat_id_retry,
164-
payload = [
165-
{"role": "system", "content": system_prompt},
166-
{"role": "user", "content": question}
167-
],
168-
model = self.model_name,
169-
json_schema = json_schema,
170-
id = idx,
171-
) for idx, question in enumerate(user_inputs)
202+
executor.submit(self._api_chat_id_retry, **task_args)
203+
for task_args in task_args_list
172204
]
173-
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating......"):
174-
response = future.result() # (id, response)
205+
206+
for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
207+
try:
208+
response = future.result() # (id, response)
209+
# response[0] 是 id,response[1] 是实际响应, 用于按 id 回填 responses 列表
175210
responses[response[0]] = response[1]
211+
except Exception:
212+
# 理论上 worker 内部已经 try/except 了,但这里兜底更安全
213+
self.logger.exception("Worker crashed unexpectedly in threadpool")
214+
176215
return responses
177216

178-
def generate_from_conversations(self, conversations: list[list[dict]]) -> list[str]:
217+
def generate_from_input(self,
218+
user_inputs: list[str],
219+
system_prompt: str = "You are a helpful assistant",
220+
json_schema: dict = None,
221+
) -> list[str]:
222+
task_args_list = [
223+
dict(
224+
id=idx,
225+
payload=[
226+
{"role": "system", "content": system_prompt},
227+
{"role": "user", "content": question},
228+
],
229+
model=self.model_name,
230+
json_schema=json_schema,
231+
)
232+
for idx, question in enumerate(user_inputs)
233+
]
234+
return self._run_threadpool(task_args_list, desc="Generating responses from prompts......")
179235

180-
responses = [None] * len(conversations)
181-
# -- end of subfunction api_chat_with_id --
236+
237+
def generate_from_conversations(self, conversations: list[list[dict]]) -> list[str]:
182238

183-
# 使用 ThreadPoolExecutor 并行处理多个问题
184-
# logging.info(f"Generating {len(questions)} responses")
185-
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
186-
futures = [
187-
executor.submit(
188-
self._api_chat_id_retry,
189-
payload = dialogue,
190-
model = self.model_name,
191-
id = idx
192-
) for idx, dialogue in enumerate(conversations)
193-
]
194-
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating......"):
195-
response = future.result() # (id, response)
196-
responses[response[0]] = response[1]
197-
return responses
239+
task_args_list = [
240+
dict(
241+
id=idx,
242+
payload=dialogue,
243+
model=self.model_name,
244+
)
245+
for idx, dialogue in enumerate(conversations)
246+
]
247+
return self._run_threadpool(task_args_list, desc="Generating responses from conversations......")
248+
198249

199250
def generate_embedding_from_input(self, texts: list[str]) -> list[list[float]]:
200-
201-
responses = [None] * len(texts)
202-
# -- end of subfunction api_embedding_with_id --
203-
204-
# 使用 ThreadPoolExecutor 并行处理多个问题
205-
# logging.info(f"Generating {len(questions)} responses")
206-
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
207-
futures = [
208-
executor.submit(
209-
self._api_chat_id_retry,
210-
payload = txt,
211-
model = self.model_name,
212-
id = idx,
213-
is_embedding = True
214-
) for idx, txt in enumerate(texts)
215-
]
216-
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating embedding......"):
217-
response = future.result() # (id, response)
218-
responses[response[0]] = response[1]
219-
return responses
251+
task_args_list = [
252+
dict(
253+
id=idx,
254+
payload=txt,
255+
model=self.model_name,
256+
is_embedding=True,
257+
)
258+
for idx, txt in enumerate(texts)
259+
]
260+
return self._run_threadpool(task_args_list, desc="Generating embedding......")
220261

221262
def cleanup(self):
222-
# Cleanup resources if needed
223-
logging.info("Cleaning up resources in APILLMServing_request")
224-
# No specific cleanup actions needed for this implementation
225-
pass
263+
self.logger.info("Cleaning up resources in APILLMServing_request")
264+
try:
265+
if hasattr(self, "session") and self.session:
266+
self.session.close()
267+
except Exception:
268+
self.logger.exception("Failed to close requests session")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ dependencies = {file = "requirements.txt"}
5252

5353

5454
[project.optional-dependencies]
55+
test = ["flask"]
5556
vllm =["vllm>=0.7.0,<=0.9.2", "numpy<2.0.0"]
5657
vllm07 = ["vllm<0.8", "numpy<2.0.0"]
5758
vllm08 = ["vllm<0.9"]

pytest.ini

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
[pytest]
2-
addopts = --ignore=test/test_ragkbcleaning.py
2+
addopts =
3+
--ignore=test/test_ragkbcleaning.py
4+
--ignore=test/test_prompt_template_with_reasoning.py
5+
--ignore=test/general_text.py
6+
7+
norecursedirs =
8+
tests/legacy
9+
build
10+
dist
11+
dataflow
12+
313
markers =
414
cpu: mark test to run on CPU
5-
gpu: mark test to run on GPU
15+
gpu: mark test to run on GPU
16+
api: mark test to run API related tests

0 commit comments

Comments
 (0)