Skip to content

Commit 0bc7e82

Browse files
authored
[Serving] Add customized generation configs (#442)
1 parent 529a9fe commit 0bc7e82

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

dataflow/serving/api_llm_serving_request.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@ def __init__(self,
2626
max_workers: int = 10,
2727
max_retries: int = 5,
2828
timeout: tuple[float, float] = (10.0, 120.0), # connect timeout, read timeout
29+
**configs
2930
):
3031
# Get API key from environment variable or config
3132
self.api_url = api_url
3233
self.model_name = model_name
33-
self.temperature = temperature
34+
# self.temperature = temperature
3435
self.max_workers = max_workers
3536
self.max_retries = max_retries
3637
self.timeout = timeout
38+
self.configs = configs
39+
self.configs.update({"temperature": temperature})
3740

3841
self.logger = get_logger()
3942

@@ -125,17 +128,17 @@ def _api_chat_with_id(
125128
start = time.time()
126129
try:
127130
if is_embedding:
128-
payload = json.dumps({
131+
payload = {
129132
"model": model,
130133
"input": payload
131-
})
134+
}
132135
elif json_schema is None:
133-
payload = json.dumps({
136+
payload = {
134137
"model": model,
135138
"messages": payload
136-
})
139+
}
137140
else:
138-
payload = json.dumps({
141+
payload = {
139142
"model": model,
140143
"messages": payload,
141144
"response_format": {
@@ -146,7 +149,10 @@ def _api_chat_with_id(
146149
"schema": json_schema
147150
}
148151
}
149-
})
152+
}
153+
154+
payload.update(self.configs)
155+
payload = json.dumps(payload)
150156
# Make a POST request to the API
151157
response = self.session.post(self.api_url, headers=self.headers, data=payload, timeout=self.timeout)
152158
cost = time.time() - start

0 commit comments

Comments
 (0)