feat: implement OAuth2 callback server with automatic token exchange
Some checks are pending
Auto Tag / auto-tag (push) Waiting to run
Test / rust-fmt-check (push) Waiting to run
Test / rust-clippy (push) Waiting to run
Test / rust-tests (push) Waiting to run
Test / frontend-typecheck (push) Waiting to run
Test / frontend-tests (push) Waiting to run

Phase 2.2: OAuth2 flow - Part 3 (Callback server) COMPLETE 

Implemented:
- Local HTTP server on localhost:8765 using warp
  * GET /callback?code=...&state=... - OAuth redirect handler
  * GET /health - Health check endpoint
  * Graceful shutdown with oneshot channel

- Automatic callback handling
  * Server auto-starts on first initiate_oauth call
  * Background task listens for OAuth redirects
  * Automatically exchanges code for token
  * Stores encrypted token in database
  * Logs audit event for each successful OAuth

- Updated initiate_oauth command
  * Starts callback server if not running
  * Stores (service, verifier) tuple in OAuth state
  * Returns auth URL to open in browser/webview

- Updated handle_oauth_callback_internal
  * Accepts AppState reference (not State)
  * Called automatically by callback server
  * Exchanges code, encrypts token, stores in DB

- Beautiful success/error HTML pages
  * Green checkmark on success
  * Auto-closes window after 3 seconds
  * Clear error messages on failure

- Global state management
  * OAUTH_STATE: Maps state key -> (service, verifier)
  * CALLBACK_SERVER_SHUTDOWN: Holds shutdown channel
  * Thread-safe with Mutex wrappers

Dependencies added:
- warp 0.3 - Lightweight HTTP framework

TDD tests (7 passing with --test-threads=1):
Callback server tests:
  * Health endpoint verification
  * Callback parameter parsing
  * Missing/partial parameter handling
  * Graceful shutdown

Integration command tests:
  * OAuth state storage and retrieval
  * Multiple key management
  * OAuthInitResponse serialization

COMPLETE OAUTH2 FLOW:
1. User calls initiate_oauth("confluence")
2. Callback server starts (if not running)
3. Frontend receives auth URL
4. User opens URL in browser/webview
5. User authorizes, redirected to localhost:8765/callback?code=...
6. Callback server receives redirect
7. Token exchanged automatically
8. Token encrypted and stored in DB
9. Success page shown to user
10. Window auto-closes

Next: Frontend components (AuthWindow, Settings UI, CSP updates)
This commit is contained in:
Shaun Arman 2026-04-03 14:59:39 -05:00
parent 75302a1cc7
commit 05f4b51370
5 changed files with 697 additions and 62 deletions

316
src-tauri/Cargo.lock generated
View File

