Skip to content

Commit 8cb2240

Browse files
committed
add initial bedrock models
1 parent 30d992f commit 8cb2240

File tree

12 files changed

+1869
-258
lines changed

12 files changed

+1869
-258
lines changed

Cargo.lock

Lines changed: 1558 additions & 188 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ serde_json = "1"
2525
toml = "0"
2626
env_logger = "0"
2727
reqwest = { version = "0", default-features = false, features = ["http2", "json", "blocking", "multipart", "rustls-tls"] }
28+
aws-config = "1.8.3"
29+
aws-sdk-bedrockruntime = "1.99.0"
30+
tokio = "1.47.0"
2831

2932
[dev-dependencies]
3033
tempfile = "3"

src/client/aws.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
use crate::client::request_schemas::AnthropicPrompt;
2+
use crate::config::api::{ApiClient, ApiConfig, ApiError};
3+
use crate::config::prompt::{Message, Prompt};
4+
use crate::Api;
5+
6+
use aws_config::BehaviorVersion;
7+
use aws_sdk_bedrockruntime::{operation::converse::ConverseOutput, Client as BedrockClient};
8+
use tokio::runtime::Runtime;
9+
10+
pub struct AwsClient {
11+
api_config: ApiConfig,
12+
client: BedrockClient,
13+
prompt: Prompt,
14+
runtime: Runtime,
15+
}
16+
17+
impl AwsClient {
18+
pub fn new(api_config: ApiConfig, prompt: Prompt) -> Self {
19+
let runtime = match tokio::runtime::Builder::new_current_thread()
20+
.enable_all()
21+
.build()
22+
{
23+
Err(e) => panic!("AwsClient failed to initialize tokio runtime: {e}"),
24+
Ok(v) => v,
25+
};
26+
let config = runtime
27+
.block_on(async { aws_config::load_defaults(BehaviorVersion::v2025_01_17()).await });
28+
println!("config: {:?}", config.region());
29+
let client = BedrockClient::new(&config);
30+
31+
AwsClient {
32+
api_config,
33+
client,
34+
prompt,
35+
runtime,
36+
}
37+
}
38+
39+
fn get_converse_output_text(&self, output: ConverseOutput) -> Result<String, ApiError> {
40+
let text = output
41+
.output()
42+
.ok_or(ApiError::new(
43+
self.prompt.model.clone(),
44+
"no output".to_string(),
45+
))?
46+
.as_message()
47+
.map_err(|_| {
48+
ApiError::new(
49+
self.prompt.model.clone(),
50+
"output not a message".to_string(),
51+
)
52+
})?
53+
.content()
54+
.first()
55+
.ok_or(ApiError::new(
56+
self.prompt.model.clone(),
57+
"no content in message".to_string(),
58+
))?
59+
.as_text()
60+
.map_err(|_| {
61+
ApiError::new(self.prompt.model.clone(), "content is not text".to_string())
62+
})?
63+
.to_string();
64+
Ok(text)
65+
}
66+
}
67+
68+
impl ApiClient for AwsClient {
69+
fn do_request(&self) -> Result<Message, ApiError> {
70+
let prompt_format = match self.prompt.api {
71+
Api::AWSBedrock => AnthropicPrompt::from(self.prompt.clone()),
72+
Api::AnotherApiForTests => panic!("This api is not made for actual use."),
73+
_ => unreachable!(),
74+
};
75+
76+
let result = self.runtime.block_on(async {
77+
let response = self
78+
.client
79+
.converse()
80+
.model_id(self.prompt.model.as_ref().unwrap())
81+
.set_messages(Some(prompt_format.into()))
82+
.send()
83+
.await;
84+
85+
match response {
86+
Ok(output) => {
87+
let text = self.get_converse_output_text(output)?;
88+
Ok(text)
89+
}
90+
Err(e) => {
91+
use aws_sdk_bedrockruntime::error::DisplayErrorContext;
92+
println!("error: {}", DisplayErrorContext(&e));
93+
94+
Err(e
95+
.as_service_error()
96+
.map(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))
97+
.unwrap_or_else(|| {
98+
ApiError::new(
99+
self.prompt.model.clone(),
100+
"Unknown service error".to_string(),
101+
)
102+
}))
103+
}
104+
}
105+
});
106+
107+
match result {
108+
Ok(response) => Ok(Message::assistant(response.as_str())),
109+
Err(e) => Err(e),
110+
}
111+
}
112+
}

