Backport: Agentic Shell Command Execution (v1.0.0 → v1.0.8) #66

Merged
sarman merged 16 commits from feature/agentic-shell-commands into master 2026-06-05 15:30:28 +00:00
4 changed files with 602 additions and 96 deletions
Showing only changes of commit 117ab390a2 - Show all commits

View File

@ -1,10 +1,14 @@
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration; use std::time::Duration;
use crate::ai::provider::Provider; use crate::ai::provider::Provider;
use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage}; use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage, ToolCall};
use crate::state::ProviderConfig; use crate::state::ProviderConfig;
// Track if we've already attempted auto-start this session
static AUTO_START_ATTEMPTED: AtomicBool = AtomicBool::new(false);
pub struct OllamaProvider; pub struct OllamaProvider;
#[async_trait] #[async_trait]
@ -18,11 +22,11 @@ impl Provider for OllamaProvider {
name: "Ollama (Local)".to_string(), name: "Ollama (Local)".to_string(),
supports_streaming: true, supports_streaming: true,
models: vec![ models: vec![
"llama3.1".to_string(), "llama3.2:3b".to_string(),
"llama3".to_string(), "phi3.5:3.8b".to_string(),
"mistral".to_string(), "llama3.1:8b".to_string(),
"codellama".to_string(), "qwen2.5:14b".to_string(),
"phi3".to_string(), "gemma2:9b".to_string(),
], ],
} }
} }
@ -31,77 +35,276 @@ impl Provider for OllamaProvider {
&self, &self,
messages: Vec<Message>, messages: Vec<Message>,
config: &ProviderConfig, config: &ProviderConfig,
_tools: Option<Vec<crate::ai::Tool>>, tools: Option<Vec<crate::ai::Tool>>,
) -> anyhow::Result<ChatResponse> { ) -> anyhow::Result<ChatResponse> {
// Longer timeout for tool calling - models need time to generate structured output
let timeout_secs = if tools.is_some() { 180 } else { 60 };
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(Duration::from_secs(60)) .timeout(Duration::from_secs(timeout_secs))
.connect_timeout(Duration::from_secs(10))
.build()?; .build()?;
let base_url = if config.api_url.is_empty() { let base_url = if config.api_url.is_empty() {
"http://localhost:11434".to_string() "http://localhost:11434".to_string()
} else { } else {
config.api_url.trim_end_matches('/').to_string() config.api_url.trim_end_matches('/').to_string()
}; };
// Auto-start Ollama if using localhost and we haven't tried yet this session
// Only attempt once to avoid recurring latency on every chat() call
if base_url == "http://localhost:11434"
&& !AUTO_START_ATTEMPTED.swap(true, Ordering::Relaxed)
{
// Check if already running before attempting start
let pre_status = crate::ollama::installer::check_ollama().await;
let already_running = pre_status.map(|s| s.running).unwrap_or(false);
if !already_running {
match crate::ollama::installer::start_ollama_service().await {
Ok(true) => {
tracing::info!("Ollama service auto-started successfully");
// Give it a moment to fully initialize
tokio::time::sleep(Duration::from_secs(2)).await;
}
Ok(false) => {
tracing::debug!("Ollama not started (not installed or already running)");
}
Err(e) => {
tracing::warn!("Failed to auto-start Ollama: {}", e);
// Continue anyway - maybe it's already running or will start soon
}
}
} else {
tracing::debug!("Ollama already running, skipping auto-start");
}
}
// Quick health check before attempting chat (short timeout for fast failure)
let health_client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()?;
let health_check_result = health_client
.get(format!("{base_url}/api/tags"))
.send()
.await;
match health_check_result {
Ok(resp) if resp.status().is_success() => {
tracing::debug!("Ollama health check passed");
}
Ok(resp) => {
let status = resp.status();
tracing::warn!("Ollama health check returned status {status}");
anyhow::bail!(
"Ollama is not ready (status {status}). Please ensure Ollama is running."
);
}
Err(e) => {
tracing::error!("Cannot connect to Ollama at {base_url}: {e}");
anyhow::bail!("Cannot connect to Ollama at {base_url}. Please ensure Ollama is running and accessible.");
}
}
let url = format!("{base_url}/api/chat"); let url = format!("{base_url}/api/chat");
// Ollama expects {model, messages: [{role, content}], stream: false} // Ollama expects {model, messages: [{role, content, tool_calls?, tool_call_id?}], stream: false}
let api_messages: Vec<serde_json::Value> = messages let api_messages: Vec<serde_json::Value> = messages
.iter() .iter()
.map(|m| { .map(|m| {
serde_json::json!({ let mut msg = serde_json::json!({
"role": m.role, "role": m.role,
"content": m.content, "content": m.content,
}) });
// Include tool_calls if present (for assistant messages with tool requests)
if let Some(ref tool_calls) = m.tool_calls {
msg["tool_calls"] = serde_json::json!(tool_calls);
}
// Include tool_call_id if present (for tool result messages)
if let Some(ref tool_call_id) = m.tool_call_id {
msg["tool_call_id"] = serde_json::json!(tool_call_id);
}
msg
}) })
.collect(); .collect();
let body = serde_json::json!({ let mut body = serde_json::json!({
"model": config.model, "model": config.model,
"messages": api_messages, "messages": api_messages,
"stream": false, "stream": false,
}); });
let resp = client // Add tools if provided (Ollama function calling format)
.post(&url) if let Some(tools_list) = tools {
.header("Content-Type", "application/json") let formatted_tools: Vec<serde_json::Value> = tools_list
.json(&body) .iter()
.send() .map(|tool| {
.await?; serde_json::json!({
"type": "function",
if !resp.status().is_success() { "function": {
let status = resp.status(); "name": tool.name,
let text = resp.text().await?; "description": tool.description,
anyhow::bail!("Ollama API error {status}: {text}"); "parameters": tool.parameters
}
})
})
.collect();
body["tools"] = serde_json::Value::from(formatted_tools);
} }
let json: serde_json::Value = resp.json().await?; // Retry logic for transient connection issues
let max_retries = 2;
let mut last_error = None;
// Parse response.message.content for attempt in 0..=max_retries {
let content = json["message"]["content"] if attempt > 0 {
.as_str() tracing::warn!(
.ok_or_else(|| anyhow::anyhow!("No content in Ollama response"))? "Ollama request failed, retrying (attempt {}/{})...",
.to_string(); attempt + 1,
max_retries + 1
);
tokio::time::sleep(Duration::from_secs(2)).await;
}
// Ollama provides eval_count / prompt_eval_count let resp_result = client
let usage = { .post(&url)
let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32; .header("Content-Type", "application/json")
let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32; .json(&body)
if prompt_tokens > 0 || completion_tokens > 0 { .send()
Some(TokenUsage { .await;
prompt_tokens,
completion_tokens, let resp = match resp_result {
total_tokens: prompt_tokens + completion_tokens, Ok(r) => r,
}) Err(e) => {
last_error = Some(format!("Connection error: {e}"));
if attempt < max_retries {
continue; // Retry
} else {
anyhow::bail!(
"Failed to connect to Ollama after {} attempts. Last error: {e}",
max_retries + 1
);
}
}
};
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await?;
last_error = Some(format!("API error {status}: {text}"));
if attempt < max_retries && status.is_server_error() {
continue; // Retry on 5xx errors
} else {
anyhow::bail!("Ollama API error {status}: {text}");
}
}
// Success - parse response and return
let json: serde_json::Value = match resp.json().await {
Ok(j) => j,
Err(e) => {
last_error = Some(format!("JSON parse error: {e}"));
if attempt < max_retries {
continue; // Retry
} else {
anyhow::bail!("Failed to parse Ollama response: {e}");
}
}
};
// Parse response.message.content
let content = json["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
// Parse tool calls from Ollama response
// Ollama returns tool_calls in message.tool_calls array
let tool_calls = if let Some(calls_array) = json["message"]["tool_calls"].as_array() {
let mut parsed_calls = Vec::new();
for (idx, call) in calls_array.iter().enumerate() {
// Generate fallback ID if not provided
let id = call["id"]
.as_str()
.map(|s| s.to_string())
.unwrap_or_else(|| format!("tool_call_{idx}"));
let function = &call["function"];
// Skip malformed tool calls (missing name) instead of failing entire response
let name = match function["name"].as_str() {
Some(n) => n.to_string(),
None => {
tracing::warn!("Skipping tool call with missing name at index {idx}");
continue;
}
};
// Arguments can be either an object or a string
let arguments = if let Some(args_obj) = function["arguments"].as_object() {
match serde_json::to_string(args_obj) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
"Failed to serialize tool call arguments at index {}: {}",
idx,
e
);
continue;
}
}
} else if let Some(args_str) = function["arguments"].as_str() {
args_str.to_string()
} else {
"{}".to_string()
};
parsed_calls.push(ToolCall {
id,
name,
arguments,
});
}
if !parsed_calls.is_empty() {
Some(parsed_calls)
} else {
None
}
} else { } else {
None None
} };
};
Ok(ChatResponse { // Ollama provides eval_count / prompt_eval_count
content, let usage = {
model: config.model.clone(), let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32;
usage, let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32;
user_message: None, if prompt_tokens > 0 || completion_tokens > 0 {
tool_calls: None, Some(TokenUsage {
}) prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
})
} else {
None
}
};
return Ok(ChatResponse {
content,
model: config.model.clone(),
usage,
user_message: None,
tool_calls,
});
}
// If we get here, all retries failed
anyhow::bail!(
"Failed to get response from Ollama after {} attempts. Last error: {:?}",
max_retries + 1,
last_error
)
} }
} }

