fix(mcp): parse and merge env vars in discovery layer

- Parse plaintext env from transport_config.env (stdio)
- Decrypt encrypted env from env_config column (stdio)
- Merge env vars with encrypted taking precedence
- Parse custom headers from transport_config.headers (HTTP)
- Update connect_stdio() to accept HashMap<String, String> for env
- Update connect_http() to accept HashMap<String, String> for headers
- Apply env vars to tokio::process::Command in stdio transport
- Log warning for HTTP custom headers (rmcp v1.7.0 limitation)

All 19 MCP tests passing.
This commit is contained in:
Shaun Arman 2026-06-01 08:25:49 -05:00
parent a970f171a8
commit 2ffe2bb6d8
4 changed files with 85 additions and 12 deletions

View File

@ -7,17 +7,26 @@ use crate::mcp::models::{McpResource, McpTool};
/// Live connection to an MCP server. /// Live connection to an MCP server.
pub type McpConnection = RunningService<RoleClient, ()>; pub type McpConnection = RunningService<RoleClient, ()>;
/// Connect to a stdio MCP server. /// Connect to a stdio MCP server with optional environment variables.
pub async fn connect_stdio(command: &str, args: &[String]) -> Result<McpConnection, String> { pub async fn connect_stdio(
let transport = crate::mcp::transport::stdio::build_stdio_transport(command, args)?; command: &str,
args: &[String],
env: std::collections::HashMap<String, String>,
) -> Result<McpConnection, String> {
let transport = crate::mcp::transport::stdio::build_stdio_transport(command, args, env)?;
().serve(transport) ().serve(transport)
.await .await
.map_err(|e| format!("MCP stdio connection failed: {e}")) .map_err(|e| format!("MCP stdio connection failed: {e}"))
} }
/// Connect to an HTTP MCP server. /// Connect to an HTTP MCP server with optional custom headers.
pub async fn connect_http(url: &str, auth_header: Option<&str>) -> Result<McpConnection, String> { pub async fn connect_http(
let transport = crate::mcp::transport::http::build_http_transport(url, auth_header); url: &str,
auth_header: Option<&str>,
custom_headers: std::collections::HashMap<String, String>,
) -> Result<McpConnection, String> {
let transport =
crate::mcp::transport::http::build_http_transport(url, auth_header, custom_headers);
().serve(transport) ().serve(transport)
.await .await
.map_err(|e| format!("MCP HTTP connection failed: {e}")) .map_err(|e| format!("MCP HTTP connection failed: {e}"))

View File

@ -5,7 +5,8 @@ use tracing::{info, warn};
use crate::mcp::client::{connect_http, connect_stdio, list_resources, list_tools, McpConnection}; use crate::mcp::client::{connect_http, connect_stdio, list_resources, list_tools, McpConnection};
use crate::mcp::models::McpServer; use crate::mcp::models::McpServer;
use crate::mcp::store::{ use crate::mcp::store::{
get_server_auth_value, list_servers, replace_resources, replace_tools, update_discovery_status, get_server_auth_value, get_server_env_config, list_servers, replace_resources, replace_tools,
update_discovery_status,
}; };
/// Discover a single MCP server: connect, list tools/resources, persist. /// Discover a single MCP server: connect, list tools/resources, persist.
@ -55,11 +56,49 @@ async fn discover_server_inner(
.collect() .collect()
}) })
.unwrap_or_default(); .unwrap_or_default();
connect_stdio(command, &args).await?
// Parse plaintext env vars from transport_config.env
let plaintext_env: std::collections::HashMap<String, String> = config
.get("env")
.and_then(|v| v.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
// Decrypt and parse encrypted env vars from env_config column
let encrypted_env = {
let db = state.db.lock().map_err(|e| e.to_string())?;
get_server_env_config(&db, &server.id)?
};
// Merge env vars (encrypted takes precedence over plaintext)
let mut merged_env = plaintext_env;
if let Some(enc_env) = encrypted_env {
merged_env.extend(enc_env);
}
connect_stdio(command, &args, merged_env).await?
} }
"http" => { "http" => {
let auth_header = auth_value.as_deref(); let auth_header = auth_value.as_deref();
connect_http(&server.url, auth_header).await?
// Parse custom headers from transport_config.headers
let config: serde_json::Value =
serde_json::from_str(&server.transport_config).unwrap_or_default();
let custom_headers: std::collections::HashMap<String, String> = config
.get("headers")
.and_then(|v| v.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
connect_http(&server.url, auth_header, custom_headers).await?
} }
other => return Err(format!("Unknown transport type: {other}")), other => return Err(format!("Unknown transport type: {other}")),
}; };

