Skip to content

Commit 5335360

Browse files
authored
Polish the API (#82)
* Rename ModelId to ModelName Signed-off-by: kerthcet <kerthcet@gmail.com> * use async-openai fields wrapper Signed-off-by: kerthcet <kerthcet@gmail.com> * polish API Signed-off-by: kerthcet <kerthcet@gmail.com> * update example Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 284a869 commit 5335360

File tree

22 files changed

+335
-196
lines changed

22 files changed

+335
-196
lines changed

.env.integration-test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
AMRS_API_KEY=your_amrs_api_key_here
22
OPENAI_API_KEY=your_openai_api_key_here
3-
FAKE_API_KEY=your_fake_api_key_here
3+
FAKER_API_KEY=your_faker_api_key_here

Cargo.lock

Lines changed: 59 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ lazy_static = "1.5.0"
1212
rand = "0.9.2"
1313
reqwest = "0.12.26"
1414
serde = "1.0.228"
15-
tokio = "1.48.0"
15+
tokio = { version = "1.48.0", features = ["full"] }

README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Thanks to [async-openai](https://github.com/64bit/async-openai), AMRS builds on
1111
- Flexible routing strategies, including:
1212
- **Random**: Randomly selects a model from the available models.
1313
- **WRR**: Weighted Round Robin selects models based on predefined weights.
14-
- **UCB**: Upper Confidence Bound based model selection (coming soon).
14+
- **UCB1**: Upper Confidence Bound based model selection (coming soon).
1515
- **Adaptive**: Dynamically selects models based on performance metrics (coming soon).
1616

1717
- Broad provider support:
@@ -27,30 +27,31 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode:
2727
// Before running the code, make sure to set your OpenAI API key in the environment variable:
2828
// export OPENAI_API_KEY="your_openai_api_key"
2929

30-
use arms::{Client, Config, ModelConfig, CreateResponseArgs, RoutingMode};
30+
use arms::client;
31+
use arms::types::responses;
3132

32-
let config = Config::builder()
33+
let config = client::Config::builder()
3334
.provider("openai")
34-
.routing_mode(RoutingMode::WRR)
35+
.routing_mode(client::RoutingMode::WRR)
3536
.model(
36-
ModelConfig::builder()
37-
.id("gpt-3.5-turbo")
37+
client::ModelConfig::builder()
38+
.name("gpt-3.5-turbo")
3839
.weight(2)
3940
.build()
4041
.unwrap(),
4142
)
4243
.model(
43-
ModelConfig::builder()
44-
.id("gpt-4")
44+
client::ModelConfig::builder()
45+
.name("gpt-4")
4546
.weight(1)
4647
.build()
4748
.unwrap(),
4849
)
4950
.build()
5051
.unwrap();
5152

52-
let mut client = Client::new(config);
53-
let request = CreateResponseArgs::default()
53+
let mut client = client::Client::new(config);
54+
let request = responses::CreateResponseArgs::default()
5455
.input("give me a poem about nature")
5556
.build()
5657
.unwrap();

bindings/python/amrs/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ class BasicModelConfig(BaseModel):
4747
)
4848

4949

50-
type ModelID = str
50+
type ModelName = str
5151

5252
class ModelConfig(BasicModelConfig):
53-
id: ModelID = Field(
53+
id: ModelName = Field(
5454
description="ID of the model to be used."
5555
)
5656
weight: Optional[int] = Field(
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import random
22

3-
from amrs.config import ModelID
3+
from amrs.config import ModelName
44
from amrs.router.router import Router
55

66
class RandomRouter(Router):
7-
def __init__(self, model_list: list[ModelID]):
7+
def __init__(self, model_list: list[ModelName]):
88
super().__init__(model_list)
99

10-
def sample(self, _: str) -> ModelID:
10+
def sample(self, _: str) -> ModelName:
1111
return random.choice(self._model_list)

bindings/python/amrs/router/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ class ModelInfo:
77
average_latency: float = 0.0
88

99
class Router(abc.ABC):
10-
def __init__(self, model_list: list[config.ModelID]):
10+
def __init__(self, model_list: list[config.ModelName]):
1111
self._model_list = model_list
1212

1313
@abc.abstractmethod
14-
def sample(self, content: str) -> config.ModelID:
14+
def sample(self, content: str) -> config.ModelName:
1515
pass
1616

1717
def new_router(model_cfgs: list[config.ModelConfig], mode: config.RoutingMode) -> Router:

src/client/client.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use std::collections::HashMap;
22

3-
use crate::config::{Config, ModelId};
3+
use crate::client::config::{Config, ModelName};
44
use crate::provider::provider;
55
use crate::router::router;
6+
use crate::types::error::OpenAIError;
7+
use crate::types::responses::{CreateResponse, Response};
68

79
pub struct Client {
8-
providers: HashMap<ModelId, Box<dyn provider::Provider>>,
10+
providers: HashMap<ModelName, Box<dyn provider::Provider>>,
911
router: Box<dyn router::Router>,
1012
}
1113

@@ -17,7 +19,7 @@ impl Client {
1719
let providers = cfg
1820
.models
1921
.iter()
20-
.map(|m| (m.id.clone(), provider::construct_provider(m.clone())))
22+
.map(|m| (m.name.clone(), provider::construct_provider(m.clone())))
2123
.collect();
2224

2325
Self {
@@ -28,18 +30,18 @@ impl Client {
2830

2931
pub async fn create_response(
3032
&mut self,
31-
request: provider::CreateResponseReq,
32-
) -> Result<provider::CreateResponseRes, provider::APIError> {
33-
let model_id = self.router.sample(&request);
34-
let provider = self.providers.get(&model_id).unwrap();
33+
request: CreateResponse,
34+
) -> Result<Response, OpenAIError> {
35+
let candidate = self.router.sample(&request);
36+
let provider = self.providers.get(&candidate).unwrap();
3537
provider.create_response(request).await
3638
}
3739
}
3840

3941
#[cfg(test)]
4042
mod tests {
4143
use super::*;
42-
use crate::config::{Config, ModelConfig, RoutingMode};
44+
use crate::client::config::{Config, ModelConfig, RoutingMode};
4345
use dotenvy::from_filename;
4446

4547
#[test]
@@ -58,7 +60,7 @@ mod tests {
5860
config: Config::builder()
5961
.models(vec![
6062
ModelConfig::builder()
61-
.id("model_c".to_string())
63+
.name("model_c".to_string())
6264
.build()
6365
.unwrap(),
6466
])
@@ -71,15 +73,15 @@ mod tests {
7173
config: Config::builder()
7274
.routing_mode(RoutingMode::WRR)
7375
.models(vec![
74-
crate::config::ModelConfig::builder()
75-
.id("model_a".to_string())
76+
crate::client::config::ModelConfig::builder()
77+
.name("model_a".to_string())
7678
.provider(Some("openai".to_string()))
7779
.base_url(Some("https://api.openai.com/v1".to_string()))
7880
.weight(1)
7981
.build()
8082
.unwrap(),
81-
crate::config::ModelConfig::builder()
82-
.id("model_b".to_string())
83+
crate::client::config::ModelConfig::builder()
84+
.name("model_b".to_string())
8385
.provider(Some("openai".to_string()))
8486
.base_url(Some("https://api.openai.com/v1".to_string()))
8587
.weight(3)
@@ -95,13 +97,13 @@ mod tests {
9597
config: Config::builder()
9698
.models(vec![
9799
ModelConfig::builder()
98-
.id("model_a".to_string())
100+
.name("model_a".to_string())
99101
.provider(Some("openai".to_string()))
100102
.base_url(Some("https://api.openai.com/v1".to_string()))
101103
.build()
102104
.unwrap(),
103105
ModelConfig::builder()
104-
.id("model_b".to_string())
106+
.name("model_b".to_string())
105107
.provider(Some("openai".to_string()))
106108
.base_url(Some("https://api.openai.com/v1".to_string()))
107109
.build()

0 commit comments

Comments
 (0)