style(mcp): apply rustfmt formatting
Some checks failed
PR Review Automation / review (pull_request) Has been cancelled
Test / rust-fmt-check (pull_request) Successful in 1m46s
Test / frontend-typecheck (pull_request) Successful in 1m39s
Test / frontend-tests (pull_request) Successful in 1m39s
Test / rust-clippy (pull_request) Successful in 3m26s
Test / rust-tests (pull_request) Successful in 4m54s

This commit is contained in:
Shaun Arman 2026-05-23 16:48:26 -05:00
parent 3588399dfd
commit a779756e48
10 changed files with 84 additions and 64 deletions

View File

@ -880,9 +880,7 @@ async fn execute_tool_call(
) )
.await .await
} }
name if name.starts_with("mcp_") => { name if name.starts_with("mcp_") => execute_mcp_tool_call(tool_call, app_state).await,
execute_mcp_tool_call(tool_call, app_state).await
}
_ => { _ => {
let error = format!("Unknown tool: {}", tool_call.name); let error = format!("Unknown tool: {}", tool_call.name);
tracing::warn!("{}", error); tracing::warn!("{}", error);

View File

@ -851,11 +851,24 @@ mod tests {
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.unwrap(); .unwrap();
for col in &[ for col in &[
"id", "name", "url", "transport_type", "transport_config", "id",
"auth_type", "auth_value", "enabled", "last_discovered_at", "name",
"discovery_status", "discovery_error", "created_at", "updated_at", "url",
"transport_type",
"transport_config",
"auth_type",
"auth_value",
"enabled",
"last_discovered_at",
"discovery_status",
"discovery_error",
"created_at",
"updated_at",
] { ] {
assert!(cols.contains(&col.to_string()), "mcp_servers missing column {col}"); assert!(
cols.contains(&col.to_string()),
"mcp_servers missing column {col}"
);
} }
let mut stmt = conn.prepare("PRAGMA table_info(mcp_tools)").unwrap(); let mut stmt = conn.prepare("PRAGMA table_info(mcp_tools)").unwrap();
@ -864,8 +877,18 @@ mod tests {
.unwrap() .unwrap()
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.unwrap(); .unwrap();
for col in &["id", "server_id", "name", "tool_key", "description", "parameters"] { for col in &[
assert!(cols.contains(&col.to_string()), "mcp_tools missing column {col}"); "id",
"server_id",
"name",
"tool_key",
"description",
"parameters",
] {
assert!(
cols.contains(&col.to_string()),
"mcp_tools missing column {col}"
);
} }
let mut stmt = conn.prepare("PRAGMA table_info(mcp_resources)").unwrap(); let mut stmt = conn.prepare("PRAGMA table_info(mcp_resources)").unwrap();
@ -875,7 +898,10 @@ mod tests {
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.unwrap(); .unwrap();
for col in &["id", "server_id", "uri", "name", "description"] { for col in &["id", "server_id", "uri", "name", "description"] {
assert!(cols.contains(&col.to_string()), "mcp_resources missing column {col}"); assert!(
cols.contains(&col.to_string()),
"mcp_resources missing column {col}"
);
} }
} }

View File

@ -33,7 +33,11 @@ pub fn sanitize_name(s: &str) -> String {
/// Build a unique, AI-safe tool key: `mcp_{server_name}_{tool_name}`. /// Build a unique, AI-safe tool key: `mcp_{server_name}_{tool_name}`.
pub fn build_tool_key(server_name: &str, tool_name: &str) -> String { pub fn build_tool_key(server_name: &str, tool_name: &str) -> String {
format!("mcp_{}_{}", sanitize_name(server_name), sanitize_name(tool_name)) format!(
"mcp_{}_{}",
sanitize_name(server_name),
sanitize_name(tool_name)
)
} }
/// Convert stored McpTool records into AI Tool definitions. /// Convert stored McpTool records into AI Tool definitions.
@ -106,9 +110,7 @@ fn parse_parameters(schema_json: &str) -> ToolParameters {
} }
/// Async wrapper — fetch enabled MCP tools from state and convert to AI tools. /// Async wrapper — fetch enabled MCP tools from state and convert to AI tools.
pub async fn get_enabled_mcp_tools( pub async fn get_enabled_mcp_tools(state: &crate::state::AppState) -> Result<Vec<Tool>, String> {
state: &crate::state::AppState,
) -> Result<Vec<Tool>, String> {
let tool_records = { let tool_records = {
let db = state.db.lock().map_err(|e| e.to_string())?; let db = state.db.lock().map_err(|e| e.to_string())?;
crate::mcp::store::get_enabled_tools(&db)? crate::mcp::store::get_enabled_tools(&db)?
@ -155,10 +157,7 @@ mod tests {
build_tool_key("My Weather API", "get_forecast"), build_tool_key("My Weather API", "get_forecast"),
"mcp_my_weather_api_get_forecast" "mcp_my_weather_api_get_forecast"
); );
assert_eq!( assert_eq!(build_tool_key("simple", "ping"), "mcp_simple_ping");
build_tool_key("simple", "ping"),
"mcp_simple_ping"
);
assert_eq!( assert_eq!(
build_tool_key("My Server", "search files"), build_tool_key("My Server", "search files"),
"mcp_my_server_search_files" "mcp_my_server_search_files"

View File

@ -1,5 +1,5 @@
use rmcp::{RoleClient, ServiceExt, service::RunningService};
use rmcp::model::{CallToolRequestParams, Content, RawContent}; use rmcp::model::{CallToolRequestParams, Content, RawContent};
use rmcp::{service::RunningService, RoleClient, ServiceExt};
use serde_json::Map; use serde_json::Map;
use crate::mcp::models::{McpResource, McpTool}; use crate::mcp::models::{McpResource, McpTool};
@ -8,10 +8,7 @@ use crate::mcp::models::{McpResource, McpTool};
pub type McpConnection = RunningService<RoleClient, ()>; pub type McpConnection = RunningService<RoleClient, ()>;
/// Connect to a stdio MCP server. /// Connect to a stdio MCP server.
pub async fn connect_stdio( pub async fn connect_stdio(command: &str, args: &[String]) -> Result<McpConnection, String> {
command: &str,
args: &[String],
) -> Result<McpConnection, String> {
let transport = crate::mcp::transport::stdio::build_stdio_transport(command, args)?; let transport = crate::mcp::transport::stdio::build_stdio_transport(command, args)?;
().serve(transport) ().serve(transport)
.await .await
@ -19,10 +16,7 @@ pub async fn connect_stdio(
} }
/// Connect to an HTTP MCP server. /// Connect to an HTTP MCP server.
pub async fn connect_http( pub async fn connect_http(url: &str, auth_header: Option<&str>) -> Result<McpConnection, String> {
url: &str,
auth_header: Option<&str>,
) -> Result<McpConnection, String> {
let transport = crate::mcp::transport::http::build_http_transport(url, auth_header); let transport = crate::mcp::transport::http::build_http_transport(url, auth_header);
().serve(transport) ().serve(transport)
.await .await
@ -44,8 +38,8 @@ pub async fn list_tools(
.into_iter() .into_iter()
.map(|t| { .map(|t| {
let tool_key = crate::mcp::adapter::build_tool_key(server_name, &t.name); let tool_key = crate::mcp::adapter::build_tool_key(server_name, &t.name);
let parameters = serde_json::to_string(&*t.input_schema) let parameters =
.unwrap_or_else(|_| "{}".to_string()); serde_json::to_string(&*t.input_schema).unwrap_or_else(|_| "{}".to_string());
McpTool { McpTool {
id: uuid::Uuid::now_v7().to_string(), id: uuid::Uuid::now_v7().to_string(),
server_id: server_id.to_string(), server_id: server_id.to_string(),
@ -90,13 +84,10 @@ pub async fn call_tool(
tool_name: &str, tool_name: &str,
arguments: &serde_json::Value, arguments: &serde_json::Value,
) -> Result<String, String> { ) -> Result<String, String> {
let args: Option<Map<String, serde_json::Value>> = arguments let args: Option<Map<String, serde_json::Value>> = arguments.as_object().cloned();
.as_object()
.cloned();
let params = match args { let params = match args {
Some(map) => CallToolRequestParams::new(tool_name.to_string()) Some(map) => CallToolRequestParams::new(tool_name.to_string()).with_arguments(map),
.with_arguments(map),
None => CallToolRequestParams::new(tool_name.to_string()), None => CallToolRequestParams::new(tool_name.to_string()),
}; };

View File

@ -13,9 +13,7 @@ use crate::mcp::store::{
use crate::state::AppState; use crate::state::AppState;
#[tauri::command] #[tauri::command]
pub async fn list_mcp_servers( pub async fn list_mcp_servers(state: State<'_, AppState>) -> Result<Vec<McpServer>, String> {
state: State<'_, AppState>,
) -> Result<Vec<McpServer>, String> {
let db = state.db.lock().map_err(|e| e.to_string())?; let db = state.db.lock().map_err(|e| e.to_string())?;
let mut servers = list_servers(&db)?; let mut servers = list_servers(&db)?;
// Never expose encrypted auth values to the frontend // Never expose encrypted auth values to the frontend
@ -218,8 +216,7 @@ pub async fn initiate_mcp_oauth(
&app_handle, &app_handle,
&window_label, &window_label,
tauri::WebviewUrl::External( tauri::WebviewUrl::External(
url::Url::parse(&auth_url) url::Url::parse(&auth_url).map_err(|e| format!("Invalid OAuth URL: {e}"))?,
.map_err(|e| format!("Invalid OAuth URL: {e}"))?,
), ),
) )
.title(format!("Authenticate: {}", server.name)) .title(format!("Authenticate: {}", server.name))

View File

@ -2,7 +2,7 @@ use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex; use tokio::sync::Mutex as TokioMutex;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::mcp::client::{McpConnection, connect_http, connect_stdio, list_resources, list_tools}; use crate::mcp::client::{connect_http, connect_stdio, list_resources, list_tools, McpConnection};
use crate::mcp::models::McpServer; use crate::mcp::models::McpServer;
use crate::mcp::store::{ use crate::mcp::store::{
get_server_auth_value, list_servers, replace_resources, replace_tools, update_discovery_status, get_server_auth_value, list_servers, replace_resources, replace_tools, update_discovery_status,
@ -82,7 +82,10 @@ pub async fn init_all_servers(app_handle: &tauri::AppHandle) -> Result<(), Strin
let servers: Vec<McpServer> = { let servers: Vec<McpServer> = {
let db = state.db.lock().map_err(|e| e.to_string())?; let db = state.db.lock().map_err(|e| e.to_string())?;
list_servers(&db)?.into_iter().filter(|s| s.enabled).collect() list_servers(&db)?
.into_iter()
.filter(|s| s.enabled)
.collect()
}; };
for server in servers { for server in servers {

View File

@ -6,10 +6,7 @@ use crate::mcp::models::{
CreateMcpServerRequest, McpResource, McpServer, McpTool, UpdateMcpServerRequest, CreateMcpServerRequest, McpResource, McpServer, McpTool, UpdateMcpServerRequest,
}; };
pub fn create_server( pub fn create_server(conn: &Connection, req: &CreateMcpServerRequest) -> Result<McpServer, String> {
conn: &Connection,
req: &CreateMcpServerRequest,
) -> Result<McpServer, String> {
let id = Uuid::now_v7().to_string(); let id = Uuid::now_v7().to_string();
let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
@ -110,8 +107,7 @@ pub fn update_server(
id: &str, id: &str,
req: &UpdateMcpServerRequest, req: &UpdateMcpServerRequest,
) -> Result<McpServer, String> { ) -> Result<McpServer, String> {
let existing = get_server(conn, id)? let existing = get_server(conn, id)?.ok_or_else(|| format!("Server {id} not found"))?;
.ok_or_else(|| format!("Server {id} not found"))?;
let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
let new_encrypted_auth = match &req.auth_value { let new_encrypted_auth = match &req.auth_value {
@ -128,11 +124,17 @@ pub fn update_server(
rusqlite::params![ rusqlite::params![
req.name.as_deref().unwrap_or(&existing.name), req.name.as_deref().unwrap_or(&existing.name),
req.url.as_deref().unwrap_or(&existing.url), req.url.as_deref().unwrap_or(&existing.url),
req.transport_type.as_deref().unwrap_or(&existing.transport_type), req.transport_type
req.transport_config.as_deref().unwrap_or(&existing.transport_config), .as_deref()
.unwrap_or(&existing.transport_type),
req.transport_config
.as_deref()
.unwrap_or(&existing.transport_config),
req.auth_type.as_deref().unwrap_or(&existing.auth_type), req.auth_type.as_deref().unwrap_or(&existing.auth_type),
new_encrypted_auth, new_encrypted_auth,
req.enabled.map(|b| b as i32).unwrap_or(existing.enabled as i32), req.enabled
.map(|b| b as i32)
.unwrap_or(existing.enabled as i32),
now, now,
id, id,
], ],
@ -201,8 +203,11 @@ pub fn replace_resources(
server_id: &str, server_id: &str,
resources: &[McpResource], resources: &[McpResource],
) -> Result<(), String> { ) -> Result<(), String> {
conn.execute("DELETE FROM mcp_resources WHERE server_id = ?1", [server_id]) conn.execute(
.map_err(|e| e.to_string())?; "DELETE FROM mcp_resources WHERE server_id = ?1",
[server_id],
)
.map_err(|e| e.to_string())?;
for res in resources { for res in resources {
conn.execute( conn.execute(
@ -266,10 +271,7 @@ pub fn get_tool_by_key(conn: &Connection, tool_key: &str) -> Result<Option<McpTo
.map_err(|e| e.to_string()) .map_err(|e| e.to_string())
} }
pub fn get_server_auth_value( pub fn get_server_auth_value(conn: &Connection, server_id: &str) -> Result<Option<String>, String> {
conn: &Connection,
server_id: &str,
) -> Result<Option<String>, String> {
let encrypted: Option<String> = conn let encrypted: Option<String> = conn
.query_row( .query_row(
"SELECT auth_value FROM mcp_servers WHERE id = ?1", "SELECT auth_value FROM mcp_servers WHERE id = ?1",
@ -395,7 +397,10 @@ mod tests {
) )
.unwrap(); .unwrap();
let raw = raw.unwrap(); let raw = raw.unwrap();
assert_ne!(raw, "super-secret-token", "auth_value must be encrypted in DB"); assert_ne!(
raw, "super-secret-token",
"auth_value must be encrypted in DB"
);
// Decrypted value must match original // Decrypted value must match original
let decrypted = get_server_auth_value(&conn, &server.id).unwrap().unwrap(); let decrypted = get_server_auth_value(&conn, &server.id).unwrap().unwrap();
@ -441,7 +446,10 @@ mod tests {
update_discovery_status(&conn, &server.id, "error", Some("connection refused")).unwrap(); update_discovery_status(&conn, &server.id, "error", Some("connection refused")).unwrap();
let errored = get_server(&conn, &server.id).unwrap().unwrap(); let errored = get_server(&conn, &server.id).unwrap().unwrap();
assert_eq!(errored.discovery_status, "error"); assert_eq!(errored.discovery_status, "error");
assert_eq!(errored.discovery_error.as_deref(), Some("connection refused")); assert_eq!(
errored.discovery_error.as_deref(),
Some("connection refused")
);
} }
#[test] #[test]

View File

@ -1,5 +1,5 @@
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::StreamableHttpClientTransport;
use std::sync::Arc; use std::sync::Arc;
/// Build an HTTP (Streamable HTTP) transport from a URL. /// Build an HTTP (Streamable HTTP) transport from a URL.

View File

@ -4,10 +4,7 @@ use tokio::process::Command;
/// Build a stdio transport from a command path and argument list. /// Build a stdio transport from a command path and argument list.
/// Rejects relative paths to prevent path traversal. /// Rejects relative paths to prevent path traversal.
pub fn build_stdio_transport( pub fn build_stdio_transport(command: &str, args: &[String]) -> Result<TokioChildProcess, String> {
command: &str,
args: &[String],
) -> Result<TokioChildProcess, String> {
if !Path::new(command).is_absolute() { if !Path::new(command).is_absolute() {
return Err(format!( return Err(format!(
"stdio command must be an absolute path, got: {command}" "stdio command must be an absolute path, got: {command}"

View File

@ -75,7 +75,8 @@ pub struct AppState {
/// Track open integration webview windows by service name -> window label /// Track open integration webview windows by service name -> window label
pub integration_webviews: Arc<Mutex<HashMap<String, String>>>, pub integration_webviews: Arc<Mutex<HashMap<String, String>>>,
/// Live MCP server connections: server_id -> connection /// Live MCP server connections: server_id -> connection
pub mcp_connections: Arc<TokioMutex<HashMap<String, Arc<TokioMutex<crate::mcp::client::McpConnection>>>>>, pub mcp_connections:
Arc<TokioMutex<HashMap<String, Arc<TokioMutex<crate::mcp::client::McpConnection>>>>>,
} }
/// Determine the application data directory. /// Determine the application data directory.