diff --git a/bitreq/src/client.rs b/bitreq/src/client.rs index b5de6f2fb..f298480bf 100644 --- a/bitreq/src/client.rs +++ b/bitreq/src/client.rs @@ -9,10 +9,182 @@ use std::collections::{hash_map, HashMap, VecDeque}; use std::sync::{Arc, Mutex}; +#[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") +))] +use crate::connection::certificates::{Certificates, CertificatesBuilder}; use crate::connection::AsyncConnection; use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest}; use crate::{Error, Request, Response}; +#[derive(Clone)] +pub(crate) struct ClientConfig { + #[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + ))] + pub(crate) tls: Option, +} + +#[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") +))] +#[derive(Clone)] +pub(crate) struct TlsConfig { + pub(crate) certificates: Certificates, +} + +#[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") +))] +impl TlsConfig { + fn new(certificates: Certificates) -> Self { Self { certificates } } +} + +pub struct ClientBuilder { + capacity: usize, + #[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + ))] + certificates: Option, +} + +/// Builder for configuring a `Client` with custom settings. +/// +/// # Example +/// +/// ```no_run +/// # async fn example() -> Result<(), bitreq::Error> { +/// use bitreq::{Client, RequestExt}; +/// +/// let client = Client::builder().with_capacity(20).build()?; +/// +/// let response = bitreq::get("https://example.com") +/// .send_async_with_client(&client) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +impl ClientBuilder { + /// Creates a new `ClientBuilder` with a default pool capacity of 10. + #[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + ))] + pub fn new() -> Self { Self { capacity: 10, certificates: None } } + + /// Creates a new `ClientBuilder` with a default pool capacity of 10. + #[cfg(not(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + )))] + pub fn new() -> Self { Self { capacity: 10 } } + + /// Sets the maximum number of connections to keep in the pool. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + #[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + ))] + /// Builds the `Client` with the configured settings. + pub fn build(self) -> Result { + let build_config = if let Some(builder) = self.certificates { + let certificates = builder.build()?; + let tls_config = TlsConfig::new(certificates); + Some(ClientConfig { tls: Some(tls_config) }) + } else { + None + }; + let client_config = build_config.map(Arc::new); + + Ok(Client { + r#async: Arc::new(Mutex::new(ClientImpl { + connections: HashMap::new(), + lru_order: VecDeque::new(), + capacity: self.capacity, + client_config, + })), + }) + } + + /// Builds the `Client` with the configured settings. + #[cfg(not(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + )))] + pub fn build(self) -> Result { + Ok(Client { + r#async: Arc::new(Mutex::new(ClientImpl { + connections: HashMap::new(), + lru_order: VecDeque::new(), + capacity: self.capacity, + client_config: None, + })), + }) + } + + /// Adds a custom DER-encoded root certificate for TLS verification. + /// The certificate must be provided in DER format. This method accepts any type + /// that can be converted into a `Vec`. + /// The certificate is appended to the default trust store rather than replacing it. + /// The trust store used depends on the TLS backend: system certificates for native-tls, + /// Mozilla's root certificates(rustls-webpki) and/or system certificates(rustls-native-certs) for rustls. + /// + /// # Example + /// + /// ```no_run + /// # use bitreq::Client; + /// # async fn example() -> Result<(), bitreq::Error> { + /// let client = Client::builder() + /// .with_root_certificate(include_bytes!("../tests/test_cert.der"))? + /// .build()?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + ))] + pub fn with_root_certificate>>(mut self, cert_der: T) -> Result { + let cert_der = cert_der.into(); + if let Some(ref mut certificates) = self.certificates { + certificates.append_certificate(cert_der)?; + + return Ok(self); + } + + self.certificates = Some(CertificatesBuilder::new(Some(cert_der))?); + Ok(self) + } + + /// Disables default root certificates for TLS connections. + /// Returns [`Error::InvalidTlsConfig`] if TLS has not been configured. + #[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") + ))] + pub fn disable_default_certificates(mut self) -> Result { + match self.certificates { + Some(ref mut certificates) => certificates.disable_default()?, + None => return Err(Error::InvalidTlsConfig), + }; + + Ok(self) + } +} + +impl Default for ClientBuilder { + fn default() -> Self { Self::new() } +} + /// A client that caches connections for reuse. /// /// The client maintains a pool of up to `capacity` connections, evicting @@ -39,10 +211,11 @@ struct ClientImpl { connections: HashMap>, lru_order: VecDeque, capacity: usize, + client_config: Option>, } impl Client { - /// Creates a new `Client` with the specified connection cache capacity. + /// Creates a new `Client` with the specified connection pool capacity. /// /// # Arguments /// @@ -54,10 +227,14 @@ impl Client { connections: HashMap::new(), lru_order: VecDeque::new(), capacity, + client_config: None, })), } } + /// Create a builder for a client + pub fn builder() -> ClientBuilder { ClientBuilder::new() } + /// Sends a request asynchronously using a cached connection if available. pub async fn send_async(&self, request: Request) -> Result { let parsed_request = ParsedRequest::new(request)?; @@ -77,7 +254,13 @@ impl Client { let conn = if let Some(conn) = conn_opt { conn } else { - let connection = AsyncConnection::new(key, parsed_request.timeout_at).await?; + let client_config = { + let state = self.r#async.lock().unwrap(); + state.client_config.as_ref().map(Arc::clone) + }; + + let connection = + AsyncConnection::new(key, parsed_request.timeout_at, client_config).await?; let connection = Arc::new(connection); let mut state = self.r#async.lock().unwrap(); diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index f8b98c133..573c69a0f 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -22,6 +22,8 @@ use tokio::net::TcpStream as AsyncTcpStream; #[cfg(feature = "async")] use tokio::sync::Mutex as AsyncMutex; +#[cfg(feature = "async")] +use crate::client::ClientConfig; use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest}; #[cfg(feature = "async")] use crate::Response; @@ -29,14 +31,19 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; -#[cfg(feature = "rustls")] +#[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") +))] +pub(crate) mod certificates; +#[cfg(any(feature = "rustls", feature = "native-tls"))] mod rustls_stream; -#[cfg(feature = "rustls")] +#[cfg(any(feature = "rustls", feature = "native-tls"))] type SecuredStream = rustls_stream::SecuredStream; pub(crate) enum HttpStream { Unsecured(UnsecuredStream, Option), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] Secured(Box, Option), #[cfg(feature = "async")] Buffer(std::io::Cursor>), @@ -81,7 +88,7 @@ impl Read for HttpStream { timeout(inner, *timeout_at)?; inner.read(buf) } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] HttpStream::Secured(inner, timeout_at) => { timeout(inner.get_ref(), *timeout_at)?; inner.read(buf) @@ -111,7 +118,7 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.write(buf) } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.write(buf) @@ -137,7 +144,7 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.flush() } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.flush() @@ -158,13 +165,13 @@ impl Write for HttpStream { } } -#[cfg(feature = "tokio-rustls")] +#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] type AsyncSecuredStream = rustls_stream::AsyncSecuredStream; #[cfg(feature = "async")] pub(crate) enum AsyncHttpStream { Unsecured(AsyncTcpStream), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] Secured(Box), } @@ -177,7 +184,7 @@ impl AsyncRead for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_read(cx, buf), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_read(cx, buf), } } @@ -192,7 +199,7 @@ impl AsyncWrite for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_write(cx, buf), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_write(cx, buf), } } @@ -200,7 +207,7 @@ impl AsyncWrite for AsyncHttpStream { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_flush(cx), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_flush(cx), } } @@ -208,7 +215,7 @@ impl AsyncWrite for AsyncHttpStream { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_shutdown(cx), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_shutdown(cx), } } @@ -238,6 +245,7 @@ struct AsyncConnectionState { /// Defaults to 60 seconds after open to align with nginx's default timeout of 75 seconds, but /// can be overridden by the `Keep-Alive` header. socket_new_requests_timeout: Mutex, + client_config: Option>, } #[cfg(feature = "async")] @@ -266,15 +274,15 @@ impl AsyncConnection { pub(crate) async fn new( params: ConnectionParams<'_>, timeout_at: Option, + client_config: Option>, ) -> Result { + let client_config_ref = &client_config; + let future = async move { let socket = Self::connect(params).await?; if params.https { - #[cfg(not(feature = "tokio-rustls"))] - return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "tokio-rustls")] - rustls_stream::wrap_async_stream(socket, params.host).await + Self::wrap_async_stream(socket, params.host, client_config_ref).await } else { Ok(AsyncHttpStream::Unsecured(socket)) } @@ -295,9 +303,36 @@ impl AsyncConnection { readable_request_id: AtomicUsize::new(0), min_dropped_reader_id: AtomicUsize::new(usize::MAX), socket_new_requests_timeout: Mutex::new(Instant::now() + Duration::from_secs(60)), + client_config, })))) } + /// Call the correct wrapper function depending on whether client_configs are present + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] + async fn wrap_async_stream( + socket: AsyncTcpStream, + host: &str, + client_config: &Option>, + ) -> Result { + if let Some(client_config) = client_config { + let tls_config = client_config.tls.as_ref().unwrap(); + let certificates = tls_config.certificates.clone(); + rustls_stream::wrap_async_stream_with_configs(socket, host, certificates).await + } else { + rustls_stream::wrap_async_stream(socket, host).await + } + } + + /// Error treatment function, should not be called under normal circustances + #[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))] + async fn wrap_async_stream( + _socket: AsyncTcpStream, + _host: &str, + _client_config: &Option>, + ) -> Result { + Err(Error::HttpsFeatureNotEnabled) + } + async fn tcp_connect(host: &str, port: u16) -> Result { #[cfg(feature = "log")] log::trace!("Looking up host {host}"); @@ -446,9 +481,13 @@ impl AsyncConnection { retry_new_connection!(_internal); }; (_internal) => { - let new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at) - .await?; + let config = conn.client_config.as_ref().map(Arc::clone); + let new_connection = AsyncConnection::new( + request.connection_params(), + request.timeout_at, + config, + ) + .await?; *self.0.lock().unwrap() = Arc::clone(&*new_connection.0.lock().unwrap()); core::mem::drop(read); // Note that this cannot recurse infinitely as we'll always be able to send at @@ -653,13 +692,10 @@ impl Connection { let socket = Self::connect(params, timeout_at)?; let stream = if params.https { - #[cfg(not(feature = "rustls"))] + #[cfg(not(any(feature = "rustls", feature = "native-tls")))] return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "rustls")] - { - let tls = rustls_stream::wrap_stream(socket, params.host)?; - HttpStream::Secured(Box::new(tls), timeout_at) - } + #[cfg(any(feature = "rustls", feature = "native-tls"))] + rustls_stream::wrap_stream(socket, params.host)? } else { HttpStream::create_unsecured(socket, timeout_at) }; @@ -806,7 +842,8 @@ async fn async_handle_redirects( let new_connection; if needs_new_connection { new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at).await?; + AsyncConnection::new(request.connection_params(), request.timeout_at, None) + .await?; connection = &new_connection; } connection.send(request).await diff --git a/bitreq/src/connection/certificates.rs b/bitreq/src/connection/certificates.rs new file mode 100644 index 000000000..2e9212181 --- /dev/null +++ b/bitreq/src/connection/certificates.rs @@ -0,0 +1,128 @@ +#[cfg(any(feature = "rustls", feature = "native-tls"))] +use std::sync::Arc; + +#[cfg(all(feature = "native-tls", not(feature = "rustls")))] +use native_tls::{Certificate, TlsConnector, TlsConnectorBuilder}; +#[cfg(feature = "rustls")] +use rustls::RootCertStore; +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +use tokio_native_tls::TlsConnector as AsyncTlsConnector; +#[cfg(feature = "rustls-webpki")] +use webpki_roots::TLS_SERVER_ROOTS; + +use crate::Error; + +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +pub(crate) struct CertificatesBuilder { + pub(crate) inner: RootCertStore, + pub(crate) disable_default: bool, +} + +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +pub(crate) struct CertificatesBuilder { + pub(crate) inner: TlsConnectorBuilder, +} + +impl CertificatesBuilder { + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + pub(crate) fn new(cert_der: Option>) -> Result { + let mut certificates = Self { inner: RootCertStore::empty(), disable_default: false }; + + if let Some(cert_der) = cert_der { + certificates.append_certificate(cert_der)?; + } + + Ok(certificates) + } + + #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] + pub(crate) fn new(cert_der: Option>) -> Result { + let builder = TlsConnector::builder(); + let mut certificates = Self { inner: builder }; + + if let Some(cert_der) = cert_der { + certificates.append_certificate(cert_der)?; + } + + Ok(certificates) + } + + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + pub(crate) fn append_certificate(&mut self, cert_der: Vec) -> Result<&mut Self, Error> { + self.inner.add(&rustls::Certificate(cert_der)).map_err(Error::RustlsAppendCert)?; + + Ok(self) + } + + #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] + pub(crate) fn append_certificate(&mut self, cert_der: Vec) -> Result<&mut Self, Error> { + let certificate = Certificate::from_der(&cert_der)?; + self.inner.add_root_certificate(certificate); + + Ok(self) + } + + #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] + pub(crate) fn build(self) -> Result { + let connector = self.inner.build()?; + let async_connector = AsyncTlsConnector::from(connector); + + Ok(Certificates(Arc::new(async_connector))) + } + + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + pub(crate) fn build(mut self) -> Result { + if !self.disable_default { + self.with_root_certificates(); + } + + Ok(Certificates(Arc::new(self.inner))) + } + + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + fn with_root_certificates(&mut self) -> &mut Self { + // Try to load native certs + #[cfg(feature = "https-rustls-probe")] + if let Ok(os_roots) = rustls_native_certs::load_native_certs() { + for root_cert in os_roots { + // Ignore erroneous OS certificates, there's nothing + // to do differently in that situation anyways. + let _ = self.inner.add(&rustls::Certificate(root_cert.0)); + } + } + + #[cfg(feature = "rustls-webpki")] + { + #[allow(deprecated)] + // Need to use add_server_trust_anchors to compile with rustls 0.21.1 + self.inner.add_server_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + } + self + } + + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + pub(crate) fn disable_default(&mut self) -> Result<&mut Self, Error> { + self.disable_default = true; + Ok(self) + } + + #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] + pub(crate) fn disable_default(&mut self) -> Result<&mut Self, Error> { + self.inner.disable_built_in_roots(true); + Ok(self) + } +} + +#[derive(Clone)] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +pub(crate) struct Certificates(pub(crate) Arc); + +#[derive(Clone)] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +pub(crate) struct Certificates(pub(crate) Arc); diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index 01a3c417f..26e47bda7 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -5,6 +5,7 @@ use alloc::sync::Arc; #[cfg(feature = "rustls")] use core::convert::TryFrom; +#[cfg(any(feature = "rustls", feature = "native-tls"))] use std::io; use std::net::TcpStream; use std::sync::OnceLock; @@ -20,10 +21,18 @@ use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; -#[cfg(feature = "tokio-rustls")] -use super::{AsyncHttpStream, AsyncTcpStream}; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +#[cfg(any(feature = "rustls", feature = "native-tls"))] +use super::HttpStream; +#[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") +))] use super::{AsyncHttpStream, AsyncTcpStream}; +#[cfg(any( + all(feature = "native-tls", feature = "tokio-native-tls"), + all(feature = "rustls", feature = "tokio-rustls") +))] +use crate::connection::certificates::Certificates; use crate::Error; #[cfg(feature = "rustls")] @@ -63,8 +72,17 @@ fn build_client_config() -> Arc { Arc::new(config) } +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +fn build_rustls_client_config(certificates: Arc) -> Arc { + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(certificates) + .with_no_client_auth(); + Arc::new(config) +} + #[cfg(feature = "rustls")] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); let dns_name = match ServerName::try_from(host) { @@ -73,10 +91,12 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + let dns_name = match ServerName::try_from(host) { + Ok(result) => result, + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + let certificates = certificates.0; + let client_config = build_rustls_client_config(certificates); + let connector = TlsConnector::from(client_config); + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = connector.connect(dns_name, tcp).await.map_err(Error::IoError)?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} + #[cfg(all(feature = "native-tls", not(feature = "rustls")))] pub type SecuredStream = TlsStream; @@ -115,7 +159,7 @@ static CONNECTOR: OnceLock> = OnceLock::new(); #[cfg(all(feature = "native-tls", not(feature = "rustls")))] fn native_tls_err(e: HandshakeError) -> Error { match e { - HandshakeError::Failure(e) => Error::NativeTlsError(e), + HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err), HandshakeError::WouldBlock(_) => { debug_assert!(false, "We shouldn't hit a blocking error"); Error::Other("Got a WouldBlock error from native-tls") @@ -125,22 +169,27 @@ fn native_tls_err(e: HandshakeError) -> Error { #[cfg(all(feature = "native-tls", not(feature = "rustls")))] fn build_tls_connector() -> Result { - TlsConnector::builder().build().map_err(Error::NativeTlsError) + TlsConnector::builder().build().map_err(Error::from) } #[cfg(all(feature = "native-tls", not(feature = "rustls")))] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); // TODO: Once we can `get_or_try_init`, so that instead // https://github.com/rust-lang/rust/issues/109737 - let connector = CONNECTOR.get_or_init(build_tls_connector)?; + let connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; #[cfg(feature = "log")] log::trace!("Establishing TLS session to {host}."); - connector.connect(host, tcp).map_err(native_tls_err) + let tls = connector.connect(host, tcp).map_err(native_tls_err)?; + + Ok(HttpStream::Secured(Box::new(tls), None)) } #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] @@ -156,12 +205,36 @@ pub(super) async fn wrap_async_stream( // TODO: Once we can `get_or_try_init`, so that instead // https://github.com/rust-lang/rust/issues/109737 - let connector = AsyncTlsConnector::from(CONNECTOR.get_or_init(build_tls_connector)?.clone()); + let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let async_connector = AsyncTlsConnector::from(sync_connector); + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = async_connector.connect(host, tcp).await?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} + +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +pub(super) async fn wrap_async_stream_with_configs( + tcp: AsyncTcpStream, + host: &str, + client_configs: Certificates, +) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + + let async_connector = client_configs.0; #[cfg(feature = "log")] log::trace!("Establishing TLS session to {host}."); - let tls = connector.connect(host, tcp).await.map_err(native_tls_err)?; + let tls = async_connector.connect(host, tcp).await?; Ok(AsyncHttpStream::Secured(Box::new(tls))) } diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index ca9d1421d..200ad0abe 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -22,9 +22,18 @@ pub enum Error { #[cfg(feature = "rustls")] /// Ran into a rustls error while creating the connection. RustlsCreateConnection(rustls::Error), + #[cfg(feature = "rustls")] + /// Ran into a rustls error while appending a certificate. + RustlsAppendCert(rustls::Error), #[cfg(feature = "native-tls")] /// Ran into a native-tls error while creating the connection. NativeTlsCreateConnection(native_tls::Error), + #[cfg(feature = "native-tls")] + /// Ran into a native-tls error while appending a certificate. + NativeTlsAppendCert, + #[cfg(any(feature = "rustls", feature = "native-tls"))] + /// The current TLS configuration is invalid. + InvalidTlsConfig, /// Ran into an IO problem while loading the response. #[cfg(feature = "std")] IoError(io::Error), @@ -104,8 +113,14 @@ impl fmt::Display for Error { InvalidUtf8InBody(err) => write!(f, "{}", err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => write!(f, "error appending certificate: {}", err), + #[cfg(feature = "native-tls")] + NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err), #[cfg(feature = "native-tls")] - NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {err}"), + NativeTlsAppendCert => write!(f, "error appending certificate"), + #[cfg(any(feature = "rustls", feature = "native-tls"))] + InvalidTlsConfig => write!(f, "error disabling default certificates. Must have custom cert."), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), MalformedChunkEnd => write!(f, "chunk did not end after reading the expected amount of bytes"), MalformedContentLength => write!(f, "non-usize content length"), @@ -147,6 +162,8 @@ impl error::Error for Error { InvalidUtf8InBody(err) => Some(err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => Some(err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => Some(err), _ => None, } } @@ -160,3 +177,8 @@ impl From for Error { impl From for Error { fn from(other: UrlParseError) -> Error { Error::InvalidUrl(other) } } + +#[cfg(feature = "native-tls")] +impl From for Error { + fn from(err: native_tls::Error) -> Error { Error::NativeTlsCreateConnection(err) } +} diff --git a/bitreq/src/request.rs b/bitreq/src/request.rs index d39d6d89a..24411bc0b 100644 --- a/bitreq/src/request.rs +++ b/bitreq/src/request.rs @@ -327,7 +327,7 @@ impl Request { #[cfg(feature = "async")] pub async fn send_async(self) -> Result { let parsed_request = ParsedRequest::new(self)?; - AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at) + AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at, None) .await? .send(parsed_request) .await diff --git a/bitreq/tests/ca_cert.der b/bitreq/tests/ca_cert.der new file mode 100644 index 000000000..994da6aa5 Binary files /dev/null and b/bitreq/tests/ca_cert.der differ diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 8d357f354..5aa500d34 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -16,6 +16,117 @@ async fn test_https() { assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); } +#[tokio::test] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +async fn test_https() { + // TODO: Implement this locally. + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); + // Test reusing the existing connection in client: + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client_builder() { + setup(); + let client = bitreq::Client::builder().build().unwrap(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +async fn test_https_with_client_builder() { + setup(); + let client = bitreq::Client::builder().build().unwrap(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client_builder_and_cert() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let client = bitreq::Client::builder() + .with_root_certificate(cert_der.as_slice()) + .unwrap() + .build() + .unwrap(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +async fn test_https_with_client_builder_and_cert() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let client = bitreq::Client::builder() + .with_root_certificate(cert_der.as_slice()) + .unwrap() + .build() + .unwrap(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +async fn test_https_with_multiple_certs() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let ca_der = include_bytes!("ca_cert.der"); + + let client = bitreq::Client::builder() + .with_root_certificate(cert_der.as_slice()) + .unwrap() + .with_root_certificate(ca_der.as_slice()) + .unwrap() + .build() + .unwrap(); + + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_multiple_certs() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let ca_der = include_bytes!("ca_cert.der"); + + let client = bitreq::Client::builder() + .with_root_certificate(cert_der.as_slice()) + .unwrap() + .with_root_certificate(ca_der.as_slice()) + .unwrap() + .build() + .unwrap(); + + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "json-using-serde")] async fn test_json_using_serde() { diff --git a/bitreq/tests/test_cert.der b/bitreq/tests/test_cert.der new file mode 100644 index 000000000..f8d4129e3 Binary files /dev/null and b/bitreq/tests/test_cert.der differ