Skip to content

Commit f10d00a

Browse files
authored
Merge pull request #13 from plasma-umass/claude_support
Updated Anthropic Claude model.
2 parents 3b8725f + 5f56183 commit f10d00a

File tree

4 files changed

+145
-179
lines changed

4 files changed

+145
-179
lines changed

src/coverup/coverup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def default_model():
7272
if 'ANTHROPIC_API_KEY' in os.environ:
7373
return "anthropic/claude-3-sonnet-20240229"
7474
if 'AWS_ACCESS_KEY_ID' in os.environ:
75-
return "bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
75+
return "anthropic.claude-3-5-sonnet-20241022-v2:0"
7676

7777
ap.add_argument('--model', type=str, default=default_model(),
7878
help='OpenAI model to use')

src/coverup/llm.py

Lines changed: 69 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -11,150 +11,100 @@
1111
with warnings.catch_warnings():
1212
# ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832
1313
warnings.simplefilter('ignore')
14-
import litellm # type: ignore
15-
14+
import litellm # type: ignore
1615

1716
# Turn off most logging
1817
litellm.set_verbose = False
1918
litellm.suppress_debug_info = True
2019
logging.getLogger().setLevel(logging.ERROR)
21-
2220
# Ignore unavailable parameters
23-
litellm.drop_params=True
24-
21+
litellm.drop_params = True
2522

2623
# Tier 5 rate limits for models; tuples indicate limit and interval in seconds
2724
# Extracted from https://platform.openai.com/account/limits on 8/30/2024
2825
MODEL_RATE_LIMITS = {
29-
'gpt-3.5-turbo': {
30-
'token': (50_000_000, 60), 'request': (10_000, 60)
31-
},
32-
'gpt-3.5-turbo-0125': {
33-
'token': (50_000_000, 60), 'request': (10_000, 60)
34-
},
35-
'gpt-3.5-turbo-1106': {
36-
'token': (50_000_000, 60), 'request': (10_000, 60)
37-
},
38-
'gpt-3.5-turbo-16k': {
39-
'token': (50_000_000, 60), 'request': (10_000, 60)
40-
},
41-
'gpt-3.5-turbo-instruct': {
42-
'token': (90_000, 60), 'request': (3_500, 60)
43-
},
44-
'gpt-3.5-turbo-instruct-0914': {
45-
'token': (90_000, 60), 'request': (3_500, 60)
46-
},
47-
'gpt-4': {
48-
'token': (1_000_000, 60), 'request': (10_000, 60)
49-
},
50-
'gpt-4-0314': {
51-
'token': (1_000_000, 60), 'request': (10_000, 60)
52-
},
53-
'gpt-4-0613': {
54-
'token': (1_000_000, 60), 'request': (10_000, 60)
55-
},
56-
'gpt-4-32k-0314': {
57-
'token': (150_000, 60), 'request': (1_000, 60)
58-
},
59-
'gpt-4-turbo': {
60-
'token': (2_000_000, 60), 'request': (10_000, 60)
61-
},
62-
'gpt-4-turbo-2024-04-09': {
63-
'token': (2_000_000, 60), 'request': (10_000, 60)
64-
},
65-
'gpt-4-turbo-preview': {
66-
'token': (2_000_000, 60), 'request': (10_000, 60)
67-
},
68-
'gpt-4-0125-preview': {
69-
'token': (2_000_000, 60), 'request': (10_000, 60)
70-
},
71-
'gpt-4-1106-preview': {
72-
'token': (2_000_000, 60), 'request': (10_000, 60)
73-
},
74-
'gpt-4o': {
75-
'token': (30_000_000, 60), 'request': (10_000, 60)
76-
},
77-
'gpt-4o-2024-05-13': {
78-
'token': (30_000_000, 60), 'request': (10_000, 60)
79-
},
80-
'gpt-4o-2024-08-06': {
81-
'token': (30_000_000, 60), 'request': (10_000, 60)
82-
},
83-
'gpt-4o-mini': {
84-
'token': (150_000_000, 60), 'request': (30_000, 60)
85-
},
86-
'gpt-4o-mini-2024-07-18': {
87-
'token': (150_000_000, 60), 'request': (30_000, 60)
88-
}
26+
'gpt-3.5-turbo': {'token': (50_000_000, 60), 'request': (10_000, 60)},
27+
'gpt-3.5-turbo-0125': {'token': (50_000_000, 60), 'request': (10_000, 60)},
28+
'gpt-3.5-turbo-1106': {'token': (50_000_000, 60), 'request': (10_000, 60)},
29+
'gpt-3.5-turbo-16k': {'token': (50_000_000, 60), 'request': (10_000, 60)},
30+
'gpt-3.5-turbo-instruct': {'token': (90_000, 60), 'request': (3_500, 60)},
31+
'gpt-3.5-turbo-instruct-0914': {'token': (90_000, 60), 'request': (3_500, 60)},
32+
'gpt-4': {'token': (1_000_000, 60), 'request': (10_000, 60)},
33+
'gpt-4-0314': {'token': (1_000_000, 60), 'request': (10_000, 60)},
34+
'gpt-4-0613': {'token': (1_000_000, 60), 'request': (10_000, 60)},
35+
'gpt-4-32k-0314': {'token': (150_000, 60), 'request': (1_000, 60)},
36+
'gpt-4-turbo': {'token': (2_000_000, 60), 'request': (10_000, 60)},
37+
'gpt-4-turbo-2024-04-09': {'token': (2_000_000, 60), 'request': (10_000, 60)},
38+
'gpt-4-turbo-preview': {'token': (2_000_000, 60), 'request': (10_000, 60)},
39+
'gpt-4-0125-preview': {'token': (2_000_000, 60), 'request': (10_000, 60)},
40+
'gpt-4-1106-preview': {'token': (2_000_000, 60), 'request': (10_000, 60)},
41+
'gpt-4o': {'token': (30_000_000, 60), 'request': (10_000, 60)},
42+
'gpt-4o-2024-05-13': {'token': (30_000_000, 60), 'request': (10_000, 60)},
43+
'gpt-4o-2024-08-06': {'token': (30_000_000, 60), 'request': (10_000, 60)},
44+
'gpt-4o-mini': {'token': (150_000_000, 60), 'request': (30_000, 60)},
45+
'gpt-4o-mini-2024-07-18': {'token': (150_000_000, 60), 'request': (30_000, 60)},
46+
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': {'token': (1_000_000, 60), 'request': (10_000, 60)},
47+
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': {'token': (1_000_000, 60), 'request': (10_000, 60)},
48+
'bedrock/anthropic.claude-3-opus-20240229-v1:0': {'token': (1_000_000, 60), 'request': (10_000, 60)},
49+
'anthropic.claude-3-5-sonnet-20241022-v2:0': {'token': (400_000, 50), 'request': (10_000, 50)},
50+
'us.anthropic.claude-3-5-sonnet-20241022-v2:0': {'token': (400_000, 50), 'request': (10_000, 50)},
51+
'us.anthropic.claude-3-7-sonnet-20250219-v1:0': {'token': (400_000, 50), 'request': (10_000, 50)},
8952
}
9053

