Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public class AuthenticationCookieMiddleware(
ILogger<AuthenticationCookieMiddleware> logger
) : IMiddleware
{
private const string? RefreshAuthenticationTokensEndpoint = "/internal-api/account-management/authentication/refresh-authentication-tokens";
private const string RefreshAuthenticationTokensEndpoint = "/internal-api/account-management/authentication/refresh-authentication-tokens";
private const string UnauthorizedReasonItemKey = "UnauthorizedReason";

public async Task InvokeAsync(HttpContext context, RequestDelegate next)
{
Expand All @@ -24,8 +25,32 @@ public async Task InvokeAsync(HttpContext context, RequestDelegate next)
await ValidateAuthenticationCookieAndConvertToHttpBearerHeader(context, refreshTokenCookieValue, accessTokenCookieValue);
}

// If session was revoked during refresh, handle based on request type
if (context.Items.TryGetValue(UnauthorizedReasonItemKey, out var reason) && reason is string unauthorizedReason)
{
if (context.Request.Path.StartsWithSegments("/api"))
{
// For API requests: return 401 immediately so JavaScript can handle it
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
context.Response.Headers[AuthenticationTokenHttpKeys.UnauthorizedReasonHeaderKey] = unauthorizedReason;
return;
}

// For non-API requests (SPA routes): delete cookies and let the page load
// The SPA will load without auth and redirect to login as needed
context.Response.Cookies.Delete(AuthenticationTokenHttpKeys.RefreshTokenCookieName);
context.Response.Cookies.Delete(AuthenticationTokenHttpKeys.AccessTokenCookieName);
}

await next(context);

// Ensure all 401 responses have an unauthorized reason header for consistent frontend handling
if (context.Response.StatusCode == StatusCodes.Status401Unauthorized &&
!context.Response.Headers.ContainsKey(AuthenticationTokenHttpKeys.UnauthorizedReasonHeaderKey))
{
context.Response.Headers[AuthenticationTokenHttpKeys.UnauthorizedReasonHeaderKey] = nameof(UnauthorizedReason.SessionNotFound);
}

if (context.Response.Headers.TryGetValue(AuthenticationTokenHttpKeys.RefreshAuthenticationTokensHeaderKey, out _))
{
logger.LogDebug("Refreshing authentication tokens as requested by endpoint");
Expand Down Expand Up @@ -71,12 +96,24 @@ private async Task ValidateAuthenticationCookieAndConvertToHttpBearerHeader(Http

context.Request.Headers.Authorization = $"Bearer {accessToken}";
}
catch (SessionRevokedException ex)
{
DeleteCookiesForApiRequestsOnly(context);
context.Items[UnauthorizedReasonItemKey] = ex.RevokedReason;
logger.LogWarning(ex, "Session revoked during token refresh. Reason: {Reason}", ex.RevokedReason);
}
catch (SecurityTokenException ex)
{
context.Response.Cookies.Delete(AuthenticationTokenHttpKeys.RefreshTokenCookieName);
context.Response.Cookies.Delete(AuthenticationTokenHttpKeys.AccessTokenCookieName);
DeleteCookiesForApiRequestsOnly(context);
context.Items[UnauthorizedReasonItemKey] = nameof(UnauthorizedReason.SessionNotFound);
logger.LogWarning(ex, "Validating or refreshing the authentication token cookies failed. {Message}", ex.Message);
}
catch (Exception ex)
{
DeleteCookiesForApiRequestsOnly(context);
context.Items[UnauthorizedReasonItemKey] = nameof(UnauthorizedReason.SessionNotFound);
logger.LogError(ex, "Unexpected exception during authentication token validation. Path: {Path}", context.Request.Path);
}
}

private async Task<(string newRefreshToken, string newAccessToken)> RefreshAuthenticationTokensAsync(string refreshToken)
Expand All @@ -91,6 +128,12 @@ private async Task ValidateAuthenticationCookieAndConvertToHttpBearerHeader(Http

if (!response.IsSuccessStatusCode)
{
var unauthorizedReason = GetUnauthorizedReason(response);
if (unauthorizedReason is not null)
{
throw new SessionRevokedException(unauthorizedReason);
}

throw new SecurityTokenException($"Failed to refresh security tokens. Response status code: {response.StatusCode}.");
}

Expand All @@ -105,6 +148,32 @@ private async Task ValidateAuthenticationCookieAndConvertToHttpBearerHeader(Http
return (newRefreshToken, newAccessToken);
}

private static string? GetUnauthorizedReason(HttpResponseMessage response)
{
if (response.Headers.TryGetValues(AuthenticationTokenHttpKeys.UnauthorizedReasonHeaderKey, out var values))
{
return values.FirstOrDefault();
}

return null;
}

/// <summary>
/// Only delete authentication cookies for API requests. For non-API requests (images, static assets),
/// keep the cookies so subsequent API requests can properly detect session issues like replay attacks.
/// The frontend's AuthenticationMiddleware only intercepts API responses, not image/asset errors.
/// </summary>
private static void DeleteCookiesForApiRequestsOnly(HttpContext context)
{
if (!context.Request.Path.StartsWithSegments("/api"))
{
return;
}

context.Response.Cookies.Delete(AuthenticationTokenHttpKeys.RefreshTokenCookieName);
context.Response.Cookies.Delete(AuthenticationTokenHttpKeys.AccessTokenCookieName);
}

private void ReplaceAuthenticationHeaderWithCookie(HttpContext context, string refreshToken, string accessToken)
{
var refreshTokenExpires = ExtractExpirationFromToken(refreshToken);
Expand Down Expand Up @@ -148,3 +217,8 @@ private DateTimeOffset ExtractExpirationFromToken(string token)
return DateTimeOffset.FromUnixTimeSeconds(long.Parse(expires));
}
}

public sealed class SessionRevokedException(string revokedReason) : SecurityTokenException($"Session has been revoked. Reason: {revokedReason}")
{
public string RevokedReason { get; } = revokedReason;
}
6 changes: 6 additions & 0 deletions application/AppGateway/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@
"Path": "/admin/{**catch-all}"
}
},
"account-management-error": {
"ClusterId": "account-management-api",
"Match": {
"Path": "/error"
}
},
"account-management-login": {
"ClusterId": "account-management-api",
"Match": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using PlatformPlatform.AccountManagement.Features.Authentication.Domain;
using PlatformPlatform.AccountManagement.Features.Users.Domain;
using PlatformPlatform.AccountManagement.Features.Users.Shared;
using PlatformPlatform.SharedKernel.Authentication;
using PlatformPlatform.SharedKernel.Authentication.TokenGeneration;
using PlatformPlatform.SharedKernel.Cqrs;
using PlatformPlatform.SharedKernel.Domain;
Expand All @@ -26,45 +27,52 @@ public sealed class RefreshAuthenticationTokensHandler(
ILogger<RefreshAuthenticationTokensHandler> logger
) : IRequestHandler<RefreshAuthenticationTokensCommand, Result>
{
private const string InvalidRefreshTokenMessage = "Invalid refresh token.";

public async Task<Result> Handle(RefreshAuthenticationTokensCommand command, CancellationToken cancellationToken)
{
var httpContext = httpContextAccessor.HttpContext ?? throw new InvalidOperationException("HttpContext is null.");

var invalidTokenHeaders = new Dictionary<string, string>
{
{ AuthenticationTokenHttpKeys.UnauthorizedReasonHeaderKey, nameof(UnauthorizedReason.SessionNotFound) }
};

if (!UserId.TryParse(httpContext.User.FindFirstValue(ClaimTypes.NameIdentifier), out var userId))
{
logger.LogWarning("No valid 'sub' claim found in refresh token");
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

if (!SessionId.TryParse(httpContext.User.FindFirstValue("sid"), out var sessionId))
{
logger.LogWarning("No valid 'sid' claim found in refresh token");
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

if (!RefreshTokenJti.TryParse(httpContext.User.FindFirstValue(JwtRegisteredClaimNames.Jti), out var jti))
{
logger.LogWarning("No valid 'jti' claim found in refresh token");
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

if (!int.TryParse(httpContext.User.FindFirstValue("ver"), out var refreshTokenVersion))
{
logger.LogWarning("No valid 'ver' claim found in refresh token");
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

var expiresClaim = httpContext.User.FindFirstValue(JwtRegisteredClaimNames.Exp);
if (expiresClaim is null)
{
logger.LogWarning("No 'exp' claim found in refresh token");
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

if (!long.TryParse(expiresClaim, out var expiresUnixSeconds))
{
logger.LogWarning("Invalid 'exp' claim format in refresh token");
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

var refreshTokenExpires = DateTimeOffset.FromUnixTimeSeconds(expiresUnixSeconds);
Expand All @@ -74,19 +82,23 @@ public async Task<Result> Handle(RefreshAuthenticationTokensCommand command, Can
if (session is null)
{
logger.LogWarning("No session found for session id '{SessionId}'", sessionId);
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

if (session.IsRevoked)
{
logger.LogWarning("Session '{SessionId}' has been revoked", session.Id);
return Result.Unauthorized("Session has been revoked.");
logger.LogWarning("Session '{SessionId}' has been revoked with reason '{RevokedReason}'", session.Id, session.RevokedReason);
var unauthorizedHeaders = new Dictionary<string, string>
{
{ AuthenticationTokenHttpKeys.UnauthorizedReasonHeaderKey, session.RevokedReason?.ToString() ?? nameof(UnauthorizedReason.Revoked) }
};
return Result.Unauthorized("Session has been revoked.", responseHeaders: unauthorizedHeaders);
}

if (session.UserId != userId)
{
logger.LogWarning("Session user id '{SessionUserId}' does not match token user id '{TokenUserId}'", session.UserId, userId);
return Result.Unauthorized("Invalid refresh token.");
return Result.Unauthorized(InvalidRefreshTokenMessage, responseHeaders: invalidTokenHeaders);
}

if (!session.IsRefreshTokenValid(jti, refreshTokenVersion, now))
Expand All @@ -95,17 +107,23 @@ public async Task<Result> Handle(RefreshAuthenticationTokensCommand command, Can
"Replay attack detected for session '{SessionId}'. Token JTI '{TokenJti}', current JTI '{CurrentJti}'. Token version '{TokenVersion}', current version '{CurrentVersion}'",
session.Id, jti, session.RefreshTokenJti, refreshTokenVersion, session.RefreshTokenVersion
);
session.Revoke(now, SessionRevokedReason.ReplayAttackDetected);
sessionRepository.Update(session);

// Atomic revocation - only one concurrent request succeeds, but all return ReplayAttackDetected
await sessionRepository.TryRevokeForReplayUnfilteredAsync(sessionId, now, cancellationToken);

events.CollectEvent(new SessionReplayDetected(session.Id, refreshTokenVersion, session.RefreshTokenVersion));
return Result.Unauthorized("Invalid refresh token. Session has been revoked due to potential replay attack.", true);
var unauthorizedHeaders = new Dictionary<string, string>
{
{ AuthenticationTokenHttpKeys.UnauthorizedReasonHeaderKey, nameof(UnauthorizedReason.ReplayAttackDetected) }
};
return Result.Unauthorized("Invalid refresh token. Session has been revoked due to potential replay attack.", true, unauthorizedHeaders);
}

var user = await userRepository.GetByIdAsync(userId, cancellationToken);
if (user is null)
{
logger.LogWarning("No user found with user id '{UserId}'", userId);
return Result.Unauthorized($"No user found with user id '{userId}'.");
return Result.Unauthorized($"No user found with user id '{userId}'.", responseHeaders: invalidTokenHeaders);
}

RefreshTokenJti tokenJti;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ public interface ISessionRepository : ICrudRepository<Session, SessionId>
/// Returns false if another concurrent request already refreshed the session.
/// </summary>
Task<bool> TryRefreshAsync(SessionId sessionId, RefreshTokenJti currentJti, int currentVersion, RefreshTokenJti newJti, DateTimeOffset now, CancellationToken cancellationToken);

/// <summary>
/// Attempts to revoke the session for a replay attack without applying tenant query filters.
/// Uses atomic update to handle concurrent requests - only one will succeed, but all callers
/// can safely return ReplayAttackDetected since the session will be revoked either way.
/// This method should only be used during token refresh where tenant context comes from the token claims.
/// </summary>
Task<bool> TryRevokeForReplayUnfilteredAsync(SessionId sessionId, DateTimeOffset now, CancellationToken cancellationToken);
}

public sealed class SessionRepository(AccountManagementDbContext accountManagementDbContext)
Expand Down Expand Up @@ -75,6 +83,21 @@ UPDATE Sessions
return rowsAffected == 1;
}

public async Task<bool> TryRevokeForReplayUnfilteredAsync(SessionId sessionId, DateTimeOffset now, CancellationToken cancellationToken)
{
var rowsAffected = await DbSet
.IgnoreQueryFilters()
.Where(s => s.Id == sessionId && s.RevokedAt == null)
.ExecuteUpdateAsync(s => s
.SetProperty(x => x.RevokedAt, now)
.SetProperty(x => x.RevokedReason, SessionRevokedReason.ReplayAttackDetected)
.SetProperty(x => x.ModifiedAt, now),
cancellationToken
);

return rowsAffected == 1;
}

public async Task<Session[]> GetActiveSessionsForUserAsync(UserId userId, CancellationToken cancellationToken)
{
var sessions = await DbSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ public enum DeviceType
Tablet
}

/// <summary>
/// Represents why a session was revoked. This is a domain concept stored in the Session aggregate.
/// For HTTP header reasons (which include additional cases like SessionNotFound), see
/// <see cref="SharedKernel.Authentication.UnauthorizedReason" />.
/// </summary>
[PublicAPI]
[JsonConverter(typeof(JsonStringEnumConverter))]
public enum SessionRevokedReason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ public async Task RefreshAuthenticationTokens_WhenReplayAttackDetected_ShouldRev

// Assert
response.StatusCode.Should().Be(HttpStatusCode.Unauthorized);
response.Headers.Should().ContainKey("x-unauthorized-reason");
response.Headers.GetValues("x-unauthorized-reason").Single().Should().Be("ReplayAttackDetected");

object[] parameters = [new { id = sessionId.ToString() }];
Connection.ExecuteScalar<string>("SELECT RevokedAt FROM Sessions WHERE Id = @id", parameters).Should().NotBeNull();
Expand All @@ -110,6 +112,8 @@ public async Task RefreshAuthenticationTokens_WhenSessionRevoked_ShouldReturnUna

// Assert
response.StatusCode.Should().Be(HttpStatusCode.Unauthorized);
response.Headers.Should().ContainKey("x-unauthorized-reason");
response.Headers.GetValues("x-unauthorized-reason").Single().Should().Be("Revoked");
TelemetryEventsCollectorSpy.CollectedEvents.Should().BeEmpty();
}

Expand All @@ -128,6 +132,8 @@ public async Task RefreshAuthenticationTokens_WhenSessionNotFound_ShouldReturnUn

// Assert
response.StatusCode.Should().Be(HttpStatusCode.Unauthorized);
response.Headers.Should().ContainKey("x-unauthorized-reason");
response.Headers.GetValues("x-unauthorized-reason").Single().Should().Be("SessionNotFound");
TelemetryEventsCollectorSpy.CollectedEvents.Should().BeEmpty();
}

Expand Down
Loading
Loading