diff --git a/README.md b/README.md index 0c6326ec..8e1f0b58 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ An Elixir implementation of [gRPC](http://www.grpc.io/). - [Usage](#usage) - [Simple RPC](#simple-rpc) - [HTTP Transcoding](#http-transcoding) + - [CORS](#cors) - [Start Application](#start-application) - [Features](#features) - [Benchmark](#benchmark) @@ -24,13 +25,13 @@ An Elixir implementation of [gRPC](http://www.grpc.io/). The package can be installed as: - ```elixir - def deps do - [ - {:grpc, "~> 0.9"} - ] - end - ``` +```elixir +def deps do + [ + {:grpc, "~> 0.9"} + ] +end +``` ## Usage @@ -96,7 +97,7 @@ end We will use this module [in the gRPC server startup section](#start-application). -**__Note:__** For other types of RPC call like streams see [here](interop/lib/interop/server.ex). +**Note:** For other types of RPC call like streams see [here](interop/lib/interop/server.ex). ### **HTTP Transcoding** @@ -152,6 +153,7 @@ mix protobuf.generate \ ``` 3. Enable http_transcode option in your Server module + ```elixir defmodule Helloworld.Greeter.Server do use GRPC.Server, @@ -167,6 +169,23 @@ end See full application code in [helloworld_transcoding](examples/helloworld_transcoding) example. +### **CORS** + +When accessing gRPC from a browser via HTTP transcoding or gRPC-Web, CORS headers may be required for the browser to allow access to the gRPC endpoint. Adding CORS headers can be done by using `GRPC.Server.Interceptors.CORS` as an interceptor in your `GRPC.Endpoint` module, configuring it as decribed in the module documentation: + +Example: + +```elixir +# Define your endpoint +defmodule Helloworld.Endpoint do + use GRPC.Endpoint + + intercept GRPC.Server.Interceptors.Logger + intercept GRPC.Server.Interceptors.CORS, allow_origin: "mydomain.io" + run Helloworld.Greeter.Server +end +``` + ### **Start Application** 1. Start gRPC Server in your supervisor tree or Application module: @@ -231,7 +250,7 @@ The accepted options for configuration are the ones listed on [Mint.HTTP.connect - [HTTP Transcoding](https://cloud.google.com/endpoints/docs/grpc/transcoding) - [TLS Authentication](https://grpc.io/docs/guides/auth/#supported-auth-mechanisms) - [Error handling](https://grpc.io/docs/guides/error/) -- Interceptors (See [`GRPC.Endpoint`](https://github.com/elixir-grpc/grpc/blob/master/lib/grpc/endpoint.ex)) +- [Interceptors](`GRPC.Endpoint`) - [Connection Backoff](https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md) - Data compression - [gRPC Reflection](https://github.com/elixir-grpc/grpc-reflection) diff --git a/lib/grpc/server.ex b/lib/grpc/server.ex index 461d88e1..3e8e17f3 100644 --- a/lib/grpc/server.ex +++ b/lib/grpc/server.ex @@ -152,7 +152,8 @@ defmodule GRPC.Server do path = "/#{service_name}/#{name}" grpc_type = GRPC.Service.grpc_type(rpc) - def __call_rpc__(unquote(path), :post, stream) do + def __call_rpc__(unquote(path), http_method, stream) + when http_method == :post or http_method == :options do GRPC.Server.call( unquote(service_mod), %{ @@ -178,8 +179,7 @@ defmodule GRPC.Server do | service_name: unquote(service_name), method_name: unquote(to_string(name)), grpc_type: unquote(grpc_type), - http_method: unquote(http_method), - http_transcode: unquote(http_transcode) + http_method: unquote(http_method) }, unquote(Macro.escape(put_elem(rpc, 0, func_name))), unquote(func_name) @@ -252,7 +252,7 @@ defmodule GRPC.Server do codec: codec, adapter: adapter, payload: payload, - http_transcode: true + access_mode: :http_transcoding } = stream, func_name ) do @@ -271,6 +271,10 @@ defmodule GRPC.Server do end end + defp do_handle_request(false, res_stream, %{is_preflight?: true} = stream, func_name) do + call_with_interceptors(res_stream, func_name, stream, []) + end + defp do_handle_request( false, res_stream, @@ -339,7 +343,8 @@ defmodule GRPC.Server do ) do GRPC.Telemetry.server_span(server, endpoint, func_name, stream, fn -> last = fn r, s -> - reply = apply(server, func_name, [r, s]) + # no response is rquired for preflight requests + reply = if stream.is_preflight?, do: [], else: apply(server, func_name, [r, s]) if res_stream do {:ok, stream} diff --git a/lib/grpc/server/adapters/cowboy/handler.ex b/lib/grpc/server/adapters/cowboy/handler.ex index eb1e3d88..13e9e3f7 100644 --- a/lib/grpc/server/adapters/cowboy/handler.ex +++ b/lib/grpc/server/adapters/cowboy/handler.ex @@ -28,7 +28,8 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do @type stream_state :: %{ pid: server_rpc_pid :: pid, handling_timer: timeout_timer_ref :: reference, - pending_reader: nil | pending_reader + pending_reader: nil | pending_reader, + access_mode: GRPC.Server.Stream.access_mode() } @type init_result :: {:cowboy_loop, :cowboy_req.req(), stream_state} | {:ok, :cowboy_req.req(), init_state} @@ -56,10 +57,12 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do |> String.downcase() |> String.to_existing_atom() - with {:ok, sub_type, content_type} <- find_content_type_subtype(req), + with {:ok, access_mode, sub_type, content_type} <- find_content_type_subtype(req), {:ok, codec} <- find_codec(sub_type, content_type, server), {:ok, compressor} <- find_compressor(req, server) do stream_pid = self() + http_transcode = access_mode == :http_transcoding + request_headers = :cowboy_req.headers(req) stream = %GRPC.Server.Stream{ server: server, @@ -69,8 +72,11 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do local: opts[:local], codec: codec, http_method: http_method, + http_request_headers: request_headers, + http_transcode: http_transcode, compressor: compressor, - http_transcode: transcode?(req) + is_preflight?: preflight?(req), + access_mode: access_mode } server_rpc_pid = :proc_lib.spawn_link(__MODULE__, :call_rpc, [server, route, stream]) @@ -78,7 +84,7 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do req = :cowboy_req.set_resp_headers(HTTP2.server_headers(stream), req) - timeout = :cowboy_req.header("grpc-timeout", req) + timeout = Map.get(request_headers, "grpc-timeout") timer_ref = if is_binary(timeout) do @@ -89,7 +95,16 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do ) end - {:cowboy_loop, req, %{pid: server_rpc_pid, handling_timer: timer_ref, pending_reader: nil}} + { + :cowboy_loop, + req, + %{ + pid: server_rpc_pid, + handling_timer: timer_ref, + pending_reader: nil, + access_mode: access_mode + } + } else {:error, error} -> Logger.error(fn -> inspect(error) end) @@ -121,12 +136,9 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do content_type end - find_subtype(content_type) - end - - defp find_subtype(content_type) do - {:ok, subtype} = extract_subtype(content_type) - {:ok, subtype, content_type} + {:ok, access_mode, subtype} = extract_subtype(content_type) + access_mode = resolve_access_mode(req, access_mode, subtype) + {:ok, access_mode, subtype, content_type} end defp find_compressor(req, server) do @@ -600,38 +612,43 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do end end - defp extract_subtype("application/json"), do: {:ok, "json"} - defp extract_subtype("application/grpc"), do: {:ok, "proto"} - defp extract_subtype("application/grpc+"), do: {:ok, "proto"} - defp extract_subtype("application/grpc;"), do: {:ok, "proto"} - defp extract_subtype(<<"application/grpc+", rest::binary>>), do: {:ok, rest} - defp extract_subtype(<<"application/grpc;", rest::binary>>), do: {:ok, rest} + defp extract_subtype("application/json"), do: {:ok, :http_transcoding, "json"} + defp extract_subtype("application/grpc"), do: {:ok, :grpc, "proto"} + defp extract_subtype("application/grpc+"), do: {:ok, :grpc, "proto"} + defp extract_subtype("application/grpc;"), do: {:ok, :grpc, "proto"} + defp extract_subtype(<<"application/grpc+", rest::binary>>), do: {:ok, :grpc, rest} + defp extract_subtype(<<"application/grpc;", rest::binary>>), do: {:ok, :grpc, rest} - defp extract_subtype("application/grpc-web"), do: {:ok, "proto"} - defp extract_subtype("application/grpc-web+"), do: {:ok, "proto"} - defp extract_subtype("application/grpc-web;"), do: {:ok, "proto"} - defp extract_subtype("application/grpc-web-text"), do: {:ok, "text"} - defp extract_subtype("application/grpc-web+" <> rest), do: {:ok, rest} - defp extract_subtype("application/grpc-web-text+" <> rest), do: {:ok, rest} + defp extract_subtype("application/grpc-web"), do: {:ok, :grpcweb, "proto"} + defp extract_subtype("application/grpc-web+"), do: {:ok, :grpcweb, "proto"} + defp extract_subtype("application/grpc-web;"), do: {:ok, :grpcweb, "proto"} + defp extract_subtype("application/grpc-web-text"), do: {:ok, :grpcweb, "text"} + defp extract_subtype("application/grpc-web+" <> rest), do: {:ok, :grpcweb, rest} + defp extract_subtype("application/grpc-web-text+" <> rest), do: {:ok, :grpcweb, rest} defp extract_subtype(type) do Logger.warning("Got unknown content-type #{type}, please create an issue.") - {:ok, "proto"} + {:ok, :grpc, "proto"} end - defp transcode?(%{version: "HTTP/1.1"}), do: true + defp resolve_access_mode(%{version: "HTTP/1.1"}, _detected_access_mode, _type_subtype), + do: :http_transcoding - defp transcode?(req) do - case find_content_type_subtype(req) do - {:ok, "json", _} -> true - _ -> false - end - end + defp resolve_access_mode(%{method: "OPTIONS"}, _detected_access_mode, _type_subtype), + do: :grpcweb + + defp resolve_access_mode(_req, detected_access_mode, _type_subtype), do: detected_access_mode + + defp preflight?(%{method: "OPTIONS"}), do: true + defp preflight?(_), do: false defp send_error(req, error, state, reason) do trailers = HTTP2.server_trailers(error.status, error.message) - status = if transcode?(req), do: GRPC.Status.http_code(error.status), else: 200 + status = + if state.access_mode == :http_transcoding, + do: GRPC.Status.http_code(error.status), + else: 200 if pid = Map.get(state, :pid) do exit_handler(pid, reason) diff --git a/lib/grpc/server/interceptors/cors.ex b/lib/grpc/server/interceptors/cors.ex new file mode 100644 index 00000000..4b5b5d48 --- /dev/null +++ b/lib/grpc/server/interceptors/cors.ex @@ -0,0 +1,129 @@ +defmodule GRPC.Server.Interceptors.CORS do + @moduledoc """ + Sends CORS headers when the client is using RPC via Web transcoding or gRPC-web. + + ## Options + + * `:allow_origin` - Required. A string containing the allowed origin, or a function capture (e.g. `&MyApp.MyModule.function/2)`) which takes a `req` and a `stream` and returns a string. + * `:allow_headers` - A string containing the allowed headers, or a function capture + (e.g. `&MyApp.MyModule.function/2)`) which takes a `req` and a `stream` and returns a string. Defaults to `nil`. + If defined as `nil`, the value of the `"access-control-request-headers"` request header from the client will be used in the response. + + ## Usage + + defmodule Your.Endpoint do + use GRPC.Endpoint + + intercept GRPC.Server.Interceptors.CORS + end + + defmodule Your.Endpoint do + use GRPC.Endpoint + + intercept GRPC.Server.Interceptors.CORS, allow_origin: "some.origin" + end + + + defmodule Your.Endpoint do + use GRPC.Endpoint + + def allow_origin(req, stream), do: "calculated.origin" + intercept GRPC.Server.Interceptors.CORS, allow: &Your.Endpoint.allow_origin/2 + end + """ + + @behaviour GRPC.Server.Interceptor + @impl true + def init(opts \\ []) do + # Function captures are represented as their AST in this step + # because of a Macro.escape call in the __before_compile__ step + # in endpoint.ex. + # This is not a full-on Macro context, so binary concatenations and + # variables are handled before this step. + + # TODO: use Keyword.validate! once we drop support for Elixir < 1.13 + + {allow_origin, opts} = Keyword.pop(opts, :allow_origin) + {allow_headers, opts} = Keyword.pop(opts, :allow_headers) + + if opts != [] do + raise ArgumentError, + "valid keys are [:allow_origin, :allow_headers], got: #{inspect(opts)}" + end + + allow_origin = + case allow_origin do + {:&, [], [{:/, [], [_signature, 2]}]} = fun -> + fun + + binary when is_binary(binary) -> + binary + + other -> + raise ArgumentError, + "allow_origin must be a string or a 2-arity remote function, got: #{inspect(other)}" + end + + allow_headers = + case allow_headers do + {:&, [], [{:/, [], [_signature, 2]}]} = fun -> + fun + + binary when is_binary(binary) -> + binary + + nil -> + nil + + other -> + raise ArgumentError, + ":allow_headers must be a string, a 2-arity remote function, or nil, got: #{inspect(other)}" + end + + {allow_origin, allow_headers} + end + + @impl true + def call(req, stream, next, {allow_origin, allow_headers}) do + if stream.access_mode != :grpc and + Map.get(stream.http_request_headers, "sec-fetch-mode") == "cors" do + headers = + %{} + |> add_allowed_origins(req, stream, allow_origin) + |> add_allowed_headers(req, stream, allow_headers) + + stream.adapter.set_headers(stream.payload, headers) + end + + next.(req, stream) + end + + defp add_allowed_origins(headers, req, stream, allow) do + value = + case allow do + allow when is_function(allow, 2) -> allow.(req, stream) + allow -> allow + end + + Map.put(headers, "access-control-allow-origin", value) + end + + defp add_allowed_headers( + headers, + req, + %{http_request_headers: %{"access-control-request-headers" => requested}} = stream, + allow + ) do + # include an access-control-allow-headers header only when a request headers is sent + value = + case allow do + nil -> requested + allow when is_function(allow, 2) -> allow.(req, stream) + allow -> allow + end + + Map.put(headers, "access-control-allow-headers", value) + end + + defp add_allowed_headers(headers, _req, _stream, _allowed), do: headers +end diff --git a/lib/grpc/server/stream.ex b/lib/grpc/server/stream.ex index 1c0df8c9..fa0a7c19 100644 --- a/lib/grpc/server/stream.ex +++ b/lib/grpc/server/stream.ex @@ -16,6 +16,7 @@ defmodule GRPC.Server.Stream do * `:payload` - the payload needed by the adapter * `:local` - local data initialized by user """ + @type access_mode :: :grpc | :grpcweb | :http_transcoding @type t :: %__MODULE__{ server: atom(), @@ -31,11 +32,15 @@ defmodule GRPC.Server.Stream do payload: any(), adapter: atom(), local: any(), + access_mode: access_mode, # compressor mainly is used in client decompressing, responses compressing should be set by # `GRPC.Server.set_compressor` compressor: module() | nil, + # notes that this is a preflight request, and not an actual request for data (e.g. in grpcweb) + is_preflight?: boolean(), # For http transcoding http_method: GRPC.Server.Router.http_method(), + http_request_headers: map(), http_transcode: boolean(), __interface__: map() } @@ -53,13 +58,21 @@ defmodule GRPC.Server.Stream do payload: nil, adapter: nil, local: nil, + access_mode: :grpc, compressor: nil, + is_preflight?: false, http_method: :post, + http_request_headers: %{}, http_transcode: false, __interface__: %{send_reply: &__MODULE__.send_reply/3} + def send_reply(%{is_preflight?: true} = stream, _reply, opts) do + do_send_reply(stream, [], opts) + end + def send_reply( - %{grpc_type: :server_stream, codec: codec, http_transcode: true, rpc: rpc} = stream, + %{grpc_type: :server_stream, codec: codec, access_mode: :http_transcoding, rpc: rpc} = + stream, reply, opts ) do @@ -69,7 +82,7 @@ defmodule GRPC.Server.Stream do do_send_reply(stream, [codec.encode(response), "\n"], opts) end - def send_reply(%{codec: codec, http_transcode: true, rpc: rpc} = stream, reply, opts) do + def send_reply(%{codec: codec, access_mode: :http_transcoding, rpc: rpc} = stream, reply, opts) do rule = GRPC.Service.rpc_options(rpc, :http) || %{value: %{}} response = GRPC.Server.Transcode.map_response_body(rule.value, reply) @@ -81,14 +94,14 @@ defmodule GRPC.Server.Stream do end defp do_send_reply( - %{adapter: adapter, codec: codec, http_transcode: http_transcode} = stream, + %{adapter: adapter, codec: codec, access_mode: access_mode} = stream, data, opts ) do opts = opts |> Keyword.put(:codec, codec) - |> Keyword.put(:http_transcode, http_transcode) + |> Keyword.put(:http_transcode, access_mode == :http_transcoding) adapter.send_reply(stream.payload, data, opts) diff --git a/test/grpc/server/interceptors/cors_test.exs b/test/grpc/server/interceptors/cors_test.exs new file mode 100644 index 00000000..7805ab7b --- /dev/null +++ b/test/grpc/server/interceptors/cors_test.exs @@ -0,0 +1,314 @@ +defmodule GRPC.Server.Interceptors.CORSTest.Endpoint.FunctionCapture do + use GRPC.Endpoint + + intercept(GRPC.Server.Interceptors.CORS, + allow_origin: &GRPC.Server.Interceptors.CORSTest.allow_origin/2, + allow_headers: &GRPC.Server.Interceptors.CORSTest.allow_headers/2 + ) +end + +defmodule GRPC.Server.Interceptors.CORSTest.Endpoint.BinaryConcatenation do + use GRPC.Endpoint + + origin1 = "https://subdomain1.domain.com" + origin2 = "https://subdomain2.domain.com" + + intercept( + GRPC.Server.Interceptors.CORS, + allow_origin: origin1 <> "," <> origin2, + allow_headers: "MySpecialHeader,AndAnother" + ) +end + +defmodule GRPC.Server.Interceptors.CORSTest do + use ExUnit.Case, async: false + + alias GRPC.Server.Interceptors.CORS, as: CORSInterceptor + alias GRPC.Server.Stream + + defmodule FakeRequest do + defstruct [] + end + + @server_name :server + @rpc {1, 2, 3} + @adaptor GRPC.Test.ServerAdapter + @function_header_value "from-function" + @default_http_headers %{ + "accept" => "application/grpc-web-text", + "accept-encoding" => "gzip, deflate, br, zstd", + "accept-language" => "en-US,en;q=0.5", + "connection" => "keep-alive", + "content-length" => "20", + "content-type" => "application/grpc-web-text", + "dnt" => "1", + "host" => "http://myhost:4100", + "priority" => "u=0", + "referer" => "http://localhost:3000/", + "sec-fetch-dest" => "empty", + "sec-fetch-mode" => "cors", + "sec-fetch-site" => "same-site", + "user-agent" => "Mozilla/5.0 (X11; Linux x86_64; rv:128.0) Gecko/20100101 Firefox/128.0", + "x-grpc-web" => "1", + "x-user-agent" => "grpc-web-javascript/0.1" + } + @requested_allowed_headers "Authorized" + @custom_allowed_headers "MySpecialHeader,AndAnother" + + def allow_origin(_req, _stream), do: @function_header_value + def allow_headers(_req, _stream), do: @custom_allowed_headers + + def create_stream() do + %Stream{ + adapter: @adaptor, + server: @server_name, + rpc: @rpc, + http_request_headers: @default_http_headers + } + end + + test "Sends headers CORS for for http transcoding and grpcweb requests" do + request = %FakeRequest{} + stream = create_stream() + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :http_transcode}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: "*") + ) + + assert_received({:setting_headers, _headers}, "Failed to set CORS headers during grpcweb") + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: "*") + ) + + assert_received({:setting_headers, _headers}, "Failed to set CORS headers during grpcweb") + end + + test "Does not send CORS headers for normal grpc requests" do + request = %FakeRequest{} + stream = create_stream() + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpc}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: "*") + ) + + refute_received({:setting_headers, _headers}, "Set CORS headers during grpc") + end + + test "CORS allow origin header value is configuraable with a static string" do + request = %FakeRequest{} + stream = Map.put(create_stream(), :access_mode, :grpcweb) + domain = "https://mydomain.io" + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: domain) + ) + + assert_received( + {:setting_headers, %{"access-control-allow-origin" => ^domain}}, + "Incorrect static header" + ) + end + + test "CORS allow origin init does not accept non-string arguments" do + assert_raise(ArgumentError, fn -> CORSInterceptor.init(allow_origin: :atom) end) + assert_raise(ArgumentError, fn -> CORSInterceptor.init(allow_origin: 1) end) + assert_raise(ArgumentError, fn -> CORSInterceptor.init(allow_origin: 1.0) end) + assert_raise(ArgumentError, fn -> CORSInterceptor.init(allow_origin: []) end) + assert_raise(ArgumentError, fn -> CORSInterceptor.init(allow_origin: %{}) end) + end + + test "CORS allow origin header value is configuraable with a two-arity function" do + request = %FakeRequest{} + stream = Map.put(create_stream(), :access_mode, :grpcweb) + + # fetch the interceptor state from the fake endpoint + [{_interceptor, interceptor_state}] = + GRPC.Server.Interceptors.CORSTest.Endpoint.FunctionCapture.__meta__(:interceptors) + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + interceptor_state + ) + + assert_received( + {:setting_headers, %{"access-control-allow-origin" => @function_header_value}}, + "Incorrect header when using function" + ) + end + + test "CORS allow origin header value is configuraable with binary concatenation" do + request = %FakeRequest{} + stream = Map.put(create_stream(), :access_mode, :grpcweb) + + # fetch the interceptor state from the fake endpoint + [{_interceptor, interceptor_state}] = + GRPC.Server.Interceptors.CORSTest.Endpoint.BinaryConcatenation.__meta__(:interceptors) + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + interceptor_state + ) + + assert_received( + {:setting_headers, + %{ + "access-control-allow-origin" => + "https://subdomain1.domain.com,https://subdomain2.domain.com" + }}, + "Incorrect header when using function" + ) + end + + test "CORS Access-Control-Allowed-Headers is included in response when clients request it" do + request = %FakeRequest{} + + stream = %{ + create_stream() + | access_mode: :grpcweb, + http_request_headers: + Map.put( + @default_http_headers, + "access-control-request-headers", + @requested_allowed_headers + ) + } + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: "*") + ) + + assert_received( + {:setting_headers, %{"access-control-allow-headers" => @requested_allowed_headers}}, + "Incorrect header when using function" + ) + end + + test "CORS Access-Control-Allowed-Headers is configurable with a static string" do + request = %FakeRequest{} + + stream = %{ + create_stream() + | access_mode: :grpcweb, + http_request_headers: + Map.put( + @default_http_headers, + "access-control-request-headers", + @requested_allowed_headers + ) + } + + allowed_headers = "Test" + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: "*", allow_headers: allowed_headers) + ) + + assert_received( + {:setting_headers, %{"access-control-allow-headers" => ^allowed_headers}}, + "Incorrect header when using function" + ) + end + + test "CORS Access-Control-Allowed-Headers is configurable with a two-arity function" do + request = %FakeRequest{} + + stream = %{ + create_stream() + | access_mode: :grpcweb, + http_request_headers: + Map.put( + @default_http_headers, + "access-control-request-headers", + @requested_allowed_headers + ) + } + + # fetch the interceptor state from the fake endpoint + [{_interceptor, interceptor_state}] = + GRPC.Server.Interceptors.CORSTest.Endpoint.FunctionCapture.__meta__(:interceptors) + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + interceptor_state + ) + + assert_received( + {:setting_headers, %{"access-control-allow-headers" => @custom_allowed_headers}}, + "Incorrect header when using function" + ) + end + + test "CORS only on cors sec-fetch-mode" do + request = %FakeRequest{} + + stream = %{ + create_stream() + | access_mode: :grpcweb, + http_request_headers: Map.put(@default_http_headers, "sec-fetch-mode", "same-origin") + } + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: "*") + ) + + refute_received({:setting_headers, _}, "Set CORS header") + end + + test "No CORS if missing sec-fetch-mode header" do + request = %FakeRequest{} + + stream = %{ + create_stream() + | access_mode: :grpcweb, + http_request_headers: Map.delete(@default_http_headers, "sec-fetch-mode") + } + + {:ok, :ok} = + CORSInterceptor.call( + request, + %{stream | access_mode: :grpcweb}, + fn _request, _stream -> {:ok, :ok} end, + CORSInterceptor.init(allow_origin: "*") + ) + + refute_received({:setting_headers, _}, "Set CORS header") + end +end diff --git a/test/support/test_adapter.exs b/test/support/test_adapter.exs index ff603557..4fe3c161 100644 --- a/test/support/test_adapter.exs +++ b/test/support/test_adapter.exs @@ -37,4 +37,9 @@ defmodule GRPC.Test.ServerAdapter do def has_sent_headers?(_stream) do false end + + def set_headers(stream, headers) do + send(self(), {:setting_headers, headers}) + stream + end end