91-
92-
def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]|None:
54+
def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int] | None:
9355
if model_name.startswith('openai/'):
9456
model_name = model_name[7:]
9557

9658
if (model_limits := MODEL_RATE_LIMITS.get(model_name)):
9759
limit = model_limits.get('token')
9860

99-
try:
100-
import tiktoken
101-
tiktoken.encoding_for_model(model_name)
102-
except KeyError:
103-
warnings.warn(f"Unable to get encoding for {model_name}; will ignore rate limit")
104-
return None
61+
if not "anthropic" in model_name:
62+
try:
63+
import tiktoken
64+
tiktoken.encoding_for_model(model_name)
65+
except KeyError:
66+
warnings.warn(f"Unable to get encoding for {model_name}; will ignore rate limit")
67+
return None
10568

10669
return limit
10770

10871
return None
10972

110-
111-
def compute_cost(usage: dict, model_name: str) -> float|None:
73+
def compute_cost(usage: dict, model_name: str) -> float | None:
11274
from math import ceil
11375

11476
if model_name.startswith('openai/'):
11577
model_name = model_name[7:]
11678

11779
if 'prompt_tokens' in usage and 'completion_tokens' in usage:
11880
if (cost := litellm.model_cost.get(model_name)):
119-
return usage['prompt_tokens'] * cost['input_cost_per_token'] +\
81+
return usage['prompt_tokens'] * cost['input_cost_per_token'] + \
12082
usage['completion_tokens'] * cost['output_cost_per_token']
12183

12284
return None
12385