@ -869,6 +869,12 @@ version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04"
[[package]]
name = "data-encoding"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea"
[[package]]
name = "data-url"
version = "0.3.2"
@ -1874,6 +1880,25 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "h2"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http 0.2.12",
"indexmap 2.13.0",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "h2"
version = "0.4.13"
@ -1885,7 +1910,7 @@ dependencies = [
"fnv",
"futures-core",
"futures-sink",
"http",
"http 1.4.0",
"indexmap 2.13.0",
"slab",
"tokio",
@ -1948,6 +1973,30 @@ dependencies = [
"hashbrown 0.14.5",
]
[[package]]
name = "headers"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270"
dependencies = [
"base64 0.21.7",
"bytes",
"headers-core",
"http 0.2.12",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
dependencies = [
"http 0.2.12",
]
[[package]]
name = "heck"
version = "0.4.1"
@ -2006,6 +2055,17 @@ dependencies = [
"markup5ever 0.36.1",
]
[[package]]
name = "http"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http"
version = "1.4.0"
@ -2016,6 +2076,17 @@ dependencies = [
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [
"bytes",
"http 0.2.12",
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.1"
@ -2023,7 +2094,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
"http 1.4.0",
]
[[package]]
@ -2034,8 +2105,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"http 1.4.0",
"http-body 1.0.1",
"pin-project-lite",
]
@ -2051,6 +2122,30 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "0.14.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
dependencies = [
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"h2 0.3.27",
"http 0.2.12",
"http-body 0.4.6",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"socket2 0.5.10",
"tokio",
"tower-service",
"tracing",
"want",
]
[[package]]
name = "hyper"
version = "1.8.1"
@ -2061,9 +2156,9 @@ dependencies = [
"bytes",
"futures-channel",
"futures-core",
"h2",
"http",
"http-body",
"h2 0.4.13",
"http 1.4.0",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
@ -2080,8 +2175,8 @@ version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"http 1.4.0",
"hyper 1.8.1",
"hyper-util",
"rustls",
"rustls-pki-types",
@ -2099,7 +2194,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper 1.8.1",
"hyper-util",
"native-tls",
"tokio",
@ -2117,14 +2212,14 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http",
"http-body",
"hyper",
"http 1.4.0",
"http-body 1.0.1",
"hyper 1.8.1",
"ipnet",
"libc",
"percent-encoding",
"pin-project-lite",
"socket2",
"socket2 0.6.3",
"system-configuration",
"tokio",
"tower-service",
@ -2824,6 +2919,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minisign-verify"
version = "0.2.5"
@ -2861,10 +2966,10 @@ dependencies = [
"bytes",
"colored",
"futures-core",
"http",
"http-body",
"http 1.4.0",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper 1.8.1",
"hyper-util",
"log",
"pin-project-lite",
@ -2907,6 +3012,24 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "multer"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http 0.2.12",
"httparse",
"log",
"memchr",
"mime",
"spin",
"version_check",
]
[[package]]
name = "native-tls"
version = "0.2.18"
@ -3513,6 +3636,26 @@ dependencies = [
"siphasher 1.0.2",
]
[[package]]
name = "pin-project"
version = "1.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "pin-project-lite"
version = "0.2.17"
@ -3798,7 +3941,7 @@ dependencies = [
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"socket2 0.6.3",
"thiserror 2.0.18",
"tokio",
"tracing",
@ -3835,7 +3978,7 @@ dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"socket2 0.6.3",
"tracing",
"windows-sys 0.60.2",
]
@ -4079,11 +4222,11 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"h2 0.4.13",
"http 1.4.0",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper 1.8.1",
"hyper-rustls",
"hyper-tls",
"hyper-util",
@ -4125,10 +4268,10 @@ dependencies = [
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http 1.4.0",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper 1.8.1",
"hyper-util",
"js-sys",
"log",
@ -4393,6 +4536,12 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -4673,6 +4822,17 @@ dependencies = [
"stable_deref_trait",
]
[[package]]
name = "sha1"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha2"
version = "0.10.9"
@ -4787,6 +4947,16 @@ version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "socket2"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678"
dependencies = [
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "socket2"
version = "0.6.3"
@ -4845,6 +5015,12 @@ dependencies = [
"system-deps",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "spki"
version = "0.7.3"
@ -5149,7 +5325,7 @@ dependencies = [
"glob",
"gtk",
"heck 0.5.0",
"http",
"http 1.4.0",
"jni",
"libc",
"log",
@ -5313,7 +5489,7 @@ dependencies = [
"bytes",
"cookie_store 0.21.1",
"data-url",
"http",
"http 1.4.0",
"regex",
"reqwest 0.12.28",
"schemars 0.8.22",
@ -5379,7 +5555,7 @@ dependencies = [
"cookie",
"dpi",
"gtk",
"http",
"http 1.4.0",
"jni",
"objc2",
"objc2-ui-kit",
@ -5402,7 +5578,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e11ea2e6f801d275fdd890d6c9603736012742a1c33b96d0db788c9cdebf7f9e"
dependencies = [
"gtk",
"http",
"http 1.4.0",
"jni",
"log",
"objc2",
@ -5434,7 +5610,7 @@ dependencies = [
"dunce",
"glob",
"html5ever 0.29.1",
"http",
"http 1.4.0",
"infer",
"json-patch",
"kuchikiki",
@ -5531,6 +5707,7 @@ dependencies = [
"tracing",
"tracing-subscriber",
"uuid",
"warp",
]
[[package]]
@ -5673,7 +5850,7 @@ dependencies = [
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"socket2 0.6.3",
"tokio-macros",
"windows-sys 0.61.2",
]
@ -5731,6 +5908,18 @@ dependencies = [
"tokio-stream",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.18"
@ -5873,8 +6062,8 @@ dependencies = [
"bitflags 2.11.0",
"bytes",
"futures-util",
"http",
"http-body",
"http 1.4.0",
"http-body 1.0.1",
"iri-string",
"pin-project-lite",
"tower",
@ -5900,6 +6089,7 @@ version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@ -5989,6 +6179,25 @@ version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49d64318d8311fc2668e48b63969f4343e0a85c4a109aa8460d6672e364b8bd1"
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.4.0",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror 1.0.69",
"url",
"utf-8",
]
[[package]]
name = "typed-path"
version = "0.12.3"
@ -6048,6 +6257,12 @@ dependencies = [
"unic-common",
]
[[package]]
name = "unicase"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]]
name = "unicode-ident"
version = "1.0.24"
@ -6111,7 +6326,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
dependencies = [
"base64 0.22.1",
"http",
"http 1.4.0",
"httparse",
"log",
]
@ -6228,6 +6443,35 @@ dependencies = [
"try-lock",
]
[[package]]
name = "warp"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4378d202ff965b011c64817db11d5829506d3404edeadb61f190d111da3f231c"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"headers",
"http 0.2.12",
"hyper 0.14.32",
"log",
"mime",
"mime_guess",
"multer",
"percent-encoding",
"pin-project",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-tungstenite",
"tokio-util",
"tower-service",
"tracing",
]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
@ -7196,7 +7440,7 @@ dependencies = [
"dunce",
"gdkx11",
"gtk",
"http",
"http 1.4.0",
"javascriptcore-rs",
"jni",
"libc",

View File

@ -41,6 +41,7 @@ dirs = "5"
aes-gcm = "0.10"
rand = "0.8"
lazy_static = "1.4"
warp = "0.3"
[dev-dependencies]
tokio-test = "0.4"

View File

@ -3,10 +3,16 @@ use crate::state::AppState;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tauri::State;
use tokio::sync::oneshot;
// Global OAuth state storage (verifier per state key)
// Global OAuth state storage (verifier + service per state key)
lazy_static::lazy_static! {
static ref OAUTH_STATE: Arc<Mutex<HashMap<String, String>>> = Arc::new(Mutex::new(HashMap::new()));
static ref OAUTH_STATE: Arc<Mutex<HashMap<String, (String, String)>>> = Arc::new(Mutex::new(HashMap::new()));
}
// Global callback server shutdown channel
lazy_static::lazy_static! {
static ref CALLBACK_SERVER_SHUTDOWN: Arc<Mutex<Option<oneshot::Sender<()>>>> = Arc::new(Mutex::new(None));
}
#[tauri::command]
@ -68,24 +74,99 @@ pub struct OAuthInitResponse {
}
/// Initiate OAuth2 authorization flow for a service.
/// Returns the authorization URL and a state key.
/// Starts the callback server and returns the authorization URL.
#[tauri::command]
pub async fn initiate_oauth(
service: String,
_state: State<'_, AppState>,
app_state: State<'_, AppState>,
) -> Result<OAuthInitResponse, String> {
// Start callback server if not already running
let server_already_running = {
let shutdown = CALLBACK_SERVER_SHUTDOWN.lock().map_err(|e| e.to_string())?;
shutdown.is_some()
};
if !server_already_running {
tracing::info!("Starting OAuth callback server");
let (mut callback_rx, shutdown_tx) =
crate::integrations::callback_server::start_callback_server(8765)
.await
.map_err(|e| format!("Failed to start callback server: {}", e))?;
// Store shutdown channel
{
let mut shutdown = CALLBACK_SERVER_SHUTDOWN.lock().map_err(|e| e.to_string())?;
*shutdown = Some(shutdown_tx);
}
// Clone the Arc fields from app_state for the spawned task
let db = app_state.db.clone();
let settings = app_state.settings.clone();
let app_data_dir = app_state.app_data_dir.clone();
tokio::spawn(async move {
let app_state_for_callback = AppState {
db,
settings,
app_data_dir,
};
while let Some(callback) = callback_rx.recv().await {
tracing::info!("Received OAuth callback for state: {}", callback.state);
// Retrieve service and verifier
let (service, verifier) = {
let mut oauth_state = match OAUTH_STATE.lock() {
Ok(state) => state,
Err(e) => {
tracing::error!("Failed to lock OAuth state: {}", e);
continue;
}
};
match oauth_state.remove(&callback.state) {
Some((svc, ver)) => (svc, ver),
None => {
tracing::warn!("Unknown OAuth state: {}", callback.state);
continue;
}
}
};
// Call handle_oauth_callback internally
let result = handle_oauth_callback_internal(
service,
callback.code,
verifier,
&app_state_for_callback,
)
.await;
match result {
Ok(_) => tracing::info!("OAuth callback handled successfully"),
Err(e) => tracing::error!("OAuth callback failed: {}", e),
}
}
tracing::info!("OAuth callback listener stopped");
});
}
// Generate PKCE challenge
let pkce = crate::integrations::auth::generate_pkce();
// Generate state key for this OAuth session
let state_key = uuid::Uuid::now_v7().to_string();
// Store verifier temporarily
// Store verifier and service name
{
let mut oauth_state = OAUTH_STATE
.lock()
.map_err(|e| format!("Failed to lock OAuth state: {}", e))?;
oauth_state.insert(state_key.clone(), pkce.code_verifier.clone());
oauth_state.insert(
state_key.clone(),
(service.clone(), pkce.code_verifier.clone()),
);
}
// Build authorization URL based on service
@ -125,25 +206,13 @@ pub async fn initiate_oauth(
})
}
/// Handle OAuth2 callback after user authorization.
/// Exchanges authorization code for access token and stores it.
#[tauri::command]
pub async fn handle_oauth_callback(
/// Internal function to handle OAuth callback (used by callback server).
async fn handle_oauth_callback_internal(
service: String,
code: String,
state_key: String,
app_state: State<'_, AppState>,
verifier: String,
app_state: &AppState,
) -> Result<(), String> {
// Retrieve verifier from temporary state
let verifier = {
let mut oauth_state = OAUTH_STATE
.lock()
.map_err(|e| format!("Failed to lock OAuth state: {}", e))?;
oauth_state
.remove(&state_key)
.ok_or_else(|| "Invalid or expired OAuth state".to_string())?
};
// Get token endpoint and client_id based on service
let (token_endpoint, client_id, redirect_uri) = match service.as_str() {
"confluence" => (
@ -233,6 +302,29 @@ pub async fn handle_oauth_callback(
Ok(())
}
/// Handle OAuth2 callback (Tauri command for external/manual calls).
/// This is rarely used since callbacks are handled automatically by the callback server.
#[tauri::command]
pub async fn handle_oauth_callback(
service: String,
code: String,
state_key: String,
app_state: State<'_, AppState>,
) -> Result<(), String> {
// Retrieve verifier from temporary state
let verifier = {
let mut oauth_state = OAUTH_STATE
.lock()
.map_err(|e| format!("Failed to lock OAuth state: {}", e))?;
oauth_state
.remove(&state_key)
.map(|(_svc, ver)| ver)
.ok_or_else(|| "Invalid or expired OAuth state".to_string())?
};
handle_oauth_callback_internal(service, code, verifier, app_state.inner()).await
}
#[cfg(test)]
mod tests {
use super::*;
@ -240,18 +332,19 @@ mod tests {
#[test]
fn test_oauth_state_storage() {
let key = "test-key".to_string();
let service = "confluence".to_string();
let verifier = "test-verifier".to_string();
// Store
{
let mut state = OAUTH_STATE.lock().unwrap();
state.insert(key.clone(), verifier.clone());
state.insert(key.clone(), (service.clone(), verifier.clone()));
}
// Retrieve
{
let state = OAUTH_STATE.lock().unwrap();
assert_eq!(state.get(&key), Some(&verifier));
assert_eq!(state.get(&key), Some(&(service.clone(), verifier.clone())));
}
// Remove
@ -274,8 +367,14 @@ mod tests {
{
let mut state = OAUTH_STATE.lock().unwrap();
state.insert(key1.clone(), "verifier1".to_string());
state.insert(key2.clone(), "verifier2".to_string());
state.insert(
key1.clone(),
("confluence".to_string(), "verifier1".to_string()),
);
state.insert(
key2.clone(),
("azuredevops".to_string(), "verifier2".to_string()),
);
}
{
@ -286,6 +385,10 @@ mod tests {
let state = OAUTH_STATE.lock().unwrap();
assert!(!state.contains_key(&key1));
assert!(state.contains_key(&key2));
assert_eq!(
state.get(&key2),
Some(&("azuredevops".to_string(), "verifier2".to_string()))
);
}
#[test]

View File

@ -0,0 +1,286 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use warp::Filter;
#[derive(Debug, Clone)]
pub struct OAuthCallback {
pub code: String,
pub state: String,
}
/// Start a local HTTP server to handle OAuth callbacks.
/// Returns a channel to receive callback data and a shutdown signal.
pub async fn start_callback_server(
port: u16,
) -> Result<(mpsc::Receiver<OAuthCallback>, oneshot::Sender<()>), String> {
let (tx, rx) = mpsc::channel::<OAuthCallback>(1);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let tx = Arc::new(tokio::sync::Mutex::new(tx));
// Callback route: GET /callback?code=...&state=...
let callback_route = warp::path("callback")
.and(warp::query::<HashMap<String, String>>())
.and(warp::any().map(move || tx.clone()))
.and_then(handle_callback);
// Health check route
let health_route = warp::path("health").map(|| warp::reply::html("OK"));
let routes = callback_route.or(health_route);
let addr: std::net::SocketAddr = ([127, 0, 0, 1], port).into();
tracing::info!(
"Starting OAuth callback server on http://127.0.0.1:{}",
port
);
// Spawn server with graceful shutdown
tokio::spawn(async move {
let (_, server) = warp::serve(routes).bind_with_graceful_shutdown(addr, async {
shutdown_rx.await.ok();
});
server.await;
tracing::info!("OAuth callback server stopped");
});
Ok((rx, shutdown_tx))
}
async fn handle_callback(
params: HashMap<String, String>,
tx: Arc<tokio::sync::Mutex<mpsc::Sender<OAuthCallback>>>,
) -> Result<impl warp::Reply, warp::Rejection> {
let code = params.get("code").cloned();
let state = params.get("state").cloned();
match (code, state) {
(Some(code), Some(state)) => {
// Send callback data to channel
let callback = OAuthCallback { code, state };
let tx = tx.lock().await;
if tx.send(callback).await.is_err() {
tracing::error!("Failed to send OAuth callback to channel");
return Ok(warp::reply::html(
r#"
<!DOCTYPE html>
<html>
<head><title>OAuth Error</title></head>
<body>
<h1>Authentication Error</h1>
<p>Failed to process callback. Please try again.</p>
</body>
</html>
"#,
));
}
Ok(warp::reply::html(
r#"
<!DOCTYPE html>
<html>
<head>
<title>Authentication Successful</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.container {
background: white;
padding: 3rem;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
text-align: center;
max-width: 400px;
}
h1 {
color: #2d3748;
margin-bottom: 1rem;
}
p {
color: #4a5568;
line-height: 1.6;
}
.checkmark {
width: 80px;
height: 80px;
border-radius: 50%;
background: #10b981;
margin: 0 auto 1.5rem;
display: flex;
align-items: center;
justify-content: center;
}
.checkmark svg {
width: 50px;
height: 50px;
stroke: white;
stroke-width: 3;
fill: none;
}
</style>
<script>
// Auto-close after 3 seconds
setTimeout(() => {
window.close();
}, 3000);
</script>
</head>
<body>
<div class="container">
<div class="checkmark">
<svg viewBox="0 0 52 52">
<polyline points="14 27 22 35 38 19"/>
</svg>
</div>
<h1>Authentication Successful!</h1>
<p>You have been successfully authenticated. This window will close automatically.</p>
<p><small>You can safely close this window if it doesn't close automatically.</small></p>
</div>
</body>
</html>
"#,
))
}
_ => {
tracing::warn!("OAuth callback missing code or state parameter");
Ok(warp::reply::html(
r#"
<!DOCTYPE html>
<html>
<head><title>OAuth Error</title></head>
<body>
<h1>Authentication Error</h1>
<p>Missing required parameters (code or state).</p>
<p>Please return to the application and try again.</p>
</body>
</html>
"#,
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_start_callback_server() {
let result = start_callback_server(8766).await;
assert!(result.is_ok());
let (mut rx, shutdown_tx) = result.unwrap();
// Give server time to start
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Test health endpoint
let health_response = reqwest::get("http://127.0.0.1:8766/health").await.unwrap();
assert!(health_response.status().is_success());
// Test callback endpoint with parameters
let callback_response =
reqwest::get("http://127.0.0.1:8766/callback?code=test_code&state=test_state")
.await
.unwrap();
assert!(callback_response.status().is_success());
// Verify callback was received
let callback = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
.await
.expect("Timeout waiting for callback")
.expect("Channel closed");
assert_eq!(callback.code, "test_code");
assert_eq!(callback.state, "test_state");
// Shutdown server
shutdown_tx.send(()).unwrap();
}
#[tokio::test]
async fn test_callback_missing_parameters() {
let result = start_callback_server(8767).await;
assert!(result.is_ok());
let (_rx, shutdown_tx) = result.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Test callback without parameters
let response = reqwest::get("http://127.0.0.1:8767/callback")
.await
.unwrap();
assert!(response.status().is_success());
let body = response.text().await.unwrap();
assert!(body.contains("Missing required parameters"));
shutdown_tx.send(()).unwrap();
}
#[tokio::test]
async fn test_callback_partial_parameters() {
let result = start_callback_server(8768).await;
assert!(result.is_ok());
let (_rx, shutdown_tx) = result.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Test callback with only code
let response = reqwest::get("http://127.0.0.1:8768/callback?code=test_code")
.await
.unwrap();
assert!(response.status().is_success());
let body = response.text().await.unwrap();
assert!(body.contains("Missing required parameters"));
shutdown_tx.send(()).unwrap();
}
#[tokio::test]
async fn test_server_graceful_shutdown() {
// Use a unique port to avoid conflicts
let port = 8770
+ (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
% 100) as u16;
let result = start_callback_server(port).await;
assert!(result.is_ok());
let (_rx, shutdown_tx) = result.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
// Server should be running
let health_url = format!("http://127.0.0.1:{}/health", port);
let health_before = reqwest::get(&health_url).await;
assert!(health_before.is_ok(), "Server should be running");
// Shutdown
shutdown_tx.send(()).unwrap();
// Give server time to shut down
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Server should be stopped
let health_after = reqwest::get(&health_url).await;
assert!(health_after.is_err(), "Server should be stopped");
}
}

View File

@ -1,5 +1,6 @@
pub mod auth;
pub mod azuredevops;
pub mod callback_server;
pub mod confluence;
pub mod servicenow;