use async_trait::async_trait; use crate::ai::provider::Provider; use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage}; use crate::state::ProviderConfig; pub struct OpenAiProvider; #[async_trait] impl Provider for OpenAiProvider { fn name(&self) -> &str { "openai" } fn info(&self) -> ProviderInfo { ProviderInfo { name: "OpenAI Compatible".to_string(), supports_streaming: true, models: vec![ "gpt-4o".to_string(), "gpt-4o-mini".to_string(), "gpt-4-turbo".to_string(), ], } } async fn chat( &self, messages: Vec, config: &ProviderConfig, ) -> anyhow::Result { // Check if using MSI GenAI format let api_format = config.api_format.as_deref().unwrap_or("openai"); if api_format == "msi_genai" { self.chat_msi_genai(messages, config).await } else { self.chat_openai(messages, config).await } } } impl OpenAiProvider { /// OpenAI-compatible API format (default) async fn chat_openai( &self, messages: Vec, config: &ProviderConfig, ) -> anyhow::Result { let client = reqwest::Client::new(); // Use custom endpoint path if provided, otherwise default to /chat/completions let endpoint_path = config .custom_endpoint_path .as_deref() .unwrap_or("/chat/completions"); let api_url = config.api_url.trim_end_matches('/'); let url = format!("{api_url}{endpoint_path}"); let mut body = serde_json::json!({ "model": config.model, "messages": messages, }); // Add max_tokens if provided, otherwise use default 4096 body["max_tokens"] = serde_json::Value::from(config.max_tokens.unwrap_or(4096)); // Add temperature if provided if let Some(temp) = config.temperature { body["temperature"] = serde_json::Value::from(temp); } // Use custom auth header and prefix if provided let auth_header = config .custom_auth_header .as_deref() .unwrap_or("Authorization"); let auth_prefix = config.custom_auth_prefix.as_deref().unwrap_or("Bearer "); let auth_value = format!("{auth_prefix}{api_key}", api_key = config.api_key); let resp = client .post(&url) .header(auth_header, auth_value) .header("Content-Type", "application/json") .json(&body) .send() .await?; if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await?; anyhow::bail!("OpenAI API error {status}: {text}"); } let json: serde_json::Value = resp.json().await?; let content = json["choices"][0]["message"]["content"] .as_str() .ok_or_else(|| anyhow::anyhow!("No content in response"))? .to_string(); let usage = json.get("usage").and_then(|u| { Some(TokenUsage { prompt_tokens: u["prompt_tokens"].as_u64()? as u32, completion_tokens: u["completion_tokens"].as_u64()? as u32, total_tokens: u["total_tokens"].as_u64()? as u32, }) }); Ok(ChatResponse { content, model: config.model.clone(), usage, }) } /// MSI GenAI custom format async fn chat_msi_genai( &self, messages: Vec, config: &ProviderConfig, ) -> anyhow::Result { let client = reqwest::Client::new(); // Use custom endpoint path, default to empty (API URL already includes /api/v2/chat) let endpoint_path = config.custom_endpoint_path.as_deref().unwrap_or(""); let api_url = config.api_url.trim_end_matches('/'); let url = format!("{api_url}{endpoint_path}"); // Extract system message if present let system_message = messages .iter() .find(|m| m.role == "system") .map(|m| m.content.clone()); // Get last user message as prompt let prompt = messages .iter() .rev() .find(|m| m.role == "user") .map(|m| m.content.clone()) .ok_or_else(|| anyhow::anyhow!("No user message found"))?; // Build request body let mut body = serde_json::json!({ "model": config.model, "prompt": prompt, }); // Add userId if provided (CORE ID email) if let Some(user_id) = &config.user_id { body["userId"] = serde_json::Value::String(user_id.clone()); } // Add optional system message if let Some(system) = system_message { body["system"] = serde_json::Value::String(system); } // Add session ID if available (for conversation continuity) if let Some(session_id) = &config.session_id { body["sessionId"] = serde_json::Value::String(session_id.clone()); } // Add modelConfig with temperature and max_tokens if provided let mut model_config = serde_json::json!({}); if let Some(temp) = config.temperature { model_config["temperature"] = serde_json::Value::from(temp); } if let Some(max_tokens) = config.max_tokens { model_config["max_tokens"] = serde_json::Value::from(max_tokens); } if !model_config.is_null() && model_config.as_object().is_some_and(|obj| !obj.is_empty()) { body["modelConfig"] = model_config; } // Use custom auth header and prefix (no prefix for MSI GenAI) let auth_header = config .custom_auth_header .as_deref() .unwrap_or("x-msi-genai-api-key"); let auth_prefix = config.custom_auth_prefix.as_deref().unwrap_or(""); let auth_value = format!("{auth_prefix}{api_key}", api_key = config.api_key); let resp = client .post(&url) .header(auth_header, auth_value) .header("Content-Type", "application/json") .header("X-msi-genai-client", "tftsr-devops-investigation") .json(&body) .send() .await?; if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await?; anyhow::bail!("MSI GenAI API error {status}: {text}"); } let json: serde_json::Value = resp.json().await?; // Extract response content from "msg" field let content = json["msg"] .as_str() .ok_or_else(|| anyhow::anyhow!("No 'msg' field in response"))? .to_string(); // Note: sessionId from response should be stored back to config.session_id // This would require making config mutable or returning it as part of ChatResponse // For now, the caller can extract it from the response if needed // TODO: Consider adding session_id to ChatResponse struct Ok(ChatResponse { content, model: config.model.clone(), usage: None, // MSI GenAI doesn't provide token usage in response }) } }