124-
12586
_token_encoding_cache: dict[str, T.Any] = dict()
87+
12688
def count_tokens(model_name: str, completion: dict):
12789
"""Counts the number of tokens in a chat completion request."""
128-
import tiktoken
129-
130-
if not (encoding := _token_encoding_cache.get(model_name)):
131-
model = model_name
132-
if model_name.startswith('openai/'):
133-
model = model_name[7:]
134-
135-
encoding = _token_encoding_cache[model_name] = tiktoken.encoding_for_model(model)
90+
from litellm import token_counter
13691

137-
count = 0
138-
for m in completion['messages']:
139-
count += len(encoding.encode(m['content']))
92+
count = token_counter(model=model_name, messages=completion['messages'])
14093

14194
return count
14295

143-
14496
class ChatterError(Exception):
14597
pass
14698

147-
14899
class Chatter:
149100
"""Chats with an LLM."""
150-
151101
def __init__(self, model: str) -> None:
152102
Chatter._validate_model(model)
153103

154104
self._model = model
155-
self._model_temperature: float|None = None
105+
self._model_temperature: float | None = None
156106
self._max_backoff = 64 # seconds
157-
self.token_rate_limit: AsyncLimiter|None
107+
self.token_rate_limit: AsyncLimiter | None
158108
self.set_token_rate_limit(token_rate_limit_for_model(model))
159109
self._add_cost = lambda cost: None
160110
self._log_msg = lambda ctx, msg: None
@@ -163,80 +113,67 @@ def __init__(self, model: str) -> None:
163113
self._functions: dict[str, dict[str, T.Any]] = dict()
164114
self._max_func_calls_per_chat = 50
165115

166-
167116
@staticmethod
168117
def _validate_model(model) -> None:
169118
try:
170119
_, provider, _, _ = litellm.get_llm_provider(model)
171120
except litellm.exceptions.BadRequestError:
172-
raise ChatterError(textwrap.dedent("""\
121+
raise ChatterError(textwrap.dedent("""
173122
Unknown or unsupported model.
174-
Please see https://docs.litellm.ai/docs/providers for supported models."""
175-
))
123+
Please see https://docs.litellm.ai/docs/providers for supported models."""))
176124

177125
result = litellm.validate_environment(model)
178126
if result['missing_keys']:
179127
if provider == 'openai':
180-
raise ChatterError(textwrap.dedent("""\
128+
raise ChatterError(textwrap.dedent("""
181129
You need an OpenAI key to use {model}.
182-
130+
183131
You can get a key here: https://platform.openai.com/api-keys
184132
Set the environment variable OPENAI_API_KEY to your key value
185-
export OPENAI_API_KEY=<your key>"""
186-
))
133+
export OPENAI_API_KEY=<your key>"""))
187134
elif provider == 'bedrock':
188-
raise ChatterError(textwrap.dedent("""\
135+
raise ChatterError(textwrap.dedent("""
189136
To use Bedrock, you need an AWS account. Set the following environment variables:
190137
export AWS_ACCESS_KEY_ID=<your key id>
191138
export AWS_SECRET_ACCESS_KEY=<your secret key>
192139
export AWS_REGION_NAME=us-west-2
193140
194141
You also need to request access to Claude:
195-
https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access"""
196-
))
142+
https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access"""))
197143
else:
198-
raise ChatterError(textwrap.dedent(f"""\
144+
raise ChatterError(textwrap.dedent(f"""
199145
You need a key (or keys) from {provider} to use {model}.
200146
Set the following environment variables:
201-
{', '.join(result['missing_keys'])}"""
202-
))
203-
147+
{', '.join(result['missing_keys'])}"""))
204148

205149
def set_model_temperature(self, temperature: T.Optional[float]) -> None:
206150
self._model_temperature = temperature
207151

208-
209152
def set_token_rate_limit(self, limit: T.Union[T.Tuple[int, int], None]) -> None:
210153
if limit:
211154
self.token_rate_limit = AsyncLimiter(*limit)
212155
else:
213156
self.token_rate_limit = None
214157

215-
216158
def set_max_backoff(self, max_backoff: int) -> None:
217159
self._max_backoff = max_backoff
218160

219-
220161
def set_add_cost(self, add_cost: T.Callable) -> None:
221162
"""Sets up a callback to indicate additional costs."""
222163
self._add_cost = add_cost
223164

224-
225165
def set_log_msg(self, log_msg: T.Callable[[str, str], None]) -> None:
226166
"""Sets up a callback to write a message to the log."""
227167
self._log_msg = log_msg
228168

229-
230169
def set_log_json(self, log_json: T.Callable[[str, dict], None]) -> None:
231170
"""Sets up a callback to write a json exchange to the log."""
232171
self._log_json = log_json
233172