View File

@ -7,8 +7,8 @@ use crate::state::ProviderConfig;
pub struct OpenAiProvider; pub struct OpenAiProvider;
fn is_custom_rest_format(api_format: Option<&str>) -> bool { fn is_msi_genai_format(api_format: Option<&str>) -> bool {
matches!(api_format, Some("custom_rest")) matches!(api_format, Some("msi-genai") | Some("custom_rest")) // custom_rest for backward compatibility
} }
#[async_trait] #[async_trait]
@ -38,8 +38,8 @@ impl Provider for OpenAiProvider {
// Check if using custom REST format // Check if using custom REST format
let api_format = config.api_format.as_deref().unwrap_or("openai"); let api_format = config.api_format.as_deref().unwrap_or("openai");
if is_custom_rest_format(Some(api_format)) { if is_msi_genai_format(Some(api_format)) {
self.chat_custom_rest(messages, config, tools).await self.chat_msi_genai(messages, config, tools).await
} else { } else {
self.chat_openai(messages, config, tools).await self.chat_openai(messages, config, tools).await
} }
@ -48,17 +48,109 @@ impl Provider for OpenAiProvider {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::is_custom_rest_format; use super::{is_msi_genai_format, OpenAiProvider};
#[test] #[test]
fn custom_rest_format_is_recognized() { fn msi_genai_format_is_recognized() {
assert!(is_custom_rest_format(Some("custom_rest"))); assert!(is_msi_genai_format(Some("msi-genai")));
} }
#[test] #[test]
fn openai_format_is_not_custom_rest() { fn custom_rest_format_backward_compatible() {
assert!(!is_custom_rest_format(Some("openai"))); // Keep backward compatibility with old format name
assert!(!is_custom_rest_format(None)); assert!(is_msi_genai_format(Some("custom_rest")));
}
#[test]
fn openai_format_is_not_msi_genai() {
assert!(!is_msi_genai_format(Some("openai")));
assert!(!is_msi_genai_format(None));
}
#[test]
fn parse_msigenai_chatgpt_tool_calls_from_json_text() {
// MSIGenAI ChatGPT format: returns tool calls as JSON object in msg
let content = r#"{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get namespaces"}}}]}"#;
let result = OpenAiProvider::parse_tool_calls_from_text(content);
assert!(result.is_some());
let calls = result.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].id, "call_1");
assert_eq!(calls[0].name, "execute_shell_command");
assert!(calls[0].arguments.contains("kubectl get namespaces"));
}
#[test]
fn parse_msigenai_claude_tool_calls_from_xml_wrapper() {
// MSIGenAI Claude format: XML wrapper around JSON array
let content = r#"<tool_calls>
[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get pods"}}}]
</tool_calls>"#;
let result = OpenAiProvider::parse_tool_calls_from_text(content);
assert!(result.is_some());
let calls = result.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].id, "call_1");
assert_eq!(calls[0].name, "execute_shell_command");
assert!(calls[0].arguments.contains("kubectl get pods"));
}
#[test]
fn parse_multiple_tool_calls_from_text() {
let content = r#"{"tool_calls":[
{"id":"call_1","function":{"name":"kubectl_get","arguments":{"resource":"pods"}}},
{"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc/nginx"}}}
]}"#;
let result = OpenAiProvider::parse_tool_calls_from_text(content);
assert!(result.is_some());
let calls = result.unwrap();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "kubectl_get");
assert_eq!(calls[1].name, "kubectl_describe");
}
#[test]
fn parse_tool_calls_returns_none_for_normal_text() {
let content = "Hello, I found 5 pods running in the cluster.";
let result = OpenAiProvider::parse_tool_calls_from_text(content);
assert!(result.is_none());
}
#[test]
fn parse_tool_calls_handles_arguments_as_string() {
// Some providers return arguments as string, not object
let content = r#"{"tool_calls":[{"id":"call_1","function":{"name":"test","arguments":"{\"key\":\"value\"}"}}]}"#;
let result = OpenAiProvider::parse_tool_calls_from_text(content);
assert!(result.is_some());
let calls = result.unwrap();
assert_eq!(calls[0].arguments, r#"{"key":"value"}"#);
}
#[test]
fn parse_tool_calls_generates_fallback_id_when_missing() {
// Some providers may omit id field - generate fallback to prevent silent drop
let content = r#"{"tool_calls":[
{"function":{"name":"kubectl_get","arguments":{"resource":"pods"}}},
{"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc"}}}
]}"#;
let result = OpenAiProvider::parse_tool_calls_from_text(content);
assert!(result.is_some());
let calls = result.unwrap();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].id, "tool_call_0"); // Fallback generated
assert_eq!(calls[0].name, "kubectl_get");
assert_eq!(calls[1].id, "call_2"); // Original preserved
assert_eq!(calls[1].name, "kubectl_describe");
} }
} }
@ -202,8 +294,13 @@ impl OpenAiProvider {
}) })
} }
/// Custom REST format (non-OpenAI payload contract) /// MSI GenAI format (non-OpenAI payload contract)
async fn chat_custom_rest( ///
/// MSI GenAI uses a custom API format with 'prompt' field instead of 'messages',
/// and has a known bug where tool calls are returned as JSON text in the 'msg'
/// field instead of structured 'tool_calls' array. This implementation includes
/// workaround parsing to extract tool calls from text.
async fn chat_msi_genai(
&self, &self,
messages: Vec<Message>, messages: Vec<Message>,
config: &ProviderConfig, config: &ProviderConfig,
@ -284,7 +381,7 @@ impl OpenAiProvider {
body["tools"] = serde_json::Value::from(formatted_tools); body["tools"] = serde_json::Value::from(formatted_tools);
body["tool_choice"] = serde_json::Value::from("auto"); body["tool_choice"] = serde_json::Value::from("auto");
tracing::info!("Custom REST: Sending {} tools in request", tool_count); tracing::info!("MSI GenAI: Sending {} tools in request", tool_count);
} }
// Use custom auth header and prefix (no default prefix for custom REST) // Use custom auth header and prefix (no default prefix for custom REST)
@ -306,13 +403,13 @@ impl OpenAiProvider {
if !resp.status().is_success() { if !resp.status().is_success() {
let status = resp.status(); let status = resp.status();
let text = resp.text().await?; let text = resp.text().await?;
anyhow::bail!("Custom REST API error {status}: {text}"); anyhow::bail!("MSI GenAI API error {status}: {text}");
} }
let json: serde_json::Value = resp.json().await?; let json: serde_json::Value = resp.json().await?;
tracing::debug!( tracing::debug!(
"Custom REST response: {}", "MSI GenAI response: {}",
serde_json::to_string_pretty(&json).unwrap_or_else(|_| "invalid JSON".to_string()) serde_json::to_string_pretty(&json).unwrap_or_else(|_| "invalid JSON".to_string())
); );
@ -323,7 +420,7 @@ impl OpenAiProvider {
.to_string(); .to_string();
// Parse tool_calls if present (check multiple possible field names) // Parse tool_calls if present (check multiple possible field names)
let tool_calls = json let mut tool_calls = json
.get("tool_calls") .get("tool_calls")
.or_else(|| json.get("toolCalls")) .or_else(|| json.get("toolCalls"))
.or_else(|| json.get("function_calls")) .or_else(|| json.get("function_calls"))
@ -331,57 +428,84 @@ impl OpenAiProvider {
if let Some(arr) = tc.as_array() { if let Some(arr) = tc.as_array() {
let calls: Vec<crate::ai::ToolCall> = arr let calls: Vec<crate::ai::ToolCall> = arr
.iter() .iter()
.filter_map(|call| { .enumerate()
.filter_map(|(index, call)| {
// Try OpenAI format first // Try OpenAI format first
if let (Some(id), Some(name), Some(args)) = ( if let (Some(id), Some(name)) = (
call.get("id").and_then(|v| v.as_str()), call.get("id").and_then(|v| v.as_str()),
call.get("function") call.get("function")
.and_then(|f| f.get("name")) .and_then(|f| f.get("name"))
.and_then(|n| n.as_str()) .and_then(|n| n.as_str())
.or_else(|| call.get("name").and_then(|n| n.as_str())), .or_else(|| call.get("name").and_then(|n| n.as_str())),
call.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str())
.or_else(|| call.get("arguments").and_then(|a| a.as_str())),
) { ) {
tracing::info!("Custom REST: Parsed tool call: {} ({})", name, id); // Accept arguments as either string or object (MSI GenAI returns both)
return Some(crate::ai::ToolCall { let arguments = call
id: id.to_string(), .get("function")
name: name.to_string(), .and_then(|f| f.get("arguments"))
arguments: args.to_string(), .or_else(|| call.get("arguments"))
}); .and_then(|args| {
if let Some(s) = args.as_str() {
Some(s.to_string())
} else {
// Serialize object to JSON string
serde_json::to_string(args).ok()
}
});
if let Some(args) = arguments {
tracing::info!(
"MSI GenAI: Parsed tool call: {} ({})",
name,
id
);
return Some(crate::ai::ToolCall {
id: id.to_string(),
name: name.to_string(),
arguments: args,
});
}
} }
// Try simpler format // Try simpler format
if let (Some(name), Some(args)) = ( if let Some(name) = call.get("name").and_then(|n| n.as_str()) {
call.get("name").and_then(|n| n.as_str()), // Accept arguments as either string or object
call.get("arguments").and_then(|a| a.as_str()), let arguments = call.get("arguments").and_then(|args| {
) { if let Some(s) = args.as_str() {
let id = call Some(s.to_string())
.get("id") } else {
.and_then(|v| v.as_str()) // Serialize object to JSON string
.unwrap_or("tool_call_0") serde_json::to_string(args).ok()
.to_string(); }
tracing::info!(
"Custom REST: Parsed tool call (simple format): {} ({})",
name,
id
);
return Some(crate::ai::ToolCall {
id,
name: name.to_string(),
arguments: args.to_string(),
}); });
if let Some(args) = arguments {
// Generate unique ID if missing (avoids duplicates)
let id = call
.get("id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("tool_call_{index}"));
tracing::info!(
"MSI GenAI: Parsed tool call (simple format): {} ({})",
name,
id
);
return Some(crate::ai::ToolCall {
id,
name: name.to_string(),
arguments: args,
});
}
} }
tracing::warn!("Custom REST: Failed to parse tool call: {:?}", call); tracing::warn!("MSI GenAI: Failed to parse tool call: {:?}", call);
None None
}) })
.collect(); .collect();
if calls.is_empty() { if calls.is_empty() {
None None
} else { } else {
tracing::info!("Custom REST: Found {} tool calls", calls.len()); tracing::info!("MSI GenAI: Found {} tool calls", calls.len());
Some(calls) Some(calls)
} }
} else { } else {
@ -389,6 +513,20 @@ impl OpenAiProvider {
} }
}); });
// WORKAROUND: MSIGenAI gateway bug - tool calls returned as JSON text in 'msg' field
// Expected: {"tool_calls": [...]}
// Actual: {"msg": '{"tool_calls":[...]}'} or {"msg": '<tool_calls>[...]</tool_calls>'}
if tool_calls.is_none() {
// Try parsing tool calls from msg content (MSIGenAI workaround)
if let Some(parsed_calls) = Self::parse_tool_calls_from_text(&content) {
tracing::warn!(
"MSI GenAI: MSIGenAI workaround - parsed {} tool calls from msg text (gateway should return structured tool_calls field)",
parsed_calls.len()
);
tool_calls = Some(parsed_calls);
}
}
// Note: sessionId from response should be stored back to config.session_id // Note: sessionId from response should be stored back to config.session_id
// This would require making config mutable or returning it as part of ChatResponse // This would require making config mutable or returning it as part of ChatResponse
// For now, the caller can extract it from the response if needed // For now, the caller can extract it from the response if needed
@ -402,4 +540,95 @@ impl OpenAiProvider {
tool_calls, tool_calls,
}) })
} }
/// Parse tool calls from text content (MSIGenAI gateway workaround)
///
/// MSIGenAI returns tool calls as JSON text in the 'msg' field instead of structured data:
/// - ChatGPT models: `{"tool_calls":[...]}`
/// - Claude models: `<tool_calls>[...]</tool_calls>`
fn parse_tool_calls_from_text(content: &str) -> Option<Vec<crate::ai::ToolCall>> {
// Try parsing as direct JSON object
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
if let Some(calls) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
return Self::extract_tool_calls_from_array(calls);
}
}
// Try finding JSON in text (handle Claude XML wrapper: <tool_calls>[...]</tool_calls>)
if let Some(start) = content.find("<tool_calls>") {
if let Some(end) = content.find("</tool_calls>") {
let json_str = &content[start + 12..end].trim();
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
if let Some(calls) = parsed.as_array() {
return Self::extract_tool_calls_from_array(calls);
}
}
}
}
// Try finding raw JSON array in text
if let Some(start) = content.find("[{") {
if let Some(end) = content.rfind("}]") {
let json_str = &content[start..=end + 1];
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
if let Some(calls) = parsed.as_array() {
return Self::extract_tool_calls_from_array(calls);
}
}
}
}
None
}
/// Extract ToolCall structs from JSON array
fn extract_tool_calls_from_array(
calls: &[serde_json::Value],
) -> Option<Vec<crate::ai::ToolCall>> {
let parsed: Vec<crate::ai::ToolCall> = calls
.iter()
.enumerate()
.filter_map(|(index, call)| {
// Generate fallback ID if missing (consistent with earlier parsing logic in this file)
let id = call
.get("id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("tool_call_{index}"));
// Try nested function.name format (OpenAI style)
let name = call
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
.or_else(|| call.get("name").and_then(|n| n.as_str()))?
.to_string();
// Arguments can be string or object
let arguments = call
.get("function")
.and_then(|f| f.get("arguments"))
.or_else(|| call.get("arguments"))
.and_then(|args| {
if let Some(s) = args.as_str() {
Some(s.to_string())
} else {
serde_json::to_string(args).ok()
}
})?;
Some(crate::ai::ToolCall {
id,
name,
arguments,
})
})
.collect();
if parsed.is_empty() {
None
} else {
Some(parsed)
}
}
} }

