Skip to content

Commit 7f27460

Browse files
authored
Unified payjoin service (#1232)
2 parents 3abd99c + 776943f commit 7f27460

File tree

24 files changed

+1542
-173
lines changed

24 files changed

+1542
-173
lines changed

Cargo-minimal.lock

Lines changed: 366 additions & 28 deletions
Large diffs are not rendered by default.

Cargo-recent.lock

Lines changed: 366 additions & 28 deletions
Large diffs are not rendered by default.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ members = [
66
"payjoin-directory",
77
"payjoin-test-utils",
88
"payjoin-ffi",
9+
"payjoin-service",
910
]
1011
resolver = "2"
1112

1213
[patch.crates-io]
1314
ohttp-relay = { path = "ohttp-relay" }
1415
payjoin = { path = "payjoin" }
1516
payjoin-directory = { path = "payjoin-directory" }
17+
payjoin-service = { path = "payjoin-service" }
1618
payjoin-test-utils = { path = "payjoin-test-utils" }
1719

1820
[profile.crane]

ohttp-relay/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ exclude = ["tests"]
1616
default = ["bootstrap"]
1717
bootstrap = ["connect-bootstrap", "ws-bootstrap"]
1818
connect-bootstrap = []
19-
ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"]
19+
ws-bootstrap = ["futures", "rustls", "tokio-tungstenite"]
2020
_test-util = []
2121

2222
[dependencies]
2323
byteorder = "1.5.0"
2424
bytes = "1.10.1"
2525
futures = { version = "0.3.31", optional = true }
26+
hex = { package = "hex-conservative", version = "0.1.1" }
2627
http = "1.3.1"
2728
http-body-util = "0.1.3"
2829
hyper = { version = "1.6.0", features = ["http1", "server"] }
@@ -31,8 +32,7 @@ hyper-rustls = { version = "0.27.7", default-features = false, features = [
3132
"http1",
3233
"ring",
3334
] }
34-
hyper-tungstenite = { version = "0.18.0", optional = true }
35-
hyper-util = { version = "0.1.16", features = ["client-legacy"] }
35+
hyper-util = { version = "0.1.16", features = ["client-legacy", "service"] }
3636
rustls = { version = "0.23.31", optional = true, default-features = false, features = [
3737
"ring",
3838
] }
@@ -44,11 +44,11 @@ tokio = { version = "1.47.1", features = [
4444
] }
4545
tokio-tungstenite = { version = "0.27.0", optional = true }
4646
tokio-util = { version = "0.7.16", features = ["net", "codec"] }
47+
tower = "0.5"
4748
tracing = "0.1.41"
4849
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
4950

5051
[dev-dependencies]
51-
hex = { package = "hex-conservative", version = "0.1.1" }
5252
mockito = "1.7.0"
5353
rcgen = "0.12"
5454
reqwest = { version = "0.12.23", default-features = false, features = [

ohttp-relay/src/bootstrap/connect.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use std::fmt::Debug;
12
use std::net::SocketAddr;
23

34
use http_body_util::combinators::BoxBody;
4-
use hyper::body::{Bytes, Incoming};
5+
use hyper::body::Bytes;
56
use hyper::upgrade::Upgraded;
67
use hyper::{Method, Request, Response};
78
use hyper_util::rt::TokioIo;
@@ -11,15 +12,16 @@ use tracing::{error, instrument};
1112
use crate::error::Error;
1213
use crate::{empty, GatewayUri};
1314

14-
pub(crate) fn is_connect_request(req: &Request<Incoming>) -> bool {
15-
Method::CONNECT == req.method()
16-
}
15+
pub(crate) fn is_connect_request<B>(req: &Request<B>) -> bool { Method::CONNECT == req.method() }
1716

1817
#[instrument]
19-
pub(crate) async fn try_upgrade(
20-
req: Request<Incoming>,
18+
pub(crate) async fn try_upgrade<B>(
19+
req: Request<B>,
2120
gateway_origin: GatewayUri,
22-
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
21+
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error>
22+
where
23+
B: Send + Debug + 'static,
24+
{
2325
let addr = gateway_origin
2426
.to_socket_addr()
2527
.await

ohttp-relay/src/bootstrap/mod.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use std::fmt::Debug;
2+
13
use http_body_util::combinators::BoxBody;
2-
use hyper::body::{Bytes, Incoming};
4+
use hyper::body::Bytes;
35
use hyper::{Request, Response};
46
use tracing::instrument;
57

@@ -13,18 +15,21 @@ pub mod connect;
1315
pub mod ws;
1416

1517
#[instrument]
16-
pub(crate) async fn handle_ohttp_keys(
17-
mut req: Request<Incoming>,
18+
pub(crate) async fn handle_ohttp_keys<B>(
19+
req: Request<B>,
1820
gateway_origin: GatewayUri,
19-
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
21+
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error>
22+
where
23+
B: Send + Debug + 'static,
24+
{
2025
#[cfg(feature = "connect-bootstrap")]
2126
if connect::is_connect_request(&req) {
2227
return connect::try_upgrade(req, gateway_origin).await;
2328
}
2429

2530
#[cfg(feature = "ws-bootstrap")]
2631
if ws::is_websocket_request(&req) {
27-
return ws::try_upgrade(&mut req, gateway_origin).await;
32+
return ws::try_upgrade(req, gateway_origin).await;
2833
}
2934

3035
Err(Error::BadRequest("Not a supported proxy upgrade request".to_string()))

ohttp-relay/src/bootstrap/ws.rs

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,116 @@
1+
use std::fmt::Debug;
12
use std::io;
23
use std::net::SocketAddr;
34
use std::pin::Pin;
45
use std::task::{Context, Poll};
56

67
use futures::{Sink, SinkExt, StreamExt};
78
use http_body_util::combinators::BoxBody;
8-
use http_body_util::BodyExt;
9-
use hyper::body::{Bytes, Incoming};
10-
use hyper::{Request, Response};
11-
use hyper_tungstenite::HyperWebsocket;
9+
use hyper::body::Bytes;
10+
use hyper::header::{CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE};
11+
use hyper::{Request, Response, StatusCode};
12+
use hyper_util::rt::TokioIo;
1213
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14+
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
1315
use tokio_tungstenite::tungstenite::protocol::Message;
1416
use tokio_tungstenite::{tungstenite, WebSocketStream};
1517
use tracing::{error, instrument};
1618

19+
use crate::empty;
1720
use crate::error::Error;
1821
use crate::gateway_uri::GatewayUri;
1922

20-
pub(crate) fn is_websocket_request(req: &Request<Incoming>) -> bool {
21-
hyper_tungstenite::is_upgrade_request(req)
23+
/// Check if the request is a WebSocket upgrade request.
24+
///
25+
/// This is done manually to support generic body types.
26+
/// When bootstrapping moves to axum, this can be replaced with
27+
/// `axum::extract::ws::WebSocketUpgrade`.
28+
pub(crate) fn is_websocket_request<B>(req: &Request<B>) -> bool {
29+
let dominated_by_upgrade = req
30+
.headers()
31+
.get(CONNECTION)
32+
.and_then(|v| v.to_str().ok())
33+
.map(|v| v.to_ascii_lowercase().contains("upgrade"))
34+
.unwrap_or(false);
35+
36+
let upgrade_to_websocket = req
37+
.headers()
38+
.get(UPGRADE)
39+
.and_then(|v| v.to_str().ok())
40+
.map(|v| v.eq_ignore_ascii_case("websocket"))
41+
.unwrap_or(false);
42+
43+
dominated_by_upgrade && upgrade_to_websocket && req.headers().contains_key(SEC_WEBSOCKET_KEY)
2244
}
2345

46+
/// Upgrade the request to a WebSocket connection and proxy to the gateway.
47+
///
48+
/// This performs the WebSocket handshake to support generic body types.
49+
/// When bootstrapping moves to axum, this can be replaced with
50+
/// `axum::extract::ws::WebSocketUpgrade`.
2451
#[instrument]
25-
pub(crate) async fn try_upgrade(
26-
req: &mut Request<Incoming>,
52+
pub(crate) async fn try_upgrade<B>(
53+
req: Request<B>,
2754
gateway_origin: GatewayUri,
28-
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
55+
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error>
56+
where
57+
B: Send + Debug + 'static,
58+
{
2959
let gateway_addr = gateway_origin
3060
.to_socket_addr()
3161
.await
3262
.map_err(|e| Error::InternalServerError(Box::new(e)))?
3363
.ok_or_else(|| Error::NotFound)?;
3464

35-
let (res, websocket) = hyper_tungstenite::upgrade(req, None)
36-
.map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?;
65+
let key = req
66+
.headers()
67+
.get(SEC_WEBSOCKET_KEY)
68+
.ok_or_else(|| Error::BadRequest("Missing Sec-WebSocket-Key header".to_string()))?
69+
.to_str()
70+
.map_err(|_| Error::BadRequest("Invalid Sec-WebSocket-Key header".to_string()))?
71+
.to_string();
72+
73+
let accept_key = derive_accept_key(key.as_bytes());
3774

3875
tokio::spawn(async move {
39-
if let Err(e) = serve_websocket(websocket, gateway_addr).await {
40-
error!("Error in websocket connection: {e}");
76+
match hyper::upgrade::on(req).await {
77+
Ok(upgraded) => {
78+
let ws_stream = WebSocketStream::from_raw_socket(
79+
TokioIo::new(upgraded),
80+
tungstenite::protocol::Role::Server,
81+
None,
82+
)
83+
.await;
84+
if let Err(e) = serve_websocket(ws_stream, gateway_addr).await {
85+
error!("Error in websocket connection: {e}");
86+
}
87+
}
88+
Err(e) => error!("WebSocket upgrade error: {}", e),
4189
}
4290
});
43-
let (parts, body) = res.into_parts();
44-
let boxbody = body.map_err(|never| match never {}).boxed();
45-
Ok(Response::from_parts(parts, boxbody))
91+
92+
let res = Response::builder()
93+
.status(StatusCode::SWITCHING_PROTOCOLS)
94+
.header(UPGRADE, "websocket")
95+
.header(CONNECTION, "Upgrade")
96+
.header(SEC_WEBSOCKET_ACCEPT, accept_key)
97+
.body(empty())
98+
.map_err(|e| Error::InternalServerError(Box::new(e)))?;
99+
100+
Ok(res)
46101
}
47102

48103
/// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa.
49-
#[instrument]
50-
async fn serve_websocket(
51-
websocket: HyperWebsocket,
104+
#[instrument(skip(ws_stream))]
105+
async fn serve_websocket<S>(
106+
ws_stream: WebSocketStream<S>,
52107
gateway_addr: SocketAddr,
53-
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
108+
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>
109+
where
110+
S: AsyncRead + AsyncWrite + Unpin,
111+
{
54112
let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?;
55-
let mut ws_io = WsIo::new(websocket.await?);
113+
let mut ws_io = WsIo::new(ws_stream);
56114
let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?;
57115
Ok(())
58116
}

0 commit comments

Comments
 (0)