Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .changeset/fix-mcp-server-unwrap.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
"@anthropic/gws": patch
---

Replace unwrap() calls with proper error handling in MCP server

- Replace `.unwrap()` on `get_one("services")` with `.unwrap_or("")`
- Replace `req.get("id").unwrap()` with `if let Some(id)` pattern
- Replace `tools_cache.as_ref().unwrap()` with match expression
- Replace `parts.last().unwrap()` with `.ok_or_else()` returning a proper validation error
- Handle broken stdout pipe by breaking the server loop instead of silently continuing
- Add unit tests for `build_mcp_cli`, `walk_resources`, and `handle_request`
173 changes: 163 additions & 10 deletions src/mcp_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ pub async fn start(args: &[String]) -> Result<(), GwsError> {
tool_mode,
};

let svc_str = matches.get_one::<String>("services").unwrap();
let svc_str = matches
.get_one::<String>("services")
.map(|s| s.as_str())
.unwrap_or("");
if !svc_str.is_empty() {
if svc_str == "all" {
config.services = services::SERVICES
Expand Down Expand Up @@ -121,14 +124,12 @@ pub async fn start(args: &[String]) -> Result<(), GwsError> {

match serde_json::from_str::<Value>(&line) {
Ok(req) => {
let is_notification = req.get("id").is_none();
let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
let params = req.get("params").cloned().unwrap_or_else(|| json!({}));

let result = handle_request(method, &params, &config, &mut tools_cache).await;

if !is_notification {
let id = req.get("id").unwrap();
if let Some(id) = req.get("id") {
let response = match result {
Ok(res) => json!({
"jsonrpc": "2.0",
Expand All @@ -153,8 +154,12 @@ pub async fn start(args: &[String]) -> Result<(), GwsError> {
}
};
out.push('\n');
let _ = stdout.write_all(out.as_bytes()).await;
let _ = stdout.flush().await;
if stdout.write_all(out.as_bytes()).await.is_err()
|| stdout.flush().await.is_err()
{
eprintln!("[gws mcp] Stdout pipe broken, shutting down.");
break;
}
}
}
Err(_) => {
Expand All @@ -174,8 +179,12 @@ pub async fn start(args: &[String]) -> Result<(), GwsError> {
}
};
out.push('\n');
let _ = stdout.write_all(out.as_bytes()).await;
let _ = stdout.flush().await;
if stdout.write_all(out.as_bytes()).await.is_err()
|| stdout.flush().await.is_err()
{
eprintln!("[gws mcp] Stdout pipe broken, shutting down.");
break;
}
}
}
}
Expand Down Expand Up @@ -208,8 +217,12 @@ async fn handle_request(
if tools_cache.is_none() {
*tools_cache = Some(build_tools_list(config).await?);
}
let tools = match tools_cache.as_ref() {
Some(t) => json!(t),
None => unreachable!("tools_cache is guaranteed to be Some by the preceding check"),
};
Ok(json!({
"tools": tools_cache.as_ref().unwrap()
"tools": tools
}))
}
"tools/call" => {
Expand Down Expand Up @@ -733,7 +746,9 @@ async fn handle_tools_call(params: &Value, config: &ServerConfig) -> Result<Valu
}
}

let method_name = parts.last().unwrap();
let method_name = parts
.last()
.expect("split() on a &str always yields at least one element");
Comment on lines +749 to +751
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While parts.last() is not expected to be None here due to the parts.len() < 3 check, using expect() still introduces a potential panic. The pull request description mentions replacing this with .ok_or_else(), which would be a more robust approach that aligns better with the goal of removing panics. Returning a GwsError::Validation would be safer if the invariant ever changes.

Suggested change
let method_name = parts
.last()
.expect("split() on a &str always yields at least one element");
let method_name = parts
.last()
.ok_or_else(|| GwsError::Validation("Invalid tool name format".to_string()))?;

let method = if let Some(res) = current_res {
res.methods
.get(*method_name)
Expand Down Expand Up @@ -1121,4 +1136,142 @@ mod tests {
assert_eq!(tools[0]["name"], "workflow_standup_report");
assert_eq!(tools[4]["name"], "workflow_file_announce");
}

// -- walk_resources tests --

#[test]
fn test_walk_resources_simple() {
let doc = mock_doc();
let mut tools = Vec::new();
walk_resources("drive", &doc.resources, &mut tools);
assert!(tools.len() >= 2); // list + get
let names: Vec<&str> = tools.iter().filter_map(|t| t["name"].as_str()).collect();
assert!(names.contains(&"drive_files_list"));
assert!(names.contains(&"drive_files_get"));
}

#[test]
fn test_walk_resources_nested() {
let doc = mock_nested_doc();
let mut tools = Vec::new();
walk_resources("gmail", &doc.resources, &mut tools);
let names: Vec<&str> = tools.iter().filter_map(|t| t["name"].as_str()).collect();
assert!(names.contains(&"gmail_users_messages_list"));
assert!(names.contains(&"gmail_users_messages_get"));
assert!(names.contains(&"gmail_users_getProfile"));
}

#[test]
fn test_walk_resources_empty_description_fallback() {
let mut methods = HashMap::new();
methods.insert(
"delete".to_string(),
RestMethod {
description: None,
..Default::default()
},
);
let mut resources = HashMap::new();
resources.insert(
"items".to_string(),
RestResource {
methods,
resources: HashMap::new(),
},
);

let mut tools = Vec::new();
walk_resources("tasks", &resources, &mut tools);
assert_eq!(tools.len(), 1);
let desc = tools[0]["description"].as_str().unwrap();
assert!(desc.contains("tasks_items_delete"));
}

#[test]
fn test_walk_resources_empty() {
let resources = HashMap::new();
let mut tools = Vec::new();
walk_resources("empty", &resources, &mut tools);
assert!(tools.is_empty());
}

// -- handle_request tests --

#[tokio::test]
async fn test_handle_request_initialize() {
let config = ServerConfig {
services: vec![],
workflows: false,
_helpers: false,
tool_mode: ToolMode::Full,
};
let mut cache = None;
let result = handle_request("initialize", &json!({}), &config, &mut cache)
.await
.unwrap();

assert_eq!(result["protocolVersion"], "2024-11-05");
assert_eq!(result["serverInfo"]["name"], "gws-mcp");
assert!(result["capabilities"]["tools"].is_object());
}

#[tokio::test]
async fn test_handle_request_unsupported_method() {
let config = ServerConfig {
services: vec![],
workflows: false,
_helpers: false,
tool_mode: ToolMode::Full,
};
let mut cache = None;
let result = handle_request("foo/bar", &json!({}), &config, &mut cache).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Method not supported"));
}

#[tokio::test]
async fn test_handle_request_notifications_initialized() {
let config = ServerConfig {
services: vec![],
workflows: false,
_helpers: false,
tool_mode: ToolMode::Full,
};
let mut cache = None;
let result = handle_request("notifications/initialized", &json!({}), &config, &mut cache)
.await
.unwrap();
assert_eq!(result, json!({}));
}

// -- build_mcp_cli tests --

#[test]
fn test_build_mcp_cli_defaults() {
let cmd = build_mcp_cli();
let matches = cmd.get_matches_from(vec!["mcp"]);
let svc = matches
.get_one::<String>("services")
.map(|s| s.as_str())
.unwrap_or("");
assert_eq!(svc, "");
assert!(!matches.get_flag("workflows"));
assert!(!matches.get_flag("helpers"));
}

#[test]
fn test_build_mcp_cli_with_services() {
let cmd = build_mcp_cli();
let matches = cmd.get_matches_from(vec!["mcp", "-s", "drive,gmail"]);
let svc = matches.get_one::<String>("services").unwrap();
assert_eq!(svc, "drive,gmail");
}

#[test]
fn test_build_mcp_cli_with_flags() {
let cmd = build_mcp_cli();
let matches = cmd.get_matches_from(vec!["mcp", "--workflows", "--helpers"]);
assert!(matches.get_flag("workflows"));
assert!(matches.get_flag("helpers"));
}
}