Backport: Agentic Shell Command Execution (v1.0.0 → v1.0.8) #66
@ -1,10 +1,14 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::ai::provider::Provider;
|
use crate::ai::provider::Provider;
|
||||||
use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage};
|
use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage, ToolCall};
|
||||||
use crate::state::ProviderConfig;
|
use crate::state::ProviderConfig;
|
||||||
|
|
||||||
|
// Track if we've already attempted auto-start this session
|
||||||
|
static AUTO_START_ATTEMPTED: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
pub struct OllamaProvider;
|
pub struct OllamaProvider;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -18,11 +22,11 @@ impl Provider for OllamaProvider {
|
|||||||
name: "Ollama (Local)".to_string(),
|
name: "Ollama (Local)".to_string(),
|
||||||
supports_streaming: true,
|
supports_streaming: true,
|
||||||
models: vec![
|
models: vec![
|
||||||
"llama3.1".to_string(),
|
"llama3.2:3b".to_string(),
|
||||||
"llama3".to_string(),
|
"phi3.5:3.8b".to_string(),
|
||||||
"mistral".to_string(),
|
"llama3.1:8b".to_string(),
|
||||||
"codellama".to_string(),
|
"qwen2.5:14b".to_string(),
|
||||||
"phi3".to_string(),
|
"gemma2:9b".to_string(),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -31,77 +35,276 @@ impl Provider for OllamaProvider {
|
|||||||
&self,
|
&self,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
config: &ProviderConfig,
|
config: &ProviderConfig,
|
||||||
_tools: Option<Vec<crate::ai::Tool>>,
|
tools: Option<Vec<crate::ai::Tool>>,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<ChatResponse> {
|
||||||
|
// Longer timeout for tool calling - models need time to generate structured output
|
||||||
|
let timeout_secs = if tools.is_some() { 180 } else { 60 };
|
||||||
|
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(Duration::from_secs(60))
|
.timeout(Duration::from_secs(timeout_secs))
|
||||||
|
.connect_timeout(Duration::from_secs(10))
|
||||||
.build()?;
|
.build()?;
|
||||||
let base_url = if config.api_url.is_empty() {
|
let base_url = if config.api_url.is_empty() {
|
||||||
"http://localhost:11434".to_string()
|
"http://localhost:11434".to_string()
|
||||||
} else {
|
} else {
|
||||||
config.api_url.trim_end_matches('/').to_string()
|
config.api_url.trim_end_matches('/').to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Auto-start Ollama if using localhost and we haven't tried yet this session
|
||||||
|
// Only attempt once to avoid recurring latency on every chat() call
|
||||||
|
if base_url == "http://localhost:11434"
|
||||||
|
&& !AUTO_START_ATTEMPTED.swap(true, Ordering::Relaxed)
|
||||||
|
{
|
||||||
|
// Check if already running before attempting start
|
||||||
|
let pre_status = crate::ollama::installer::check_ollama().await;
|
||||||
|
let already_running = pre_status.map(|s| s.running).unwrap_or(false);
|
||||||
|
|
||||||
|
if !already_running {
|
||||||
|
match crate::ollama::installer::start_ollama_service().await {
|
||||||
|
Ok(true) => {
|
||||||
|
tracing::info!("Ollama service auto-started successfully");
|
||||||
|
// Give it a moment to fully initialize
|
||||||
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||||
|
}
|
||||||
|
Ok(false) => {
|
||||||
|
tracing::debug!("Ollama not started (not installed or already running)");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to auto-start Ollama: {}", e);
|
||||||
|
// Continue anyway - maybe it's already running or will start soon
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::debug!("Ollama already running, skipping auto-start");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quick health check before attempting chat (short timeout for fast failure)
|
||||||
|
let health_client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(2))
|
||||||
|
.build()?;
|
||||||
|
let health_check_result = health_client
|
||||||
|
.get(format!("{base_url}/api/tags"))
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match health_check_result {
|
||||||
|
Ok(resp) if resp.status().is_success() => {
|
||||||
|
tracing::debug!("Ollama health check passed");
|
||||||
|
}
|
||||||
|
Ok(resp) => {
|
||||||
|
let status = resp.status();
|
||||||
|
tracing::warn!("Ollama health check returned status {status}");
|
||||||
|
anyhow::bail!(
|
||||||
|
"Ollama is not ready (status {status}). Please ensure Ollama is running."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Cannot connect to Ollama at {base_url}: {e}");
|
||||||
|
anyhow::bail!("Cannot connect to Ollama at {base_url}. Please ensure Ollama is running and accessible.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let url = format!("{base_url}/api/chat");
|
let url = format!("{base_url}/api/chat");
|
||||||
|
|
||||||
// Ollama expects {model, messages: [{role, content}], stream: false}
|
// Ollama expects {model, messages: [{role, content, tool_calls?, tool_call_id?}], stream: false}
|
||||||
let api_messages: Vec<serde_json::Value> = messages
|
let api_messages: Vec<serde_json::Value> = messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| {
|
.map(|m| {
|
||||||
serde_json::json!({
|
let mut msg = serde_json::json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": m.content,
|
"content": m.content,
|
||||||
})
|
});
|
||||||
|
|
||||||
|
// Include tool_calls if present (for assistant messages with tool requests)
|
||||||
|
if let Some(ref tool_calls) = m.tool_calls {
|
||||||
|
msg["tool_calls"] = serde_json::json!(tool_calls);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include tool_call_id if present (for tool result messages)
|
||||||
|
if let Some(ref tool_call_id) = m.tool_call_id {
|
||||||
|
msg["tool_call_id"] = serde_json::json!(tool_call_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
msg
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let body = serde_json::json!({
|
let mut body = serde_json::json!({
|
||||||
"model": config.model,
|
"model": config.model,
|
||||||
"messages": api_messages,
|
"messages": api_messages,
|
||||||
"stream": false,
|
"stream": false,
|
||||||
});
|
});
|
||||||
|
|
||||||
let resp = client
|
// Add tools if provided (Ollama function calling format)
|
||||||
.post(&url)
|
if let Some(tools_list) = tools {
|
||||||
.header("Content-Type", "application/json")
|
let formatted_tools: Vec<serde_json::Value> = tools_list
|
||||||
.json(&body)
|
.iter()
|
||||||
.send()
|
.map(|tool| {
|
||||||
.await?;
|
serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
if !resp.status().is_success() {
|
"function": {
|
||||||
let status = resp.status();
|
"name": tool.name,
|
||||||
let text = resp.text().await?;
|
"description": tool.description,
|
||||||
anyhow::bail!("Ollama API error {status}: {text}");
|
"parameters": tool.parameters
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
body["tools"] = serde_json::Value::from(formatted_tools);
|
||||||
}
|
}
|
||||||
|
|
||||||
let json: serde_json::Value = resp.json().await?;
|
// Retry logic for transient connection issues
|
||||||
|
let max_retries = 2;
|
||||||
|
let mut last_error = None;
|
||||||
|
|
||||||
// Parse response.message.content
|
for attempt in 0..=max_retries {
|
||||||
let content = json["message"]["content"]
|
if attempt > 0 {
|
||||||
.as_str()
|
tracing::warn!(
|
||||||
.ok_or_else(|| anyhow::anyhow!("No content in Ollama response"))?
|
"Ollama request failed, retrying (attempt {}/{})...",
|
||||||
.to_string();
|
attempt + 1,
|
||||||
|
max_retries + 1
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||||
|
}
|
||||||
|
|
||||||
// Ollama provides eval_count / prompt_eval_count
|
let resp_result = client
|
||||||
let usage = {
|
.post(&url)
|
||||||
let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32;
|
.header("Content-Type", "application/json")
|
||||||
let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32;
|
.json(&body)
|
||||||
if prompt_tokens > 0 || completion_tokens > 0 {
|
.send()
|
||||||
Some(TokenUsage {
|
.await;
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
let resp = match resp_result {
|
||||||
total_tokens: prompt_tokens + completion_tokens,
|
Ok(r) => r,
|
||||||
})
|
Err(e) => {
|
||||||
|
last_error = Some(format!("Connection error: {e}"));
|
||||||
|
if attempt < max_retries {
|
||||||
|
continue; // Retry
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Failed to connect to Ollama after {} attempts. Last error: {e}",
|
||||||
|
max_retries + 1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let status = resp.status();
|
||||||
|
let text = resp.text().await?;
|
||||||
|
last_error = Some(format!("API error {status}: {text}"));
|
||||||
|
if attempt < max_retries && status.is_server_error() {
|
||||||
|
continue; // Retry on 5xx errors
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Ollama API error {status}: {text}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success - parse response and return
|
||||||
|
let json: serde_json::Value = match resp.json().await {
|
||||||
|
Ok(j) => j,
|
||||||
|
Err(e) => {
|
||||||
|
last_error = Some(format!("JSON parse error: {e}"));
|
||||||
|
if attempt < max_retries {
|
||||||
|
continue; // Retry
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse response.message.content
|
||||||
|
let content = json["message"]["content"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// Parse tool calls from Ollama response
|
||||||
|
// Ollama returns tool_calls in message.tool_calls array
|
||||||
|
let tool_calls = if let Some(calls_array) = json["message"]["tool_calls"].as_array() {
|
||||||
|
let mut parsed_calls = Vec::new();
|
||||||
|
for (idx, call) in calls_array.iter().enumerate() {
|
||||||
|
// Generate fallback ID if not provided
|
||||||
|
let id = call["id"]
|
||||||
|
.as_str()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.unwrap_or_else(|| format!("tool_call_{idx}"));
|
||||||
|
|
||||||
|
let function = &call["function"];
|
||||||
|
|
||||||
|
// Skip malformed tool calls (missing name) instead of failing entire response
|
||||||
|
let name = match function["name"].as_str() {
|
||||||
|
Some(n) => n.to_string(),
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Skipping tool call with missing name at index {idx}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Arguments can be either an object or a string
|
||||||
|
let arguments = if let Some(args_obj) = function["arguments"].as_object() {
|
||||||
|
match serde_json::to_string(args_obj) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to serialize tool call arguments at index {}: {}",
|
||||||
|
idx,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Some(args_str) = function["arguments"].as_str() {
|
||||||
|
args_str.to_string()
|
||||||
|
} else {
|
||||||
|
"{}".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
parsed_calls.push(ToolCall {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if !parsed_calls.is_empty() {
|
||||||
|
Some(parsed_calls)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
};
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ChatResponse {
|
// Ollama provides eval_count / prompt_eval_count
|
||||||
content,
|
let usage = {
|
||||||
model: config.model.clone(),
|
let prompt_tokens = json["prompt_eval_count"].as_u64().unwrap_or(0) as u32;
|
||||||
usage,
|
let completion_tokens = json["eval_count"].as_u64().unwrap_or(0) as u32;
|
||||||
user_message: None,
|
if prompt_tokens > 0 || completion_tokens > 0 {
|
||||||
tool_calls: None,
|
Some(TokenUsage {
|
||||||
})
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens: prompt_tokens + completion_tokens,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok(ChatResponse {
|
||||||
|
content,
|
||||||
|
model: config.model.clone(),
|
||||||
|
usage,
|
||||||
|
user_message: None,
|
||||||
|
tool_calls,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get here, all retries failed
|
||||||
|
anyhow::bail!(
|
||||||
|
"Failed to get response from Ollama after {} attempts. Last error: {:?}",
|
||||||
|
max_retries + 1,
|
||||||
|
last_error
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,8 +7,8 @@ use crate::state::ProviderConfig;
|
|||||||
|
|
||||||
pub struct OpenAiProvider;
|
pub struct OpenAiProvider;
|
||||||
|
|
||||||
fn is_custom_rest_format(api_format: Option<&str>) -> bool {
|
fn is_msi_genai_format(api_format: Option<&str>) -> bool {
|
||||||
matches!(api_format, Some("custom_rest"))
|
matches!(api_format, Some("msi-genai") | Some("custom_rest")) // custom_rest for backward compatibility
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -38,8 +38,8 @@ impl Provider for OpenAiProvider {
|
|||||||
// Check if using custom REST format
|
// Check if using custom REST format
|
||||||
let api_format = config.api_format.as_deref().unwrap_or("openai");
|
let api_format = config.api_format.as_deref().unwrap_or("openai");
|
||||||
|
|
||||||
if is_custom_rest_format(Some(api_format)) {
|
if is_msi_genai_format(Some(api_format)) {
|
||||||
self.chat_custom_rest(messages, config, tools).await
|
self.chat_msi_genai(messages, config, tools).await
|
||||||
} else {
|
} else {
|
||||||
self.chat_openai(messages, config, tools).await
|
self.chat_openai(messages, config, tools).await
|
||||||
}
|
}
|
||||||
@ -48,17 +48,109 @@ impl Provider for OpenAiProvider {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::is_custom_rest_format;
|
use super::{is_msi_genai_format, OpenAiProvider};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn custom_rest_format_is_recognized() {
|
fn msi_genai_format_is_recognized() {
|
||||||
assert!(is_custom_rest_format(Some("custom_rest")));
|
assert!(is_msi_genai_format(Some("msi-genai")));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn openai_format_is_not_custom_rest() {
|
fn custom_rest_format_backward_compatible() {
|
||||||
assert!(!is_custom_rest_format(Some("openai")));
|
// Keep backward compatibility with old format name
|
||||||
assert!(!is_custom_rest_format(None));
|
assert!(is_msi_genai_format(Some("custom_rest")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn openai_format_is_not_msi_genai() {
|
||||||
|
assert!(!is_msi_genai_format(Some("openai")));
|
||||||
|
assert!(!is_msi_genai_format(None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_msigenai_chatgpt_tool_calls_from_json_text() {
|
||||||
|
// MSIGenAI ChatGPT format: returns tool calls as JSON object in msg
|
||||||
|
let content = r#"{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get namespaces"}}}]}"#;
|
||||||
|
|
||||||
|
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
let calls = result.unwrap();
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].id, "call_1");
|
||||||
|
assert_eq!(calls[0].name, "execute_shell_command");
|
||||||
|
assert!(calls[0].arguments.contains("kubectl get namespaces"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_msigenai_claude_tool_calls_from_xml_wrapper() {
|
||||||
|
// MSIGenAI Claude format: XML wrapper around JSON array
|
||||||
|
let content = r#"<tool_calls>
|
||||||
|
[{"id":"call_1","type":"function","function":{"name":"execute_shell_command","arguments":{"command":"kubectl get pods"}}}]
|
||||||
|
</tool_calls>"#;
|
||||||
|
|
||||||
|
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
let calls = result.unwrap();
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].id, "call_1");
|
||||||
|
assert_eq!(calls[0].name, "execute_shell_command");
|
||||||
|
assert!(calls[0].arguments.contains("kubectl get pods"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_multiple_tool_calls_from_text() {
|
||||||
|
let content = r#"{"tool_calls":[
|
||||||
|
{"id":"call_1","function":{"name":"kubectl_get","arguments":{"resource":"pods"}}},
|
||||||
|
{"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc/nginx"}}}
|
||||||
|
]}"#;
|
||||||
|
|
||||||
|
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
let calls = result.unwrap();
|
||||||
|
assert_eq!(calls.len(), 2);
|
||||||
|
assert_eq!(calls[0].name, "kubectl_get");
|
||||||
|
assert_eq!(calls[1].name, "kubectl_describe");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_returns_none_for_normal_text() {
|
||||||
|
let content = "Hello, I found 5 pods running in the cluster.";
|
||||||
|
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||||
|
assert!(result.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_handles_arguments_as_string() {
|
||||||
|
// Some providers return arguments as string, not object
|
||||||
|
let content = r#"{"tool_calls":[{"id":"call_1","function":{"name":"test","arguments":"{\"key\":\"value\"}"}}]}"#;
|
||||||
|
|
||||||
|
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
let calls = result.unwrap();
|
||||||
|
assert_eq!(calls[0].arguments, r#"{"key":"value"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_generates_fallback_id_when_missing() {
|
||||||
|
// Some providers may omit id field - generate fallback to prevent silent drop
|
||||||
|
let content = r#"{"tool_calls":[
|
||||||
|
{"function":{"name":"kubectl_get","arguments":{"resource":"pods"}}},
|
||||||
|
{"id":"call_2","function":{"name":"kubectl_describe","arguments":{"resource":"svc"}}}
|
||||||
|
]}"#;
|
||||||
|
|
||||||
|
let result = OpenAiProvider::parse_tool_calls_from_text(content);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
let calls = result.unwrap();
|
||||||
|
assert_eq!(calls.len(), 2);
|
||||||
|
assert_eq!(calls[0].id, "tool_call_0"); // Fallback generated
|
||||||
|
assert_eq!(calls[0].name, "kubectl_get");
|
||||||
|
assert_eq!(calls[1].id, "call_2"); // Original preserved
|
||||||
|
assert_eq!(calls[1].name, "kubectl_describe");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,8 +294,13 @@ impl OpenAiProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Custom REST format (non-OpenAI payload contract)
|
/// MSI GenAI format (non-OpenAI payload contract)
|
||||||
async fn chat_custom_rest(
|
///
|
||||||
|
/// MSI GenAI uses a custom API format with 'prompt' field instead of 'messages',
|
||||||
|
/// and has a known bug where tool calls are returned as JSON text in the 'msg'
|
||||||
|
/// field instead of structured 'tool_calls' array. This implementation includes
|
||||||
|
/// workaround parsing to extract tool calls from text.
|
||||||
|
async fn chat_msi_genai(
|
||||||
&self,
|
&self,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
config: &ProviderConfig,
|
config: &ProviderConfig,
|
||||||
@ -284,7 +381,7 @@ impl OpenAiProvider {
|
|||||||
body["tools"] = serde_json::Value::from(formatted_tools);
|
body["tools"] = serde_json::Value::from(formatted_tools);
|
||||||
body["tool_choice"] = serde_json::Value::from("auto");
|
body["tool_choice"] = serde_json::Value::from("auto");
|
||||||
|
|
||||||
tracing::info!("Custom REST: Sending {} tools in request", tool_count);
|
tracing::info!("MSI GenAI: Sending {} tools in request", tool_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use custom auth header and prefix (no default prefix for custom REST)
|
// Use custom auth header and prefix (no default prefix for custom REST)
|
||||||
@ -306,13 +403,13 @@ impl OpenAiProvider {
|
|||||||
if !resp.status().is_success() {
|
if !resp.status().is_success() {
|
||||||
let status = resp.status();
|
let status = resp.status();
|
||||||
let text = resp.text().await?;
|
let text = resp.text().await?;
|
||||||
anyhow::bail!("Custom REST API error {status}: {text}");
|
anyhow::bail!("MSI GenAI API error {status}: {text}");
|
||||||
}
|
}
|
||||||
|
|
||||||
let json: serde_json::Value = resp.json().await?;
|
let json: serde_json::Value = resp.json().await?;
|
||||||
|
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Custom REST response: {}",
|
"MSI GenAI response: {}",
|
||||||
serde_json::to_string_pretty(&json).unwrap_or_else(|_| "invalid JSON".to_string())
|
serde_json::to_string_pretty(&json).unwrap_or_else(|_| "invalid JSON".to_string())
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -323,7 +420,7 @@ impl OpenAiProvider {
|
|||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
// Parse tool_calls if present (check multiple possible field names)
|
// Parse tool_calls if present (check multiple possible field names)
|
||||||
let tool_calls = json
|
let mut tool_calls = json
|
||||||
.get("tool_calls")
|
.get("tool_calls")
|
||||||
.or_else(|| json.get("toolCalls"))
|
.or_else(|| json.get("toolCalls"))
|
||||||
.or_else(|| json.get("function_calls"))
|
.or_else(|| json.get("function_calls"))
|
||||||
@ -331,57 +428,84 @@ impl OpenAiProvider {
|
|||||||
if let Some(arr) = tc.as_array() {
|
if let Some(arr) = tc.as_array() {
|
||||||
let calls: Vec<crate::ai::ToolCall> = arr
|
let calls: Vec<crate::ai::ToolCall> = arr
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|call| {
|
.enumerate()
|
||||||
|
.filter_map(|(index, call)| {
|
||||||
// Try OpenAI format first
|
// Try OpenAI format first
|
||||||
if let (Some(id), Some(name), Some(args)) = (
|
if let (Some(id), Some(name)) = (
|
||||||
call.get("id").and_then(|v| v.as_str()),
|
call.get("id").and_then(|v| v.as_str()),
|
||||||
call.get("function")
|
call.get("function")
|
||||||
.and_then(|f| f.get("name"))
|
.and_then(|f| f.get("name"))
|
||||||
.and_then(|n| n.as_str())
|
.and_then(|n| n.as_str())
|
||||||
.or_else(|| call.get("name").and_then(|n| n.as_str())),
|
.or_else(|| call.get("name").and_then(|n| n.as_str())),
|
||||||
call.get("function")
|
|
||||||
.and_then(|f| f.get("arguments"))
|
|
||||||
.and_then(|a| a.as_str())
|
|
||||||
.or_else(|| call.get("arguments").and_then(|a| a.as_str())),
|
|
||||||
) {
|
) {
|
||||||
tracing::info!("Custom REST: Parsed tool call: {} ({})", name, id);
|
// Accept arguments as either string or object (MSI GenAI returns both)
|
||||||
return Some(crate::ai::ToolCall {
|
let arguments = call
|
||||||
id: id.to_string(),
|
.get("function")
|
||||||
name: name.to_string(),
|
.and_then(|f| f.get("arguments"))
|
||||||
arguments: args.to_string(),
|
.or_else(|| call.get("arguments"))
|
||||||
});
|
.and_then(|args| {
|
||||||
|
if let Some(s) = args.as_str() {
|
||||||
|
Some(s.to_string())
|
||||||
|
} else {
|
||||||
|
// Serialize object to JSON string
|
||||||
|
serde_json::to_string(args).ok()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(args) = arguments {
|
||||||
|
tracing::info!(
|
||||||
|
"MSI GenAI: Parsed tool call: {} ({})",
|
||||||
|
name,
|
||||||
|
id
|
||||||
|
);
|
||||||
|
return Some(crate::ai::ToolCall {
|
||||||
|
id: id.to_string(),
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: args,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try simpler format
|
// Try simpler format
|
||||||
if let (Some(name), Some(args)) = (
|
if let Some(name) = call.get("name").and_then(|n| n.as_str()) {
|
||||||
call.get("name").and_then(|n| n.as_str()),
|
// Accept arguments as either string or object
|
||||||
call.get("arguments").and_then(|a| a.as_str()),
|
let arguments = call.get("arguments").and_then(|args| {
|
||||||
) {
|
if let Some(s) = args.as_str() {
|
||||||
let id = call
|
Some(s.to_string())
|
||||||
.get("id")
|
} else {
|
||||||
.and_then(|v| v.as_str())
|
// Serialize object to JSON string
|
||||||
.unwrap_or("tool_call_0")
|
serde_json::to_string(args).ok()
|
||||||
.to_string();
|
}
|
||||||
tracing::info!(
|
|
||||||
"Custom REST: Parsed tool call (simple format): {} ({})",
|
|
||||||
name,
|
|
||||||
id
|
|
||||||
);
|
|
||||||
return Some(crate::ai::ToolCall {
|
|
||||||
id,
|
|
||||||
name: name.to_string(),
|
|
||||||
arguments: args.to_string(),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(args) = arguments {
|
||||||
|
// Generate unique ID if missing (avoids duplicates)
|
||||||
|
let id = call
|
||||||
|
.get("id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.unwrap_or_else(|| format!("tool_call_{index}"));
|
||||||
|
tracing::info!(
|
||||||
|
"MSI GenAI: Parsed tool call (simple format): {} ({})",
|
||||||
|
name,
|
||||||
|
id
|
||||||
|
);
|
||||||
|
return Some(crate::ai::ToolCall {
|
||||||
|
id,
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: args,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::warn!("Custom REST: Failed to parse tool call: {:?}", call);
|
tracing::warn!("MSI GenAI: Failed to parse tool call: {:?}", call);
|
||||||
None
|
None
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
if calls.is_empty() {
|
if calls.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
tracing::info!("Custom REST: Found {} tool calls", calls.len());
|
tracing::info!("MSI GenAI: Found {} tool calls", calls.len());
|
||||||
Some(calls)
|
Some(calls)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -389,6 +513,20 @@ impl OpenAiProvider {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// WORKAROUND: MSIGenAI gateway bug - tool calls returned as JSON text in 'msg' field
|
||||||
|
// Expected: {"tool_calls": [...]}
|
||||||
|
// Actual: {"msg": '{"tool_calls":[...]}'} or {"msg": '<tool_calls>[...]</tool_calls>'}
|
||||||
|
if tool_calls.is_none() {
|
||||||
|
// Try parsing tool calls from msg content (MSIGenAI workaround)
|
||||||
|
if let Some(parsed_calls) = Self::parse_tool_calls_from_text(&content) {
|
||||||
|
tracing::warn!(
|
||||||
|
"MSI GenAI: MSIGenAI workaround - parsed {} tool calls from msg text (gateway should return structured tool_calls field)",
|
||||||
|
parsed_calls.len()
|
||||||
|
);
|
||||||
|
tool_calls = Some(parsed_calls);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Note: sessionId from response should be stored back to config.session_id
|
// Note: sessionId from response should be stored back to config.session_id
|
||||||
// This would require making config mutable or returning it as part of ChatResponse
|
// This would require making config mutable or returning it as part of ChatResponse
|
||||||
// For now, the caller can extract it from the response if needed
|
// For now, the caller can extract it from the response if needed
|
||||||
@ -402,4 +540,95 @@ impl OpenAiProvider {
|
|||||||
tool_calls,
|
tool_calls,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse tool calls from text content (MSIGenAI gateway workaround)
|
||||||
|
///
|
||||||
|
/// MSIGenAI returns tool calls as JSON text in the 'msg' field instead of structured data:
|
||||||
|
/// - ChatGPT models: `{"tool_calls":[...]}`
|
||||||
|
/// - Claude models: `<tool_calls>[...]</tool_calls>`
|
||||||
|
fn parse_tool_calls_from_text(content: &str) -> Option<Vec<crate::ai::ToolCall>> {
|
||||||
|
// Try parsing as direct JSON object
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
|
if let Some(calls) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
|
||||||
|
return Self::extract_tool_calls_from_array(calls);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try finding JSON in text (handle Claude XML wrapper: <tool_calls>[...]</tool_calls>)
|
||||||
|
if let Some(start) = content.find("<tool_calls>") {
|
||||||
|
if let Some(end) = content.find("</tool_calls>") {
|
||||||
|
let json_str = &content[start + 12..end].trim();
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
|
||||||
|
if let Some(calls) = parsed.as_array() {
|
||||||
|
return Self::extract_tool_calls_from_array(calls);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try finding raw JSON array in text
|
||||||
|
if let Some(start) = content.find("[{") {
|
||||||
|
if let Some(end) = content.rfind("}]") {
|
||||||
|
let json_str = &content[start..=end + 1];
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
|
||||||
|
if let Some(calls) = parsed.as_array() {
|
||||||
|
return Self::extract_tool_calls_from_array(calls);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract ToolCall structs from JSON array
|
||||||
|
fn extract_tool_calls_from_array(
|
||||||
|
calls: &[serde_json::Value],
|
||||||
|
) -> Option<Vec<crate::ai::ToolCall>> {
|
||||||
|
let parsed: Vec<crate::ai::ToolCall> = calls
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(index, call)| {
|
||||||
|
// Generate fallback ID if missing (consistent with earlier parsing logic in this file)
|
||||||
|
let id = call
|
||||||
|
.get("id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.unwrap_or_else(|| format!("tool_call_{index}"));
|
||||||
|
|
||||||
|
// Try nested function.name format (OpenAI style)
|
||||||
|
let name = call
|
||||||
|
.get("function")
|
||||||
|
.and_then(|f| f.get("name"))
|
||||||
|
.and_then(|n| n.as_str())
|
||||||
|
.or_else(|| call.get("name").and_then(|n| n.as_str()))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// Arguments can be string or object
|
||||||
|
let arguments = call
|
||||||
|
.get("function")
|
||||||
|
.and_then(|f| f.get("arguments"))
|
||||||
|
.or_else(|| call.get("arguments"))
|
||||||
|
.and_then(|args| {
|
||||||
|
if let Some(s) = args.as_str() {
|
||||||
|
Some(s.to_string())
|
||||||
|
} else {
|
||||||
|
serde_json::to_string(args).ok()
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Some(crate::ai::ToolCall {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if parsed.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -552,6 +552,79 @@ pub async fn test_provider_connection(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn detect_tool_calling_support(provider_config: ProviderConfig) -> Result<bool, String> {
|
||||||
|
use crate::ai::{Tool, ToolParameters};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
// Create a simple test tool
|
||||||
|
let test_tool = Tool {
|
||||||
|
name: "test_tool".to_string(),
|
||||||
|
description: "A test tool that returns 'success'. Call this tool with no arguments."
|
||||||
|
.to_string(),
|
||||||
|
parameters: ToolParameters {
|
||||||
|
param_type: "object".to_string(),
|
||||||
|
properties: HashMap::new(),
|
||||||
|
required: vec![],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Override config with detection-optimized settings
|
||||||
|
let mut detection_config = provider_config.clone();
|
||||||
|
detection_config.max_tokens = Some(100); // Small budget for capability check
|
||||||
|
detection_config.temperature = Some(0.0); // Deterministic for reliability
|
||||||
|
|
||||||
|
let provider = create_provider(&detection_config);
|
||||||
|
let messages = vec![Message {
|
||||||
|
role: "user".into(),
|
||||||
|
content: "Please call the test_tool function.".into(),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}];
|
||||||
|
|
||||||
|
match provider
|
||||||
|
.chat(messages, &detection_config, Some(vec![test_tool]))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(response) => {
|
||||||
|
// Check if response contains tool_calls
|
||||||
|
if let Some(tool_calls) = response.tool_calls {
|
||||||
|
if tool_calls.iter().any(|tc| tc.name == "test_tool") {
|
||||||
|
info!(
|
||||||
|
"Tool calling support detected for provider {}",
|
||||||
|
provider_config.name
|
||||||
|
);
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Provider responded but didn't use tool calls
|
||||||
|
info!(
|
||||||
|
"Provider {} responded but did not call tool",
|
||||||
|
provider_config.name
|
||||||
|
);
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Check if error indicates tool calling is not supported
|
||||||
|
let error_msg = e.to_string().to_lowercase();
|
||||||
|
if error_msg.contains("tool")
|
||||||
|
|| error_msg.contains("function")
|
||||||
|
|| error_msg.contains("503")
|
||||||
|
{
|
||||||
|
info!(
|
||||||
|
"Tool calling not supported for provider {}: {}",
|
||||||
|
provider_config.name, e
|
||||||
|
);
|
||||||
|
Ok(false)
|
||||||
|
} else {
|
||||||
|
// Connection or other error
|
||||||
|
Err(format!("Failed to test tool calling support: {e}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn list_providers() -> Result<Vec<ProviderInfo>, String> {
|
pub async fn list_providers() -> Result<Vec<ProviderInfo>, String> {
|
||||||
Ok(vec![
|
Ok(vec![
|
||||||
|
|||||||
@ -104,6 +104,7 @@ pub fn run() {
|
|||||||
commands::ai::analyze_logs,
|
commands::ai::analyze_logs,
|
||||||
commands::ai::chat_message,
|
commands::ai::chat_message,
|
||||||
commands::ai::test_provider_connection,
|
commands::ai::test_provider_connection,
|
||||||
|
commands::ai::detect_tool_calling_support,
|
||||||
commands::ai::list_providers,
|
commands::ai::list_providers,
|
||||||
commands::system::save_ai_provider,
|
commands::system::save_ai_provider,
|
||||||
commands::system::load_ai_providers,
|
commands::system::load_ai_providers,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user