|
| 1 | +use std::fmt::Debug; |
1 | 2 | use std::io; |
2 | 3 | use std::net::SocketAddr; |
3 | 4 | use std::pin::Pin; |
4 | 5 | use std::task::{Context, Poll}; |
5 | 6 |
|
6 | 7 | use futures::{Sink, SinkExt, StreamExt}; |
7 | 8 | use http_body_util::combinators::BoxBody; |
8 | | -use http_body_util::BodyExt; |
9 | | -use hyper::body::{Bytes, Incoming}; |
10 | | -use hyper::{Request, Response}; |
11 | | -use hyper_tungstenite::HyperWebsocket; |
| 9 | +use hyper::body::Bytes; |
| 10 | +use hyper::header::{CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE}; |
| 11 | +use hyper::{Request, Response, StatusCode}; |
| 12 | +use hyper_util::rt::TokioIo; |
12 | 13 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| 14 | +use tokio_tungstenite::tungstenite::handshake::derive_accept_key; |
13 | 15 | use tokio_tungstenite::tungstenite::protocol::Message; |
14 | 16 | use tokio_tungstenite::{tungstenite, WebSocketStream}; |
15 | 17 | use tracing::{error, instrument}; |
16 | 18 |
|
| 19 | +use crate::empty; |
17 | 20 | use crate::error::Error; |
18 | 21 | use crate::gateway_uri::GatewayUri; |
19 | 22 |
|
20 | | -pub(crate) fn is_websocket_request(req: &Request<Incoming>) -> bool { |
21 | | - hyper_tungstenite::is_upgrade_request(req) |
| 23 | +/// Check if the request is a WebSocket upgrade request. |
| 24 | +/// |
| 25 | +/// This is done manually to support generic body types. |
| 26 | +/// When bootstrapping moves to axum, this can be replaced with |
| 27 | +/// `axum::extract::ws::WebSocketUpgrade`. |
| 28 | +pub(crate) fn is_websocket_request<B>(req: &Request<B>) -> bool { |
| 29 | + let dominated_by_upgrade = req |
| 30 | + .headers() |
| 31 | + .get(CONNECTION) |
| 32 | + .and_then(|v| v.to_str().ok()) |
| 33 | + .map(|v| v.to_ascii_lowercase().contains("upgrade")) |
| 34 | + .unwrap_or(false); |
| 35 | + |
| 36 | + let upgrade_to_websocket = req |
| 37 | + .headers() |
| 38 | + .get(UPGRADE) |
| 39 | + .and_then(|v| v.to_str().ok()) |
| 40 | + .map(|v| v.eq_ignore_ascii_case("websocket")) |
| 41 | + .unwrap_or(false); |
| 42 | + |
| 43 | + dominated_by_upgrade && upgrade_to_websocket && req.headers().contains_key(SEC_WEBSOCKET_KEY) |
22 | 44 | } |
23 | 45 |
|
| 46 | +/// Upgrade the request to a WebSocket connection and proxy to the gateway. |
| 47 | +/// |
| 48 | +/// This performs the WebSocket handshake to support generic body types. |
| 49 | +/// When bootstrapping moves to axum, this can be replaced with |
| 50 | +/// `axum::extract::ws::WebSocketUpgrade`. |
24 | 51 | #[instrument] |
25 | | -pub(crate) async fn try_upgrade( |
26 | | - req: &mut Request<Incoming>, |
| 52 | +pub(crate) async fn try_upgrade<B>( |
| 53 | + req: Request<B>, |
27 | 54 | gateway_origin: GatewayUri, |
28 | | -) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> { |
| 55 | +) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> |
| 56 | +where |
| 57 | + B: Send + Debug + 'static, |
| 58 | +{ |
29 | 59 | let gateway_addr = gateway_origin |
30 | 60 | .to_socket_addr() |
31 | 61 | .await |
32 | 62 | .map_err(|e| Error::InternalServerError(Box::new(e)))? |
33 | 63 | .ok_or_else(|| Error::NotFound)?; |
34 | 64 |
|
35 | | - let (res, websocket) = hyper_tungstenite::upgrade(req, None) |
36 | | - .map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?; |
| 65 | + let key = req |
| 66 | + .headers() |
| 67 | + .get(SEC_WEBSOCKET_KEY) |
| 68 | + .ok_or_else(|| Error::BadRequest("Missing Sec-WebSocket-Key header".to_string()))? |
| 69 | + .to_str() |
| 70 | + .map_err(|_| Error::BadRequest("Invalid Sec-WebSocket-Key header".to_string()))? |
| 71 | + .to_string(); |
| 72 | + |
| 73 | + let accept_key = derive_accept_key(key.as_bytes()); |
37 | 74 |
|
38 | 75 | tokio::spawn(async move { |
39 | | - if let Err(e) = serve_websocket(websocket, gateway_addr).await { |
40 | | - error!("Error in websocket connection: {e}"); |
| 76 | + match hyper::upgrade::on(req).await { |
| 77 | + Ok(upgraded) => { |
| 78 | + let ws_stream = WebSocketStream::from_raw_socket( |
| 79 | + TokioIo::new(upgraded), |
| 80 | + tungstenite::protocol::Role::Server, |
| 81 | + None, |
| 82 | + ) |
| 83 | + .await; |
| 84 | + if let Err(e) = serve_websocket(ws_stream, gateway_addr).await { |
| 85 | + error!("Error in websocket connection: {e}"); |
| 86 | + } |
| 87 | + } |
| 88 | + Err(e) => error!("WebSocket upgrade error: {}", e), |
41 | 89 | } |
42 | 90 | }); |
43 | | - let (parts, body) = res.into_parts(); |
44 | | - let boxbody = body.map_err(|never| match never {}).boxed(); |
45 | | - Ok(Response::from_parts(parts, boxbody)) |
| 91 | + |
| 92 | + let res = Response::builder() |
| 93 | + .status(StatusCode::SWITCHING_PROTOCOLS) |
| 94 | + .header(UPGRADE, "websocket") |
| 95 | + .header(CONNECTION, "Upgrade") |
| 96 | + .header(SEC_WEBSOCKET_ACCEPT, accept_key) |
| 97 | + .body(empty()) |
| 98 | + .map_err(|e| Error::InternalServerError(Box::new(e)))?; |
| 99 | + |
| 100 | + Ok(res) |
46 | 101 | } |
47 | 102 |
|
48 | 103 | /// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa. |
49 | | -#[instrument] |
50 | | -async fn serve_websocket( |
51 | | - websocket: HyperWebsocket, |
| 104 | +#[instrument(skip(ws_stream))] |
| 105 | +async fn serve_websocket<S>( |
| 106 | + ws_stream: WebSocketStream<S>, |
52 | 107 | gateway_addr: SocketAddr, |
53 | | -) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { |
| 108 | +) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> |
| 109 | +where |
| 110 | + S: AsyncRead + AsyncWrite + Unpin, |
| 111 | +{ |
54 | 112 | let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?; |
55 | | - let mut ws_io = WsIo::new(websocket.await?); |
| 113 | + let mut ws_io = WsIo::new(ws_stream); |
56 | 114 | let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?; |
57 | 115 | Ok(()) |
58 | 116 | } |
|
0 commit comments