diff --git a/Cargo.toml b/Cargo.toml index cef0ff45..6d8f5e26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,6 +89,7 @@ lambda-web = ["dep:lambda-web", "odbc-static"] [dev-dependencies] actix-http = "3" +tokio = { version = "1", features = ["rt", "time", "test-util"] } [build-dependencies] awc = { version = "3", features = ["rustls-0_23-webpki-roots"] } diff --git a/src/webserver/oidc.rs b/src/webserver/oidc.rs index ea429acc..a6d83e20 100644 --- a/src/webserver/oidc.rs +++ b/src/webserver/oidc.rs @@ -48,6 +48,7 @@ const SQLPAGE_NONCE_COOKIE_NAME: &str = "sqlpage_oidc_nonce"; const SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX: &str = "sqlpage_oidc_state_"; const OIDC_CLIENT_MAX_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60); const OIDC_CLIENT_MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5); +const OIDC_HTTP_BODY_TIMEOUT: Duration = OIDC_CLIENT_MIN_REFRESH_INTERVAL; const SQLPAGE_OIDC_REDIRECT_COUNT_COOKIE: &str = "sqlpage_oidc_redirect_count"; const MAX_OIDC_REDIRECTS: u8 = 3; const AUTH_COOKIE_EXPIRATION: awc::cookie::time::Duration = @@ -837,7 +838,7 @@ async fn execute_oidc_request_with_awc( req = req.insert_header((name.as_str(), value.to_str()?)); } let (req_head, body) = request.into_parts(); - let mut response = req.send_body(body).await.map_err(|e| { + let response = req.send_body(body).await.map_err(|e| { anyhow!(e.to_string()).context(format!( "Failed to send request: {} {}", &req_head.method, &req_head.uri @@ -849,6 +850,7 @@ async fn execute_oidc_request_with_awc( for (name, value) in head { resp_builder = resp_builder.header(name.as_str(), value.to_str()?); } + let mut response = response.timeout(OIDC_HTTP_BODY_TIMEOUT); let body = response .body() .await diff --git a/tests/oidc/mod.rs b/tests/oidc/mod.rs index 7cf128c1..027fe6ac 100644 --- a/tests/oidc/mod.rs +++ b/tests/oidc/mod.rs @@ -12,6 +12,7 @@ use serde_json::json; use sqlpage::webserver::http::create_app; use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use std::time::Duration; use tokio_util::sync::{CancellationToken, DropGuard}; fn base64url_encode(data: &[u8]) -> String { @@ -50,6 +51,7 @@ struct ProviderState<'a> { client_id: String, auth_codes: HashMap, // code -> nonce jwt_customizer: Option>>, + token_endpoint_delay: Duration, } type ProviderStateWithLifetime<'a> = ProviderState<'a>; @@ -142,6 +144,9 @@ async fn token_endpoint( .map(|customizer| customizer(claims.clone(), &state.secret)) .unwrap_or_else(|| make_jwt(&claims, &state.secret)); + let delay = state.token_endpoint_delay; + drop(state); + let response = TokenResponse { access_token: "test_access_token".to_string(), token_type: "Bearer".to_string(), @@ -149,9 +154,14 @@ async fn token_endpoint( expires_in: 3600, }; + let json_bytes = serde_json::to_vec(&response).unwrap(); + let body = futures_util::stream::once(async move { + tokio::time::sleep(delay).await; + Ok::(web::Bytes::from(json_bytes)) + }); HttpResponse::Ok() .insert_header((header::CONTENT_TYPE, "application/json")) - .json(response) + .streaming(body) } pub struct FakeOidcProvider { @@ -185,6 +195,7 @@ impl FakeOidcProvider { client_id: client_id.clone(), auth_codes: HashMap::new(), jwt_customizer: None, + token_endpoint_delay: Duration::ZERO, })); let state_for_server = Arc::clone(&state); @@ -226,6 +237,10 @@ impl FakeOidcProvider { f(&mut state) } + pub fn set_token_endpoint_delay(&self, delay: Duration) { + self.with_state_mut(|s| s.token_endpoint_delay = delay); + } + pub fn store_auth_code(&self, code: String, nonce: String) { self.with_state_mut(|s| { s.auth_codes.insert(code, nonce); @@ -540,3 +555,50 @@ async fn test_oidc_logout_uses_correct_scheme() { let post_logout = params.get("post_logout_redirect_uri").unwrap(); assert_eq!(post_logout, "https://example.com/logged_out"); } + +/// A slow OIDC provider must not freeze the server. +/// See https://github.com/sqlpage/SQLPage/issues/1231 +#[actix_web::test] +async fn test_slow_token_endpoint_does_not_freeze_server() { + let (app, provider) = setup_oidc_test(|_| {}).await; + let mut cookies: Vec> = Vec::new(); + + let resp = request_with_cookies!(app, test::TestRequest::get().uri("/"), cookies); + assert_eq!(resp.status(), StatusCode::SEE_OTHER); + let auth_url = Url::parse(resp.headers().get("location").unwrap().to_str().unwrap()).unwrap(); + let state_param = get_query_param(&auth_url, "state"); + let nonce = get_query_param(&auth_url, "nonce"); + let redirect_uri = get_query_param(&auth_url, "redirect_uri"); + provider.store_auth_code("test_auth_code".to_string(), nonce); + + provider.set_token_endpoint_delay(Duration::from_secs(999)); + + let callback_uri = format!( + "{}?code=test_auth_code&state={}", + Url::parse(&redirect_uri).unwrap().path(), + state_param + ); + + let handle = tokio::task::spawn_local(async move { + let mut req = test::TestRequest::get().uri(&callback_uri); + for cookie in cookies.iter() { + req = req.cookie(cookie.clone()); + } + test::call_service(&app, req.to_request()).await + }); + + // Let the localhost TCP round-trip complete so awc reads response headers. + tokio::time::sleep(Duration::from_millis(50)).await; + + // Freeze time and advance past the body-read timeout. + tokio::time::pause(); + tokio::time::advance(Duration::from_secs(60)).await; + + // The body timeout should have fired, completing the request with an error + // that SQLPage handles by redirecting to the OIDC provider. + let resp = tokio::time::timeout(Duration::from_secs(1), handle) + .await + .expect("OIDC callback hung on a slow token endpoint") + .unwrap(); + assert_eq!(resp.status(), StatusCode::SEE_OTHER); +}