//! Generic OAuth2 provider trait and types. //! //! Defines a common interface for OAuth2 authentication across different //! email providers (Fastmail, Gmail, Outlook, etc.). //! //! Providers can use default implementations for `exchange_code`, `refresh_token`, //! and `get_user_email` by implementing the simpler configuration methods, or //! override them for provider-specific behavior. use async_trait::async_trait; use base64::{engine::general_purpose::{URL_SAFE_NO_PAD, STANDARD}, Engine}; use rand::Rng; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; /// Result of starting an OAuth flow. #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct OAuthStartResult { /// URL to open in the user's browser. pub auth_url: String, /// CSRF state token to verify on callback. pub state: String, /// Local port for the callback server. pub port: u16, /// PKCE code verifier (frontend stores for token exchange). pub code_verifier: String, /// Provider identifier for routing. pub provider: String, } /// Token response from OAuth2 token exchange. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenResult { /// Access token for API calls. pub access_token: String, /// Refresh token for obtaining new access tokens. pub refresh_token: Option, /// Token expiration in seconds. pub expires_in: Option, /// Token type (usually "Bearer"). pub token_type: String, /// ID token (for OpenID Connect providers like Google/Microsoft). pub id_token: Option, /// Email address (extracted from token or discovery). #[serde(skip_deserializing)] pub email: Option, } /// Configuration for an OAuth2 provider. #[derive(Debug, Clone)] pub struct OAuthProviderConfig { /// Authorization endpoint URL. pub auth_url: String, /// Token exchange endpoint URL. pub token_url: String, /// Scopes required for email access. pub scopes: Vec, /// Whether this provider uses JMAP (vs IMAP with XOAUTH2). pub uses_jmap: bool, /// JMAP session discovery URL (if uses_jmap). pub jmap_session_url: Option, /// IMAP server hostname (if not uses_jmap). pub imap_server: Option, /// IMAP server port (if not uses_jmap). pub imap_port: Option, /// SMTP server hostname (if not uses_jmap). pub smtp_server: Option, /// SMTP server port (if not uses_jmap). pub smtp_port: Option, /// URL for fetching user info (email address). pub userinfo_url: Option, /// JSON path to email in userinfo response (e.g., "email", "mail", "username"). pub email_json_path: Vec<&'static str>, } /// How to send client credentials in token requests. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub enum ClientAuthMethod { /// Send client_id and client_secret in the request body (default). #[default] FormBody, /// Send credentials via HTTP Basic auth header. BasicAuth, /// Only send client_id (no secret required, e.g., PKCE-only flows). ClientIdOnly, } /// Trait for OAuth2 email providers. /// /// Provides default implementations for `exchange_code`, `refresh_token`, and /// `get_user_email` that work for most OAuth2 providers. Override these methods /// only when provider-specific behavior is needed. #[async_trait] pub trait OAuthProvider: Send + Sync + 'static { /// Returns the provider's identifier (e.g., "fastmail", "google", "microsoft"). fn id(&self) -> &'static str; /// Returns a human-readable name for the provider. fn display_name(&self) -> &'static str; /// Returns the provider configuration. fn config(&self) -> &OAuthProviderConfig; /// Returns the OAuth2 client ID. fn client_id(&self) -> &str; /// Returns the OAuth2 client secret (if required). fn client_secret(&self) -> Option<&str> { None } /// Returns how client credentials should be sent in token requests. fn client_auth_method(&self) -> ClientAuthMethod { ClientAuthMethod::FormBody } /// Starts the OAuth2 authorization flow. /// /// Returns the authorization URL to open in the browser and the data needed /// to complete the flow after the user authorizes. fn start_auth(&self, redirect_port: u16) -> OAuthStartResult { let code_verifier = generate_code_verifier(); let code_challenge = generate_code_challenge(&code_verifier); let state = generate_state(); let redirect_uri = format!("http://127.0.0.1:{}/", redirect_port); let config = self.config(); // Build authorization URL with PKCE let mut auth_url = format!( "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}&code_challenge={}&code_challenge_method=S256", config.auth_url, urlencoding::encode(self.client_id()), urlencoding::encode(&redirect_uri), urlencoding::encode(&config.scopes.join(" ")), urlencoding::encode(&state), urlencoding::encode(&code_challenge), ); // Add provider-specific parameters self.customize_auth_url(&mut auth_url); OAuthStartResult { auth_url, state, port: redirect_port, code_verifier, provider: self.id().to_string(), } } /// Hook to customize the authorization URL with provider-specific parameters. fn customize_auth_url(&self, _url: &mut String) {} /// Exchanges an authorization code for access and refresh tokens. /// /// Default implementation handles standard OAuth2 token exchange with PKCE. /// Respects `client_auth_method()` for credential handling. async fn exchange_code( &self, code: &str, code_verifier: &str, redirect_port: u16, ) -> Result { let redirect_uri = format!("http://127.0.0.1:{}/", redirect_port); let config = self.config(); let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(15)) .connect_timeout(std::time::Duration::from_secs(10)) .build() .map_err(|e| format!("Failed to build HTTP client: {}", e))?; let mut request = client.post(&config.token_url); // Build form params based on auth method let mut form_params: Vec<(&str, &str)> = vec![ ("grant_type", "authorization_code"), ("code", code), ("redirect_uri", &redirect_uri), ("code_verifier", code_verifier), ]; match self.client_auth_method() { ClientAuthMethod::BasicAuth => { // Send credentials via Basic auth header if let Some(secret) = self.client_secret() { let credentials = format!("{}:{}", self.client_id(), secret); request = request.header( "Authorization", format!("Basic {}", STANDARD.encode(credentials.as_bytes())) ); } } ClientAuthMethod::FormBody => { // Send credentials in form body form_params.push(("client_id", self.client_id())); if let Some(secret) = self.client_secret() { form_params.push(("client_secret", secret)); } } ClientAuthMethod::ClientIdOnly => { // Only client_id, no secret form_params.push(("client_id", self.client_id())); } } let response = request .form(&form_params) .send() .await .map_err(|e| format!("Token request failed: {}", e))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); return Err(format!("Token exchange failed ({}): {}", status, body)); } response .json() .await .map_err(|e| format!("Failed to parse token response: {}", e)) } /// Refreshes an expired access token using a refresh token. /// /// Default implementation handles standard OAuth2 token refresh. async fn refresh_token(&self, refresh_token: &str) -> Result { let config = self.config(); let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(15)) .connect_timeout(std::time::Duration::from_secs(10)) .build() .map_err(|e| format!("Failed to build HTTP client: {}", e))?; let mut request = client.post(&config.token_url); let mut form_params: Vec<(&str, &str)> = vec![ ("grant_type", "refresh_token"), ("refresh_token", refresh_token), ]; match self.client_auth_method() { ClientAuthMethod::BasicAuth => { if let Some(secret) = self.client_secret() { let credentials = format!("{}:{}", self.client_id(), secret); request = request.header( "Authorization", format!("Basic {}", STANDARD.encode(credentials.as_bytes())) ); } } ClientAuthMethod::FormBody => { form_params.push(("client_id", self.client_id())); if let Some(secret) = self.client_secret() { form_params.push(("client_secret", secret)); } } ClientAuthMethod::ClientIdOnly => { form_params.push(("client_id", self.client_id())); } } let response = request .form(&form_params) .send() .await .map_err(|e| format!("Token refresh request failed: {}", e))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); return Err(format!("Token refresh failed ({}): {}", status, body)); } response .json() .await .map_err(|e| format!("Failed to parse token response: {}", e)) } /// Extracts the user's email address from the token response or via API call. /// /// Default implementation fetches from `config.userinfo_url` and extracts /// email using `config.email_json_path`. Override for custom behavior. async fn get_user_email(&self, access_token: &str) -> Result { let config = self.config(); let userinfo_url = config.userinfo_url.as_ref() .ok_or_else(|| "No userinfo URL configured".to_string())?; let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(15)) .connect_timeout(std::time::Duration::from_secs(10)) .build() .map_err(|e| format!("Failed to build HTTP client: {}", e))?; let response = client .get(userinfo_url) .bearer_auth(access_token) .send() .await .map_err(|e| format!("Userinfo request failed: {}", e))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); return Err(format!("Userinfo request failed ({}): {}", status, body)); } let userinfo: serde_json::Value = response .json() .await .map_err(|e| format!("Failed to parse userinfo response: {}", e))?; // Try each path in order until we find an email for path in &config.email_json_path { if let Some(email) = userinfo[*path].as_str() { return Ok(email.to_string()); } } Err("No email found in userinfo response".to_string()) } } // ============ Helper Functions ============ /// Generates a cryptographically secure random string for PKCE code verifier. pub fn generate_code_verifier() -> String { let mut rng = rand::rng(); let bytes: Vec = (0..32).map(|_| rng.random()).collect(); URL_SAFE_NO_PAD.encode(bytes) } /// Generates the PKCE code challenge from the verifier. pub fn generate_code_challenge(verifier: &str) -> String { let mut hasher = Sha256::new(); hasher.update(verifier.as_bytes()); let hash = hasher.finalize(); URL_SAFE_NO_PAD.encode(hash) } /// Generates a random state token for CSRF protection. pub fn generate_state() -> String { let mut rng = rand::rng(); let bytes: Vec = (0..16).map(|_| rng.random()).collect(); URL_SAFE_NO_PAD.encode(bytes) } /// URL encoding helper (minimal implementation for OAuth params). pub mod urlencoding { pub fn encode(s: &str) -> String { let mut result = String::with_capacity(s.len() * 3); for c in s.chars() { match c { 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c), _ => { for b in c.to_string().as_bytes() { result.push_str(&format!("%{:02X}", b)); } } } } result } } #[cfg(test)] mod tests { use super::*; // ============ PKCE Code Verifier Tests ============ #[test] fn code_verifier_is_base64url_encoded() { let verifier = generate_code_verifier(); // base64url-no-pad uses only: A-Z, a-z, 0-9, -, _ assert!(verifier.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); } #[test] fn code_verifier_has_expected_length() { let verifier = generate_code_verifier(); // 32 random bytes -> base64url(32) = ceil(32*4/3) = 43 chars (no padding) assert_eq!(verifier.len(), 43); } #[test] fn code_verifier_is_unique() { let v1 = generate_code_verifier(); let v2 = generate_code_verifier(); assert_ne!(v1, v2, "Two generated verifiers should be different"); } // ============ PKCE Code Challenge Tests ============ #[test] fn code_challenge_is_base64url_encoded() { let verifier = generate_code_verifier(); let challenge = generate_code_challenge(&verifier); assert!(challenge.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); } #[test] fn code_challenge_has_sha256_length() { let verifier = generate_code_verifier(); let challenge = generate_code_challenge(&verifier); // SHA-256 = 32 bytes -> base64url(32) = 43 chars assert_eq!(challenge.len(), 43); } #[test] fn code_challenge_is_deterministic() { let verifier = "test_verifier_1234567890abcdef"; let c1 = generate_code_challenge(verifier); let c2 = generate_code_challenge(verifier); assert_eq!(c1, c2); } #[test] fn code_challenge_differs_for_different_verifiers() { let c1 = generate_code_challenge("verifier_a"); let c2 = generate_code_challenge("verifier_b"); assert_ne!(c1, c2); } #[test] fn code_challenge_matches_known_value() { // RFC 7636 Appendix B test vector: // verifier: "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" // expected challenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = generate_code_challenge(verifier); assert_eq!(challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); } // ============ State Token Tests ============ #[test] fn state_is_base64url_encoded() { let state = generate_state(); assert!(state.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); } #[test] fn state_has_expected_length() { let state = generate_state(); // 16 random bytes -> base64url(16) = ceil(16*4/3) = 22 chars assert_eq!(state.len(), 22); } #[test] fn state_is_unique() { let s1 = generate_state(); let s2 = generate_state(); assert_ne!(s1, s2, "Two generated state tokens should be different"); } // ============ URL Encoding Tests ============ #[test] fn encode_plain_string() { assert_eq!(urlencoding::encode("hello"), "hello"); } #[test] fn encode_spaces() { assert_eq!(urlencoding::encode("hello world"), "hello%20world"); } #[test] fn encode_special_characters() { assert_eq!(urlencoding::encode("a=b&c=d"), "a%3Db%26c%3Dd"); } #[test] fn encode_preserves_unreserved_characters() { // RFC 3986: unreserved = A-Z a-z 0-9 - . _ ~ let unreserved = "abcXYZ012-._~"; assert_eq!(urlencoding::encode(unreserved), unreserved); } #[test] fn encode_colons_and_slashes() { assert_eq!( urlencoding::encode("http://example.com"), "http%3A%2F%2Fexample.com" ); } #[test] fn encode_empty_string() { assert_eq!(urlencoding::encode(""), ""); } #[test] fn encode_scope_string() { let scopes = "urn:ietf:params:jmap:core urn:ietf:params:jmap:mail"; let encoded = urlencoding::encode(scopes); assert!(encoded.contains("%3A")); assert!(encoded.contains("%20")); assert!(!encoded.contains(' ')); assert!(!encoded.contains(':')); } // ============ start_auth URL Construction Tests ============ /// Minimal provider implementation for testing start_auth. struct TestProvider { id: &'static str, client_id: String, config: OAuthProviderConfig, } impl TestProvider { fn new() -> Self { Self { id: "test", client_id: "test_client_id".to_string(), config: OAuthProviderConfig { auth_url: "https://auth.example.com/authorize".to_string(), token_url: "https://auth.example.com/token".to_string(), scopes: vec!["scope1".to_string(), "scope2".to_string()], uses_jmap: false, jmap_session_url: None, imap_server: Some("imap.example.com".to_string()), imap_port: Some(993), smtp_server: Some("smtp.example.com".to_string()), smtp_port: Some(587), userinfo_url: Some("https://auth.example.com/userinfo".to_string()), email_json_path: vec!["email"], }, } } } #[async_trait] impl OAuthProvider for TestProvider { fn id(&self) -> &'static str { self.id } fn display_name(&self) -> &'static str { "Test Provider" } fn config(&self) -> &OAuthProviderConfig { &self.config } fn client_id(&self) -> &str { &self.client_id } } #[test] fn start_auth_returns_correct_provider() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert_eq!(result.provider, "test"); } #[test] fn start_auth_returns_correct_port() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert_eq!(result.port, 12345); } #[test] fn start_auth_url_contains_auth_endpoint() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert!(result.auth_url.starts_with("https://auth.example.com/authorize?")); } #[test] fn start_auth_url_contains_client_id() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert!(result.auth_url.contains("client_id=test_client_id")); } #[test] fn start_auth_url_contains_redirect_uri() { let provider = TestProvider::new(); let result = provider.start_auth(12345); // redirect_uri=http://127.0.0.1:12345/ (URL-encoded) let encoded_redirect = urlencoding::encode("http://127.0.0.1:12345/"); assert!( result.auth_url.contains(&format!("redirect_uri={}", encoded_redirect)), "Auth URL should contain redirect_uri with correct port. URL: {}", result.auth_url ); } #[test] fn start_auth_url_contains_response_type_code() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert!(result.auth_url.contains("response_type=code")); } #[test] fn start_auth_url_contains_scopes() { let provider = TestProvider::new(); let result = provider.start_auth(12345); // "scope1 scope2" URL-encoded as "scope1%20scope2" assert!(result.auth_url.contains("scope=scope1%20scope2")); } #[test] fn start_auth_url_contains_pkce_challenge() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert!(result.auth_url.contains("code_challenge=")); assert!(result.auth_url.contains("code_challenge_method=S256")); } #[test] fn start_auth_url_contains_state() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert!(result.auth_url.contains(&format!("state={}", urlencoding::encode(&result.state)))); } #[test] fn start_auth_code_verifier_is_nonempty() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert!(!result.code_verifier.is_empty()); } #[test] fn start_auth_state_is_nonempty() { let provider = TestProvider::new(); let result = provider.start_auth(12345); assert!(!result.state.is_empty()); } #[test] fn start_auth_challenge_matches_verifier() { let provider = TestProvider::new(); let result = provider.start_auth(12345); // The code_challenge in the URL should match SHA256(code_verifier) let expected_challenge = generate_code_challenge(&result.code_verifier); assert!( result.auth_url.contains(&format!("code_challenge={}", urlencoding::encode(&expected_challenge))), "code_challenge in URL should match SHA256 of code_verifier" ); } // ============ ClientAuthMethod Tests ============ #[test] fn client_auth_method_default_is_form_body() { let method = ClientAuthMethod::default(); assert_eq!(method, ClientAuthMethod::FormBody); } // ============ OAuthStartResult Tests ============ #[test] fn oauth_start_result_fields() { let result = OAuthStartResult { auth_url: "https://example.com/auth".to_string(), state: "abc123".to_string(), port: 8080, code_verifier: "verifier_xyz".to_string(), provider: "test".to_string(), }; assert_eq!(result.auth_url, "https://example.com/auth"); assert_eq!(result.state, "abc123"); assert_eq!(result.port, 8080); assert_eq!(result.code_verifier, "verifier_xyz"); assert_eq!(result.provider, "test"); } // ============ TokenResult Tests ============ #[test] fn token_result_deserialization() { let json = r#"{ "access_token": "ya29.xxx", "refresh_token": "1//xxx", "expires_in": 3600, "token_type": "Bearer", "id_token": null }"#; let result: TokenResult = serde_json::from_str(json).unwrap(); assert_eq!(result.access_token, "ya29.xxx"); assert_eq!(result.refresh_token.as_deref(), Some("1//xxx")); assert_eq!(result.expires_in, Some(3600)); assert_eq!(result.token_type, "Bearer"); assert!(result.id_token.is_none()); // email is skip_deserializing, so always None from JSON assert!(result.email.is_none()); } #[test] fn token_result_minimal_deserialization() { let json = r#"{ "access_token": "tok", "token_type": "Bearer" }"#; let result: TokenResult = serde_json::from_str(json).unwrap(); assert_eq!(result.access_token, "tok"); assert!(result.refresh_token.is_none()); assert!(result.expires_in.is_none()); } }