diff --git a/IsolatedFunctionAuth/IsolatedFunctionAuth.csproj b/IsolatedFunctionAuth/IsolatedFunctionAuth.csproj index 838d865..78dedc4 100644 --- a/IsolatedFunctionAuth/IsolatedFunctionAuth.csproj +++ b/IsolatedFunctionAuth/IsolatedFunctionAuth.csproj @@ -6,11 +6,11 @@ - - - - - + + + + + diff --git a/IsolatedFunctionAuth/Middleware/AuthenticationMiddleware.cs b/IsolatedFunctionAuth/Middleware/AuthenticationMiddleware.cs index 0db49c5..82f5b6f 100644 --- a/IsolatedFunctionAuth/Middleware/AuthenticationMiddleware.cs +++ b/IsolatedFunctionAuth/Middleware/AuthenticationMiddleware.cs @@ -41,14 +41,14 @@ public async Task Invoke( if (!TryGetTokenFromHeaders(context, out var token)) { // Unable to get token from headers - context.SetHttpResponseStatusCode(HttpStatusCode.Unauthorized); + await context.SetHttpResponseStatusCode(HttpStatusCode.Unauthorized); return; } if (!_tokenValidator.CanReadToken(token)) { // Token is malformed - context.SetHttpResponseStatusCode(HttpStatusCode.Unauthorized); + await context.SetHttpResponseStatusCode(HttpStatusCode.Unauthorized); return; } @@ -73,7 +73,7 @@ public async Task Invoke( catch (SecurityTokenException) { // Token is not valid (expired etc.) - context.SetHttpResponseStatusCode(HttpStatusCode.Unauthorized); + await context.SetHttpResponseStatusCode(HttpStatusCode.Unauthorized); return; } } diff --git a/IsolatedFunctionAuth/Middleware/AuthorizationMiddleware.cs b/IsolatedFunctionAuth/Middleware/AuthorizationMiddleware.cs index 5132106..97210e1 100644 --- a/IsolatedFunctionAuth/Middleware/AuthorizationMiddleware.cs +++ b/IsolatedFunctionAuth/Middleware/AuthorizationMiddleware.cs @@ -22,7 +22,7 @@ public async Task Invoke( var principalFeature = context.Features.Get(); if (!AuthorizePrincipal(context, principalFeature.Principal)) { - context.SetHttpResponseStatusCode(HttpStatusCode.Forbidden); + await context.SetHttpResponseStatusCode(HttpStatusCode.Forbidden); return; } diff --git a/IsolatedFunctionAuth/Middleware/FunctionContextExtensions.cs b/IsolatedFunctionAuth/Middleware/FunctionContextExtensions.cs index e90f5b9..45f44da 100644 --- a/IsolatedFunctionAuth/Middleware/FunctionContextExtensions.cs +++ b/IsolatedFunctionAuth/Middleware/FunctionContextExtensions.cs @@ -1,33 +1,25 @@ using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; using System; using System.Linq; using System.Net; using System.Reflection; +using System.Threading.Tasks; namespace IsolatedFunctionAuth.Middleware { public static class FunctionContextExtensions { - public static void SetHttpResponseStatusCode( + public static async Task SetHttpResponseStatusCode( this FunctionContext context, HttpStatusCode statusCode) { - // Terrible reflection code since I haven't found a nicer way to do this... - // For some reason the types are marked as internal - // If there's code that will break in this sample, - // it's probably here. - var coreAssembly = Assembly.Load("Microsoft.Azure.Functions.Worker.Core"); - var featureInterfaceName = "Microsoft.Azure.Functions.Worker.Context.Features.IFunctionBindingsFeature"; - var featureInterfaceType = coreAssembly.GetType(featureInterfaceName); - var bindingsFeature = context.Features.Single( - f => f.Key.FullName == featureInterfaceType.FullName).Value; - var invocationResultProp = featureInterfaceType.GetProperty("InvocationResult"); - - var grpcAssembly = Assembly.Load("Microsoft.Azure.Functions.Worker.Grpc"); - var responseDataType = grpcAssembly.GetType("Microsoft.Azure.Functions.Worker.GrpcHttpResponseData"); - var responseData = Activator.CreateInstance(responseDataType, context, statusCode); - - invocationResultProp.SetMethod.Invoke(bindingsFeature, new object[] { responseData }); + var req = await context.GetHttpRequestDataAsync(); + if (req == null) { return; } + var response = HttpResponseData.CreateResponse(req); + response.StatusCode = statusCode; + var result = context.GetInvocationResult(); + result.Value = response; } public static MethodInfo GetTargetFunctionMethod(this FunctionContext context)