diff --git a/src-tauri/src/ai/ollama.rs b/src-tauri/src/ai/ollama.rs index 1bde9746..f2360d51 100644 --- a/src-tauri/src/ai/ollama.rs +++ b/src-tauri/src/ai/ollama.rs @@ -1,10 +1,14 @@ use async_trait::async_trait; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use crate::ai::provider::Provider; -use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage}; +use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage, ToolCall}; use crate::state::ProviderConfig; +// Track if we've already attempted auto-start this session +static AUTO_START_ATTEMPTED: AtomicBool = AtomicBool::new(false); + pub struct OllamaProvider; #[async_trait] @@ -18,11 +22,11 @@ impl Provider for OllamaProvider { name: "Ollama (Local)".to_string(), supports_streaming: true, models: vec![ - "llama3.1".to_string(), - "llama3".to_string(), - "mistral".to_string(), - "codellama".to_string(), - "phi3".to_string(), + "llama3.2:3b".to_string(), + "phi3.5:3.8b".to_string(), + "llama3.1:8b".to_string(), + "qwen2.5:14b".to_string(), + "gemma2:9b".to_string(), ], } } @@ -31,77 +35,276 @@ impl Provider for OllamaProvider { &self, messages: Vec, config: &ProviderConfig, - _tools: Option>, + tools: Option>, ) -> anyhow::Result { + // Longer timeout for tool calling - models need time to generate structured output + let timeout_secs = if tools.is_some() { 180 } else { 60 }; + let client = reqwest::Client::builder() - .timeout(Duration::from_secs(60)) + .timeout(Duration::from_secs(timeout_secs)) + .connect_timeout(Duration::from_secs(10)) .build()?; let base_url = if config.api_url.is_empty() { "http://localhost:11434".to_string() } else { config.api_url.trim_end_matches('/').to_string() }; + + // Auto-start Ollama if using localhost and we haven't tried yet this session + // Only attempt once to avoid recurring latency on every chat() call + if base_url == "http://localhost:11434" + && !AUTO_START_ATTEMPTED.swap(true, Ordering::Relaxed) + { + // Check if already running before attempting start + let pre_status = crate::ollama::installer::check_ollama().await; + let already_running = pre_status.map(|s| s.running).unwrap_or(false); + + if !already_running { + match crate::ollama::installer::start_ollama_service().await { + Ok(true) => { + tracing::info!("Ollama service auto-started successfully"); + // Give it a moment to fully initialize + tokio::time::sleep(Duration::from_secs(2)).await; + } + Ok(false) => { + tracing::debug!("Ollama not started (not installed or already running)"); + } + Err(e) => { + tracing::warn!("Failed to auto-start Ollama: {}", e); + // Continue anyway - maybe it's already running or will start soon + } + } + } else { + tracing::debug!("Ollama already running, skipping auto-start"); + } + } + + // Quick health check before attempting chat (short timeout for fast failure) + let health_client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .build()?; + let health_check_result = health_client + .get(format!("{base_url}/api/tags")) + .send() + .await; + + match health_check_result { + Ok(resp) if resp.status().is_success() => { + tracing::debug!("Ollama health check passed"); + } + Ok(resp) => { + let status = resp.status(); + tracing::warn!("Ollama health check returned status {status}"); + anyhow::bail!( + "Ollama is not ready (status {status}). Please ensure Ollama is running." + ); + } + Err(e) => { + tracing::error!("Cannot connect to Ollama at {base_url}: {e}"); + anyhow::bail!("Cannot connect to Ollama at {base_url}. Please ensure Ollama is running and accessible."); + } + } + let url = format!("{base_url}/api/chat"); - // Ollama expects {model, messages: [{role, content}], stream: false} + // Ollama expects {model, messages: [{role, content, tool_calls?, tool_call_id?}], stream: false} let api_messages: Vec = messages .iter() .map(|m| { - serde_json::json!({ + let mut msg = serde_json::json!({ "role": m.role, "content": m.content, - }) + }); + + // Include tool_calls if present (for assistant messages with tool requests) + if let Some(ref tool_calls) = m.tool_calls { + msg["tool_calls"] = serde_json::json!(tool_calls); + } + + // Include tool_call_id if present (for tool result messages) + if let Some(ref tool_call_id) = m.tool_call_id { + msg["tool_call_id"] = serde_json::json!(tool_call_id); + } + + msg }) .collect(); - let body = serde_json::json!({ + let mut body = serde_json::json!({ "model": config.model, "messages": api_messages, "stream": false, }); - let resp = client - .post(&url) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await?; - - if !resp.status().is_success() { - let status = resp.status(); - let text = resp.text().await?; - anyhow::bail!("Ollama API error {status}: {text}"); + // Add tools if provided (Ollama function calling format) + if let Some(tools_list) = tools { + let formatted_tools: Vec = tools_list + .iter() + .map(|tool| { + serde_json::json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } + }) + }) + .collect(); + body["tools"] = serde_json::Value::from(formatted_tools); } - let json: serde_json::Value = resp.json().await?; + // Retry logic for transient connection issues + let max_retries = 2; + let mut last_error = None; - // Parse response.message.content - let content = json["message"]["content"] - .as_str() - .ok_or_else(|| anyhow::anyhow!("No content in Ollama response"))? - .to_string(); + for attempt in 0..=max_retries { + if attempt > 0 { + tracing::warn!( + "Ollama request failed, retrying (attempt {}/{})...", + attempt + 1, + max_retries + 1 + ); + tokio::time::sleep(Duration::from_secs(2)).await; + } - // Ollama provides eval_count / prompt_eval_count - let usage = { - let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32; - let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32; - if prompt_tokens > 0 || completion_tokens > 0 { - Some(TokenUsage { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - }) + let resp_result = client + .post(&url) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await; + + let resp = match resp_result { + Ok(r) => r, + Err(e) => { + last_error = Some(format!("Connection error: {e}")); + if attempt < max_retries { + continue; // Retry + } else { + anyhow::bail!( + "Failed to connect to Ollama after {} attempts. Last error: {e}", + max_retries + 1 + ); + } + } + }; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await?; + last_error = Some(format!("API error {status}: {text}")); + if attempt < max_retries && status.is_server_error() { + continue; // Retry on 5xx errors + } else { + anyhow::bail!("Ollama API error {status}: {text}"); + } + } + + // Success - parse response and return + let json: serde_json::Value = match resp.json().await { + Ok(j) => j, + Err(e) => { + last_error = Some(format!("JSON parse error: {e}")); + if attempt < max_retries { + continue; // Retry + } else { + anyhow::bail!("Failed to parse Ollama response: {e}"); + } + } + }; + + // Parse response.message.content + let content = json["message"]["content"] + .as_str() + .unwrap_or("") + .to_string(); + + // Parse tool calls from Ollama response + // Ollama returns tool_calls in message.tool_calls array + let tool_calls = if let Some(calls_array) = json["message"]["tool_calls"].as_array() { + let mut parsed_calls = Vec::new(); + for (idx, call) in calls_array.iter().enumerate() { + // Generate fallback ID if not provided + let id = call["id"] + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tool_call_{idx}")); + + let function = &call["function"]; + + // Skip malformed tool calls (missing name) instead of failing entire response + let name = match function["name"].as_str() { + Some(n) => n.to_string(), + None => { + tracing::warn!("Skipping tool call with missing name at index {idx}"); + continue; + } + }; + + // Arguments can be either an object or a string + let arguments = if let Some(args_obj) = function["arguments"].as_object() { + match serde_json::to_string(args_obj) { + Ok(s) => s, + Err(e) => { + tracing::warn!( + "Failed to serialize tool call arguments at index {}: {}", + idx, + e + ); + continue; + } + } + } else if let Some(args_str) = function["arguments"].as_str() { + args_str.to_string() + } else { + "{}".to_string() + }; + + parsed_calls.push(ToolCall { + id, + name, + arguments, + }); + } + if !parsed_calls.is_empty() { + Some(parsed_calls) + } else { + None + } } else { None - } - }; + }; - Ok(ChatResponse { - content, - model: config.model.clone(), - usage, - user_message: None, - tool_calls: None, - }) + // Ollama provides eval_count / prompt_eval_count + let usage = { + let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32; + let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32; + if prompt_tokens > 0 || completion_tokens > 0 { + Some(TokenUsage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }) + } else { + None + } + }; + + return Ok(ChatResponse { + content, + model: config.model.clone(), + usage, + user_message: None, + tool_calls, + }); + } + + // If we get here, all retries failed + anyhow::bail!( + "Failed to get response from Ollama after {} attempts. Last error: {:?}", + max_retries + 1, + last_error + ) } } diff --git a/src-tauri/src/ai/openai.rs b/src-tauri/src/ai/openai.rs index 4d83bb31..5b4ffc6d 100644 --- a/src-tauri/src/ai/openai.rs +++ b/src-tauri/src/ai/openai.rs @@ -7,8 +7,8 @@ use crate::state::ProviderConfig; pub struct OpenAiProvider; -fn is_custom_rest_format(api_format: Option<&str>) -> bool { - matches!(api_format, Some("custom_rest")) +fn is_msi_genai_format(api_format: Option<&str>) -> bool { + matches!(api_format, Some("msi-genai") | Some("custom_rest")) // custom_rest for backward compatibility } #[async_trait] @@ -38,8 +38,8 @@ impl Provider for OpenAiProvider { // Check if using custom REST format let api_format = config.api_format.as_deref().unwrap_or("openai"); - if is_custom_rest_format(Some(api_format)) { - self.chat_custom_rest(messages, config, tools).await + if is_msi_genai_format(Some(api_format)) { + self.chat_msi_genai(messages, config, tools).await } else { self.chat_openai(messages, config, tools).await } @@ -48,17 +48,109 @@ impl Provider for OpenAiProvider { #[cfg(test)] mod tests { - use super::is_custom_rest_format; + use super::{is_msi_genai_format, OpenAiProvider}; #[test] - fn custom_rest_format_is_recognized() { - assert!(is_custom_rest_format(Some("custom_rest"))); + fn msi_genai_format_is_recognized() { + assert!(is_msi_genai_format(Some("msi-genai"))); } #[test] - fn openai_format_is_not_custom_rest() { - assert!(!is_custom_rest_format(Some("openai"))); - assert!(!is_custom_rest_format(None)); + fn custom_rest_format_backward_compatible() { + // Keep backward compatibility with old format name + assert!(is_msi_genai_format(Some("custom_rest"))); + } + + #[test] + fn openai_format_is_not_msi_genai() { + assert!(!is_msi_genai_format(Some("openai"))); + assert!(!is_msi_genai_format(None)); + } + + #[test] + fn parse_msigenai_chatgpt_tool_calls_from_json_text() { + // MSIGenAI ChatGPT format: returns tool calls as JSON object in msg + let content = r#"{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get namespaces"}}}]}"#; + + let result = OpenAiProvider::parse_tool_calls_from_text(content); + assert!(result.is_some()); + + let calls = result.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].id, "call_1"); + assert_eq!(calls[0].name, "execute_shell_command"); + assert!(calls[0].arguments.contains("kubectl get namespaces")); + } + + #[test] + fn parse_msigenai_claude_tool_calls_from_xml_wrapper() { + // MSIGenAI Claude format: XML wrapper around JSON array + let content = r#" +[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get pods"}}}] +"#; + + let result = OpenAiProvider::parse_tool_calls_from_text(content); + assert!(result.is_some()); + + let calls = result.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].id, "call_1"); + assert_eq!(calls[0].name, "execute_shell_command"); + assert!(calls[0].arguments.contains("kubectl get pods")); + } + + #[test] + fn parse_multiple_tool_calls_from_text() { + let content = r#"{"tool_calls":[ + {"id":"call_1","function":{"name":"kubectl_get","arguments":{"resource":"pods"}}}, + {"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc/nginx"}}} + ]}"#; + + let result = OpenAiProvider::parse_tool_calls_from_text(content); + assert!(result.is_some()); + + let calls = result.unwrap(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "kubectl_get"); + assert_eq!(calls[1].name, "kubectl_describe"); + } + + #[test] + fn parse_tool_calls_returns_none_for_normal_text() { + let content = "Hello, I found 5 pods running in the cluster."; + let result = OpenAiProvider::parse_tool_calls_from_text(content); + assert!(result.is_none()); + } + + #[test] + fn parse_tool_calls_handles_arguments_as_string() { + // Some providers return arguments as string, not object + let content = r#"{"tool_calls":[{"id":"call_1","function":{"name":"test","arguments":"{\"key\":\"value\"}"}}]}"#; + + let result = OpenAiProvider::parse_tool_calls_from_text(content); + assert!(result.is_some()); + + let calls = result.unwrap(); + assert_eq!(calls[0].arguments, r#"{"key":"value"}"#); + } + + #[test] + fn parse_tool_calls_generates_fallback_id_when_missing() { + // Some providers may omit id field - generate fallback to prevent silent drop + let content = r#"{"tool_calls":[ + {"function":{"name":"kubectl_get","arguments":{"resource":"pods"}}}, + {"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc"}}} + ]}"#; + + let result = OpenAiProvider::parse_tool_calls_from_text(content); + assert!(result.is_some()); + + let calls = result.unwrap(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].id, "tool_call_0"); // Fallback generated + assert_eq!(calls[0].name, "kubectl_get"); + assert_eq!(calls[1].id, "call_2"); // Original preserved + assert_eq!(calls[1].name, "kubectl_describe"); } } @@ -202,8 +294,13 @@ impl OpenAiProvider { }) } - /// Custom REST format (non-OpenAI payload contract) - async fn chat_custom_rest( + /// MSI GenAI format (non-OpenAI payload contract) + /// + /// MSI GenAI uses a custom API format with 'prompt' field instead of 'messages', + /// and has a known bug where tool calls are returned as JSON text in the 'msg' + /// field instead of structured 'tool_calls' array. This implementation includes + /// workaround parsing to extract tool calls from text. + async fn chat_msi_genai( &self, messages: Vec, config: &ProviderConfig, @@ -284,7 +381,7 @@ impl OpenAiProvider { body["tools"] = serde_json::Value::from(formatted_tools); body["tool_choice"] = serde_json::Value::from("auto"); - tracing::info!("Custom REST: Sending {} tools in request", tool_count); + tracing::info!("MSI GenAI: Sending {} tools in request", tool_count); } // Use custom auth header and prefix (no default prefix for custom REST) @@ -306,13 +403,13 @@ impl OpenAiProvider { if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await?; - anyhow::bail!("Custom REST API error {status}: {text}"); + anyhow::bail!("MSI GenAI API error {status}: {text}"); } let json: serde_json::Value = resp.json().await?; tracing::debug!( - "Custom REST response: {}", + "MSI GenAI response: {}", serde_json::to_string_pretty(&json).unwrap_or_else(|_| "invalid JSON".to_string()) ); @@ -323,7 +420,7 @@ impl OpenAiProvider { .to_string(); // Parse tool_calls if present (check multiple possible field names) - let tool_calls = json + let mut tool_calls = json .get("tool_calls") .or_else(|| json.get("toolCalls")) .or_else(|| json.get("function_calls")) @@ -331,57 +428,84 @@ impl OpenAiProvider { if let Some(arr) = tc.as_array() { let calls: Vec = arr .iter() - .filter_map(|call| { + .enumerate() + .filter_map(|(index, call)| { // Try OpenAI format first - if let (Some(id), Some(name), Some(args)) = ( + if let (Some(id), Some(name)) = ( call.get("id").and_then(|v| v.as_str()), call.get("function") .and_then(|f| f.get("name")) .and_then(|n| n.as_str()) .or_else(|| call.get("name").and_then(|n| n.as_str())), - call.get("function") - .and_then(|f| f.get("arguments")) - .and_then(|a| a.as_str()) - .or_else(|| call.get("arguments").and_then(|a| a.as_str())), ) { - tracing::info!("Custom REST: Parsed tool call: {} ({})", name, id); - return Some(crate::ai::ToolCall { - id: id.to_string(), - name: name.to_string(), - arguments: args.to_string(), - }); + // Accept arguments as either string or object (MSI GenAI returns both) + let arguments = call + .get("function") + .and_then(|f| f.get("arguments")) + .or_else(|| call.get("arguments")) + .and_then(|args| { + if let Some(s) = args.as_str() { + Some(s.to_string()) + } else { + // Serialize object to JSON string + serde_json::to_string(args).ok() + } + }); + + if let Some(args) = arguments { + tracing::info!( + "MSI GenAI: Parsed tool call: {} ({})", + name, + id + ); + return Some(crate::ai::ToolCall { + id: id.to_string(), + name: name.to_string(), + arguments: args, + }); + } } // Try simpler format - if let (Some(name), Some(args)) = ( - call.get("name").and_then(|n| n.as_str()), - call.get("arguments").and_then(|a| a.as_str()), - ) { - let id = call - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("tool_call_0") - .to_string(); - tracing::info!( - "Custom REST: Parsed tool call (simple format): {} ({})", - name, - id - ); - return Some(crate::ai::ToolCall { - id, - name: name.to_string(), - arguments: args.to_string(), + if let Some(name) = call.get("name").and_then(|n| n.as_str()) { + // Accept arguments as either string or object + let arguments = call.get("arguments").and_then(|args| { + if let Some(s) = args.as_str() { + Some(s.to_string()) + } else { + // Serialize object to JSON string + serde_json::to_string(args).ok() + } }); + + if let Some(args) = arguments { + // Generate unique ID if missing (avoids duplicates) + let id = call + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tool_call_{index}")); + tracing::info!( + "MSI GenAI: Parsed tool call (simple format): {} ({})", + name, + id + ); + return Some(crate::ai::ToolCall { + id, + name: name.to_string(), + arguments: args, + }); + } } - tracing::warn!("Custom REST: Failed to parse tool call: {:?}", call); + tracing::warn!("MSI GenAI: Failed to parse tool call: {:?}", call); None }) .collect(); if calls.is_empty() { None } else { - tracing::info!("Custom REST: Found {} tool calls", calls.len()); + tracing::info!("MSI GenAI: Found {} tool calls", calls.len()); Some(calls) } } else { @@ -389,6 +513,20 @@ impl OpenAiProvider { } }); + // WORKAROUND: MSIGenAI gateway bug - tool calls returned as JSON text in 'msg' field + // Expected: {"tool_calls": [...]} + // Actual: {"msg": '{"tool_calls":[...]}'} or {"msg": '[...]'} + if tool_calls.is_none() { + // Try parsing tool calls from msg content (MSIGenAI workaround) + if let Some(parsed_calls) = Self::parse_tool_calls_from_text(&content) { + tracing::warn!( + "MSI GenAI: MSIGenAI workaround - parsed {} tool calls from msg text (gateway should return structured tool_calls field)", + parsed_calls.len() + ); + tool_calls = Some(parsed_calls); + } + } + // Note: sessionId from response should be stored back to config.session_id // This would require making config mutable or returning it as part of ChatResponse // For now, the caller can extract it from the response if needed @@ -402,4 +540,95 @@ impl OpenAiProvider { tool_calls, }) } + + /// Parse tool calls from text content (MSIGenAI gateway workaround) + /// + /// MSIGenAI returns tool calls as JSON text in the 'msg' field instead of structured data: + /// - ChatGPT models: `{"tool_calls":[...]}` + /// - Claude models: `[...]` + fn parse_tool_calls_from_text(content: &str) -> Option> { + // Try parsing as direct JSON object + if let Ok(parsed) = serde_json::from_str::(content) { + if let Some(calls) = parsed.get("tool_calls").and_then(|v| v.as_array()) { + return Self::extract_tool_calls_from_array(calls); + } + } + + // Try finding JSON in text (handle Claude XML wrapper: [...]) + if let Some(start) = content.find("") { + if let Some(end) = content.find("") { + let json_str = &content[start + 12..end].trim(); + if let Ok(parsed) = serde_json::from_str::(json_str) { + if let Some(calls) = parsed.as_array() { + return Self::extract_tool_calls_from_array(calls); + } + } + } + } + + // Try finding raw JSON array in text + if let Some(start) = content.find("[{") { + if let Some(end) = content.rfind("}]") { + let json_str = &content[start..=end + 1]; + if let Ok(parsed) = serde_json::from_str::(json_str) { + if let Some(calls) = parsed.as_array() { + return Self::extract_tool_calls_from_array(calls); + } + } + } + } + + None + } + + /// Extract ToolCall structs from JSON array + fn extract_tool_calls_from_array( + calls: &[serde_json::Value], + ) -> Option> { + let parsed: Vec = calls + .iter() + .enumerate() + .filter_map(|(index, call)| { + // Generate fallback ID if missing (consistent with earlier parsing logic in this file) + let id = call + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tool_call_{index}")); + + // Try nested function.name format (OpenAI style) + let name = call + .get("function") + .and_then(|f| f.get("name")) + .and_then(|n| n.as_str()) + .or_else(|| call.get("name").and_then(|n| n.as_str()))? + .to_string(); + + // Arguments can be string or object + let arguments = call + .get("function") + .and_then(|f| f.get("arguments")) + .or_else(|| call.get("arguments")) + .and_then(|args| { + if let Some(s) = args.as_str() { + Some(s.to_string()) + } else { + serde_json::to_string(args).ok() + } + })?; + + Some(crate::ai::ToolCall { + id, + name, + arguments, + }) + }) + .collect(); + + if parsed.is_empty() { + None + } else { + Some(parsed) + } + } } diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs index eef7572a..69f55e05 100644 --- a/src-tauri/src/commands/ai.rs +++ b/src-tauri/src/commands/ai.rs @@ -552,6 +552,79 @@ pub async fn test_provider_connection( }) } +#[tauri::command] +pub async fn detect_tool_calling_support(provider_config: ProviderConfig) -> Result { + use crate::ai::{Tool, ToolParameters}; + use std::collections::HashMap; + use tracing::info; + + // Create a simple test tool + let test_tool = Tool { + name: "test_tool".to_string(), + description: "A test tool that returns 'success'. Call this tool with no arguments." + .to_string(), + parameters: ToolParameters { + param_type: "object".to_string(), + properties: HashMap::new(), + required: vec![], + }, + }; + + // Override config with detection-optimized settings + let mut detection_config = provider_config.clone(); + detection_config.max_tokens = Some(100); // Small budget for capability check + detection_config.temperature = Some(0.0); // Deterministic for reliability + + let provider = create_provider(&detection_config); + let messages = vec![Message { + role: "user".into(), + content: "Please call the test_tool function.".into(), + tool_call_id: None, + tool_calls: None, + }]; + + match provider + .chat(messages, &detection_config, Some(vec![test_tool])) + .await + { + Ok(response) => { + // Check if response contains tool_calls + if let Some(tool_calls) = response.tool_calls { + if tool_calls.iter().any(|tc| tc.name == "test_tool") { + info!( + "Tool calling support detected for provider {}", + provider_config.name + ); + return Ok(true); + } + } + // Provider responded but didn't use tool calls + info!( + "Provider {} responded but did not call tool", + provider_config.name + ); + Ok(false) + } + Err(e) => { + // Check if error indicates tool calling is not supported + let error_msg = e.to_string().to_lowercase(); + if error_msg.contains("tool") + || error_msg.contains("function") + || error_msg.contains("503") + { + info!( + "Tool calling not supported for provider {}: {}", + provider_config.name, e + ); + Ok(false) + } else { + // Connection or other error + Err(format!("Failed to test tool calling support: {e}")) + } + } + } +} + #[tauri::command] pub async fn list_providers() -> Result, String> { Ok(vec![ diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 794693fa..d43cc29e 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -104,6 +104,7 @@ pub fn run() { commands::ai::analyze_logs, commands::ai::chat_message, commands::ai::test_provider_connection, + commands::ai::detect_tool_calling_support, commands::ai::list_providers, commands::system::save_ai_provider, commands::system::load_ai_providers,