Skip to content

Commit 7286c61

Browse files
authored
Support chat completion (#83)
* Support completion Signed-off-by: kerthcet <kerthcet@gmail.com> * Replace completion with chat completion Signed-off-by: kerthcet <kerthcet@gmail.com> * update Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 5335360 commit 7286c61

15 files changed

Lines changed: 213 additions & 96 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version = "0.1.0"
44
edition = "2024"
55

66
[dependencies]
7-
async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses",] }
7+
async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "chat-completion"] }
88
async-trait = "0.1.89"
99
derive_builder = "0.20.2"
1010
dotenvy = "0.15.7"

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode:
2727
// Before running the code, make sure to set your OpenAI API key in the environment variable:
2828
// export OPENAI_API_KEY="your_openai_api_key"
2929

30+
use tokio::runtime::Runtime;
3031
use arms::client;
31-
use arms::types::responses;
32+
use arms::types::chat;
3233

3334
let config = client::Config::builder()
3435
.provider("openai")
@@ -51,12 +52,15 @@ let config = client::Config::builder()
5152
.unwrap();
5253

5354
let mut client = client::Client::new(config);
54-
let request = responses::CreateResponseArgs::default()
55-
.input("give me a poem about nature")
55+
let request = chat::CreateChatCompletionRequestArgs::default()
56+
.messages([
57+
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
58+
chat::ChatCompletionRequestUserMessage::from("How is the weather today?").into(),
59+
])
5660
.build()
5761
.unwrap();
5862