234-
235173
def set_signal_retry(self, signal_retry: T.Callable) -> None:
236174
"""Sets up a callback to indicate a retry."""
237175
self._signal_retry = signal_retry
238176

239-
240177
def add_function(self, function: T.Callable) -> None:
241178
"""Makes a function availabe to the LLM."""
242179
if not litellm.supports_function_calling(self._model):
@@ -245,28 +182,29 @@ def add_function(self, function: T.Callable) -> None:
245182
try:
246183
schema = json.loads(getattr(function, "__doc__", ""))
247184
if 'name' not in schema:
248-
raise ChatterError("Name missing from function {function} schema.")
185+
raise ChatterError(f"Name missing from function {function} schema.")
249186
except json.decoder.JSONDecodeError as e:
250187
raise ChatterError(f"Invalid JSON in function docstring: {e}")
251188

252-
assert schema['name'] not in self._functions, "Duplicated function name {schema['name']}"
189+
assert schema['name'] not in self._functions, f"Duplicated function name {schema['name']}"
253190
self._functions[schema['name']] = {"function": function, "schema": schema}
254191

255-
256192
def _request(self, messages: T.List[dict]) -> dict:
257-
return {
193+
request = {
258194
'model': self._model,
259195
**({'temperature': self._model_temperature} if self._model_temperature is not None else {}),
260196
'messages': messages,
261197
**({'api_base': "http://localhost:11434"} if "ollama" in self._model else {}),
262-
**({'tools': [{'type': 'function', 'function': f['schema']} for f in self._functions.values()]} \
263-
if self._functions else {})
198+
**({'tools': [{'type': 'function', 'function': f['schema']} for f in self._functions.values()]} if self._functions else {})
264199
}
265200

201+
if self._model.startswith("bedrock/anthropic"):
202+
request['anthropic_version'] = "bedrock-2023-05-31"
266203

267-
async def _send_request(self, request: dict, ctx: object) -> litellm.ModelResponse|None:
268-
"""Sends the LLM chat request, handling common failures and returning the response."""
204+
return request
269205

206+
async def _send_request(self, request: dict, ctx: object) -> litellm.ModelResponse | None:
207+
"""Sends the LLM chat request, handling common failures and returning the response."""
270208
sleep = 1
271209
while True:
272210
try:
@@ -290,11 +228,12 @@ async def _send_request(self, request: dict, ctx: object) -> litellm.ModelRespon
290228
self._log_msg(ctx, f"Failed: {type(e)} {e}")
291229
raise
292230

293-
self._log_msg(ctx, f"Error: {type(e)} {e}")
294-
295231
import random
296-
sleep = min(sleep*2, self._max_backoff)
297-
sleep_time = random.uniform(sleep/2, sleep)
232+
sleep = min(sleep * 2, self._max_backoff)
233+
sleep_time = random.uniform(sleep / 2, sleep)
234+
235+
self._log_msg(ctx, f"Error: {type(e)} {e} {sleep=} {sleep_time=}")
236+
298237
self._signal_retry()
299238
await asyncio.sleep(sleep_time)
300239

@@ -319,27 +258,24 @@ async def _send_request(self, request: dict, ctx: object) -> litellm.ModelRespon
319258
self._log_msg(ctx, f"Error: {type(e)} {e}")
320259
return None # gives up this segment
321260

322-
323261
def _call_function(self, ctx: object, tool_call: litellm.ModelResponse) -> str:
324262
args = json.loads(tool_call.function.arguments)
325263
function = self._functions[tool_call.function.name]
326264

327265
try:
328266
return str(function['function'](ctx=ctx, **args))
329267
except Exception as e:
330-
self._log_msg(ctx, f"""\
331-
Error executing function "{tool_call.function.name}": {e}
268+
self._log_msg(ctx, f"""
269+
Error executing function \"{tool_call.function.name}\": {e}
332270
args:{args}
333271
334272
{traceback.format_exc()}
335273
""")
336274
return f'Error executing function: {e}'
337275

338-
339-
async def chat(self, messages: list, *, ctx: T.Optional[object] = None) -> dict|None:
276+
async def chat(self, messages: list, *, ctx: T.Optional[object] = None) -> dict | None:
340277
"""Chats with the LLM, sending the given messages, handling common failures and returning the response.
341278
Automatically calls any tool functions requested."""
342-
343279
func_calls = 0
344280
while func_calls <= self._max_func_calls_per_chat:
345281
request = self._request(messages)

0 commit comments

Comments
 (0)