View File

@ -1,17 +1,32 @@
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::StreamableHttpClientTransport;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
/// Build an HTTP (Streamable HTTP) transport from a URL. /// Build an HTTP (Streamable HTTP) transport from a URL with optional custom headers.
/// Optionally attaches an Authorization bearer token. /// Optionally attaches an Authorization bearer token.
///
/// NOTE: Custom headers are parsed but not yet applied due to rmcp v1.7.0 API limitations.
/// The rmcp library's StreamableHttpClientTransportConfig does not expose a .header() method.
/// Custom headers support is deferred until rmcp adds this capability or we find an alternative.
pub fn build_http_transport( pub fn build_http_transport(
url: &str, url: &str,
auth_header: Option<&str>, auth_header: Option<&str>,
custom_headers: HashMap<String, String>,
) -> impl rmcp::transport::Transport<rmcp::RoleClient> { ) -> impl rmcp::transport::Transport<rmcp::RoleClient> {
// Log warning if custom headers are provided (not yet supported)
if !custom_headers.is_empty() {
tracing::warn!(
"Custom HTTP headers provided but not supported by rmcp v1.7.0: {:?}",
custom_headers.keys().collect::<Vec<_>>()
);
}
let config = match auth_header { let config = match auth_header {
Some(token) => StreamableHttpClientTransportConfig::with_uri(Arc::from(url)) Some(token) => StreamableHttpClientTransportConfig::with_uri(Arc::from(url))
.auth_header(token.to_string()), .auth_header(token.to_string()),
None => StreamableHttpClientTransportConfig::with_uri(Arc::from(url)), None => StreamableHttpClientTransportConfig::with_uri(Arc::from(url)),
}; };
StreamableHttpClientTransport::from_config(config) StreamableHttpClientTransport::from_config(config)
} }

View File

@ -1,10 +1,15 @@
use rmcp::transport::TokioChildProcess; use rmcp::transport::TokioChildProcess;
use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use tokio::process::Command; use tokio::process::Command;
/// Build a stdio transport from a command path and argument list. /// Build a stdio transport from a command path, argument list, and environment variables.
/// Rejects relative paths to prevent path traversal. /// Rejects relative paths to prevent path traversal.
pub fn build_stdio_transport(command: &str, args: &[String]) -> Result<TokioChildProcess, String> { pub fn build_stdio_transport(
command: &str,
args: &[String],
env: HashMap<String, String>,
) -> Result<TokioChildProcess, String> {
if !Path::new(command).is_absolute() { if !Path::new(command).is_absolute() {
return Err(format!( return Err(format!(
"stdio command must be an absolute path, got: {command}" "stdio command must be an absolute path, got: {command}"
@ -14,5 +19,10 @@ pub fn build_stdio_transport(command: &str, args: &[String]) -> Result<TokioChil
let mut cmd = Command::new(command); let mut cmd = Command::new(command);
cmd.args(args); cmd.args(args);
// Apply environment variables
for (key, value) in env {
cmd.env(key, value);
}
TokioChildProcess::new(cmd).map_err(|e| format!("Failed to spawn stdio process: {e}")) TokioChildProcess::new(cmd).map_err(|e| format!("Failed to spawn stdio process: {e}"))
} }