diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs
index 6cd6336..7636570 100644
--- a/src/GeneratedEndpoints/MinimalApiGenerator.cs
+++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs
@@ -58,6 +58,7 @@ public sealed class MinimalApiGenerator : IIncrementalGenerator
private const string NameAttributeNamedParameter = "Name";
private const string SummaryAttributeNamedParameter = "Summary";
private const string DescriptionAttributeNamedParameter = "Description";
+ private const string ResponseTypeAttributeNamedParameter = "ResponseType";
private const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute";
private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}";
@@ -355,10 +356,10 @@ namespace {{AttributesNamespace}};
[global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)]
internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute
{
- ///
- /// Gets the response type produced by the endpoint.
- ///
- public global::System.Type ResponseType { get; }
+ ///
+ /// Gets the response type produced by the endpoint.
+ ///
+ public global::System.Type ResponseType { get; init; } = default!;
///
/// Gets the HTTP status code returned by the endpoint.
@@ -378,13 +379,11 @@ internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribu
///
/// Initializes a new instance of the class.
///
- /// The CLR type of the response body.
/// The HTTP status code returned by the endpoint.
/// The primary content type produced by the endpoint.
/// Additional content types produced by the endpoint.
- public {{ProducesResponseAttributeName}}(global::System.Type responseType, int statusCode = 200, string? contentType = null, params string[] additionalContentTypes)
+ public {{ProducesResponseAttributeName}}(int statusCode = 200, string? contentType = null, params string[] additionalContentTypes)
{
- ResponseType = responseType ?? throw new global::System.ArgumentNullException(nameof(responseType));
StatusCode = statusCode;
ContentType = contentType;
AdditionalContentTypes = additionalContentTypes ?? [];
@@ -971,18 +970,17 @@ private static void TryAddProducesMetadata(
? GetStringArrayValues(attribute.ConstructorArguments[2])
: null;
}
- else if (attribute.ConstructorArguments.Length >= 1 &&
- attribute.ConstructorArguments[0].Value is ITypeSymbol responseTypeSymbol)
+ else if (GetNamedTypeSymbol(attribute, ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol)
{
responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
- statusCode = attribute.ConstructorArguments.Length > 1 && attribute.ConstructorArguments[1].Value is int producesStatusCode
+ statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode
? producesStatusCode
: 200;
- contentType = attribute.ConstructorArguments.Length > 2
- ? NormalizeOptionalContentType(attribute.ConstructorArguments[2].Value as string)
+ contentType = attribute.ConstructorArguments.Length > 1
+ ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string)
: null;
- additionalContentTypes = attribute.ConstructorArguments.Length > 3
- ? GetStringArrayValues(attribute.ConstructorArguments[3])
+ additionalContentTypes = attribute.ConstructorArguments.Length > 2
+ ? GetStringArrayValues(attribute.ConstructorArguments[2])
: null;
}
else
@@ -994,6 +992,17 @@ private static void TryAddProducesMetadata(
producesList.Add(new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes));
}
+ private static ITypeSymbol? GetNamedTypeSymbol(AttributeData attribute, string namedParameter)
+ {
+ foreach (var namedArg in attribute.NamedArguments)
+ {
+ if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol)
+ return typeSymbol;
+ }
+
+ return null;
+ }
+
private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values)
{
var list = new List();