src/client/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
pub mod aws;
2+
mod request_schemas;
3+
pub mod reqwest;
4+
mod response_schemas;
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,17 @@ impl From<Prompt> for AnthropicPrompt {
6565
}
6666
}
6767
}
68+
69+
#[derive(Serialize, Deserialize)]
70+
#[serde(untagged)]
71+
pub enum PromptFormat {
72+
OpenAi(OpenAiPrompt),
73+
Anthropic(AnthropicPrompt),
74+
AWSBedrock(AnthropicPrompt),
75+
}
76+
77+
impl Into<Vec<aws_sdk_bedrockruntime::types::Message>> for AnthropicPrompt {
78+
fn into(self) -> Vec<aws_sdk_bedrockruntime::types::Message> {
79+
self.messages.iter().cloned().map(|m| m.into()).collect()
80+
}
81+
}

src/client/reqwest.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
use std::time::Duration;
2+
3+
use super::request_schemas::{AnthropicPrompt, OpenAiPrompt, PromptFormat};
4+
use super::response_schemas::{AnthropicResponse, OllamaResponse, OpenAiResponse};
5+
use crate::config::api::{ApiClient, ApiConfig, ApiError};
6+
use crate::config::prompt::{Message, Prompt};
7+
use crate::utils::handle_api_response;
8+
use crate::Api;
9+
10+
pub struct ReqwestClient {
11+
api_config: ApiConfig,
12+
client: reqwest::blocking::Client,
13+
prompt: Prompt,
14+
}
15+
16+
impl ReqwestClient {
17+
pub fn new(api_config: ApiConfig, prompt: Prompt) -> Self {
18+
let client = reqwest::blocking::Client::builder()
19+
.timeout(
20+
api_config
21+
.timeout_seconds
22+
.map(|t| Duration::from_secs(t.into())),
23+
)
24+
.build()
25+
.expect("Unable to initialize reqwest HTTP client");
26+
27+
ReqwestClient {
28+
api_config,
29+
client,
30+
prompt,
31+
}
32+
}
33+
}
34+
35+
impl ApiClient for ReqwestClient {
36+
fn do_request(&self) -> Result<Message, ApiError> {
37+
let prompt_format = match self.prompt.api {
38+
Api::Ollama
39+
| Api::Openai
40+
| Api::AzureOpenai
41+
| Api::Mistral
42+
| Api::Groq
43+
| Api::Cerebras => PromptFormat::OpenAi(OpenAiPrompt::from(self.prompt.clone())),
44+
Api::Anthropic => PromptFormat::Anthropic(AnthropicPrompt::from(self.prompt.clone())),
45+
Api::AWSBedrock => PromptFormat::AWSBedrock(AnthropicPrompt::from(self.prompt.clone())),
46+
Api::AnotherApiForTests => panic!("This api is not made for actual use."),
47+
};
48+
49+
let request = self
50+
.client
51+
.post(&self.api_config.url)
52+
.header("Content-Type", "application/json")
53+
.json(&prompt_format);
54+
55+
// https://stackoverflow.com/questions/77862683/rust-reqwest-cant-make-a-request
56+
let request = match self.prompt.api {
57+
Api::Cerebras => request.header("User-Agent", "CUSTOM_NAME/1.0"),
58+
_ => request,
59+
};
60+
61+
// Add auth if necessary
62+
let request = match self.prompt.api {
63+
Api::Openai | Api::Mistral | Api::Groq | Api::Cerebras => request.header(
64+
"Authorization",
65+
&format!("Bearer {}", &self.api_config.get_api_key()),
66+
),
67+
Api::AzureOpenai => request.header("api-key", &self.api_config.get_api_key()),
68+
Api::Anthropic => request
69+
.header("x-api-key", &self.api_config.get_api_key())
70+
.header(
71+
"anthropic-version",
72+
self.api_config.version.as_ref().expect(
73+
"version required for Anthropic, please add version key to your api config",
74+
),
75+
),
76+
_ => request,
77+
};
78+
79+
let response_text: String = match self.prompt.api {
80+
Api::Ollama => handle_api_response::<OllamaResponse>(
81+
request
82+
.send()
83+
.map_err(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))?,
84+
),
85+
Api::Openai | Api::AzureOpenai | Api::Mistral | Api::Groq | Api::Cerebras => {
86+
handle_api_response::<OpenAiResponse>(
87+
request
88+
.send()
89+
.map_err(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))?,
90+
)
91+
}
92+
Api::Anthropic => handle_api_response::<AnthropicResponse>(
93+
request
94+
.send()
95+
.map_err(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))?,
96+
),
97+
Api::AWSBedrock | Api::AnotherApiForTests => unreachable!(),
98+
};
99+
100+
Ok(Message::assistant(&response_text))
101+
}
102+
}