View File

@ -552,6 +552,79 @@ pub async fn test_provider_connection(
}) })
} }
#[tauri::command]
pub async fn detect_tool_calling_support(provider_config: ProviderConfig) -> Result<bool, String> {
use crate::ai::{Tool, ToolParameters};
use std::collections::HashMap;
use tracing::info;
// Create a simple test tool
let test_tool = Tool {
name: "test_tool".to_string(),
description: "A test tool that returns 'success'. Call this tool with no arguments."
.to_string(),
parameters: ToolParameters {
param_type: "object".to_string(),
properties: HashMap::new(),
required: vec![],
},
};
// Override config with detection-optimized settings
let mut detection_config = provider_config.clone();
detection_config.max_tokens = Some(100); // Small budget for capability check
detection_config.temperature = Some(0.0); // Deterministic for reliability
let provider = create_provider(&detection_config);
let messages = vec![Message {
role: "user".into(),
content: "Please call the test_tool function.".into(),
tool_call_id: None,
tool_calls: None,
}];
match provider
.chat(messages, &detection_config, Some(vec![test_tool]))
.await
{
Ok(response) => {
// Check if response contains tool_calls
if let Some(tool_calls) = response.tool_calls {
if tool_calls.iter().any(|tc| tc.name == "test_tool") {
info!(
"Tool calling support detected for provider {}",
provider_config.name
);
return Ok(true);
}
}
// Provider responded but didn't use tool calls
info!(
"Provider {} responded but did not call tool",
provider_config.name
);
Ok(false)
}
Err(e) => {
// Check if error indicates tool calling is not supported
let error_msg = e.to_string().to_lowercase();
if error_msg.contains("tool")
|| error_msg.contains("function")
|| error_msg.contains("503")
{
info!(
"Tool calling not supported for provider {}: {}",
provider_config.name, e
);
Ok(false)
} else {
// Connection or other error
Err(format!("Failed to test tool calling support: {e}"))
}
}
}
}
#[tauri::command] #[tauri::command]
pub async fn list_providers() -> Result<Vec<ProviderInfo>, String> { pub async fn list_providers() -> Result<Vec<ProviderInfo>, String> {
Ok(vec![ Ok(vec![

View File

@ -104,6 +104,7 @@ pub fn run() {
commands::ai::analyze_logs, commands::ai::analyze_logs,
commands::ai::chat_message, commands::ai::chat_message,
commands::ai::test_provider_connection, commands::ai::test_provider_connection,
commands::ai::detect_tool_calling_support,
commands::ai::list_providers, commands::ai::list_providers,
commands::system::save_ai_provider, commands::system::save_ai_provider,
commands::system::load_ai_providers, commands::system::load_ai_providers,