59-
let response = client.create_response(request).await.unwrap();
63+
let result = Runtime::new().unwrap().block_on(client.create_completion(request));
6064
```
6165

6266
## Contributing

src/client/client.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::client::config::{Config, ModelName};
44
use crate::provider::provider;
55
use crate::router::router;
66
use crate::types::error::OpenAIError;
7-
use crate::types::responses::{CreateResponse, Response};
7+
use crate::types::{chat, responses};
88

99
pub struct Client {
1010
providers: HashMap<ModelName, Box<dyn provider::Provider>>,
@@ -30,12 +30,22 @@ impl Client {
3030

3131
pub async fn create_response(
3232
&mut self,
33-
request: CreateResponse,
34-
) -> Result<Response, OpenAIError> {
35-
let candidate = self.router.sample(&request);
33+
request: responses::CreateResponse,
34+
) -> Result<responses::Response, OpenAIError> {
35+
let candidate = self.router.sample();
3636
let provider = self.providers.get(&candidate).unwrap();
3737
provider.create_response(request).await
3838
}
39+
40+
// This is chat completion endpoint.
41+
pub async fn create_completion(
42+
&mut self,
43+
request: chat::CreateChatCompletionRequest,
44+
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
45+
let candidate = self.router.sample();
46+
let provider = self.providers.get(&candidate).unwrap();
47+
provider.create_completion(request).await
48+
}
3949
}
4050

4151
#[cfg(test)]

src/client/config.rs

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ pub struct ModelConfig {
3838
pub(crate) base_url: Option<String>,
3939
#[builder(default = "None", setter(custom))]
4040
pub(crate) provider: Option<String>,
41-
#[builder(default = "None")]
42-
pub(crate) temperature: Option<f32>,
43-
#[builder(default = "None")]
44-
pub(crate) max_output_tokens: Option<usize>,
4541

4642
#[builder(setter(custom))]
4743
pub(crate) name: ModelName,
@@ -86,10 +82,6 @@ pub struct Config {
8682
pub(crate) base_url: Option<String>,
8783
#[builder(default = "DEFAULT_PROVIDER.to_string()", setter(custom))]
8884
pub(crate) provider: String,
89-
#[builder(default = "0.8")]
90-
pub(crate) temperature: f32,
91-
#[builder(default = "1024")]
92-
pub(crate) max_output_tokens: usize,
9385

9486
#[builder(default = "RoutingMode::Random")]
9587
pub(crate) routing_mode: RoutingMode,
@@ -131,13 +123,6 @@ impl Config {
131123
if model.provider.is_none() {
132124
model.provider = Some(self.provider.clone());
133125
}
134-
135-
if model.temperature.is_none() {
136-
model.temperature = Some(self.temperature);
137-
}
138-
if model.max_output_tokens.is_none() {
139-
model.max_output_tokens = Some(self.max_output_tokens);
140-
}
141126
}
142127
self
143128
}
@@ -176,24 +161,6 @@ impl ConfigBuilder {
176161
));
177162
}
178163

179-
if let Some(max_output_tokens) = model.max_output_tokens {
180-
if max_output_tokens <= 0 {
181-
return Err(format!(
182-
"Model '{}' max_output_tokens must be positive.",
183-
model.name
184-
));
185-
}
186-
}
187-
188-
if let Some(temperature) = model.temperature {
189-
if temperature < 0.0 || temperature > 1.0 {
190-
return Err(format!(
191-
"Model '{}' temperature must be between 0.0 and 1.0.",
192-
model.name
193-
));
194-
}
195-
}
196-
197164
// check the existence of API key in environment variables
198165
if let Some(provider) = &model.provider {
199166
let env_var = format!("{}_API_KEY", provider);
@@ -251,20 +218,10 @@ mod tests {
251218
assert!(valid_simplest_models_cfg.is_ok());
252219
assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER);
253220
assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None);
254-
assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8);
255-
assert!(
256-
valid_simplest_models_cfg
257-
.as_ref()
258-
.unwrap()
259-
.max_output_tokens
260-
== 1024
261-
);
262221
assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RoutingMode::Random);
263222
assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1);
264223
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None);
265224
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None);
266-
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].temperature == None);
267-
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].max_output_tokens == None);
268225
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].weight == -1);
269226

270227
// case 2:
@@ -299,7 +256,6 @@ mod tests {
299256
// AMRS_API_KEY is set in .env.test already.
300257
let valid_cfg_with_customized_provider = Config::builder()
301258
.base_url("http://example.ai")
302-
.max_output_tokens(2048)
303259
.model(
304260
ModelConfig::builder()
305261
.name("custom-model")
@@ -325,8 +281,6 @@ mod tests {
325281
from_filename(".env.test").ok();
326282

327283
let mut valid_cfg = Config::builder()
328-
.temperature(0.5)
329-
.max_output_tokens(1500)
330284
.model(
331285
ModelConfig::builder()
332286
.name("model-1".to_string())
@@ -338,8 +292,6 @@ mod tests {
338292

339293
assert!(valid_cfg.is_ok());
340294
assert!(valid_cfg.as_ref().unwrap().models.len() == 1);
341-
assert!(valid_cfg.as_ref().unwrap().models[0].temperature == Some(0.5));
342-
assert!(valid_cfg.as_ref().unwrap().models[0].max_output_tokens == Some(1500));
343295
assert!(valid_cfg.as_ref().unwrap().models[0].provider == Some("OPENAI".to_string()));
344296
assert!(
345297
valid_cfg.as_ref().unwrap().models[0].base_url

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ mod router {
88
}
99

1010
mod provider {
11+
mod common;
1112
mod faker;
1213
mod openai;
1314
pub mod provider;
1415
}
1516
pub mod types {
17+
pub mod chat;
1618
pub mod error;
1719
pub mod responses;
1820
}

src/main.rs

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use tokio::runtime::Runtime;
22

33
use arms::client;
4-
use arms::types::responses;
4+
use arms::types::chat;
55

66
fn main() {
7+
// case 1: completion with DeepInfra provider.
78
let config = client::Config::builder()
89
.provider("deepinfra")
910
.routing_mode(client::RoutingMode::WRR)
@@ -26,27 +27,77 @@ fn main() {
2627

2728
let mut client = client::Client::new(config);
2829

29-
let request = responses::CreateResponseArgs::default()
30-
.input(responses::InputParam::Items(vec![
31-
responses::InputItem::EasyMessage(responses::EasyInputMessage {
32-
r#type: responses::MessageType::Message,
33-
role: responses::Role::User,
34-
content: responses::EasyInputContent::Text("What is AGI?".to_string()),
35-
}),
36-
]))
30+
let request = chat::CreateChatCompletionRequestArgs::default()
31+
.messages([
32+
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
33+
chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?")
34+
.into(),
35+
chat::ChatCompletionRequestAssistantMessage::from(
36+
"The Los Angeles Dodgers won the World Series in 2020.",
37+
)
38+
.into(),
39+
chat::ChatCompletionRequestUserMessage::from("Where was it played?").into(),
40+
])
3741
.build()
3842
.unwrap();
3943

4044
let result = Runtime::new()
4145
.unwrap()
42-
.block_on(client.create_response(request));
46+
.block_on(client.create_completion(request));
4347

4448
match result {
4549
Ok(response) => {
46-
println!("Response ID: {}", response.id);
50+
println!("Response: {:?}", response);
4751
}
4852
Err(e) => {
4953
eprintln!("Error: {}", e);
5054
}
5155
}
56+
57+
// case 2: response with DeepInfra provider.
58+
// let config = client::Config::builder()
59+
// .provider("deepinfra")
60+
// .routing_mode(client::RoutingMode::WRR)
61+
// .model(
62+
// client::ModelConfig::builder()
63+
// .name("nvidia/Nemotron-3-Nano-30B-A3B")
64+
// .weight(1)
65+
// .build()
66+
// .unwrap(),
67+
// )
68+
// .model(
69+
// client::ModelConfig::builder()
70+
// .name("deepseek-ai/DeepSeek-V3.2")
71+
// .weight(2)
72+
// .build()
73+
// .unwrap(),
74+
// )
75+
// .build()
76+
// .unwrap();
77+
78+
// let mut client = client::Client::new(config);
79+
80+
// let request = responses::CreateResponseArgs::default()
81+
// .input(responses::InputParam::Items(vec![
82+
// responses::InputItem::EasyMessage(responses::EasyInputMessage {
83+
// r#type: responses::MessageType::Message,
84+
// role: responses::Role::User,
85+
// content: responses::EasyInputContent::Text("What is AGI?".to_string()),
86+
// }),
87+
// ]))
88+
// .build()
89+
// .unwrap();
90+
91+
// let result = Runtime::new()
92+
// .unwrap()
93+
// .block_on(client.create_response(request));
94+
95+
// match result {
96+
// Ok(response) => {
97+
// println!("Response ID: {}", response.id);
98+
// }
99+
// Err(e) => {
100+
// eprintln!("Error: {}", e);
101+
// }
102+
// }
52103
}

src/provider/common.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use crate::types::error::OpenAIError;
2+
use crate::types::{chat, responses};
3+
4+
pub fn validate_completion_request(
5+
request: &chat::CreateChatCompletionRequest,
6+
) -> Result<(), OpenAIError> {
7+
if request.model != "" {
8+
return Err(OpenAIError::InvalidArgument(
9+
"Model must be specified in the client.Config".to_string(),
10+
));
11+
}
12+
Ok(())
13+
}
14+
15+
pub fn validate_response_request(request: &responses::CreateResponse) -> Result<(), OpenAIError> {
16+
if request.model.is_some() {
17+
return Err(OpenAIError::InvalidArgument(
18+
"Model must be specified in the client.Config".to_string(),
19+
));
20+
}
21+
Ok(())
22+
}

src/provider/faker.rs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use async_trait::async_trait;
22

33
use crate::client::config::{ModelConfig, ModelName};
4-
use crate::provider::provider;
4+
use crate::provider::{common, provider};
5+
use crate::types::chat;
56
use crate::types::error::OpenAIError;
67
use crate::types::responses::{
78
AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
@@ -26,8 +27,8 @@ impl provider::Provider for FakerProvider {
2627
"FakeProvider"
2728
}
2829

29-
async fn create_response(&self, request: CreateResponse) -> Result<Response, OpenAIError> {
30-
provider::validate_responses_request(&request)?;
30+
async fn create_response(&self, _request: CreateResponse) -> Result<Response, OpenAIError> {
31+
common::validate_response_request(&_request)?;
3132

3233
Ok(Response {
3334
id: "fake-response-id".to_string(),
@@ -71,4 +72,34 @@ impl provider::Provider for FakerProvider {
7172
truncation: None,
7273
})
7374
}
75+
76+
async fn create_completion(
77+
&self,
78+
request: chat::CreateChatCompletionRequest,
79+
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
80+
common::validate_completion_request(&request)?;
81+
Ok(chat::CreateChatCompletionResponse {
82+
id: "fake-completion-id".to_string(),
83+
object: "text_completion".to_string(),
84+
created: 1_600_000_000,
85+
model: self.model.clone(),
86+
usage: None,
87+
service_tier: None,
88+
choices: vec![chat::ChatChoice {
89+
index: 0,
90+
message: chat::ChatCompletionResponseMessage {
91+
role: chat::Role::Assistant,
92+
content: Some("This is a fake chat completion.".to_string()),
93+
refusal: None,
94+
tool_calls: None,
95+
annotations: None,
96+
function_call: None,
97+
audio: None,
98+
},
99+
finish_reason: None,
100+
logprobs: None,
101+
}],
102+
system_fingerprint: None,
103+
})
104+
}
74105
}

0 commit comments

Comments
 (0)