Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ lambda-web = ["dep:lambda-web", "odbc-static"]

[dev-dependencies]
actix-http = "3"
tokio = { version = "1", features = ["rt", "time", "test-util"] }

[build-dependencies]
awc = { version = "3", features = ["rustls-0_23-webpki-roots"] }
Expand Down
4 changes: 3 additions & 1 deletion src/webserver/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const SQLPAGE_NONCE_COOKIE_NAME: &str = "sqlpage_oidc_nonce";
const SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX: &str = "sqlpage_oidc_state_";
const OIDC_CLIENT_MAX_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60);
const OIDC_CLIENT_MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5);
const OIDC_HTTP_BODY_TIMEOUT: Duration = OIDC_CLIENT_MIN_REFRESH_INTERVAL;
const SQLPAGE_OIDC_REDIRECT_COUNT_COOKIE: &str = "sqlpage_oidc_redirect_count";
const MAX_OIDC_REDIRECTS: u8 = 3;
const AUTH_COOKIE_EXPIRATION: awc::cookie::time::Duration =
Expand Down Expand Up @@ -837,7 +838,7 @@ async fn execute_oidc_request_with_awc(
req = req.insert_header((name.as_str(), value.to_str()?));
}
let (req_head, body) = request.into_parts();
let mut response = req.send_body(body).await.map_err(|e| {
let response = req.send_body(body).await.map_err(|e| {
anyhow!(e.to_string()).context(format!(
"Failed to send request: {} {}",
&req_head.method, &req_head.uri
Expand All @@ -849,6 +850,7 @@ async fn execute_oidc_request_with_awc(
for (name, value) in head {
resp_builder = resp_builder.header(name.as_str(), value.to_str()?);
}
let mut response = response.timeout(OIDC_HTTP_BODY_TIMEOUT);
let body = response
.body()
.await
Expand Down
64 changes: 63 additions & 1 deletion tests/oidc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use serde_json::json;
use sqlpage::webserver::http::create_app;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio_util::sync::{CancellationToken, DropGuard};

fn base64url_encode(data: &[u8]) -> String {
Expand Down Expand Up @@ -50,6 +51,7 @@ struct ProviderState<'a> {
client_id: String,
auth_codes: HashMap<String, String>, // code -> nonce
jwt_customizer: Option<Box<JwtCustomizer<'a>>>,
token_endpoint_delay: Duration,
}

type ProviderStateWithLifetime<'a> = ProviderState<'a>;
Expand Down Expand Up @@ -142,16 +144,24 @@ async fn token_endpoint(
.map(|customizer| customizer(claims.clone(), &state.secret))
.unwrap_or_else(|| make_jwt(&claims, &state.secret));

let delay = state.token_endpoint_delay;
drop(state);

let response = TokenResponse {
access_token: "test_access_token".to_string(),
token_type: "Bearer".to_string(),
id_token,
expires_in: 3600,
};

let json_bytes = serde_json::to_vec(&response).unwrap();
let body = futures_util::stream::once(async move {
tokio::time::sleep(delay).await;
Ok::<web::Bytes, actix_web::Error>(web::Bytes::from(json_bytes))
});
HttpResponse::Ok()
.insert_header((header::CONTENT_TYPE, "application/json"))
.json(response)
.streaming(body)
}

pub struct FakeOidcProvider {
Expand Down Expand Up @@ -185,6 +195,7 @@ impl FakeOidcProvider {
client_id: client_id.clone(),
auth_codes: HashMap::new(),
jwt_customizer: None,
token_endpoint_delay: Duration::ZERO,
}));

let state_for_server = Arc::clone(&state);
Expand Down Expand Up @@ -226,6 +237,10 @@ impl FakeOidcProvider {
f(&mut state)
}

pub fn set_token_endpoint_delay(&self, delay: Duration) {
self.with_state_mut(|s| s.token_endpoint_delay = delay);
}

pub fn store_auth_code(&self, code: String, nonce: String) {
self.with_state_mut(|s| {
s.auth_codes.insert(code, nonce);
Expand Down Expand Up @@ -540,3 +555,50 @@ async fn test_oidc_logout_uses_correct_scheme() {
let post_logout = params.get("post_logout_redirect_uri").unwrap();
assert_eq!(post_logout, "https://example.com/logged_out");
}

/// A slow OIDC provider must not freeze the server.
/// See https://github.com/sqlpage/SQLPage/issues/1231
#[actix_web::test]
async fn test_slow_token_endpoint_does_not_freeze_server() {
let (app, provider) = setup_oidc_test(|_| {}).await;
let mut cookies: Vec<Cookie<'static>> = Vec::new();

let resp = request_with_cookies!(app, test::TestRequest::get().uri("/"), cookies);
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
let auth_url = Url::parse(resp.headers().get("location").unwrap().to_str().unwrap()).unwrap();
let state_param = get_query_param(&auth_url, "state");
let nonce = get_query_param(&auth_url, "nonce");
let redirect_uri = get_query_param(&auth_url, "redirect_uri");
provider.store_auth_code("test_auth_code".to_string(), nonce);

provider.set_token_endpoint_delay(Duration::from_secs(999));

let callback_uri = format!(
"{}?code=test_auth_code&state={}",
Url::parse(&redirect_uri).unwrap().path(),
state_param
);

let handle = tokio::task::spawn_local(async move {
let mut req = test::TestRequest::get().uri(&callback_uri);
for cookie in cookies.iter() {
req = req.cookie(cookie.clone());
}
test::call_service(&app, req.to_request()).await
});

// Let the localhost TCP round-trip complete so awc reads response headers.
tokio::time::sleep(Duration::from_millis(50)).await;

// Freeze time and advance past the body-read timeout.
tokio::time::pause();
tokio::time::advance(Duration::from_secs(60)).await;

// The body timeout should have fired, completing the request with an error
// that SQLPage handles by redirecting to the OIDC provider.
let resp = tokio::time::timeout(Duration::from_secs(1), handle)
.await
.expect("OIDC callback hung on a slow token endpoint")
.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
}