diff --git a/src/webserver/oidc.rs b/src/webserver/oidc.rs index 21a40e24..da4f669b 100644 --- a/src/webserver/oidc.rs +++ b/src/webserver/oidc.rs @@ -51,6 +51,7 @@ const OIDC_CLIENT_MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5); const OIDC_HTTP_BODY_TIMEOUT: Duration = OIDC_CLIENT_MIN_REFRESH_INTERVAL; const SQLPAGE_OIDC_REDIRECT_COUNT_COOKIE: &str = "sqlpage_oidc_redirect_count"; const MAX_OIDC_REDIRECTS: u8 = 3; +const MAX_OIDC_PARALLEL_LOGIN_FLOWS: usize = 8; const AUTH_COOKIE_EXPIRATION: awc::cookie::time::Duration = actix_web::cookie::time::Duration::days(7); const LOGIN_FLOW_STATE_COOKIE_EXPIRATION: awc::cookie::time::Duration = @@ -455,7 +456,8 @@ fn handle_unauthenticated_request( let initial_url = request.uri().to_string(); let redirect_count = get_redirect_count(&request); - let response = build_auth_provider_redirect_response(oidc_state, &initial_url, redirect_count); + let response = + build_auth_provider_redirect_response(oidc_state, &request, &initial_url, redirect_count); MiddlewareResponse::Respond(request.into_response(response)) } @@ -487,7 +489,7 @@ fn handle_oidc_callback_error( if let Ok(http_client) = get_http_client_from_appdata(&request) { oidc_state.maybe_refresh(http_client, OIDC_CLIENT_MIN_REFRESH_INTERVAL); } - let resp = build_auth_provider_redirect_response(oidc_state, "/", redirect_count); + let resp = build_auth_provider_redirect_response(oidc_state, &request, "/", redirect_count); request.into_response(resp) } @@ -585,7 +587,6 @@ fn process_oidc_logout( .path("/") .finish(), )?; - log::debug!("User logged out successfully"); Ok(response) } @@ -736,6 +737,7 @@ fn set_auth_cookie(response: &mut HttpResponse, id_token: &OidcToken) { fn build_auth_provider_redirect_response( oidc_state: &OidcState, + request: &ServiceRequest, initial_url: &str, redirect_count: u8, ) -> HttpResponse { @@ -750,11 +752,17 @@ fn build_auth_provider_redirect_response( .same_site(actix_web::cookie::SameSite::Lax) .max_age(LOGIN_FLOW_STATE_COOKIE_EXPIRATION) .finish(); - HttpResponse::SeeOther() - .append_header((header::LOCATION, url.to_string())) - .cookie(tmp_login_flow_state_cookie) - .cookie(redirect_count_cookie) - .body("Redirecting...") + let mut response = HttpResponse::SeeOther(); + response.append_header((header::LOCATION, url.to_string())); + if let Ok(cookies) = request.cookies() { + for mut cookie in get_tmp_login_flow_state_cookies_to_evict(&cookies).cloned() { + cookie.make_removal(); + response.cookie(cookie); + } + } + response.cookie(tmp_login_flow_state_cookie); + response.cookie(redirect_count_cookie); + response.body("Redirecting...") } fn build_redirect_response(target_url: String) -> HttpResponse { @@ -1078,6 +1086,15 @@ fn get_tmp_login_flow_state_cookie( .with_context(|| format!("No {cookie_name} cookie found")) } +fn get_tmp_login_flow_state_cookies_to_evict<'a>( + cookies: &'a [Cookie<'static>], +) -> impl Iterator> + 'a { + let is_state = &|c: &Cookie<'_>| c.name().starts_with(SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX); + let login_state_count = cookies.iter().filter(|c| is_state(c)).count(); + let to_evict = login_state_count.saturating_sub(MAX_OIDC_PARALLEL_LOGIN_FLOWS - 1); + cookies.iter().filter(|c| is_state(c)).take(to_evict) +} + #[derive(Debug, Serialize, Deserialize, Clone)] struct LoginFlowState<'a> { #[serde(rename = "n")] @@ -1127,6 +1144,7 @@ fn validate_redirect_url(url: String, redirect_uri: &str) -> String { mod tests { use super::*; use actix_web::http::StatusCode; + use actix_web::{cookie::Cookie, test::TestRequest}; use openidconnect::url::Url; #[test] @@ -1182,4 +1200,25 @@ mod tests { .expect("generated URL should parse"); verify_logout_params(¶ms, secret).expect("generated URL should validate"); } + + #[test] + fn evicts_excess_tmp_login_flow_state_cookies() { + let request = (0..MAX_OIDC_PARALLEL_LOGIN_FLOWS) + .fold(TestRequest::default(), |request, i| { + request.cookie(Cookie::new( + format!("{SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX}{i}"), + format!("value-{i}"), + )) + }) + .to_srv_request(); + + let cookies = request.cookies().unwrap(); + let cookies_to_evict: Vec<_> = + get_tmp_login_flow_state_cookies_to_evict(&cookies).collect(); + + assert_eq!(cookies_to_evict.len(), 1); + assert!(cookies_to_evict[0] + .name() + .starts_with(SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX)); + } }