Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/GeneratedEndpoints/MinimalApiGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public sealed class MinimalApiGenerator : IIncrementalGenerator
private const string NameAttributeNamedParameter = "Name";
private const string SummaryAttributeNamedParameter = "Summary";
private const string DescriptionAttributeNamedParameter = "Description";
private const string ResponseTypeAttributeNamedParameter = "ResponseType";

private const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute";
private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}";
Expand Down Expand Up @@ -355,10 +356,10 @@ namespace {{AttributesNamespace}};
[global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)]
internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute
{
/// <summary>
/// Gets the response type produced by the endpoint.
/// </summary>
public global::System.Type ResponseType { get; }
/// <summary>
/// Gets the response type produced by the endpoint.
/// </summary>
public global::System.Type ResponseType { get; init; } = default!;

/// <summary>
/// Gets the HTTP status code returned by the endpoint.
Expand All @@ -378,13 +379,11 @@ internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribu
/// <summary>
/// Initializes a new instance of the <see cref="{{ProducesResponseAttributeName}}"/> class.
/// </summary>
/// <param name="responseType">The CLR type of the response body.</param>
/// <param name="statusCode">The HTTP status code returned by the endpoint.</param>
/// <param name="contentType">The primary content type produced by the endpoint.</param>
/// <param name="additionalContentTypes">Additional content types produced by the endpoint.</param>
public {{ProducesResponseAttributeName}}(global::System.Type responseType, int statusCode = 200, string? contentType = null, params string[] additionalContentTypes)
public {{ProducesResponseAttributeName}}(int statusCode = 200, string? contentType = null, params string[] additionalContentTypes)
{
ResponseType = responseType ?? throw new global::System.ArgumentNullException(nameof(responseType));
StatusCode = statusCode;
ContentType = contentType;
AdditionalContentTypes = additionalContentTypes ?? [];
Expand Down Expand Up @@ -971,18 +970,17 @@ private static void TryAddProducesMetadata(
? GetStringArrayValues(attribute.ConstructorArguments[2])
: null;
}
else if (attribute.ConstructorArguments.Length >= 1 &&
attribute.ConstructorArguments[0].Value is ITypeSymbol responseTypeSymbol)
else if (GetNamedTypeSymbol(attribute, ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol)
{
responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
statusCode = attribute.ConstructorArguments.Length > 1 && attribute.ConstructorArguments[1].Value is int producesStatusCode
statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode
? producesStatusCode
: 200;
contentType = attribute.ConstructorArguments.Length > 2
? NormalizeOptionalContentType(attribute.ConstructorArguments[2].Value as string)
contentType = attribute.ConstructorArguments.Length > 1
? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string)
: null;
additionalContentTypes = attribute.ConstructorArguments.Length > 3
? GetStringArrayValues(attribute.ConstructorArguments[3])
additionalContentTypes = attribute.ConstructorArguments.Length > 2
? GetStringArrayValues(attribute.ConstructorArguments[2])
: null;
}
else
Expand All @@ -994,6 +992,17 @@ private static void TryAddProducesMetadata(
producesList.Add(new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes));
}

private static ITypeSymbol? GetNamedTypeSymbol(AttributeData attribute, string namedParameter)
{
foreach (var namedArg in attribute.NamedArguments)
{
if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol)
return typeSymbol;
}

return null;
}

private static EquatableImmutableArray<string> MergeUnion(EquatableImmutableArray<string>? existing, IEnumerable<string> values)
{
var list = new List<string>();
Expand Down
Loading