feat: add Ollama function calling and tool calling auto-detection
Enable Ollama models to execute shell commands. Add detection command for provider capability testing. - Replace ai/ollama.rs with function calling support (180s timeout, retry logic) - Update ai/openai.rs with tool calling improvements - Add detect_tool_calling_support() command to commands/ai.rs Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
e5593cbfe2
commit
117ab390a2
@ -1,10 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ai::provider::Provider;
|
||||
use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage};
|
||||
use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage, ToolCall};
|
||||
use crate::state::ProviderConfig;
|
||||
|
||||
// Track if we've already attempted auto-start this session
|
||||
static AUTO_START_ATTEMPTED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
pub struct OllamaProvider;
|
||||
|
||||
#[async_trait]
|
||||
@ -18,11 +22,11 @@ impl Provider for OllamaProvider {
|
||||
name: "Ollama (Local)".to_string(),
|
||||
supports_streaming: true,
|
||||
models: vec![
|
||||
"llama3.1".to_string(),
|
||||
"llama3".to_string(),
|
||||
"mistral".to_string(),
|
||||
"codellama".to_string(),
|
||||
"phi3".to_string(),
|
||||
"llama3.2:3b".to_string(),
|
||||
"phi3.5:3.8b".to_string(),
|
||||
"llama3.1:8b".to_string(),
|
||||
"qwen2.5:14b".to_string(),
|
||||
"gemma2:9b".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
@ -31,77 +35,276 @@ impl Provider for OllamaProvider {
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
config: &ProviderConfig,
|
||||
_tools: Option<Vec<crate::ai::Tool>>,
|
||||
tools: Option<Vec<crate::ai::Tool>>,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
// Longer timeout for tool calling - models need time to generate structured output
|
||||
let timeout_secs = if tools.is_some() { 180 } else { 60 };
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(60))
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(10))
|
||||
.build()?;
|
||||
let base_url = if config.api_url.is_empty() {
|
||||
"http://localhost:11434".to_string()
|
||||
} else {
|
||||
config.api_url.trim_end_matches('/').to_string()
|
||||
};
|
||||
|
||||
// Auto-start Ollama if using localhost and we haven't tried yet this session
|
||||
// Only attempt once to avoid recurring latency on every chat() call
|
||||
if base_url == "http://localhost:11434"
|
||||
&& !AUTO_START_ATTEMPTED.swap(true, Ordering::Relaxed)
|
||||
{
|
||||
// Check if already running before attempting start
|
||||
let pre_status = crate::ollama::installer::check_ollama().await;
|
||||
let already_running = pre_status.map(|s| s.running).unwrap_or(false);
|
||||
|
||||
if !already_running {
|
||||
match crate::ollama::installer::start_ollama_service().await {
|
||||
Ok(true) => {
|
||||
tracing::info!("Ollama service auto-started successfully");
|
||||
// Give it a moment to fully initialize
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
}
|
||||
Ok(false) => {
|
||||
tracing::debug!("Ollama not started (not installed or already running)");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to auto-start Ollama: {}", e);
|
||||
// Continue anyway - maybe it's already running or will start soon
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("Ollama already running, skipping auto-start");
|
||||
}
|
||||
}
|
||||
|
||||
// Quick health check before attempting chat (short timeout for fast failure)
|
||||
let health_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
.build()?;
|
||||
let health_check_result = health_client
|
||||
.get(format!("{base_url}/api/tags"))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match health_check_result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
tracing::debug!("Ollama health check passed");
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
tracing::warn!("Ollama health check returned status {status}");
|
||||
anyhow::bail!(
|
||||
"Ollama is not ready (status {status}). Please ensure Ollama is running."
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Cannot connect to Ollama at {base_url}: {e}");
|
||||
anyhow::bail!("Cannot connect to Ollama at {base_url}. Please ensure Ollama is running and accessible.");
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{base_url}/api/chat");
|
||||
|
||||
// Ollama expects {model, messages: [{role, content}], stream: false}
|
||||
// Ollama expects {model, messages: [{role, content, tool_calls?, tool_call_id?}], stream: false}
|
||||
let api_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
serde_json::json!({
|
||||
let mut msg = serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
})
|
||||
});
|
||||
|
||||
// Include tool_calls if present (for assistant messages with tool requests)
|
||||
if let Some(ref tool_calls) = m.tool_calls {
|
||||
msg["tool_calls"] = serde_json::json!(tool_calls);
|
||||
}
|
||||
|
||||
// Include tool_call_id if present (for tool result messages)
|
||||
if let Some(ref tool_call_id) = m.tool_call_id {
|
||||
msg["tool_call_id"] = serde_json::json!(tool_call_id);
|
||||
}
|
||||
|
||||
msg
|
||||
})
|
||||
.collect();
|
||||
|
||||
let body = serde_json::json!({
|
||||
let mut body = serde_json::json!({
|
||||
"model": config.model,
|
||||
"messages": api_messages,
|
||||
"stream": false,
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Ollama API error {status}: {text}");
|
||||
// Add tools if provided (Ollama function calling format)
|
||||
if let Some(tools_list) = tools {
|
||||
let formatted_tools: Vec<serde_json::Value> = tools_list
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
body["tools"] = serde_json::Value::from(formatted_tools);
|
||||
}
|
||||
|
||||
let json: serde_json::Value = resp.json().await?;
|
||||
// Retry logic for transient connection issues
|
||||
let max_retries = 2;
|
||||
let mut last_error = None;
|
||||
|
||||
// Parse response.message.content
|
||||
let content = json["message"]["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("No content in Ollama response"))?
|
||||
.to_string();
|
||||
for attempt in 0..=max_retries {
|
||||
if attempt > 0 {
|
||||
tracing::warn!(
|
||||
"Ollama request failed, retrying (attempt {}/{})...",
|
||||
attempt + 1,
|
||||
max_retries + 1
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
}
|
||||
|
||||
// Ollama provides eval_count / prompt_eval_count
|
||||
let usage = {
|
||||
let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32;
|
||||
if prompt_tokens > 0 || completion_tokens > 0 {
|
||||
Some(TokenUsage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens,
|
||||
})
|
||||
let resp_result = client
|
||||
.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let resp = match resp_result {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
last_error = Some(format!("Connection error: {e}"));
|
||||
if attempt < max_retries {
|
||||
continue; // Retry
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"Failed to connect to Ollama after {} attempts. Last error: {e}",
|
||||
max_retries + 1
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
last_error = Some(format!("API error {status}: {text}"));
|
||||
if attempt < max_retries && status.is_server_error() {
|
||||
continue; // Retry on 5xx errors
|
||||
} else {
|
||||
anyhow::bail!("Ollama API error {status}: {text}");
|
||||
}
|
||||
}
|
||||
|
||||
// Success - parse response and return
|
||||
let json: serde_json::Value = match resp.json().await {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
last_error = Some(format!("JSON parse error: {e}"));
|
||||
if attempt < max_retries {
|
||||
continue; // Retry
|
||||
} else {
|
||||
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Parse response.message.content
|
||||
let content = json["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Parse tool calls from Ollama response
|
||||
// Ollama returns tool_calls in message.tool_calls array
|
||||
let tool_calls = if let Some(calls_array) = json["message"]["tool_calls"].as_array() {
|
||||
let mut parsed_calls = Vec::new();
|
||||
for (idx, call) in calls_array.iter().enumerate() {
|
||||
// Generate fallback ID if not provided
|
||||
let id = call["id"]
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| format!("tool_call_{idx}"));
|
||||
|
||||
let function = &call["function"];
|
||||
|
||||
// Skip malformed tool calls (missing name) instead of failing entire response
|
||||
let name = match function["name"].as_str() {
|
||||
Some(n) => n.to_string(),
|
||||
None => {
|
||||
tracing::warn!("Skipping tool call with missing name at index {idx}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Arguments can be either an object or a string
|
||||
let arguments = if let Some(args_obj) = function["arguments"].as_object() {
|
||||
match serde_json::to_string(args_obj) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to serialize tool call arguments at index {}: {}",
|
||||
idx,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else if let Some(args_str) = function["arguments"].as_str() {
|
||||
args_str.to_string()
|
||||
} else {
|
||||
"{}".to_string()
|
||||
};
|
||||
|
||||
parsed_calls.push(ToolCall {
|
||||
id,
|
||||
name,
|
||||
arguments,
|
||||
});
|
||||
}
|
||||
if !parsed_calls.is_empty() {
|
||||
Some(parsed_calls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
Ok(ChatResponse {
|
||||
content,
|
||||
model: config.model.clone(),
|
||||
usage,
|
||||
user_message: None,
|
||||
tool_calls: None,
|
||||
})
|
||||
// Ollama provides eval_count / prompt_eval_count
|
||||
let usage = {
|
||||
let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32;
|
||||
if prompt_tokens > 0 || completion_tokens > 0 {
|
||||
Some(TokenUsage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
return Ok(ChatResponse {
|
||||
content,
|
||||
model: config.model.clone(),
|
||||
usage,
|
||||
user_message: None,
|
||||
tool_calls,
|
||||
});
|
||||
}
|
||||
|
||||
// If we get here, all retries failed
|
||||
anyhow::bail!(
|
||||
"Failed to get response from Ollama after {} attempts. Last error: {:?}",
|
||||
max_retries + 1,
|
||||
last_error
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,8 +7,8 @@ use crate::state::ProviderConfig;
|
||||
|
||||
pub struct OpenAiProvider;
|
||||
|
||||
fn is_custom_rest_format(api_format: Option<&str>) -> bool {
|
||||
matches!(api_format, Some("custom_rest"))
|
||||
fn is_msi_genai_format(api_format: Option<&str>) -> bool {
|
||||
matches!(api_format, Some("msi-genai") | Some("custom_rest")) // custom_rest for backward compatibility
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@ -38,8 +38,8 @@ impl Provider for OpenAiProvider {
|
||||
// Check if using custom REST format
|
||||
let api_format = config.api_format.as_deref().unwrap_or("openai");
|
||||
|
||||
if is_custom_rest_format(Some(api_format)) {
|
||||
self.chat_custom_rest(messages, config, tools).await
|
||||
if is_msi_genai_format(Some(api_format)) {
|
||||
self.chat_msi_genai(messages, config, tools).await
|
||||
} else {
|
||||
self.chat_openai(messages, config, tools).await
|
||||
}
|
||||
@ -48,17 +48,109 @@ impl Provider for OpenAiProvider {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::is_custom_rest_format;
|
||||
use super::{is_msi_genai_format, OpenAiProvider};
|
||||
|
||||
#[test]
|
||||
fn custom_rest_format_is_recognized() {
|
||||
assert!(is_custom_rest_format(Some("custom_rest")));
|
||||
fn msi_genai_format_is_recognized() {
|
||||
assert!(is_msi_genai_format(Some("msi-genai")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openai_format_is_not_custom_rest() {
|
||||
assert!(!is_custom_rest_format(Some("openai")));
|
||||
assert!(!is_custom_rest_format(None));
|
||||
fn custom_rest_format_backward_compatible() {
|
||||
// Keep backward compatibility with old format name
|
||||
assert!(is_msi_genai_format(Some("custom_rest")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openai_format_is_not_msi_genai() {
|
||||
assert!(!is_msi_genai_format(Some("openai")));
|
||||
assert!(!is_msi_genai_format(None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_msigenai_chatgpt_tool_calls_from_json_text() {
|
||||
// MSIGenAI ChatGPT format: returns tool calls as JSON object in msg
|
||||
let content = r#"{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get namespaces"}}}]}"#;
|
||||
|
||||
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||
assert!(result.is_some());
|
||||
|
||||
let calls = result.unwrap();
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].id, "call_1");
|
||||
assert_eq!(calls[0].name, "execute_shell_command");
|
||||
assert!(calls[0].arguments.contains("kubectl get namespaces"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_msigenai_claude_tool_calls_from_xml_wrapper() {
|
||||
// MSIGenAI Claude format: XML wrapper around JSON array
|
||||
let content = r#"<tool_calls>
|
||||
[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get pods"}}}]
|
||||
</tool_calls>"#;
|
||||
|
||||
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||
assert!(result.is_some());
|
||||
|
||||
let calls = result.unwrap();
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].id, "call_1");
|
||||
assert_eq!(calls[0].name, "execute_shell_command");
|
||||
assert!(calls[0].arguments.contains("kubectl get pods"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multiple_tool_calls_from_text() {
|
||||
let content = r#"{"tool_calls":[
|
||||
{"id":"call_1","function":{"name":"kubectl_get","arguments":{"resource":"pods"}}},
|
||||
{"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc/nginx"}}}
|
||||
]}"#;
|
||||
|
||||
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||
assert!(result.is_some());
|
||||
|
||||
let calls = result.unwrap();
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].name, "kubectl_get");
|
||||
assert_eq!(calls[1].name, "kubectl_describe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_returns_none_for_normal_text() {
|
||||
let content = "Hello, I found 5 pods running in the cluster.";
|
||||
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_handles_arguments_as_string() {
|
||||
// Some providers return arguments as string, not object
|
||||
let content = r#"{"tool_calls":[{"id":"call_1","function":{"name":"test","arguments":"{\"key\":\"value\"}"}}]}"#;
|
||||
|
||||
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||
assert!(result.is_some());
|
||||
|
||||
let calls = result.unwrap();
|
||||
assert_eq!(calls[0].arguments, r#"{"key":"value"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_generates_fallback_id_when_missing() {
|
||||
// Some providers may omit id field - generate fallback to prevent silent drop
|
||||
let content = r#"{"tool_calls":[
|
||||
{"function":{"name":"kubectl_get","arguments":{"resource":"pods"}}},
|
||||
{"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc"}}}
|
||||
]}"#;
|
||||
|
||||
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||
assert!(result.is_some());
|
||||
|
||||
let calls = result.unwrap();
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].id, "tool_call_0"); // Fallback generated
|
||||
assert_eq!(calls[0].name, "kubectl_get");
|
||||
assert_eq!(calls[1].id, "call_2"); // Original preserved
|
||||
assert_eq!(calls[1].name, "kubectl_describe");
|
||||
}
|
||||
}
|
||||
|
||||
@ -202,8 +294,13 @@ impl OpenAiProvider {
|
||||
})
|
||||
}
|
||||
|
||||
/// Custom REST format (non-OpenAI payload contract)
|
||||
async fn chat_custom_rest(
|
||||
/// MSI GenAI format (non-OpenAI payload contract)
|
||||
///
|
||||
/// MSI GenAI uses a custom API format with 'prompt' field instead of 'messages',
|
||||
/// and has a known bug where tool calls are returned as JSON text in the 'msg'
|
||||
/// field instead of structured 'tool_calls' array. This implementation includes
|
||||
/// workaround parsing to extract tool calls from text.
|
||||
async fn chat_msi_genai(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
config: &ProviderConfig,
|
||||
@ -284,7 +381,7 @@ impl OpenAiProvider {
|
||||
body["tools"] = serde_json::Value::from(formatted_tools);
|
||||
body["tool_choice"] = serde_json::Value::from("auto");
|
||||
|
||||
tracing::info!("Custom REST: Sending {} tools in request", tool_count);
|
||||
tracing::info!("MSI GenAI: Sending {} tools in request", tool_count);
|
||||
}
|
||||
|
||||
// Use custom auth header and prefix (no default prefix for custom REST)
|
||||
@ -306,13 +403,13 @@ impl OpenAiProvider {
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Custom REST API error {status}: {text}");
|
||||
anyhow::bail!("MSI GenAI API error {status}: {text}");
|
||||
}
|
||||
|
||||
let json: serde_json::Value = resp.json().await?;
|
||||
|
||||
tracing::debug!(
|
||||
"Custom REST response: {}",
|
||||
"MSI GenAI response: {}",
|
||||
serde_json::to_string_pretty(&json).unwrap_or_else(|_| "invalid JSON".to_string())
|
||||
);
|
||||
|
||||
@ -323,7 +420,7 @@ impl OpenAiProvider {
|
||||
.to_string();
|
||||
|
||||
// Parse tool_calls if present (check multiple possible field names)
|
||||
let tool_calls = json
|
||||
let mut tool_calls = json
|
||||
.get("tool_calls")
|
||||
.or_else(|| json.get("toolCalls"))
|
||||
.or_else(|| json.get("function_calls"))
|
||||
@ -331,57 +428,84 @@ impl OpenAiProvider {
|
||||
if let Some(arr) = tc.as_array() {
|
||||
let calls: Vec<crate::ai::ToolCall> = arr
|
||||
.iter()
|
||||
.filter_map(|call| {
|
||||
.enumerate()
|
||||
.filter_map(|(index, call)| {
|
||||
// Try OpenAI format first
|
||||
if let (Some(id), Some(name), Some(args)) = (
|
||||
if let (Some(id), Some(name)) = (
|
||||
call.get("id").and_then(|v| v.as_str()),
|
||||
call.get("function")
|
||||
.and_then(|f| f.get("name"))
|
||||
.and_then(|n| n.as_str())
|
||||
.or_else(|| call.get("name").and_then(|n| n.as_str())),
|
||||
call.get("function")
|
||||
.and_then(|f| f.get("arguments"))
|
||||
.and_then(|a| a.as_str())
|
||||
.or_else(|| call.get("arguments").and_then(|a| a.as_str())),
|
||||
) {
|
||||
tracing::info!("Custom REST: Parsed tool call: {} ({})", name, id);
|
||||
return Some(crate::ai::ToolCall {
|
||||
id: id.to_string(),
|
||||
name: name.to_string(),
|
||||
arguments: args.to_string(),
|
||||
});
|
||||
// Accept arguments as either string or object (MSI GenAI returns both)
|
||||
let arguments = call
|
||||
.get("function")
|
||||
.and_then(|f| f.get("arguments"))
|
||||
.or_else(|| call.get("arguments"))
|
||||
.and_then(|args| {
|
||||
if let Some(s) = args.as_str() {
|
||||
Some(s.to_string())
|
||||
} else {
|
||||
// Serialize object to JSON string
|
||||
serde_json::to_string(args).ok()
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(args) = arguments {
|
||||
tracing::info!(
|
||||
"MSI GenAI: Parsed tool call: {} ({})",
|
||||
name,
|
||||
id
|
||||
);
|
||||
return Some(crate::ai::ToolCall {
|
||||
id: id.to_string(),
|
||||
name: name.to_string(),
|
||||
arguments: args,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Try simpler format
|
||||
if let (Some(name), Some(args)) = (
|
||||
call.get("name").and_then(|n| n.as_str()),
|
||||
call.get("arguments").and_then(|a| a.as_str()),
|
||||
) {
|
||||
let id = call
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("tool_call_0")
|
||||
.to_string();
|
||||
tracing::info!(
|
||||
"Custom REST: Parsed tool call (simple format): {} ({})",
|
||||
name,
|
||||
id
|
||||
);
|
||||
return Some(crate::ai::ToolCall {
|
||||
id,
|
||||
name: name.to_string(),
|
||||
arguments: args.to_string(),
|
||||
if let Some(name) = call.get("name").and_then(|n| n.as_str()) {
|
||||
// Accept arguments as either string or object
|
||||
let arguments = call.get("arguments").and_then(|args| {
|
||||
if let Some(s) = args.as_str() {
|
||||
Some(s.to_string())
|
||||
} else {
|
||||
// Serialize object to JSON string
|
||||
serde_json::to_string(args).ok()
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(args) = arguments {
|
||||
// Generate unique ID if missing (avoids duplicates)
|
||||
let id = call
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| format!("tool_call_{index}"));
|
||||
tracing::info!(
|
||||
"MSI GenAI: Parsed tool call (simple format): {} ({})",
|
||||
name,
|
||||
id
|
||||
);
|
||||
return Some(crate::ai::ToolCall {
|
||||
id,
|
||||
name: name.to_string(),
|
||||
arguments: args,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("Custom REST: Failed to parse tool call: {:?}", call);
|
||||
tracing::warn!("MSI GenAI: Failed to parse tool call: {:?}", call);
|
||||
None
|
||||
})
|
||||
.collect();
|
||||
if calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
tracing::info!("Custom REST: Found {} tool calls", calls.len());
|
||||
tracing::info!("MSI GenAI: Found {} tool calls", calls.len());
|
||||
Some(calls)
|
||||
}
|
||||
} else {
|
||||
@ -389,6 +513,20 @@ impl OpenAiProvider {
|
||||
}
|
||||
});
|
||||
|
||||
// WORKAROUND: MSIGenAI gateway bug - tool calls returned as JSON text in 'msg' field
|
||||
// Expected: {"tool_calls": [...]}
|
||||
// Actual: {"msg": '{"tool_calls":[...]}'} or {"msg": '<tool_calls>[...]</tool_calls>'}
|
||||
if tool_calls.is_none() {
|
||||
// Try parsing tool calls from msg content (MSIGenAI workaround)
|
||||
if let Some(parsed_calls) = Self::parse_tool_calls_from_text(&content) {
|
||||
tracing::warn!(
|
||||
"MSI GenAI: MSIGenAI workaround - parsed {} tool calls from msg text (gateway should return structured tool_calls field)",
|
||||
parsed_calls.len()
|
||||
);
|
||||
tool_calls = Some(parsed_calls);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: sessionId from response should be stored back to config.session_id
|
||||
// This would require making config mutable or returning it as part of ChatResponse
|
||||
// For now, the caller can extract it from the response if needed
|
||||
@ -402,4 +540,95 @@ impl OpenAiProvider {
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse tool calls from text content (MSIGenAI gateway workaround)
|
||||
///
|
||||
/// MSIGenAI returns tool calls as JSON text in the 'msg' field instead of structured data:
|
||||
/// - ChatGPT models: `{"tool_calls":[...]}`
|
||||
/// - Claude models: `<tool_calls>[...]</tool_calls>`
|
||||
fn parse_tool_calls_from_text(content: &str) -> Option<Vec<crate::ai::ToolCall>> {
|
||||
// Try parsing as direct JSON object
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
if let Some(calls) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
return Self::extract_tool_calls_from_array(calls);
|
||||
}
|
||||
}
|
||||
|
||||
// Try finding JSON in text (handle Claude XML wrapper: <tool_calls>[...]</tool_calls>)
|
||||
if let Some(start) = content.find("<tool_calls>") {
|
||||
if let Some(end) = content.find("</tool_calls>") {
|
||||
let json_str = &content[start + 12..end].trim();
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
|
||||
if let Some(calls) = parsed.as_array() {
|
||||
return Self::extract_tool_calls_from_array(calls);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try finding raw JSON array in text
|
||||
if let Some(start) = content.find("[{") {
|
||||
if let Some(end) = content.rfind("}]") {
|
||||
let json_str = &content[start..=end + 1];
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
|
||||
if let Some(calls) = parsed.as_array() {
|
||||
return Self::extract_tool_calls_from_array(calls);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract ToolCall structs from JSON array
|
||||
fn extract_tool_calls_from_array(
|
||||
calls: &[serde_json::Value],
|
||||
) -> Option<Vec<crate::ai::ToolCall>> {
|
||||
let parsed: Vec<crate::ai::ToolCall> = calls
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, call)| {
|
||||
// Generate fallback ID if missing (consistent with earlier parsing logic in this file)
|
||||
let id = call
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| format!("tool_call_{index}"));
|
||||
|
||||
// Try nested function.name format (OpenAI style)
|
||||
let name = call
|
||||
.get("function")
|
||||
.and_then(|f| f.get("name"))
|
||||
.and_then(|n| n.as_str())
|
||||
.or_else(|| call.get("name").and_then(|n| n.as_str()))?
|
||||
.to_string();
|
||||
|
||||
// Arguments can be string or object
|
||||
let arguments = call
|
||||
.get("function")
|
||||
.and_then(|f| f.get("arguments"))
|
||||
.or_else(|| call.get("arguments"))
|
||||
.and_then(|args| {
|
||||
if let Some(s) = args.as_str() {
|
||||
Some(s.to_string())
|
||||
} else {
|
||||
serde_json::to_string(args).ok()
|
||||
}
|
||||
})?;
|
||||
|
||||
Some(crate::ai::ToolCall {
|
||||
id,
|
||||
name,
|
||||
arguments,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
if parsed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(parsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -552,6 +552,79 @@ pub async fn test_provider_connection(
|
||||
})
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn detect_tool_calling_support(provider_config: ProviderConfig) -> Result<bool, String> {
|
||||
use crate::ai::{Tool, ToolParameters};
|
||||
use std::collections::HashMap;
|
||||
use tracing::info;
|
||||
|
||||
// Create a simple test tool
|
||||
let test_tool = Tool {
|
||||
name: "test_tool".to_string(),
|
||||
description: "A test tool that returns 'success'. Call this tool with no arguments."
|
||||
.to_string(),
|
||||
parameters: ToolParameters {
|
||||
param_type: "object".to_string(),
|
||||
properties: HashMap::new(),
|
||||
required: vec![],
|
||||
},
|
||||
};
|
||||
|
||||
// Override config with detection-optimized settings
|
||||
let mut detection_config = provider_config.clone();
|
||||
detection_config.max_tokens = Some(100); // Small budget for capability check
|
||||
detection_config.temperature = Some(0.0); // Deterministic for reliability
|
||||
|
||||
let provider = create_provider(&detection_config);
|
||||
let messages = vec![Message {
|
||||
role: "user".into(),
|
||||
content: "Please call the test_tool function.".into(),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}];
|
||||
|
||||
match provider
|
||||
.chat(messages, &detection_config, Some(vec![test_tool]))
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
// Check if response contains tool_calls
|
||||
if let Some(tool_calls) = response.tool_calls {
|
||||
if tool_calls.iter().any(|tc| tc.name == "test_tool") {
|
||||
info!(
|
||||
"Tool calling support detected for provider {}",
|
||||
provider_config.name
|
||||
);
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
// Provider responded but didn't use tool calls
|
||||
info!(
|
||||
"Provider {} responded but did not call tool",
|
||||
provider_config.name
|
||||
);
|
||||
Ok(false)
|
||||
}
|
||||
Err(e) => {
|
||||
// Check if error indicates tool calling is not supported
|
||||
let error_msg = e.to_string().to_lowercase();
|
||||
if error_msg.contains("tool")
|
||||
|| error_msg.contains("function")
|
||||
|| error_msg.contains("503")
|
||||
{
|
||||
info!(
|
||||
"Tool calling not supported for provider {}: {}",
|
||||
provider_config.name, e
|
||||
);
|
||||
Ok(false)
|
||||
} else {
|
||||
// Connection or other error
|
||||
Err(format!("Failed to test tool calling support: {e}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_providers() -> Result<Vec<ProviderInfo>, String> {
|
||||
Ok(vec![
|
||||
|
||||
@ -104,6 +104,7 @@ pub fn run() {
|
||||
commands::ai::analyze_logs,
|
||||
commands::ai::chat_message,
|
||||
commands::ai::test_provider_connection,
|
||||
commands::ai::detect_tool_calling_support,
|
||||
commands::ai::list_providers,
|
||||
commands::system::save_ai_provider,
|
||||
commands::system::load_ai_providers,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user