diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 6cd6336..7636570 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -58,6 +58,7 @@ public sealed class MinimalApiGenerator : IIncrementalGenerator private const string NameAttributeNamedParameter = "Name"; private const string SummaryAttributeNamedParameter = "Summary"; private const string DescriptionAttributeNamedParameter = "Description"; + private const string ResponseTypeAttributeNamedParameter = "ResponseType"; private const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; @@ -355,10 +356,10 @@ namespace {{AttributesNamespace}}; [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute { - /// - /// Gets the response type produced by the endpoint. - /// - public global::System.Type ResponseType { get; } + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type ResponseType { get; init; } = default!; /// /// Gets the HTTP status code returned by the endpoint. @@ -378,13 +379,11 @@ internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribu /// /// Initializes a new instance of the class. /// - /// The CLR type of the response body. /// The HTTP status code returned by the endpoint. /// The primary content type produced by the endpoint. /// Additional content types produced by the endpoint. - public {{ProducesResponseAttributeName}}(global::System.Type responseType, int statusCode = 200, string? contentType = null, params string[] additionalContentTypes) + public {{ProducesResponseAttributeName}}(int statusCode = 200, string? contentType = null, params string[] additionalContentTypes) { - ResponseType = responseType ?? throw new global::System.ArgumentNullException(nameof(responseType)); StatusCode = statusCode; ContentType = contentType; AdditionalContentTypes = additionalContentTypes ?? []; @@ -971,18 +970,17 @@ private static void TryAddProducesMetadata( ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; } - else if (attribute.ConstructorArguments.Length >= 1 && - attribute.ConstructorArguments[0].Value is ITypeSymbol responseTypeSymbol) + else if (GetNamedTypeSymbol(attribute, ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) { responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - statusCode = attribute.ConstructorArguments.Length > 1 && attribute.ConstructorArguments[1].Value is int producesStatusCode + statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode ? producesStatusCode : 200; - contentType = attribute.ConstructorArguments.Length > 2 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[2].Value as string) + contentType = attribute.ConstructorArguments.Length > 1 + ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; - additionalContentTypes = attribute.ConstructorArguments.Length > 3 - ? GetStringArrayValues(attribute.ConstructorArguments[3]) + additionalContentTypes = attribute.ConstructorArguments.Length > 2 + ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; } else @@ -994,6 +992,17 @@ private static void TryAddProducesMetadata( producesList.Add(new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes)); } + private static ITypeSymbol? GetNamedTypeSymbol(AttributeData attribute, string namedParameter) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol) + return typeSymbol; + } + + return null; + } + private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) { var list = new List();