Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions src/webserver/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const OIDC_CLIENT_MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5);
const OIDC_HTTP_BODY_TIMEOUT: Duration = OIDC_CLIENT_MIN_REFRESH_INTERVAL;
const SQLPAGE_OIDC_REDIRECT_COUNT_COOKIE: &str = "sqlpage_oidc_redirect_count";
const MAX_OIDC_REDIRECTS: u8 = 3;
const MAX_OIDC_PARALLEL_LOGIN_FLOWS: usize = 8;
const AUTH_COOKIE_EXPIRATION: awc::cookie::time::Duration =
actix_web::cookie::time::Duration::days(7);
const LOGIN_FLOW_STATE_COOKIE_EXPIRATION: awc::cookie::time::Duration =
Expand Down Expand Up @@ -455,7 +456,8 @@ fn handle_unauthenticated_request(

let initial_url = request.uri().to_string();
let redirect_count = get_redirect_count(&request);
let response = build_auth_provider_redirect_response(oidc_state, &initial_url, redirect_count);
let response =
build_auth_provider_redirect_response(oidc_state, &request, &initial_url, redirect_count);
MiddlewareResponse::Respond(request.into_response(response))
}

Expand Down Expand Up @@ -487,7 +489,7 @@ fn handle_oidc_callback_error(
if let Ok(http_client) = get_http_client_from_appdata(&request) {
oidc_state.maybe_refresh(http_client, OIDC_CLIENT_MIN_REFRESH_INTERVAL);
}
let resp = build_auth_provider_redirect_response(oidc_state, "/", redirect_count);
let resp = build_auth_provider_redirect_response(oidc_state, &request, "/", redirect_count);
request.into_response(resp)
}

Expand Down Expand Up @@ -585,7 +587,6 @@ fn process_oidc_logout(
.path("/")
.finish(),
)?;

log::debug!("User logged out successfully");
Ok(response)
}
Expand Down Expand Up @@ -736,6 +737,7 @@ fn set_auth_cookie(response: &mut HttpResponse, id_token: &OidcToken) {

fn build_auth_provider_redirect_response(
oidc_state: &OidcState,
request: &ServiceRequest,
initial_url: &str,
redirect_count: u8,
) -> HttpResponse {
Expand All @@ -750,11 +752,17 @@ fn build_auth_provider_redirect_response(
.same_site(actix_web::cookie::SameSite::Lax)
.max_age(LOGIN_FLOW_STATE_COOKIE_EXPIRATION)
.finish();
HttpResponse::SeeOther()
.append_header((header::LOCATION, url.to_string()))
.cookie(tmp_login_flow_state_cookie)
.cookie(redirect_count_cookie)
.body("Redirecting...")
let mut response = HttpResponse::SeeOther();
response.append_header((header::LOCATION, url.to_string()));
if let Ok(cookies) = request.cookies() {
for mut cookie in get_tmp_login_flow_state_cookies_to_evict(&cookies).cloned() {
cookie.make_removal();
response.cookie(cookie);
}
}
response.cookie(tmp_login_flow_state_cookie);
response.cookie(redirect_count_cookie);
response.body("Redirecting...")
}

fn build_redirect_response(target_url: String) -> HttpResponse {
Expand Down Expand Up @@ -1078,6 +1086,15 @@ fn get_tmp_login_flow_state_cookie(
.with_context(|| format!("No {cookie_name} cookie found"))
}

fn get_tmp_login_flow_state_cookies_to_evict<'a>(
cookies: &'a [Cookie<'static>],
) -> impl Iterator<Item = &'a Cookie<'static>> + 'a {
let is_state = &|c: &Cookie<'_>| c.name().starts_with(SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX);
let login_state_count = cookies.iter().filter(|c| is_state(c)).count();
let to_evict = login_state_count.saturating_sub(MAX_OIDC_PARALLEL_LOGIN_FLOWS - 1);
cookies.iter().filter(|c| is_state(c)).take(to_evict)
}

#[derive(Debug, Serialize, Deserialize, Clone)]
struct LoginFlowState<'a> {
#[serde(rename = "n")]
Expand Down Expand Up @@ -1127,6 +1144,7 @@ fn validate_redirect_url(url: String, redirect_uri: &str) -> String {
mod tests {
use super::*;
use actix_web::http::StatusCode;
use actix_web::{cookie::Cookie, test::TestRequest};
use openidconnect::url::Url;

#[test]
Expand Down Expand Up @@ -1182,4 +1200,25 @@ mod tests {
.expect("generated URL should parse");
verify_logout_params(&params, secret).expect("generated URL should validate");
}

#[test]
fn evicts_excess_tmp_login_flow_state_cookies() {
let request = (0..MAX_OIDC_PARALLEL_LOGIN_FLOWS)
.fold(TestRequest::default(), |request, i| {
request.cookie(Cookie::new(
format!("{SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX}{i}"),
format!("value-{i}"),
))
})
.to_srv_request();

let cookies = request.cookies().unwrap();
let cookies_to_evict: Vec<_> =
get_tmp_login_flow_state_cookies_to_evict(&cookies).collect();

assert_eq!(cookies_to_evict.len(), 1);
assert!(cookies_to_evict[0]
.name()
.starts_with(SQLPAGE_TMP_LOGIN_STATE_COOKIE_PREFIX));
}
}
Loading