src/config/api.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::io::Write;
77
use std::path::PathBuf;
88
use std::str::FromStr;
99

10-
use super::{prompt::Prompt, resolve_config_path};
10+
use super::{prompt::Message, prompt::Prompt, resolve_config_path};
1111

1212
const API_KEYS_FILE: &str = ".api_configs.toml";
1313

@@ -20,6 +20,7 @@ pub enum Api {
2020
Groq,
2121
Mistral,
2222
Openai,
23+
AWSBedrock,
2324
AzureOpenai,
2425
Cerebras,
2526
}
@@ -46,6 +47,7 @@ impl ToString for Api {
4647
match self {
4748
Api::Ollama => "ollama".to_string(),
4849
Api::Openai => "openai".to_string(),
50+
Api::AWSBedrock => "awsbedrock".to_string(),
4951
Api::AzureOpenai => "azureopenai".to_string(),
5052
Api::Mistral => "mistral".to_string(),
5153
Api::Groq => "groq".to_string(),
@@ -138,6 +140,17 @@ impl ApiConfig {
138140
}
139141
}
140142

143+
pub(super) fn awsbedrock() -> Self {
144+
ApiConfig {
145+
api_key_command: None,
146+
api_key: None,
147+
url: String::from(""),
148+
default_model: Some(String::from("us.anthropic.claude-3-7-sonnet-20250219-v1:0")),
149+
version: None,
150+
timeout_seconds: None,
151+
}
152+
}
153+
141154
pub(super) fn azureopenai() -> Self {
142155
ApiConfig {
143156
api_key_command: None,
@@ -202,6 +215,7 @@ pub(super) fn generate_api_keys_file() -> std::io::Result<()> {
202215
let mut api_config = HashMap::new();
203216
api_config.insert(Api::Ollama.to_string(), ApiConfig::ollama());
204217
api_config.insert(Api::Openai.to_string(), ApiConfig::openai());
218+
api_config.insert(Api::AWSBedrock.to_string(), ApiConfig::awsbedrock());
205219
api_config.insert(Api::AzureOpenai.to_string(), ApiConfig::azureopenai());
206220
api_config.insert(Api::Mistral.to_string(), ApiConfig::mistral());
207221
api_config.insert(Api::Groq.to_string(), ApiConfig::groq());
@@ -244,3 +258,27 @@ pub fn get_api_config(api: &str) -> ApiConfig {
244258
)
245259
})
246260
}
261+
262+
pub trait ApiClient {
263+
fn do_request(&self) -> Result<Message, ApiError>;
264+
}
265+
266+
#[derive(Debug)]
267+
pub struct ApiError {
268+
pub model: Option<String>,
269+
pub error: String,
270+
}
271+
272+
impl std::fmt::Display for ApiError {
273+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274+
write!(f, "Can't invoke '{:?}'. Reason: {}", self.model, self.error)
275+
}
276+
}
277+
278+
impl std::error::Error for ApiError {}
279+
280+
impl ApiError {
281+
pub fn new(model: Option<String>, error: String) -> Self {
282+
ApiError { model, error }
283+
}
284+
}

src/config/prompt.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::io::Write;
77
use std::path::PathBuf;
88

99
use crate::config::{api::Api, resolve_config_path};
10+
use aws_sdk_bedrockruntime::types::ConversationRole;
1011

1112
const PROMPT_FILE: &str = "prompts.toml";
1213
const CONVERSATION_FILE: &str = "conversation.toml";
@@ -92,6 +93,23 @@ impl Message {
9293
}
9394
}
9495

96+
impl Into<aws_sdk_bedrockruntime::types::Message> for Message {
97+
fn into(self) -> aws_sdk_bedrockruntime::types::Message {
98+
let role = match self.role.as_str() {
99+
"assistant" => ConversationRole::Assistant,
100+
"user" => ConversationRole::User,
101+
_ => panic!("system role not supported for bedrock messages"),
102+
};
103+
aws_sdk_bedrockruntime::types::Message::builder()
104+
.role(role)
105+
.content(aws_sdk_bedrockruntime::types::ContentBlock::Text(
106+
self.content,
107+
))
108+
.build()
109+
.unwrap()
110+
}
111+
}
112+
95113
pub(super) fn prompts_path() -> PathBuf {
96114
resolve_config_path().join(PROMPT_FILE)
97115
}

src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod client;
12
mod config;
23
mod prompt_customization;
34
mod text;

0 commit comments

Comments
 (0)