tftsr-devops_investigation/src-tauri/src/ai/openai.rs
Shaun Arman 9d8bdd383c feat: add MSI GenAI custom provider support
- Extended ProviderConfig with optional custom fields for non-OpenAI APIs
- Added custom_endpoint_path, custom_auth_header, custom_auth_prefix fields
- Added api_format field to distinguish between OpenAI and MSI GenAI formats
- Added session_id field for stateful conversation APIs
- Implemented chat_msi_genai() method in OpenAI provider
- MSI GenAI uses different request format (prompt+sessionId) and response (msg field)
- Updated TypeScript types to match Rust schema
- Added UI controls in Settings/AIProviders for custom provider configuration
- API format selector auto-populates appropriate defaults (OpenAI vs MSI GenAI)
- Backward compatible: existing providers default to OpenAI format
2026-04-03 15:45:42 -05:00

190 lines
6.2 KiB
Rust

use async_trait::async_trait;
use crate::ai::provider::Provider;
use crate::ai::{ChatResponse, Message, ProviderInfo, TokenUsage};
use crate::state::ProviderConfig;
pub struct OpenAiProvider;
#[async_trait]
impl Provider for OpenAiProvider {
fn name(&self) -> &str {
"openai"
}
fn info(&self) -> ProviderInfo {
ProviderInfo {
name: "OpenAI Compatible".to_string(),
supports_streaming: true,
models: vec![
"gpt-4o".to_string(),
"gpt-4o-mini".to_string(),
"gpt-4-turbo".to_string(),
],
}
}
async fn chat(
&self,
messages: Vec<Message>,
config: &ProviderConfig,
) -> anyhow::Result<ChatResponse> {
// Check if using MSI GenAI format
let api_format = config.api_format.as_deref().unwrap_or("openai");
if api_format == "msi_genai" {
self.chat_msi_genai(messages, config).await
} else {
self.chat_openai(messages, config).await
}
}
}
impl OpenAiProvider {
/// OpenAI-compatible API format (default)
async fn chat_openai(
&self,
messages: Vec<Message>,
config: &ProviderConfig,
) -> anyhow::Result<ChatResponse> {
let client = reqwest::Client::new();
// Use custom endpoint path if provided, otherwise default to /chat/completions
let endpoint_path = config
.custom_endpoint_path
.as_deref()
.unwrap_or("/chat/completions");
let url = format!("{}{}", config.api_url.trim_end_matches('/'), endpoint_path);
let body = serde_json::json!({
"model": config.model,
"messages": messages,
"max_tokens": 4096,
});
// Use custom auth header and prefix if provided
let auth_header = config.custom_auth_header.as_deref().unwrap_or("Authorization");
let auth_prefix = config.custom_auth_prefix.as_deref().unwrap_or("Bearer ");
let auth_value = format!("{}{}", auth_prefix, config.api_key);
let resp = client
.post(&url)
.header(auth_header, auth_value)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await?;
anyhow::bail!("OpenAI API error {status}: {text}");
}
let json: serde_json::Value = resp.json().await?;
let content = json["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("No content in response"))?
.to_string();
let usage = json.get("usage").and_then(|u| {
Some(TokenUsage {
prompt_tokens: u["prompt_tokens"].as_u64()? as u32,
completion_tokens: u["completion_tokens"].as_u64()? as u32,
total_tokens: u["total_tokens"].as_u64()? as u32,
})
});
Ok(ChatResponse {
content,
model: config.model.clone(),
usage,
})
}
/// MSI GenAI custom format
async fn chat_msi_genai(
&self,
messages: Vec<Message>,
config: &ProviderConfig,
) -> anyhow::Result<ChatResponse> {
let client = reqwest::Client::new();
// Use custom endpoint path, default to empty (API URL already includes /api/v2/chat)
let endpoint_path = config.custom_endpoint_path.as_deref().unwrap_or("");
let url = format!("{}{}", config.api_url.trim_end_matches('/'), endpoint_path);
// Extract system message if present
let system_message = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.clone());
// Get last user message as prompt
let prompt = messages
.iter()
.rev()
.find(|m| m.role == "user")
.map(|m| m.content.clone())
.ok_or_else(|| anyhow::anyhow!("No user message found"))?;
// Build request body
let mut body = serde_json::json!({
"model": config.model,
"prompt": prompt,
"userId": "user@motorolasolutions.com", // Default user ID
});
// Add optional system message
if let Some(system) = system_message {
body["system"] = serde_json::Value::String(system);
}
// Add session ID if available (for conversation continuity)
if let Some(session_id) = &config.session_id {
body["sessionId"] = serde_json::Value::String(session_id.clone());
}
// Use custom auth header and prefix (no prefix for MSI GenAI)
let auth_header = config
.custom_auth_header
.as_deref()
.unwrap_or("x-msi-genai-api-key");
let auth_prefix = config.custom_auth_prefix.as_deref().unwrap_or("");
let auth_value = format!("{}{}", auth_prefix, config.api_key);
let resp = client
.post(&url)
.header(auth_header, auth_value)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await?;
anyhow::bail!("MSI GenAI API error {status}: {text}");
}
let json: serde_json::Value = resp.json().await?;
// Extract response content from "msg" field
let content = json["msg"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("No 'msg' field in response"))?
.to_string();
// Note: sessionId from response should be stored back to config.session_id
// This would require making config mutable or returning it as part of ChatResponse
// For now, the caller can extract it from the response if needed
// TODO: Consider adding session_id to ChatResponse struct
Ok(ChatResponse {
content,
model: config.model.clone(),
usage: None, // MSI GenAI doesn't provide token usage in response
})
}
}