coven/
oauth.rs

1//! OAuth 2.0 helper for consumer cloud provider authentication.
2//!
3//! Provides PKCE-based authorization code flow with a localhost callback server.
4//! Used by Google Drive, Dropbox, and OneDrive cloud home backends.
5
6use std::collections::HashMap;
7use std::sync::OnceLock;
8
9use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
10use rand::RngCore;
11use serde::{Deserialize, Serialize};
12use sha2::{Digest, Sha256};
13use thiserror::Error;
14use tracing::{info, warn};
15
16/// OAuth provider configuration.
17#[derive(Clone, Debug)]
18pub struct OAuthConfig {
19    pub client_id: String,
20    /// None for public clients (PKCE-only, no client secret needed).
21    pub client_secret: Option<String>,
22    pub auth_url: String,
23    pub token_url: String,
24    pub scopes: Vec<String>,
25    /// Localhost callback port. Default: 19284.
26    pub redirect_port: u16,
27    /// Extra params appended to the authorization URL (e.g. Google's
28    /// `access_type=offline` or Dropbox's `token_access_type=offline`).
29    pub extra_auth_params: Vec<(String, String)>,
30}
31
32/// OAuth client credentials for one provider — the consuming app's registered
33/// OAuth application. coven ships no app credentials of its own; the host
34/// registers them at startup via [`set_oauth_client_creds`].
35#[derive(Clone, Debug, Default)]
36pub struct OAuthClientCreds {
37    pub client_id: String,
38    /// None for public (PKCE-only) clients.
39    pub client_secret: Option<String>,
40}
41
42static OAUTH_CLIENT_CREDS: OnceLock<HashMap<String, OAuthClientCreds>> = OnceLock::new();
43
44/// Register the host's OAuth client credentials, keyed by provider name
45/// (`"google_drive"`, `"dropbox"`, `"onedrive"`). Call once at startup, before
46/// any OAuth flow. Providers absent from the map get empty credentials.
47pub fn set_oauth_client_creds(creds: HashMap<String, OAuthClientCreds>) {
48    let _ = OAUTH_CLIENT_CREDS.set(creds);
49}
50
51/// The credentials registered for a provider, or empty if none were registered.
52pub fn oauth_client_creds(provider: &str) -> OAuthClientCreds {
53    OAUTH_CLIENT_CREDS
54        .get()
55        .and_then(|m| m.get(provider).cloned())
56        .unwrap_or_default()
57}
58
59/// Tokens returned from an OAuth authorization or refresh.
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct OAuthTokens {
62    pub access_token: String,
63    pub refresh_token: Option<String>,
64    /// Unix timestamp when the access token expires. None if unknown.
65    pub expires_at: Option<i64>,
66}
67
68#[derive(Error, Debug)]
69pub enum OAuthError {
70    #[error("failed to open browser: {0}")]
71    BrowserOpen(String),
72    #[error("callback server error: {0}")]
73    Server(String),
74    #[error("token exchange error: {0}")]
75    TokenExchange(String),
76    #[error("authorization denied: {0}")]
77    Denied(String),
78    #[error("timeout waiting for authorization callback")]
79    Timeout,
80}
81
82struct AbortOnDrop(Option<tokio::task::JoinHandle<()>>);
83
84impl AbortOnDrop {
85    fn new(handle: tokio::task::JoinHandle<()>) -> Self {
86        Self(Some(handle))
87    }
88
89    /// Take the join handle for the success path, where the caller wants to
90    /// await its termination (e.g. with a timeout) so the listener's port is
91    /// released before another flow tries to bind it. Disarms the Drop.
92    fn take_handle(mut self) -> Option<tokio::task::JoinHandle<()>> {
93        self.0.take()
94    }
95}
96
97impl Drop for AbortOnDrop {
98    fn drop(&mut self) {
99        if let Some(h) = self.0.take() {
100            h.abort();
101        }
102    }
103}
104
105/// Token response from the OAuth provider (internal deserialization).
106#[derive(Deserialize)]
107struct TokenResponse {
108    access_token: String,
109    refresh_token: Option<String>,
110    expires_in: Option<i64>,
111    #[serde(default)]
112    error: Option<String>,
113    #[serde(default)]
114    error_description: Option<String>,
115}
116
117/// Generate a random PKCE code verifier (43-128 URL-safe characters).
118pub fn generate_code_verifier() -> String {
119    let mut bytes = [0u8; 32];
120    rand::rng().fill_bytes(&mut bytes);
121    URL_SAFE_NO_PAD.encode(bytes)
122}
123
124/// Compute the S256 PKCE code challenge from a verifier.
125pub fn code_challenge(verifier: &str) -> String {
126    let hash = Sha256::digest(verifier.as_bytes());
127    URL_SAFE_NO_PAD.encode(hash)
128}
129
130/// Open the user's browser, wait for the OAuth callback, and exchange the
131/// authorization code for tokens.
132///
133/// Flow:
134/// 1. Generate PKCE verifier + challenge
135/// 2. Open browser to `auth_url` with the required parameters
136/// 3. Spawn a one-shot HTTP server on `localhost:{redirect_port}`
137/// 4. Wait for the callback with the authorization code
138/// 5. Exchange the code for tokens at `token_url`
139pub async fn authorize(
140    config: &OAuthConfig,
141    cancel: tokio::sync::watch::Receiver<bool>,
142    clock: &dyn crate::clock::Clock,
143) -> Result<OAuthTokens, OAuthError> {
144    let verifier = generate_code_verifier();
145    let challenge = code_challenge(&verifier);
146    let redirect_uri = format!("http://localhost:{}/callback", config.redirect_port);
147
148    let mut auth_params = vec![
149        ("response_type", "code".to_string()),
150        ("client_id", config.client_id.clone()),
151        ("redirect_uri", redirect_uri.clone()),
152        ("code_challenge", challenge),
153        ("code_challenge_method", "S256".to_string()),
154    ];
155
156    for (k, v) in &config.extra_auth_params {
157        auth_params.push((k.as_str(), v.clone()));
158    }
159
160    if !config.scopes.is_empty() {
161        auth_params.push(("scope", config.scopes.join(" ")));
162    }
163
164    let auth_url = format!(
165        "{}?{}",
166        config.auth_url,
167        serde_urlencoded::to_string(&auth_params)
168            .map_err(|e| OAuthError::Server(format!("failed to encode params: {e}")))?
169    );
170
171    // Channel to receive the authorization code from the callback handler
172    let (tx, rx) = tokio::sync::oneshot::channel::<Result<String, String>>();
173    let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx)));
174
175    let tx_for_handler = tx.clone();
176    let app = axum::Router::new().route(
177        "/callback",
178        axum::routing::get(
179            move |axum::extract::Query(params): axum::extract::Query<
180                std::collections::HashMap<String, String>,
181            >| {
182                let tx = tx_for_handler.clone();
183                async move {
184                    let mut guard = tx.lock().await;
185                    let is_error = params.contains_key("error") || !params.contains_key("code");
186                    if let Some(sender) = guard.take() {
187                        if let Some(error) = params.get("error") {
188                            let desc = params
189                                .get("error_description")
190                                .cloned()
191                                .unwrap_or_else(|| error.clone());
192                            let _ = sender.send(Err(desc));
193                        } else if let Some(code) = params.get("code") {
194                            let _ = sender.send(Ok(code.clone()));
195                        } else {
196                            let _ = sender.send(Err("no code in callback".to_string()));
197                        }
198                    }
199                    let html = if is_error {
200                        include_str!("oauth_success.html")
201                            .replace("Authorization complete", "Authorization denied")
202                            .replace(
203                                "You can close this window and return to bae.",
204                                "Authorization was denied. You can close this window and try again in bae.",
205                            )
206                    } else {
207                        include_str!("oauth_success.html").to_string()
208                    };
209                    (
210                        [
211                            (axum::http::header::CACHE_CONTROL, "no-store"),
212                            (axum::http::header::CONNECTION, "close"),
213                        ],
214                        axum::response::Html(html),
215                    )
216                }
217            },
218        ),
219    );
220
221    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", config.redirect_port))
222        .await
223        .map_err(|e| OAuthError::Server(format!("failed to bind port: {e}")))?;
224
225    // Spawn the server. The guard aborts the task on drop so future
226    // cancellation (parent .await dropped) tears the listener down too.
227    let server_guard = AbortOnDrop::new(tokio::spawn(async move {
228        if let Err(e) = axum::serve(listener, app)
229            .with_graceful_shutdown(async {
230                tokio::time::sleep(std::time::Duration::from_secs(300)).await;
231            })
232            .await
233        {
234            warn!("OAuth callback server exited with error: {e}");
235        }
236    }));
237
238    // Open the browser
239    open::that(&auth_url).map_err(|e| OAuthError::BrowserOpen(format!("{e}")))?;
240
241    info!("Opened browser for OAuth authorization, waiting for callback");
242
243    // Wait for the callback, cancellation, or timeout
244    let mut cancel = cancel;
245    let result = tokio::select! {
246        result = rx => {
247            result
248                .map_err(|_| OAuthError::Server("callback channel closed".to_string()))
249                .and_then(|r| r.map_err(OAuthError::Denied))
250        }
251        _ = cancel.wait_for(|&v| v) => {
252            Err(OAuthError::Denied("cancelled".to_string()))
253        }
254        _ = tokio::time::sleep(std::time::Duration::from_secs(300)) => {
255            Err(OAuthError::Timeout)
256        }
257    };
258
259    // Disarm the abort-on-drop and await the listener task's termination
260    // briefly. Bounded timeout because awaiting the abort could otherwise
261    // deadlock on a small thread pool with no idle worker. This matters
262    // for back-to-back sign-in flows: without the wait, the next bind on
263    // the same port can race the not-yet-released listener (no SO_REUSEADDR).
264    if let Some(handle) = server_guard.take_handle() {
265        handle.abort();
266        match tokio::time::timeout(std::time::Duration::from_millis(500), handle).await {
267            Ok(Ok(())) => {}
268            Ok(Err(e)) if e.is_cancelled() => {}
269            Ok(Err(e)) => {
270                warn!("OAuth callback server task panicked on shutdown: {e}");
271            }
272            Err(_) => {
273                warn!(
274                    "OAuth callback server did not exit within 500ms; \
275                     port {} may briefly remain in use",
276                    config.redirect_port
277                );
278            }
279        }
280    }
281
282    let code = result?;
283
284    info!("Received authorization code, exchanging for tokens");
285
286    // Exchange the code for tokens
287    exchange_code(config, &code, &verifier, &redirect_uri, clock).await
288}
289
290/// Exchange an authorization code for tokens.
291pub async fn exchange_code(
292    config: &OAuthConfig,
293    code: &str,
294    verifier: &str,
295    redirect_uri: &str,
296    clock: &dyn crate::clock::Clock,
297) -> Result<OAuthTokens, OAuthError> {
298    let client = reqwest::Client::new();
299    let mut params = vec![
300        ("grant_type", "authorization_code"),
301        ("code", code),
302        ("redirect_uri", redirect_uri),
303        ("client_id", &config.client_id),
304        ("code_verifier", verifier),
305    ];
306
307    let secret_ref;
308    if let Some(ref secret) = config.client_secret {
309        secret_ref = secret.clone();
310        params.push(("client_secret", &secret_ref));
311    }
312
313    let resp = client
314        .post(&config.token_url)
315        .form(&params)
316        .send()
317        .await
318        .map_err(|e| OAuthError::TokenExchange(format!("request failed: {e}")))?;
319
320    let status = resp.status();
321    let body = resp
322        .text()
323        .await
324        .map_err(|e| OAuthError::TokenExchange(format!("read body: {e}")))?;
325
326    let token_resp: TokenResponse = serde_json::from_str(&body)
327        .map_err(|e| OAuthError::TokenExchange(format!("parse response: {e} (body: {body})")))?;
328
329    if let Some(error) = token_resp.error {
330        let desc = token_resp.error_description.unwrap_or(error);
331        return Err(OAuthError::TokenExchange(format!(
332            "provider error (HTTP {status}): {desc}"
333        )));
334    }
335
336    let expires_at = token_resp
337        .expires_in
338        .map(|secs| clock.now().timestamp() + secs);
339
340    Ok(OAuthTokens {
341        access_token: token_resp.access_token,
342        refresh_token: token_resp.refresh_token,
343        expires_at,
344    })
345}
346
347/// Refresh an expired access token using a refresh token.
348pub async fn refresh(
349    config: &OAuthConfig,
350    refresh_token: &str,
351    clock: &dyn crate::clock::Clock,
352) -> Result<OAuthTokens, OAuthError> {
353    let client = reqwest::Client::new();
354    let mut params = vec![
355        ("grant_type", "refresh_token"),
356        ("refresh_token", refresh_token),
357        ("client_id", &config.client_id),
358    ];
359
360    let secret_ref;
361    if let Some(ref secret) = config.client_secret {
362        secret_ref = secret.clone();
363        params.push(("client_secret", &secret_ref));
364    }
365
366    let resp = client
367        .post(&config.token_url)
368        .form(&params)
369        .send()
370        .await
371        .map_err(|e| OAuthError::TokenExchange(format!("refresh request failed: {e}")))?;
372
373    let status = resp.status();
374    let body = resp
375        .text()
376        .await
377        .map_err(|e| OAuthError::TokenExchange(format!("read body: {e}")))?;
378
379    let token_resp: TokenResponse = serde_json::from_str(&body)
380        .map_err(|e| OAuthError::TokenExchange(format!("parse response: {e} (body: {body})")))?;
381
382    if let Some(error) = token_resp.error {
383        let desc = token_resp.error_description.unwrap_or(error);
384        return Err(OAuthError::TokenExchange(format!(
385            "provider error (HTTP {status}): {desc}"
386        )));
387    }
388
389    let expires_at = token_resp
390        .expires_in
391        .map(|secs| clock.now().timestamp() + secs);
392
393    // If the provider doesn't return a new refresh token, keep the old one
394    let new_refresh = token_resp
395        .refresh_token
396        .or_else(|| Some(refresh_token.to_string()));
397
398    Ok(OAuthTokens {
399        access_token: token_resp.access_token,
400        refresh_token: new_refresh,
401        expires_at,
402    })
403}
404
405/// Run an OAuth authorization flow for the given cloud provider.
406///
407/// Returns tokens on success. Only Google Drive, Dropbox, and OneDrive
408/// support OAuth; other providers return an error.
409pub async fn authorize_provider(
410    provider: crate::config::CloudProvider,
411    cancel: tokio::sync::watch::Receiver<bool>,
412    clock: &dyn crate::clock::Clock,
413) -> Result<OAuthTokens, OAuthError> {
414    use crate::config::CloudProvider;
415    use crate::storage::cloud::{dropbox, google_drive, onedrive};
416
417    let config = match provider {
418        CloudProvider::GoogleDrive => google_drive::GoogleDriveCloudHome::oauth_config(),
419        CloudProvider::Dropbox => dropbox::DropboxCloudHome::oauth_config(),
420        CloudProvider::OneDrive => onedrive::OneDriveCloudHome::oauth_config(),
421        other => {
422            return Err(OAuthError::Denied(format!("{other:?} does not use OAuth")));
423        }
424    };
425
426    authorize(&config, cancel, clock).await
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn pkce_verifier_is_url_safe() {
435        let verifier = generate_code_verifier();
436        assert!(verifier.len() >= 43);
437        assert!(verifier
438            .chars()
439            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
440    }
441
442    #[test]
443    fn pkce_challenge_is_deterministic() {
444        let verifier = "test-verifier-string";
445        let c1 = code_challenge(verifier);
446        let c2 = code_challenge(verifier);
447        assert_eq!(c1, c2);
448    }
449
450    #[test]
451    fn pkce_challenge_is_base64url() {
452        let verifier = generate_code_verifier();
453        let challenge = code_challenge(&verifier);
454        assert!(challenge
455            .chars()
456            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
457    }
458
459    #[test]
460    fn oauth_tokens_serialization_roundtrip() {
461        let tokens = OAuthTokens {
462            access_token: "at_123".to_string(),
463            refresh_token: Some("rt_456".to_string()),
464            expires_at: Some(1700000000),
465        };
466        let json = serde_json::to_string(&tokens).unwrap();
467        let parsed: OAuthTokens = serde_json::from_str(&json).unwrap();
468        assert_eq!(parsed.access_token, "at_123");
469        assert_eq!(parsed.refresh_token, Some("rt_456".to_string()));
470        assert_eq!(parsed.expires_at, Some(1700000000));
471    }
472}