feat(mcp): add MCP Server Support #53

Merged
sarman merged 4 commits from feature/mcp-server-support into master 2026-05-23 22:15:11 +00:00
10 changed files with 84 additions and 64 deletions
Showing only changes of commit a779756e48 - Show all commits

View File

@ -880,9 +880,7 @@ async fn execute_tool_call(
)
.await
}
name if name.starts_with("mcp_") => {
execute_mcp_tool_call(tool_call, app_state).await
}
name if name.starts_with("mcp_") => execute_mcp_tool_call(tool_call, app_state).await,
_ => {
let error = format!("Unknown tool: {}", tool_call.name);
tracing::warn!("{}", error);

View File

@ -851,11 +851,24 @@ mod tests {
.collect::<Result<Vec<_>, _>>()
.unwrap();
for col in &[
"id", "name", "url", "transport_type", "transport_config",
"auth_type", "auth_value", "enabled", "last_discovered_at",
"discovery_status", "discovery_error", "created_at", "updated_at",
"id",
"name",
"url",
"transport_type",
"transport_config",
"auth_type",
"auth_value",
"enabled",
"last_discovered_at",
"discovery_status",
"discovery_error",
"created_at",
"updated_at",
] {
assert!(cols.contains(&col.to_string()), "mcp_servers missing column {col}");
assert!(
cols.contains(&col.to_string()),
"mcp_servers missing column {col}"
);
}
let mut stmt = conn.prepare("PRAGMA table_info(mcp_tools)").unwrap();
@ -864,8 +877,18 @@ mod tests {
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
for col in &["id", "server_id", "name", "tool_key", "description", "parameters"] {
assert!(cols.contains(&col.to_string()), "mcp_tools missing column {col}");
for col in &[
"id",
"server_id",
"name",
"tool_key",
"description",
"parameters",
] {
assert!(
cols.contains(&col.to_string()),
"mcp_tools missing column {col}"
);
}
let mut stmt = conn.prepare("PRAGMA table_info(mcp_resources)").unwrap();
@ -875,7 +898,10 @@ mod tests {
.collect::<Result<Vec<_>, _>>()
.unwrap();
for col in &["id", "server_id", "uri", "name", "description"] {
assert!(cols.contains(&col.to_string()), "mcp_resources missing column {col}");
assert!(
cols.contains(&col.to_string()),
"mcp_resources missing column {col}"
);
}
}

View File

@ -33,7 +33,11 @@ pub fn sanitize_name(s: &str) -> String {
/// Build a unique, AI-safe tool key: `mcp_{server_name}_{tool_name}`.
pub fn build_tool_key(server_name: &str, tool_name: &str) -> String {
format!("mcp_{}_{}", sanitize_name(server_name), sanitize_name(tool_name))
format!(
"mcp_{}_{}",
sanitize_name(server_name),
sanitize_name(tool_name)
)
}
/// Convert stored McpTool records into AI Tool definitions.
@ -106,9 +110,7 @@ fn parse_parameters(schema_json: &str) -> ToolParameters {
}
/// Async wrapper — fetch enabled MCP tools from state and convert to AI tools.
pub async fn get_enabled_mcp_tools(
state: &crate::state::AppState,
) -> Result<Vec<Tool>, String> {
pub async fn get_enabled_mcp_tools(state: &crate::state::AppState) -> Result<Vec<Tool>, String> {
let tool_records = {
let db = state.db.lock().map_err(|e| e.to_string())?;
crate::mcp::store::get_enabled_tools(&db)?
@ -155,10 +157,7 @@ mod tests {
build_tool_key("My Weather API", "get_forecast"),
"mcp_my_weather_api_get_forecast"
);
assert_eq!(
build_tool_key("simple", "ping"),
"mcp_simple_ping"
);
assert_eq!(build_tool_key("simple", "ping"), "mcp_simple_ping");
assert_eq!(
build_tool_key("My Server", "search files"),
"mcp_my_server_search_files"

View File

@ -1,5 +1,5 @@
use rmcp::{RoleClient, ServiceExt, service::RunningService};
use rmcp::model::{CallToolRequestParams, Content, RawContent};
use rmcp::{service::RunningService, RoleClient, ServiceExt};
use serde_json::Map;
use crate::mcp::models::{McpResource, McpTool};
@ -8,10 +8,7 @@ use crate::mcp::models::{McpResource, McpTool};
pub type McpConnection = RunningService<RoleClient, ()>;
/// Connect to a stdio MCP server.
pub async fn connect_stdio(
command: &str,
args: &[String],
) -> Result<McpConnection, String> {
pub async fn connect_stdio(command: &str, args: &[String]) -> Result<McpConnection, String> {
let transport = crate::mcp::transport::stdio::build_stdio_transport(command, args)?;
().serve(transport)
.await
@ -19,10 +16,7 @@ pub async fn connect_stdio(
}
/// Connect to an HTTP MCP server.
pub async fn connect_http(
url: &str,
auth_header: Option<&str>,
) -> Result<McpConnection, String> {
pub async fn connect_http(url: &str, auth_header: Option<&str>) -> Result<McpConnection, String> {
let transport = crate::mcp::transport::http::build_http_transport(url, auth_header);
().serve(transport)
.await
@ -44,8 +38,8 @@ pub async fn list_tools(
.into_iter()
.map(|t| {
let tool_key = crate::mcp::adapter::build_tool_key(server_name, &t.name);
let parameters = serde_json::to_string(&*t.input_schema)
.unwrap_or_else(|_| "{}".to_string());
let parameters =
serde_json::to_string(&*t.input_schema).unwrap_or_else(|_| "{}".to_string());
McpTool {
id: uuid::Uuid::now_v7().to_string(),
server_id: server_id.to_string(),
@ -90,13 +84,10 @@ pub async fn call_tool(
tool_name: &str,
arguments: &serde_json::Value,
) -> Result<String, String> {
let args: Option<Map<String, serde_json::Value>> = arguments
.as_object()
.cloned();
let args: Option<Map<String, serde_json::Value>> = arguments.as_object().cloned();
let params = match args {
Some(map) => CallToolRequestParams::new(tool_name.to_string())
.with_arguments(map),
Some(map) => CallToolRequestParams::new(tool_name.to_string()).with_arguments(map),
None => CallToolRequestParams::new(tool_name.to_string()),
};

View File

@ -13,9 +13,7 @@ use crate::mcp::store::{
use crate::state::AppState;
#[tauri::command]
pub async fn list_mcp_servers(
state: State<'_, AppState>,
) -> Result<Vec<McpServer>, String> {
pub async fn list_mcp_servers(state: State<'_, AppState>) -> Result<Vec<McpServer>, String> {
let db = state.db.lock().map_err(|e| e.to_string())?;
let mut servers = list_servers(&db)?;
// Never expose encrypted auth values to the frontend
@ -218,8 +216,7 @@ pub async fn initiate_mcp_oauth(
&app_handle,
&window_label,
tauri::WebviewUrl::External(
url::Url::parse(&auth_url)
.map_err(|e| format!("Invalid OAuth URL: {e}"))?,
url::Url::parse(&auth_url).map_err(|e| format!("Invalid OAuth URL: {e}"))?,
),
)
.title(format!("Authenticate: {}", server.name))

View File

@ -2,7 +2,7 @@ use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
use tracing::{info, warn};
use crate::mcp::client::{McpConnection, connect_http, connect_stdio, list_resources, list_tools};
use crate::mcp::client::{connect_http, connect_stdio, list_resources, list_tools, McpConnection};
use crate::mcp::models::McpServer;
use crate::mcp::store::{
get_server_auth_value, list_servers, replace_resources, replace_tools, update_discovery_status,
@ -82,7 +82,10 @@ pub async fn init_all_servers(app_handle: &tauri::AppHandle) -> Result<(), Strin
let servers: Vec<McpServer> = {
let db = state.db.lock().map_err(|e| e.to_string())?;
list_servers(&db)?.into_iter().filter(|s| s.enabled).collect()
list_servers(&db)?
.into_iter()
.filter(|s| s.enabled)
.collect()
};
for server in servers {

View File

@ -6,10 +6,7 @@ use crate::mcp::models::{
CreateMcpServerRequest, McpResource, McpServer, McpTool, UpdateMcpServerRequest,
};
pub fn create_server(
conn: &Connection,
req: &CreateMcpServerRequest,
) -> Result<McpServer, String> {
pub fn create_server(conn: &Connection, req: &CreateMcpServerRequest) -> Result<McpServer, String> {
let id = Uuid::now_v7().to_string();
let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
@ -110,8 +107,7 @@ pub fn update_server(
id: &str,
req: &UpdateMcpServerRequest,
) -> Result<McpServer, String> {
let existing = get_server(conn, id)?
.ok_or_else(|| format!("Server {id} not found"))?;
let existing = get_server(conn, id)?.ok_or_else(|| format!("Server {id} not found"))?;
let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
let new_encrypted_auth = match &req.auth_value {
@ -128,11 +124,17 @@ pub fn update_server(
rusqlite::params![
req.name.as_deref().unwrap_or(&existing.name),
req.url.as_deref().unwrap_or(&existing.url),
req.transport_type.as_deref().unwrap_or(&existing.transport_type),
req.transport_config.as_deref().unwrap_or(&existing.transport_config),
req.transport_type
.as_deref()
.unwrap_or(&existing.transport_type),
req.transport_config
.as_deref()
.unwrap_or(&existing.transport_config),
req.auth_type.as_deref().unwrap_or(&existing.auth_type),
new_encrypted_auth,
req.enabled.map(|b| b as i32).unwrap_or(existing.enabled as i32),
req.enabled
.map(|b| b as i32)
.unwrap_or(existing.enabled as i32),
now,
id,
],
@ -201,7 +203,10 @@ pub fn replace_resources(
server_id: &str,
resources: &[McpResource],
) -> Result<(), String> {
conn.execute("DELETE FROM mcp_resources WHERE server_id = ?1", [server_id])
conn.execute(
"DELETE FROM mcp_resources WHERE server_id = ?1",
[server_id],
)
.map_err(|e| e.to_string())?;
for res in resources {
@ -266,10 +271,7 @@ pub fn get_tool_by_key(conn: &Connection, tool_key: &str) -> Result<Option<McpTo
.map_err(|e| e.to_string())
}
pub fn get_server_auth_value(
conn: &Connection,
server_id: &str,
) -> Result<Option<String>, String> {
pub fn get_server_auth_value(conn: &Connection, server_id: &str) -> Result<Option<String>, String> {
let encrypted: Option<String> = conn
.query_row(
"SELECT auth_value FROM mcp_servers WHERE id = ?1",
@ -395,7 +397,10 @@ mod tests {
)
.unwrap();
let raw = raw.unwrap();
assert_ne!(raw, "super-secret-token", "auth_value must be encrypted in DB");
assert_ne!(
raw, "super-secret-token",
"auth_value must be encrypted in DB"
);
// Decrypted value must match original
let decrypted = get_server_auth_value(&conn, &server.id).unwrap().unwrap();
@ -441,7 +446,10 @@ mod tests {
update_discovery_status(&conn, &server.id, "error", Some("connection refused")).unwrap();
let errored = get_server(&conn, &server.id).unwrap().unwrap();
assert_eq!(errored.discovery_status, "error");
assert_eq!(errored.discovery_error.as_deref(), Some("connection refused"));
assert_eq!(
errored.discovery_error.as_deref(),
Some("connection refused")
);
}
#[test]

View File

@ -1,5 +1,5 @@
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::StreamableHttpClientTransport;
use std::sync::Arc;
/// Build an HTTP (Streamable HTTP) transport from a URL.

View File

@ -4,10 +4,7 @@ use tokio::process::Command;
/// Build a stdio transport from a command path and argument list.
/// Rejects relative paths to prevent path traversal.
pub fn build_stdio_transport(
command: &str,
args: &[String],
) -> Result<TokioChildProcess, String> {
pub fn build_stdio_transport(command: &str, args: &[String]) -> Result<TokioChildProcess, String> {
if !Path::new(command).is_absolute() {
return Err(format!(
"stdio command must be an absolute path, got: {command}"

View File

@ -75,7 +75,8 @@ pub struct AppState {
/// Track open integration webview windows by service name -> window label
pub integration_webviews: Arc<Mutex<HashMap<String, String>>>,
/// Live MCP server connections: server_id -> connection
pub mcp_connections: Arc<TokioMutex<HashMap<String, Arc<TokioMutex<crate::mcp::client::McpConnection>>>>>,
pub mcp_connections:
Arc<TokioMutex<HashMap<String, Arc<TokioMutex<crate::mcp::client::McpConnection>>>>>,
}
/// Determine the application data directory.