mirror of
https://github.com/poem-web/poem.git
synced 2026-01-25 04:18:25 +00:00
Merge branch 'hyper-1'
This commit is contained in:
@@ -45,8 +45,8 @@ quick-xml = { version = "0.30.0", features = ["serialize"] }
|
||||
base64 = "0.21.0"
|
||||
serde_urlencoded = "0.7.1"
|
||||
indexmap = "2.0.0"
|
||||
reqwest = { version = "0.11.23", default-features = false }
|
||||
|
||||
# rustls, update together
|
||||
hyper-rustls = { version = "0.24.0", default-features = false }
|
||||
rustls = "0.21.0"
|
||||
tokio-rustls = "0.24.0"
|
||||
rustls = "0.22.0"
|
||||
tokio-rustls = "0.25.0"
|
||||
|
||||
@@ -11,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>
|
||||
.unwrap(),
|
||||
);
|
||||
let request = Request::new(HelloRequest {
|
||||
name: "Tonic".into(),
|
||||
name: "Poem".into(),
|
||||
});
|
||||
let response = client.say_hello(request).await?;
|
||||
println!("RESPONSE={response:?}");
|
||||
|
||||
@@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>
|
||||
.unwrap(),
|
||||
);
|
||||
let request = Request::new(HelloRequest {
|
||||
name: "Tonic".into(),
|
||||
name: "Poem".into(),
|
||||
});
|
||||
let response = client.say_hello(request).await?;
|
||||
println!("RESPONSE={response:?}");
|
||||
|
||||
@@ -21,9 +21,8 @@ json-codec = ["serde", "serde_json"]
|
||||
poem = { workspace = true, default-features = true }
|
||||
|
||||
futures-util.workspace = true
|
||||
hyper = { version = "0.14.20", features = ["client"] }
|
||||
async-stream = "0.3.3"
|
||||
tokio = { workspace = true, features = ["io-util", "rt", "sync"] }
|
||||
tokio = { workspace = true, features = ["io-util", "rt", "sync", "net"] }
|
||||
flate2 = "1.0.24"
|
||||
itoa = "1.0.2"
|
||||
percent-encoding = "2.1.0"
|
||||
@@ -32,16 +31,18 @@ prost = "0.12.0"
|
||||
base64 = "0.21.0"
|
||||
prost-types = "0.12.0"
|
||||
tokio-stream = { workspace = true, features = ["sync"] }
|
||||
hyper-rustls = { workspace = true, features = [
|
||||
"webpki-roots",
|
||||
"http2",
|
||||
"native-tokio",
|
||||
] }
|
||||
serde = { workspace = true, optional = true }
|
||||
serde_json = { workspace = true, optional = true }
|
||||
rustls.workspace = true
|
||||
rustls = { workspace = true }
|
||||
thiserror.workspace = true
|
||||
fastrand = "2.0.0"
|
||||
http = "1.0.0"
|
||||
hyper = { version = "1.0.0", features = ["http1", "http2"] }
|
||||
hyper-util = { version = "0.1.1", features = ["client-legacy", "tokio"] }
|
||||
http-body-util = "0.1.0"
|
||||
tokio-rustls.workspace = true
|
||||
tower-service = "0.3.2"
|
||||
webpki-roots = "0.26"
|
||||
|
||||
[build-dependencies]
|
||||
poem-grpc-build.workspace = true
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
use std::{io::Error as IoError, sync::Arc};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::TryStreamExt;
|
||||
use hyper_rustls::HttpsConnectorBuilder;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
|
||||
use poem::{
|
||||
http::{
|
||||
header, header::InvalidHeaderValue, uri::InvalidUri, Extensions, HeaderValue, Method,
|
||||
@@ -14,10 +16,13 @@ use rustls::ClientConfig as TlsClientConfig;
|
||||
|
||||
use crate::{
|
||||
codec::Codec,
|
||||
connector::HttpsConnector,
|
||||
encoding::{create_decode_response_body, create_encode_request_body},
|
||||
Code, Metadata, Request, Response, Status, Streaming,
|
||||
};
|
||||
|
||||
pub(crate) type BoxBody = http_body_util::combinators::BoxBody<Bytes, IoError>;
|
||||
|
||||
/// A configuration for GRPC client
|
||||
#[derive(Default)]
|
||||
pub struct ClientConfig {
|
||||
@@ -392,29 +397,17 @@ fn create_client_endpoint(
|
||||
config: ClientConfig,
|
||||
) -> Arc<dyn Endpoint<Output = HttpResponse> + 'static> {
|
||||
let mut config = config;
|
||||
let cli = match config.tls_config.take() {
|
||||
Some(tls_config) => hyper::Client::builder().http2_only(true).build(
|
||||
HttpsConnectorBuilder::new()
|
||||
.with_tls_config(tls_config)
|
||||
.https_or_http()
|
||||
.enable_http2()
|
||||
.build(),
|
||||
),
|
||||
None => hyper::Client::builder().http2_only(true).build(
|
||||
HttpsConnectorBuilder::new()
|
||||
.with_webpki_roots()
|
||||
.https_or_http()
|
||||
.enable_http2()
|
||||
.build(),
|
||||
),
|
||||
};
|
||||
let cli = Client::builder(TokioExecutor::new())
|
||||
.http2_only(true)
|
||||
.build(HttpsConnector::new(config.tls_config.take()));
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
Arc::new(poem::endpoint::make(move |request| {
|
||||
let config = config.clone();
|
||||
let cli = cli.clone();
|
||||
async move {
|
||||
let mut request: hyper::Request<hyper::Body> = request.into();
|
||||
let mut request: hyper::Request<BoxBody> = request.into();
|
||||
|
||||
if config.uris.is_empty() {
|
||||
return Err(poem::Error::from_string(
|
||||
@@ -443,7 +436,12 @@ fn create_client_endpoint(
|
||||
}
|
||||
|
||||
let resp = cli.request(request).await.map_err(to_boxed_error)?;
|
||||
Ok::<_, poem::Error>(HttpResponse::from(resp))
|
||||
let (parts, body) = resp.into_parts();
|
||||
|
||||
Ok::<_, poem::Error>(HttpResponse::from(hyper::Response::from_parts(
|
||||
parts,
|
||||
body.map_err(|err| IoError::other(err)),
|
||||
)))
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
155
poem-grpc/src/connector.rs
Normal file
155
poem-grpc/src/connector.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use std::{
|
||||
io::{Error as IoError, Result as IoResult},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::{future::BoxFuture, FutureExt};
|
||||
use http::{uri::Scheme, Uri};
|
||||
use hyper::rt::{Read, ReadBufCursor, Write};
|
||||
use hyper_util::{
|
||||
client::legacy::connect::{Connected, Connection},
|
||||
rt::TokioIo,
|
||||
};
|
||||
use rustls::{ClientConfig, RootCertStore};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
use tower_service::Service;
|
||||
|
||||
pub(crate) enum MaybeHttpsStream {
|
||||
TcpStream(TokioIo<TcpStream>),
|
||||
TlsStream {
|
||||
stream: TokioIo<TlsStream<TcpStream>>,
|
||||
is_http2: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl Read for MaybeHttpsStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: ReadBufCursor<'_>,
|
||||
) -> Poll<IoResult<()>> {
|
||||
match self.get_mut() {
|
||||
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for MaybeHttpsStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, IoError>> {
|
||||
match self.get_mut() {
|
||||
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
match self.get_mut() {
|
||||
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_flush(cx),
|
||||
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
match self.get_mut() {
|
||||
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for MaybeHttpsStream {
|
||||
fn connected(&self) -> Connected {
|
||||
match self {
|
||||
MaybeHttpsStream::TcpStream(_) => Connected::new(),
|
||||
MaybeHttpsStream::TlsStream { is_http2, .. } => {
|
||||
let mut connected = Connected::new();
|
||||
if *is_http2 {
|
||||
connected = connected.negotiated_h2();
|
||||
}
|
||||
connected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct HttpsConnector {
|
||||
tls_config: Option<ClientConfig>,
|
||||
}
|
||||
|
||||
impl HttpsConnector {
|
||||
#[inline]
|
||||
pub(crate) fn new(tls_config: Option<ClientConfig>) -> Self {
|
||||
HttpsConnector { tls_config }
|
||||
}
|
||||
}
|
||||
|
||||
impl Service<Uri> for HttpsConnector {
|
||||
type Response = MaybeHttpsStream;
|
||||
type Error = IoError;
|
||||
type Future = BoxFuture<'static, Result<MaybeHttpsStream, IoError>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, uri: Uri) -> Self::Future {
|
||||
do_connect(uri, self.tls_config.clone()).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn default_tls_config() -> ClientConfig {
|
||||
let mut root_cert_store = RootCertStore::empty();
|
||||
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
ClientConfig::builder()
|
||||
.with_root_certificates(root_cert_store)
|
||||
.with_no_client_auth()
|
||||
}
|
||||
|
||||
async fn do_connect(
|
||||
uri: Uri,
|
||||
tls_config: Option<ClientConfig>,
|
||||
) -> Result<MaybeHttpsStream, IoError> {
|
||||
let scheme = uri
|
||||
.scheme()
|
||||
.ok_or_else(|| IoError::other("missing scheme"))?
|
||||
.clone();
|
||||
let host = uri
|
||||
.host()
|
||||
.ok_or_else(|| IoError::other("missing host"))?
|
||||
.to_string();
|
||||
let port = uri
|
||||
.port_u16()
|
||||
.unwrap_or_else(|| if scheme == Scheme::HTTPS { 443 } else { 80 });
|
||||
|
||||
if scheme == Scheme::HTTP {
|
||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||
Ok(MaybeHttpsStream::TcpStream(TokioIo::new(stream)))
|
||||
} else if scheme == Scheme::HTTPS {
|
||||
let mut tls_config = tls_config.unwrap_or_else(default_tls_config);
|
||||
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let connector = TlsConnector::from(Arc::new(tls_config));
|
||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||
let domain = host.try_into().map_err(IoError::other)?;
|
||||
let mut is_http2 = false;
|
||||
let stream = connector
|
||||
.connect_with(domain, stream, |conn| {
|
||||
is_http2 = conn.alpn_protocol() == Some(b"h2");
|
||||
})
|
||||
.await?;
|
||||
Ok(MaybeHttpsStream::TlsStream {
|
||||
stream: TokioIo::new(stream),
|
||||
is_http2,
|
||||
})
|
||||
} else {
|
||||
Err(IoError::other("invalid scheme"))
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
use std::io::Result as IoResult;
|
||||
use std::io::{Error as IoError, Result as IoResult};
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use flate2::read::GzDecoder;
|
||||
use futures_util::StreamExt;
|
||||
use hyper::{body::HttpBody, HeaderMap};
|
||||
use http_body_util::{BodyExt, StreamBody};
|
||||
use hyper::{body::Frame, HeaderMap};
|
||||
use poem::Body;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
use crate::{
|
||||
client::BoxBody,
|
||||
codec::{Decoder, Encoder},
|
||||
Code, Status, Streaming,
|
||||
};
|
||||
@@ -76,18 +80,20 @@ pub(crate) fn create_decode_request_body<T: Decoder>(
|
||||
mut decoder: T,
|
||||
body: Body,
|
||||
) -> Streaming<T::Item> {
|
||||
let mut body: hyper::Body = body.into();
|
||||
let mut body: BoxBody = body.into();
|
||||
|
||||
Streaming::new(async_stream::try_stream! {
|
||||
let mut frame_decoder = DataFrameDecoder::default();
|
||||
|
||||
loop {
|
||||
match body.data().await.transpose().map_err(Status::from_std_error)? {
|
||||
Some(data) => {
|
||||
frame_decoder.put_slice(data);
|
||||
while let Some(data) = frame_decoder.next()? {
|
||||
let message = decoder.decode(&data).map_err(Status::from_std_error)?;
|
||||
yield message;
|
||||
match body.frame().await.transpose().map_err(Status::from_std_error)? {
|
||||
Some(frame) => {
|
||||
if let Ok(data) = frame.into_data() {
|
||||
frame_decoder.put_slice(data);
|
||||
while let Some(data) = frame_decoder.next()? {
|
||||
let message = decoder.decode(&data).map_err(Status::from_std_error)?;
|
||||
yield message;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@@ -103,7 +109,7 @@ pub(crate) fn create_encode_response_body<T: Encoder>(
|
||||
mut encoder: T,
|
||||
mut stream: Streaming<T::Item>,
|
||||
) -> Body {
|
||||
let (mut sender, body) = hyper::Body::channel();
|
||||
let (tx, rx) = mpsc::channel(16);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = BytesMut::new();
|
||||
@@ -112,45 +118,51 @@ pub(crate) fn create_encode_response_body<T: Encoder>(
|
||||
match item {
|
||||
Ok(message) => {
|
||||
if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message) {
|
||||
if sender.send_data(data).await.is_err() {
|
||||
if tx.send(Frame::data(data)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(status) => {
|
||||
let _ = sender.send_trailers(status.to_headers()).await;
|
||||
_ = tx.send(Frame::trailers(status.to_headers())).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = sender
|
||||
.send_trailers(Status::new(Code::Ok).to_headers())
|
||||
_ = tx
|
||||
.send(Frame::trailers(Status::new(Code::Ok).to_headers()))
|
||||
.await;
|
||||
});
|
||||
|
||||
body.into()
|
||||
BodyExt::boxed(StreamBody::new(
|
||||
ReceiverStream::new(rx).map(|frame| Ok::<_, IoError>(frame)),
|
||||
))
|
||||
.into()
|
||||
}
|
||||
|
||||
pub(crate) fn create_encode_request_body<T: Encoder>(
|
||||
mut encoder: T,
|
||||
mut stream: Streaming<T::Item>,
|
||||
) -> Body {
|
||||
let (mut sender, body) = hyper::Body::channel();
|
||||
let (tx, rx) = mpsc::channel(16);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
while let Some(Ok(message)) = stream.next().await {
|
||||
if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message) {
|
||||
if sender.send_data(data).await.is_err() {
|
||||
if tx.send(Frame::data(data)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
body.into()
|
||||
BodyExt::boxed(StreamBody::new(
|
||||
ReceiverStream::new(rx).map(|frame| Ok::<_, IoError>(frame)),
|
||||
))
|
||||
.into()
|
||||
}
|
||||
|
||||
pub(crate) fn create_decode_response_body<T: Decoder>(
|
||||
@@ -167,35 +179,33 @@ pub(crate) fn create_decode_response_body<T: Decoder>(
|
||||
};
|
||||
}
|
||||
|
||||
let mut body: hyper::Body = body.into();
|
||||
let mut body: BoxBody = body.into();
|
||||
|
||||
Ok(Streaming::new(async_stream::try_stream! {
|
||||
let mut frame_decoder = DataFrameDecoder::default();
|
||||
let mut status = None;
|
||||
|
||||
loop {
|
||||
if let Some(data) = body.data().await.transpose().map_err(Status::from_std_error)? {
|
||||
while let Some(frame) = body.frame().await.transpose().map_err(Status::from_std_error)? {
|
||||
if frame.is_data() {
|
||||
let data = frame.into_data().unwrap();
|
||||
frame_decoder.put_slice(data);
|
||||
while let Some(data) = frame_decoder.next()? {
|
||||
let message = decoder.decode(&data).map_err(Status::from_std_error)?;
|
||||
yield message;
|
||||
}
|
||||
continue;
|
||||
frame_decoder.check_incomplete()?;
|
||||
} else if frame.is_trailers() {
|
||||
let headers = frame.into_trailers().unwrap();
|
||||
status = Some(Status::from_headers(&headers)?
|
||||
.ok_or_else(|| Status::new(Code::Internal)
|
||||
.with_message("missing grpc-status"))?);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
frame_decoder.check_incomplete()?;
|
||||
|
||||
match body.trailers().await.map_err(Status::from_std_error)? {
|
||||
Some(trailers) => {
|
||||
let status = Status::from_headers(&trailers)?
|
||||
.ok_or_else(|| Status::new(Code::Internal).with_message("missing grpc-status"))?;
|
||||
if !status.is_ok() {
|
||||
Err(status)?;
|
||||
}
|
||||
}
|
||||
None => Err(Status::new(Code::Internal).with_message("missing trailers"))?,
|
||||
}
|
||||
|
||||
break;
|
||||
let status = status.ok_or_else(|| Status::new(Code::Internal).with_message("missing trailers"))?;
|
||||
if !status.is_ok() {
|
||||
Err(status)?;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ pub mod service;
|
||||
pub mod codec;
|
||||
pub mod metadata;
|
||||
|
||||
mod connector;
|
||||
mod encoding;
|
||||
mod health;
|
||||
mod reflection;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use base64::engine::{general_purpose::STANDARD_NO_PAD, Engine};
|
||||
use hyper::header::HeaderName;
|
||||
use http::HeaderName;
|
||||
use poem::http::{HeaderMap, HeaderValue};
|
||||
|
||||
/// A metadata map
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{
|
||||
};
|
||||
|
||||
use futures_util::Stream;
|
||||
use hyper::http::Extensions;
|
||||
use http::Extensions;
|
||||
|
||||
use crate::{Metadata, Status, Streaming};
|
||||
|
||||
@@ -90,7 +90,7 @@ impl<T> Request<T> {
|
||||
/// Inserts a value to extensions, similar to
|
||||
/// `self.extensions().insert(data)`.
|
||||
#[inline]
|
||||
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
|
||||
pub fn set_data(&mut self, data: impl Send + Sync + Clone + 'static) {
|
||||
self.extensions.insert(data);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use hyper::{header::HeaderValue, HeaderMap};
|
||||
use http::{header::HeaderValue, HeaderMap};
|
||||
use percent_encoding::{percent_decode_str, percent_encode, AsciiSet, CONTROLS};
|
||||
|
||||
use crate::Metadata;
|
||||
|
||||
@@ -21,7 +21,7 @@ categories = [
|
||||
[dependencies]
|
||||
poem = { workspace = true, default-features = false }
|
||||
|
||||
lambda_http = { version = "0.8.0" }
|
||||
lambda_http = { version = "0.9.0" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }
|
||||
|
||||
@@ -30,6 +30,7 @@ use poem::{Body, Endpoint, EndpointExt, FromRequest, IntoEndpoint, Request, Requ
|
||||
/// println!("request_id: {}", ctx.request_id);
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Context(pub lambda_runtime::Context);
|
||||
|
||||
impl Deref for Context {
|
||||
|
||||
@@ -24,7 +24,7 @@ syn = { workspace = true, features = ["full", "visit-mut"] }
|
||||
thiserror.workspace = true
|
||||
indexmap.workspace = true
|
||||
regex.workspace = true
|
||||
http = "0.2.5"
|
||||
http = "1.0.0"
|
||||
mime.workspace = true
|
||||
|
||||
[package.metadata.workspaces]
|
||||
|
||||
@@ -31,6 +31,7 @@ pub enum ApiExtractorType {
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[derive(Clone)]
|
||||
pub struct UrlQuery(pub Vec<(String, String)>);
|
||||
|
||||
impl Deref for UrlQuery {
|
||||
|
||||
@@ -21,7 +21,7 @@ categories = [
|
||||
[features]
|
||||
default = ["server"]
|
||||
|
||||
server = ["tokio/rt", "tokio/net", "hyper/server", "hyper/runtime"]
|
||||
server = ["tokio/rt", "tokio/net", "hyper/server"]
|
||||
websocket = ["tokio/rt", "tokio-tungstenite", "base64"]
|
||||
multipart = ["multer"]
|
||||
rustls = ["server", "tokio-rustls", "rustls-pemfile"]
|
||||
@@ -51,14 +51,13 @@ i18n = [
|
||||
"intl-memoizer",
|
||||
]
|
||||
acme = ["acme-native-roots"]
|
||||
acme-native-roots = ["acme-base", "hyper-rustls/native-tokio"]
|
||||
acme-webpki-roots = ["acme-base", "hyper-rustls/webpki-tokio"]
|
||||
acme-native-roots = ["acme-base", "reqwest/rustls-tls-native-roots"]
|
||||
acme-webpki-roots = ["acme-base", "reqwest/rustls-tls-webpki-roots"]
|
||||
acme-base = [
|
||||
"server",
|
||||
"hyper/client",
|
||||
"reqwest",
|
||||
"rustls",
|
||||
"ring",
|
||||
"hyper-rustls",
|
||||
"base64",
|
||||
"rcgen",
|
||||
"x509-parser",
|
||||
@@ -74,8 +73,10 @@ poem-derive.workspace = true
|
||||
async-trait = "0.1.51"
|
||||
bytes.workspace = true
|
||||
futures-util = { workspace = true, features = ["sink"] }
|
||||
http = "0.2.5"
|
||||
hyper = { version = "0.14.17", features = ["http1", "http2", "stream"] }
|
||||
http = "1.0.0"
|
||||
hyper = { version = "1.0.0", features = ["http1", "http2"] }
|
||||
hyper-util = { version = "0.1.1", features = ["server-auto", "tokio"] }
|
||||
http-body-util = "0.1.0"
|
||||
tokio = { workspace = true, features = ["sync", "time", "macros", "net"] }
|
||||
tokio-util = { version = "0.7.0", features = ["io"] }
|
||||
serde.workspace = true
|
||||
@@ -87,15 +88,16 @@ percent-encoding = "2.1.0"
|
||||
regex.workspace = true
|
||||
smallvec = "1.6.1"
|
||||
tracing.workspace = true
|
||||
headers = "0.3.7"
|
||||
headers = "0.4.0"
|
||||
thiserror.workspace = true
|
||||
rfc7239 = "0.1.0"
|
||||
mime.workspace = true
|
||||
wildmatch = "2"
|
||||
sync_wrapper = { version = "0.1.2", features = ["futures"] }
|
||||
|
||||
# Non-feature optional dependencies
|
||||
multer = { version = "2.1.0", features = ["tokio"], optional = true }
|
||||
tokio-tungstenite = { version = "0.20.0", optional = true }
|
||||
multer = { version = "3.0.0", features = ["tokio"], optional = true }
|
||||
tokio-tungstenite = { version = "0.21.0", optional = true }
|
||||
tokio-rustls = { workspace = true, optional = true }
|
||||
rustls-pemfile = { version = "1.0.0", optional = true }
|
||||
async-compression = { version = "0.4.0", optional = true, features = [
|
||||
@@ -148,12 +150,7 @@ fluent-syntax = { version = "0.11.0", optional = true }
|
||||
unic-langid = { version = "0.9.0", optional = true, features = ["macros"] }
|
||||
intl-memoizer = { version = "0.5.1", optional = true }
|
||||
ring = { version = "0.16.20", optional = true }
|
||||
hyper-rustls = { workspace = true, optional = true, features = [
|
||||
"http1",
|
||||
"http2",
|
||||
"tls12",
|
||||
"logging",
|
||||
] }
|
||||
reqwest = { workspace = true, features = ["json"], optional = true }
|
||||
rcgen = { version = "0.11.1", optional = true }
|
||||
x509-parser = { version = "0.15.0", optional = true }
|
||||
tokio-metrics = { version = "0.3.0", optional = true }
|
||||
|
||||
141
poem/src/body.rs
141
poem/src/body.rs
@@ -1,14 +1,16 @@
|
||||
use std::{
|
||||
fmt::{Debug, Display, Formatter},
|
||||
fmt::{Debug, Formatter},
|
||||
io::{Error as IoError, ErrorKind},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures_util::{Stream, TryStreamExt};
|
||||
use hyper::body::HttpBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::{Body as _, Frame};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use sync_wrapper::SyncStream;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
use crate::{
|
||||
@@ -16,9 +18,25 @@ use crate::{
|
||||
Result,
|
||||
};
|
||||
|
||||
pub(crate) type BoxBody = http_body_util::combinators::BoxBody<Bytes, IoError>;
|
||||
|
||||
/// A body object for requests and responses.
|
||||
#[derive(Default)]
|
||||
pub struct Body(pub(crate) hyper::Body);
|
||||
pub struct Body(pub(crate) BoxBody);
|
||||
|
||||
impl From<Body> for BoxBody {
|
||||
#[inline]
|
||||
fn from(body: Body) -> Self {
|
||||
body.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BoxBody> for Body {
|
||||
#[inline]
|
||||
fn from(body: BoxBody) -> Self {
|
||||
Body(body)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Body {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -26,50 +44,50 @@ impl Debug for Body {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<hyper::Body> for Body {
|
||||
fn from(body: hyper::Body) -> Self {
|
||||
Body(body)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Body> for hyper::Body {
|
||||
fn from(body: Body) -> Self {
|
||||
body.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&'static [u8]> for Body {
|
||||
#[inline]
|
||||
fn from(data: &'static [u8]) -> Self {
|
||||
Self(data.into())
|
||||
Self(BoxBody::new(
|
||||
http_body_util::Full::new(data.into()).map_err::<_, IoError>(|_| unreachable!()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&'static str> for Body {
|
||||
#[inline]
|
||||
fn from(data: &'static str) -> Self {
|
||||
Self(data.into())
|
||||
Self(BoxBody::new(
|
||||
http_body_util::Full::new(data.into()).map_err::<_, IoError>(|_| unreachable!()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Bytes> for Body {
|
||||
#[inline]
|
||||
fn from(data: Bytes) -> Self {
|
||||
Self(data.into())
|
||||
Self(
|
||||
http_body_util::Full::new(data)
|
||||
.map_err::<_, IoError>(|_| unreachable!())
|
||||
.boxed(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for Body {
|
||||
#[inline]
|
||||
fn from(data: Vec<u8>) -> Self {
|
||||
Self(data.into())
|
||||
Self(
|
||||
http_body_util::Full::new(data.into())
|
||||
.map_err::<_, IoError>(|_| unreachable!())
|
||||
.boxed(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Body {
|
||||
#[inline]
|
||||
fn from(data: String) -> Self {
|
||||
Self(data.into())
|
||||
data.into_bytes().into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,8 +120,8 @@ impl Body {
|
||||
/// Create a body object from reader.
|
||||
#[inline]
|
||||
pub fn from_async_read(reader: impl AsyncRead + Send + 'static) -> Self {
|
||||
Self(hyper::Body::wrap_stream(tokio_util::io::ReaderStream::new(
|
||||
reader,
|
||||
Self(BoxBody::new(http_body_util::StreamBody::new(
|
||||
SyncStream::new(tokio_util::io::ReaderStream::new(reader).map_ok(Frame::data)),
|
||||
)))
|
||||
}
|
||||
|
||||
@@ -112,9 +130,15 @@ impl Body {
|
||||
where
|
||||
S: Stream<Item = Result<O, E>> + Send + 'static,
|
||||
O: Into<Bytes> + 'static,
|
||||
E: std::error::Error + Send + Sync + 'static,
|
||||
E: Into<IoError> + 'static,
|
||||
{
|
||||
Self(hyper::Body::wrap_stream(stream))
|
||||
Self(BoxBody::new(http_body_util::StreamBody::new(
|
||||
SyncStream::new(
|
||||
stream
|
||||
.map_ok(|data| Frame::data(data.into()))
|
||||
.map_err(Into::into),
|
||||
),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Create a body object from JSON.
|
||||
@@ -125,29 +149,33 @@ impl Body {
|
||||
/// Create an empty body.
|
||||
#[inline]
|
||||
pub fn empty() -> Self {
|
||||
Self(hyper::Body::empty())
|
||||
Self(
|
||||
http_body_util::Empty::new()
|
||||
.map_err::<_, IoError>(|_| unreachable!())
|
||||
.boxed(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns `true` if this body is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
let size_hint = hyper::body::HttpBody::size_hint(&self.0);
|
||||
let size_hint = hyper::body::Body::size_hint(&self.0);
|
||||
size_hint.lower() == 0 && size_hint.upper() == Some(0)
|
||||
}
|
||||
|
||||
/// Consumes this body object to return a [`Bytes`] that contains all data.
|
||||
pub async fn into_bytes(self) -> Result<Bytes, ReadBodyError> {
|
||||
hyper::body::to_bytes(self.0)
|
||||
Ok(self
|
||||
.0
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|err| ReadBodyError::Io(IoError::new(ErrorKind::Other, err)))
|
||||
.map_err(|err| ReadBodyError::Io(IoError::new(ErrorKind::Other, err)))?
|
||||
.to_bytes())
|
||||
}
|
||||
|
||||
/// Consumes this body object to return a [`Vec<u8>`] that contains all
|
||||
/// data.
|
||||
pub async fn into_vec(self) -> Result<Vec<u8>, ReadBodyError> {
|
||||
Ok(hyper::body::to_bytes(self.0)
|
||||
.await
|
||||
.map_err(|err| ReadBodyError::Io(IoError::new(ErrorKind::Other, err)))?
|
||||
.to_vec())
|
||||
self.into_bytes().await.map(|data| data.to_vec())
|
||||
}
|
||||
|
||||
/// Consumes this body object to return a [`Bytes`] that contains all
|
||||
@@ -223,41 +251,23 @@ impl Body {
|
||||
|
||||
/// Consumes this body object to return a reader.
|
||||
pub fn into_async_read(self) -> impl AsyncRead + Unpin + Send + 'static {
|
||||
tokio_util::io::StreamReader::new(BodyStream::new(self.0))
|
||||
tokio_util::io::StreamReader::new(self.into_bytes_stream())
|
||||
}
|
||||
|
||||
/// Consumes this body object to return a bytes stream.
|
||||
pub fn into_bytes_stream(self) -> impl Stream<Item = Result<Bytes, IoError>> + Send + 'static {
|
||||
TryStreamExt::map_err(self.0, |err| IoError::new(ErrorKind::Other, err))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub(crate) struct BodyStream<T> {
|
||||
#[pin] inner: T,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> BodyStream<T> {
|
||||
#[inline]
|
||||
pub(crate) fn new(inner: T) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Stream for BodyStream<T>
|
||||
where
|
||||
T: HttpBody,
|
||||
T::Error: Display,
|
||||
{
|
||||
type Item = Result<T::Data, std::io::Error>;
|
||||
|
||||
#[inline]
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.project()
|
||||
.inner
|
||||
.poll_data(cx)
|
||||
.map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))
|
||||
let mut body = self.0;
|
||||
futures_util::stream::poll_fn(move |ctx| loop {
|
||||
match Pin::new(&mut body).poll_frame(ctx) {
|
||||
Poll::Ready(Some(Ok(frame))) => match frame.into_data() {
|
||||
Ok(data) => return Poll::Ready(Some(Ok(data))),
|
||||
Err(_) => continue,
|
||||
},
|
||||
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,9 +277,6 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn create() {
|
||||
let body = Body::from(hyper::Body::from("abc"));
|
||||
assert_eq!(body.into_string().await.unwrap(), "abc");
|
||||
|
||||
let body = Body::from(b"abc".as_ref());
|
||||
assert_eq!(body.into_vec().await.unwrap(), b"abc");
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::{error::Error as StdError, future::Future};
|
||||
|
||||
use bytes::Bytes;
|
||||
use hyper::body::{HttpBody, Sender};
|
||||
use http_body_util::BodyExt;
|
||||
use tower::{Service, ServiceExt};
|
||||
|
||||
use crate::{Endpoint, Error, Request, Response, Result};
|
||||
use crate::{body::BoxBody, Endpoint, Error, Request, Response, Result};
|
||||
|
||||
/// Extension trait for tower service compat.
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
|
||||
@@ -12,12 +12,12 @@ pub trait TowerCompatExt {
|
||||
/// Converts a tower service to a poem endpoint.
|
||||
fn compat<ResBody, Err, Fut>(self) -> TowerCompatEndpoint<Self>
|
||||
where
|
||||
ResBody: HttpBody + Send + 'static,
|
||||
ResBody: hyper::body::Body + Send + Sync + 'static,
|
||||
ResBody::Data: Into<Bytes> + Send + 'static,
|
||||
ResBody::Error: StdError + Send + Sync + 'static,
|
||||
Err: Into<Error>,
|
||||
Self: Service<
|
||||
http::Request<hyper::Body>,
|
||||
http::Request<BoxBody>,
|
||||
Response = hyper::Response<ResBody>,
|
||||
Error = Err,
|
||||
Future = Fut,
|
||||
@@ -41,12 +41,12 @@ pub struct TowerCompatEndpoint<Svc>(Svc);
|
||||
#[async_trait::async_trait]
|
||||
impl<Svc, ResBody, Err, Fut> Endpoint for TowerCompatEndpoint<Svc>
|
||||
where
|
||||
ResBody: HttpBody + Send + 'static,
|
||||
ResBody: hyper::body::Body + Send + Sync + 'static,
|
||||
ResBody::Data: Into<Bytes> + Send + 'static,
|
||||
ResBody::Error: StdError + Send + Sync + 'static,
|
||||
Err: Into<Error>,
|
||||
Svc: Service<
|
||||
http::Request<hyper::Body>,
|
||||
http::Request<BoxBody>,
|
||||
Response = hyper::Response<ResBody>,
|
||||
Error = Err,
|
||||
Future = Fut,
|
||||
@@ -62,59 +62,14 @@ where
|
||||
let mut svc = self.0.clone();
|
||||
|
||||
svc.ready().await.map_err(Into::into)?;
|
||||
|
||||
let hyper_req: http::Request<hyper::Body> = req.into();
|
||||
let hyper_resp = svc
|
||||
.call(hyper_req.map(Into::into))
|
||||
.await
|
||||
.map_err(Into::into)?;
|
||||
|
||||
if !hyper_resp.body().is_end_stream() {
|
||||
Ok(hyper_resp
|
||||
.map(|body| {
|
||||
let (sender, new_body) = hyper::Body::channel();
|
||||
tokio::spawn(copy_body(body, sender));
|
||||
new_body
|
||||
})
|
||||
.into())
|
||||
} else {
|
||||
Ok(hyper_resp.map(|_| hyper::Body::empty()).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn copy_body<T>(body: T, mut sender: Sender)
|
||||
where
|
||||
T: HttpBody + Send + 'static,
|
||||
T::Data: Into<Bytes> + Send + 'static,
|
||||
T::Error: StdError + Send + Sync + 'static,
|
||||
{
|
||||
tokio::pin!(body);
|
||||
|
||||
loop {
|
||||
match body.data().await {
|
||||
Some(Ok(data)) => {
|
||||
if sender.send_data(data.into()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Err(_)) => break,
|
||||
None => {}
|
||||
}
|
||||
|
||||
match body.trailers().await {
|
||||
Ok(Some(trailers)) => {
|
||||
if sender.send_trailers(trailers).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(_) => break,
|
||||
}
|
||||
|
||||
if body.is_end_stream() {
|
||||
break;
|
||||
}
|
||||
svc.call(req.into()).await.map_err(Into::into).map(|resp| {
|
||||
let (parts, body) = resp.into_parts();
|
||||
let body = body
|
||||
.map_frame(|frame| frame.map_data(Into::into))
|
||||
.map_err(std::io::Error::other)
|
||||
.boxed();
|
||||
hyper::Response::from_parts(parts, body).into()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -463,7 +463,7 @@ impl Error {
|
||||
/// assert_eq!(resp.data::<i32>(), Some(&100));
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
|
||||
pub fn set_data(&mut self, data: impl Clone + Send + Sync + 'static) {
|
||||
self.extensions.insert(data);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,15 +3,13 @@ use std::{
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use http::Uri;
|
||||
|
||||
use crate::listener::acme::{
|
||||
builder::AutoCertBuilder, endpoint::Http01Endpoint, ChallengeType, Http01TokensMap,
|
||||
};
|
||||
|
||||
/// ACME configuration
|
||||
pub struct AutoCert {
|
||||
pub(crate) directory_url: Uri,
|
||||
pub(crate) directory_url: String,
|
||||
pub(crate) domains: Vec<String>,
|
||||
pub(crate) contacts: Vec<String>,
|
||||
pub(crate) challenge_type: ChallengeType,
|
||||
|
||||
@@ -4,26 +4,21 @@ use std::{
|
||||
};
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||
use http::{header, Uri};
|
||||
use hyper::{client::HttpConnector, Client};
|
||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||
use reqwest::Client;
|
||||
|
||||
use crate::{
|
||||
listener::acme::{
|
||||
jose,
|
||||
keypair::KeyPair,
|
||||
protocol::{
|
||||
CsrRequest, Directory, FetchAuthorizationResponse, Identifier, NewAccountRequest,
|
||||
NewOrderRequest, NewOrderResponse,
|
||||
},
|
||||
ChallengeType,
|
||||
use crate::listener::acme::{
|
||||
jose,
|
||||
keypair::KeyPair,
|
||||
protocol::{
|
||||
CsrRequest, Directory, FetchAuthorizationResponse, Identifier, NewAccountRequest,
|
||||
NewOrderRequest, NewOrderResponse,
|
||||
},
|
||||
Body,
|
||||
ChallengeType,
|
||||
};
|
||||
|
||||
/// A client for ACME-supporting TLS certificate services.
|
||||
pub struct AcmeClient {
|
||||
client: Client<HttpsConnector<HttpConnector>>,
|
||||
client: Client,
|
||||
directory: Directory,
|
||||
pub(crate) key_pair: Arc<KeyPair>,
|
||||
contacts: Vec<String>,
|
||||
@@ -33,14 +28,8 @@ pub struct AcmeClient {
|
||||
impl AcmeClient {
|
||||
/// Create a new client. `directory_url` is the url for the ACME provider. `contacts` is a list
|
||||
/// of URLS (ex: `mailto:`) the ACME service can use to reach you if there's issues with your certificates.
|
||||
pub async fn try_new(directory_url: &Uri, contacts: Vec<String>) -> IoResult<Self> {
|
||||
let client_builder = HttpsConnectorBuilder::new();
|
||||
#[cfg(feature = "acme-native-roots")]
|
||||
let client_builder1 = client_builder.with_native_roots();
|
||||
#[cfg(all(feature = "acme-webpki-roots", not(feature = "acme-native-roots")))]
|
||||
let client_builder1 = client_builder.with_webpki_roots();
|
||||
let client =
|
||||
Client::builder().build(client_builder1.https_or_http().enable_http1().build());
|
||||
pub(crate) async fn try_new(directory_url: &str, contacts: Vec<String>) -> IoResult<Self> {
|
||||
let client = Client::new();
|
||||
let directory = get_directory(&client, directory_url).await?;
|
||||
Ok(Self {
|
||||
client,
|
||||
@@ -98,7 +87,7 @@ impl AcmeClient {
|
||||
|
||||
pub(crate) async fn fetch_authorization(
|
||||
&self,
|
||||
auth_url: &Uri,
|
||||
auth_url: &str,
|
||||
) -> IoResult<FetchAuthorizationResponse> {
|
||||
tracing::debug!(auth_uri = %auth_url, "fetch authorization");
|
||||
|
||||
@@ -126,7 +115,7 @@ impl AcmeClient {
|
||||
&self,
|
||||
domain: &str,
|
||||
challenge_type: ChallengeType,
|
||||
url: &Uri,
|
||||
url: &str,
|
||||
) -> IoResult<()> {
|
||||
tracing::debug!(
|
||||
auth_uri = %url,
|
||||
@@ -149,7 +138,7 @@ impl AcmeClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn send_csr(&self, url: &Uri, csr: &[u8]) -> IoResult<NewOrderResponse> {
|
||||
pub(crate) async fn send_csr(&self, url: &str, csr: &[u8]) -> IoResult<NewOrderResponse> {
|
||||
tracing::debug!(url = %url, "send certificate request");
|
||||
|
||||
let nonce = get_nonce(&self.client, &self.directory).await?;
|
||||
@@ -166,7 +155,7 @@ impl AcmeClient {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn obtain_certificate(&self, url: &Uri) -> IoResult<Vec<u8>> {
|
||||
pub(crate) async fn obtain_certificate(&self, url: &str) -> IoResult<Vec<u8>> {
|
||||
tracing::debug!(url = %url, "send certificate request");
|
||||
|
||||
let nonce = get_nonce(&self.client, &self.directory).await?;
|
||||
@@ -180,22 +169,23 @@ impl AcmeClient {
|
||||
)
|
||||
.await?;
|
||||
|
||||
resp.into_body().into_vec().await.map_err(|err| {
|
||||
IoError::new(
|
||||
ErrorKind::Other,
|
||||
format!("failed to download certificate: {err}"),
|
||||
)
|
||||
})
|
||||
Ok(resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
IoError::new(
|
||||
ErrorKind::Other,
|
||||
format!("failed to download certificate: {err}"),
|
||||
)
|
||||
})?
|
||||
.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_directory(
|
||||
client: &Client<HttpsConnector<HttpConnector>>,
|
||||
directory_url: &Uri,
|
||||
) -> IoResult<Directory> {
|
||||
async fn get_directory(client: &Client, directory_url: &str) -> IoResult<Directory> {
|
||||
tracing::debug!("loading directory");
|
||||
|
||||
let resp = client.get(directory_url.clone()).await.map_err(|err| {
|
||||
let resp = client.get(directory_url).send().await.map_err(|err| {
|
||||
IoError::new(ErrorKind::Other, format!("failed to load directory: {err}"))
|
||||
})?;
|
||||
|
||||
@@ -206,12 +196,9 @@ async fn get_directory(
|
||||
));
|
||||
}
|
||||
|
||||
let directory = Body(resp.into_body())
|
||||
.into_json::<Directory>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
IoError::new(ErrorKind::Other, format!("failed to load directory: {err}"))
|
||||
})?;
|
||||
let directory = resp.json::<Directory>().await.map_err(|err| {
|
||||
IoError::new(ErrorKind::Other, format!("failed to load directory: {err}"))
|
||||
})?;
|
||||
|
||||
tracing::debug!(
|
||||
new_nonce = ?directory.new_nonce,
|
||||
@@ -222,14 +209,12 @@ async fn get_directory(
|
||||
Ok(directory)
|
||||
}
|
||||
|
||||
async fn get_nonce(
|
||||
client: &Client<HttpsConnector<HttpConnector>>,
|
||||
directory: &Directory,
|
||||
) -> IoResult<String> {
|
||||
async fn get_nonce(client: &Client, directory: &Directory) -> IoResult<String> {
|
||||
tracing::debug!("creating nonce");
|
||||
|
||||
let resp = client
|
||||
.get(directory.new_nonce.clone())
|
||||
.get(&directory.new_nonce)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| IoError::new(ErrorKind::Other, format!("failed to get nonce: {err}")))?;
|
||||
|
||||
@@ -252,7 +237,7 @@ async fn get_nonce(
|
||||
}
|
||||
|
||||
async fn create_acme_account(
|
||||
client: &Client<HttpsConnector<HttpConnector>>,
|
||||
client: &Client,
|
||||
directory: &Directory,
|
||||
key_pair: &KeyPair,
|
||||
contacts: Vec<String>,
|
||||
@@ -274,9 +259,11 @@ async fn create_acme_account(
|
||||
)
|
||||
.await?;
|
||||
let kid = resp
|
||||
.header(header::LOCATION)
|
||||
.ok_or_else(|| IoError::new(ErrorKind::Other, "unable to get account id"))?
|
||||
.to_string();
|
||||
.headers()
|
||||
.get("location")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToString::to_string)
|
||||
.ok_or_else(|| IoError::new(ErrorKind::Other, "unable to get account id"))?;
|
||||
|
||||
tracing::debug!(kid = kid.as_str(), "account created");
|
||||
Ok(kid)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||
use http::{Method, Uri};
|
||||
use hyper::{client::HttpConnector, Client};
|
||||
use hyper_rustls::HttpsConnector;
|
||||
use reqwest::{Client, Response};
|
||||
use ring::digest::{digest, Digest, SHA256};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
use crate::{listener::acme::keypair::KeyPair, Request, Response};
|
||||
use crate::listener::acme::keypair::KeyPair;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Protected<'a> {
|
||||
@@ -100,11 +98,11 @@ struct Body {
|
||||
}
|
||||
|
||||
pub(crate) async fn request(
|
||||
cli: &Client<HttpsConnector<HttpConnector>>,
|
||||
cli: &Client,
|
||||
key_pair: &KeyPair,
|
||||
kid: Option<&str>,
|
||||
nonce: &str,
|
||||
uri: &Uri,
|
||||
uri: &str,
|
||||
payload: Option<impl Serialize>,
|
||||
) -> IoResult<Response> {
|
||||
let jwk = match kid {
|
||||
@@ -121,26 +119,26 @@ pub(crate) async fn request(
|
||||
let payload = URL_SAFE_NO_PAD.encode(payload);
|
||||
let combined = format!("{}.{}", &protected, &payload);
|
||||
let signature = URL_SAFE_NO_PAD.encode(key_pair.sign(combined.as_bytes())?);
|
||||
let body = serde_json::to_vec(&Body {
|
||||
protected,
|
||||
payload,
|
||||
signature,
|
||||
})
|
||||
.unwrap();
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri.clone())
|
||||
.content_type("application/jose+json")
|
||||
.body(body);
|
||||
|
||||
tracing::debug!(uri = %uri, "http request");
|
||||
|
||||
let resp = cli.request(req.into()).await.map_err(|err| {
|
||||
IoError::new(
|
||||
ErrorKind::Other,
|
||||
format!("failed to send http request: {err}"),
|
||||
)
|
||||
})?;
|
||||
let resp = cli
|
||||
.post(uri)
|
||||
.json(&Body {
|
||||
protected,
|
||||
payload,
|
||||
signature,
|
||||
})
|
||||
.header("content-type", "application/jose+json")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
IoError::new(
|
||||
ErrorKind::Other,
|
||||
format!("failed to send http request: {err}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(IoError::new(
|
||||
ErrorKind::Other,
|
||||
@@ -151,11 +149,11 @@ pub(crate) async fn request(
|
||||
}
|
||||
|
||||
pub(crate) async fn request_json<T, R>(
|
||||
cli: &Client<HttpsConnector<HttpConnector>>,
|
||||
cli: &Client,
|
||||
key_pair: &KeyPair,
|
||||
kid: Option<&str>,
|
||||
nonce: &str,
|
||||
uri: &Uri,
|
||||
uri: &str,
|
||||
payload: Option<T>,
|
||||
) -> IoResult<R>
|
||||
where
|
||||
@@ -165,8 +163,7 @@ where
|
||||
let resp = request(cli, key_pair, kid, nonce, uri, payload).await?;
|
||||
|
||||
let data = resp
|
||||
.into_body()
|
||||
.into_string()
|
||||
.text()
|
||||
.await
|
||||
.map_err(|_| IoError::new(ErrorKind::Other, "failed to read response"))?;
|
||||
serde_json::from_str(&data)
|
||||
|
||||
@@ -10,8 +10,10 @@ use rcgen::{
|
||||
};
|
||||
use tokio_rustls::{
|
||||
rustls::{
|
||||
sign::{any_ecdsa_type, CertifiedKey},
|
||||
PrivateKey, ServerConfig,
|
||||
crypto::ring::sign::any_ecdsa_type,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
sign::CertifiedKey,
|
||||
ServerConfig,
|
||||
},
|
||||
server::TlsStream,
|
||||
TlsAcceptor,
|
||||
@@ -37,7 +39,6 @@ pub(crate) async fn auto_cert_acceptor<T: Listener>(
|
||||
challenge_type: ChallengeType,
|
||||
) -> IoResult<AutoCertAcceptor<T::Acceptor>> {
|
||||
let mut server_config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver);
|
||||
server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
@@ -137,7 +138,7 @@ impl<T: Listener> Listener for AutoCertListener<T> {
|
||||
if let (Some(certs), Some(key)) = (cache_certs, cert_key) {
|
||||
let certs = certs
|
||||
.into_iter()
|
||||
.map(tokio_rustls::rustls::Certificate)
|
||||
.map(CertificateDer::from)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expires_at = match certs
|
||||
@@ -156,7 +157,7 @@ impl<T: Listener> Listener for AutoCertListener<T> {
|
||||
);
|
||||
*cert_resolver.cert.write() = Some(Arc::new(CertifiedKey::new(
|
||||
certs,
|
||||
any_ecdsa_type(&PrivateKey(key)).unwrap(),
|
||||
any_ecdsa_type(&PrivateKeyDer::Pkcs8(key.into())).unwrap(),
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -231,13 +232,14 @@ fn gen_acme_cert(domain: &str, acme_hash: &[u8]) -> IoResult<CertifiedKey> {
|
||||
params.custom_extensions = vec![CustomExtension::new_acme_identifier(acme_hash)];
|
||||
let cert = Certificate::from_params(params)
|
||||
.map_err(|_| IoError::new(ErrorKind::Other, "failed to generate acme certificate"))?;
|
||||
let key = any_ecdsa_type(&PrivateKey(cert.serialize_private_key_der())).unwrap();
|
||||
let key = any_ecdsa_type(&PrivateKeyDer::Pkcs8(
|
||||
cert.serialize_private_key_der().into(),
|
||||
))
|
||||
.unwrap();
|
||||
Ok(CertifiedKey::new(
|
||||
vec![tokio_rustls::rustls::Certificate(
|
||||
cert.serialize_der().map_err(|_| {
|
||||
IoError::new(ErrorKind::Other, "failed to serialize acme certificate")
|
||||
})?,
|
||||
)],
|
||||
vec![CertificateDer::from(cert.serialize_der().map_err(
|
||||
|_| IoError::new(ErrorKind::Other, "failed to serialize acme certificate"),
|
||||
)?)],
|
||||
key,
|
||||
))
|
||||
}
|
||||
@@ -353,7 +355,10 @@ pub async fn issue_cert<T: AsRef<str>>(
|
||||
format!("failed create certificate request: {err}"),
|
||||
)
|
||||
})?;
|
||||
let pk = any_ecdsa_type(&PrivateKey(cert.serialize_private_key_der())).unwrap();
|
||||
let pk = any_ecdsa_type(&PrivateKeyDer::Pkcs8(
|
||||
cert.serialize_private_key_der().into(),
|
||||
))
|
||||
.unwrap();
|
||||
let csr = cert.serialize_request_der().map_err(|err| {
|
||||
IoError::new(
|
||||
ErrorKind::Other,
|
||||
@@ -400,7 +405,7 @@ pub async fn issue_cert<T: AsRef<str>>(
|
||||
let cert_chain = rustls_pemfile::certs(&mut acme_cert_pem.as_slice())
|
||||
.map_err(|err| IoError::new(ErrorKind::Other, format!("invalid pem: {err}")))?
|
||||
.into_iter()
|
||||
.map(tokio_rustls::rustls::Certificate)
|
||||
.map(CertificateDer::from)
|
||||
.collect();
|
||||
let cert_key = CertifiedKey::new(cert_chain, pk);
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ mod keypair;
|
||||
mod listener;
|
||||
mod protocol;
|
||||
mod resolver;
|
||||
mod serde;
|
||||
|
||||
pub use auto_cert::AutoCert;
|
||||
pub use builder::AutoCertBuilder;
|
||||
|
||||
@@ -5,8 +5,6 @@ use std::{
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::listener::acme::serde::SerdeUri;
|
||||
|
||||
/// HTTP-01 challenge
|
||||
const CHALLENGE_TYPE_HTTP_01: &str = "http-01";
|
||||
|
||||
@@ -38,9 +36,9 @@ impl Display for ChallengeType {
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub(crate) struct Directory {
|
||||
pub(crate) new_nonce: SerdeUri,
|
||||
pub(crate) new_account: SerdeUri,
|
||||
pub(crate) new_order: SerdeUri,
|
||||
pub(crate) new_nonce: String,
|
||||
pub(crate) new_account: String,
|
||||
pub(crate) new_order: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -75,17 +73,17 @@ pub(crate) struct Problem {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub(crate) struct NewOrderResponse {
|
||||
pub(crate) status: String,
|
||||
pub(crate) authorizations: Vec<SerdeUri>,
|
||||
pub(crate) authorizations: Vec<String>,
|
||||
pub(crate) error: Option<Problem>,
|
||||
pub(crate) finalize: SerdeUri,
|
||||
pub(crate) certificate: Option<SerdeUri>,
|
||||
pub(crate) finalize: String,
|
||||
pub(crate) certificate: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct Challenge {
|
||||
#[serde(rename = "type")]
|
||||
pub(crate) ty: String,
|
||||
pub(crate) url: SerdeUri,
|
||||
pub(crate) url: String,
|
||||
pub(crate) token: String,
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ pub fn seconds_until_expiry(cert: &CertifiedKey) -> i64 {
|
||||
}
|
||||
|
||||
/// Shared ACME key state.
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ResolveServerCert {
|
||||
/// The current TLS certificate. Swap it with `Arc::write`.
|
||||
pub cert: RwLock<Option<Arc<CertifiedKey>>>,
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
use std::{
|
||||
fmt::{self, Debug, Formatter},
|
||||
ops::Deref,
|
||||
};
|
||||
|
||||
use http::Uri;
|
||||
use serde::{de::Error, Deserialize, Deserializer};
|
||||
|
||||
pub(crate) struct SerdeUri(pub(crate) Uri);
|
||||
|
||||
impl<'de> Deserialize<'de> for SerdeUri {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
String::deserialize(deserializer)?
|
||||
.parse::<Uri>()
|
||||
.map(SerdeUri)
|
||||
.map_err(|err| D::Error::custom(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for SerdeUri {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Debug::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SerdeUri {
|
||||
type Target = Uri;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
@@ -9,12 +9,11 @@ use rustls_pemfile::Item;
|
||||
use tokio::io::{Error as IoError, ErrorKind, Result as IoResult};
|
||||
use tokio_rustls::{
|
||||
rustls::{
|
||||
server::{
|
||||
AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ClientHello,
|
||||
NoClientAuth, ResolvesServerCert,
|
||||
},
|
||||
sign::{self, CertifiedKey},
|
||||
Certificate, PrivateKey, RootCertStore, ServerConfig,
|
||||
crypto::ring::sign::any_supported_type,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier},
|
||||
sign::CertifiedKey,
|
||||
RootCertStore, ServerConfig,
|
||||
},
|
||||
server::TlsStream,
|
||||
};
|
||||
@@ -72,7 +71,7 @@ impl RustlsCertificate {
|
||||
impl RustlsCertificate {
|
||||
fn create_certificate_key(&self) -> IoResult<CertifiedKey> {
|
||||
let cert = rustls_pemfile::certs(&mut self.cert.as_slice())
|
||||
.map(|mut certs| certs.drain(..).map(Certificate).collect())
|
||||
.map(|mut certs| certs.drain(..).map(CertificateDer::from).collect())
|
||||
.map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls certificates"))?;
|
||||
|
||||
let priv_key = {
|
||||
@@ -90,12 +89,12 @@ impl RustlsCertificate {
|
||||
_ => continue,
|
||||
};
|
||||
if !key.is_empty() {
|
||||
break PrivateKey(key);
|
||||
break PrivateKeyDer::Pkcs8(key.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let key = sign::any_supported_type(&priv_key)
|
||||
let key = any_supported_type(&priv_key)
|
||||
.map_err(|_| IoError::new(ErrorKind::Other, "invalid private key"))?;
|
||||
|
||||
Ok(CertifiedKey {
|
||||
@@ -106,7 +105,6 @@ impl RustlsCertificate {
|
||||
} else {
|
||||
None
|
||||
},
|
||||
sct_list: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -239,24 +237,30 @@ impl RustlsConfig {
|
||||
);
|
||||
}
|
||||
|
||||
let client_auth = match &self.client_auth {
|
||||
TlsClientAuth::Off => NoClientAuth::boxed(),
|
||||
let builder = ServerConfig::builder();
|
||||
let builder = match &self.client_auth {
|
||||
TlsClientAuth::Off => builder.with_no_client_auth(),
|
||||
TlsClientAuth::Optional(trust_anchor) => {
|
||||
AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?)
|
||||
.boxed()
|
||||
let verifier =
|
||||
WebPkiClientVerifier::builder(read_trust_anchor(trust_anchor)?.into())
|
||||
.allow_unauthenticated()
|
||||
.build()
|
||||
.map_err(|err| IoError::other(err))?;
|
||||
builder.with_client_cert_verifier(verifier)
|
||||
}
|
||||
TlsClientAuth::Required(trust_anchor) => {
|
||||
AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?).boxed()
|
||||
let verifier =
|
||||
WebPkiClientVerifier::builder(read_trust_anchor(trust_anchor)?.into())
|
||||
.build()
|
||||
.map_err(|err| IoError::other(err))?;
|
||||
builder.with_client_cert_verifier(verifier)
|
||||
}
|
||||
};
|
||||
|
||||
let mut server_config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_client_cert_verifier(client_auth)
|
||||
.with_cert_resolver(Arc::new(ResolveServerCert {
|
||||
certifcate_keys,
|
||||
fallback,
|
||||
}));
|
||||
let mut server_config = builder.with_cert_resolver(Arc::new(ResolveServerCert {
|
||||
certifcate_keys,
|
||||
fallback,
|
||||
}));
|
||||
server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
|
||||
|
||||
Ok(server_config)
|
||||
@@ -268,7 +272,7 @@ fn read_trust_anchor(mut trust_anchor: &[u8]) -> IoResult<RootCertStore> {
|
||||
let ders = rustls_pemfile::certs(&mut trust_anchor)?;
|
||||
for der in ders {
|
||||
store
|
||||
.add(&Certificate(der))
|
||||
.add(CertificateDer::from(der))
|
||||
.map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
|
||||
}
|
||||
Ok(store)
|
||||
@@ -402,6 +406,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ResolveServerCert {
|
||||
certifcate_keys: HashMap<String, Arc<CertifiedKey>>,
|
||||
fallback: Option<Arc<CertifiedKey>>,
|
||||
@@ -422,7 +427,7 @@ mod tests {
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
};
|
||||
use tokio_rustls::rustls::{ClientConfig, ServerName};
|
||||
use tokio_rustls::rustls::{pki_types::ServerName, ClientConfig};
|
||||
|
||||
use super::*;
|
||||
use crate::listener::TcpListener;
|
||||
@@ -441,7 +446,6 @@ mod tests {
|
||||
|
||||
tokio::spawn(async move {
|
||||
let config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(
|
||||
read_trust_anchor(include_bytes!("certs/chain1.pem")).unwrap(),
|
||||
)
|
||||
|
||||
@@ -2,10 +2,10 @@ use std::sync::Arc;
|
||||
|
||||
use libopentelemetry::{
|
||||
global,
|
||||
propagation::Extractor,
|
||||
trace::{FutureExt, Span, SpanKind, TraceContextExt, Tracer},
|
||||
Context, Key,
|
||||
};
|
||||
use opentelemetry_http::HeaderExtractor;
|
||||
use opentelemetry_semantic_conventions::{resource, trace};
|
||||
|
||||
use crate::{
|
||||
@@ -52,6 +52,21 @@ pub struct OpenTelemetryTracingEndpoint<T, E> {
|
||||
inner: E,
|
||||
}
|
||||
|
||||
struct HeaderExtractor<'a>(&'a http::HeaderMap);
|
||||
|
||||
impl<'a> Extractor for HeaderExtractor<'a> {
|
||||
fn get(&self, key: &str) -> Option<&str> {
|
||||
self.0.get(key).and_then(|value| value.to_str().ok())
|
||||
}
|
||||
|
||||
fn keys(&self) -> Vec<&str> {
|
||||
self.0
|
||||
.keys()
|
||||
.map(|value| value.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T, E> Endpoint for OpenTelemetryTracingEndpoint<T, E>
|
||||
where
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
fmt::{self, Debug, Formatter},
|
||||
future::Future,
|
||||
io::Error,
|
||||
@@ -9,6 +8,8 @@ use std::{
|
||||
};
|
||||
|
||||
use http::uri::Scheme;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{body::Incoming, rt::Write as _};
|
||||
use parking_lot::Mutex;
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
@@ -16,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
#[cfg(feature = "cookie")]
|
||||
use crate::web::cookie::CookieJar;
|
||||
use crate::{
|
||||
body::Body,
|
||||
body::{Body, BoxBody},
|
||||
error::{ParsePathError, ParseQueryError, UpgradeError},
|
||||
http::{
|
||||
header::{self, HeaderMap, HeaderName, HeaderValue},
|
||||
@@ -108,10 +109,10 @@ impl Debug for Request {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(http::Request<hyper::Body>, LocalAddr, RemoteAddr, Scheme)> for Request {
|
||||
impl From<(http::Request<Incoming>, LocalAddr, RemoteAddr, Scheme)> for Request {
|
||||
fn from(
|
||||
(req, local_addr, remote_addr, scheme): (
|
||||
http::Request<hyper::Body>,
|
||||
http::Request<Incoming>,
|
||||
LocalAddr,
|
||||
RemoteAddr,
|
||||
Scheme,
|
||||
@@ -131,7 +132,7 @@ impl From<(http::Request<hyper::Body>, LocalAddr, RemoteAddr, Scheme)> for Reque
|
||||
version: parts.version,
|
||||
headers: parts.headers,
|
||||
extensions: parts.extensions,
|
||||
body: Body(body),
|
||||
body: Body(body.map_err(|err| Error::other(err)).boxed()),
|
||||
state: RequestState {
|
||||
local_addr,
|
||||
remote_addr,
|
||||
@@ -146,13 +147,13 @@ impl From<(http::Request<hyper::Body>, LocalAddr, RemoteAddr, Scheme)> for Reque
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Request> for hyper::Request<hyper::Body> {
|
||||
impl From<Request> for hyper::Request<BoxBody> {
|
||||
fn from(req: Request) -> Self {
|
||||
let mut hyper_req = http::Request::builder()
|
||||
.method(req.method)
|
||||
.uri(req.uri)
|
||||
.version(req.version)
|
||||
.body(req.body.into())
|
||||
.body(req.body.0)
|
||||
.unwrap();
|
||||
*hyper_req.headers_mut() = req.headers;
|
||||
*hyper_req.extensions_mut() = req.extensions;
|
||||
@@ -372,7 +373,7 @@ impl Request {
|
||||
/// Inserts a value to extensions, similar to
|
||||
/// `self.extensions().insert(data)`.
|
||||
#[inline]
|
||||
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
|
||||
pub fn set_data(&mut self, data: impl Clone + Send + Sync + 'static) {
|
||||
self.extensions.insert(data);
|
||||
}
|
||||
|
||||
@@ -490,7 +491,7 @@ impl AsyncRead for Upgraded {
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
self.project().stream.poll_read(cx, buf)
|
||||
Pin::new(&mut hyper_util::rt::TokioIo::new(self.project().stream)).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -597,7 +598,7 @@ impl RequestBuilder {
|
||||
#[must_use]
|
||||
pub fn extension<T>(mut self, extension: T) -> Self
|
||||
where
|
||||
T: Any + Send + Sync + 'static,
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
self.extensions.insert(extension);
|
||||
self
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
fmt::{self, Debug, Formatter},
|
||||
};
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
|
||||
use bytes::Bytes;
|
||||
use headers::HeaderMapExt;
|
||||
use http_body_util::BodyExt;
|
||||
|
||||
use crate::{
|
||||
body::BoxBody,
|
||||
http::{
|
||||
header::{self, HeaderMap, HeaderName, HeaderValue},
|
||||
Extensions, StatusCode, Version,
|
||||
@@ -77,9 +77,9 @@ impl<T: Into<Body>> From<(StatusCode, T)> for Response {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Response> for hyper::Response<hyper::Body> {
|
||||
impl From<Response> for hyper::Response<BoxBody> {
|
||||
fn from(resp: Response) -> Self {
|
||||
let mut hyper_resp = hyper::Response::new(resp.body.into());
|
||||
let mut hyper_resp = hyper::Response::new(resp.body.0);
|
||||
*hyper_resp.status_mut() = resp.status;
|
||||
*hyper_resp.version_mut() = resp.version;
|
||||
*hyper_resp.headers_mut() = resp.headers;
|
||||
@@ -88,15 +88,24 @@ impl From<Response> for hyper::Response<hyper::Body> {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<hyper::Response<hyper::Body>> for Response {
|
||||
fn from(hyper_resp: hyper::Response<hyper::Body>) -> Self {
|
||||
impl<T: hyper::body::Body> From<hyper::Response<T>> for Response
|
||||
where
|
||||
T: hyper::body::Body + Send + Sync + 'static,
|
||||
T::Data: Into<Bytes>,
|
||||
T::Error: Into<std::io::Error>,
|
||||
{
|
||||
fn from(hyper_resp: hyper::Response<T>) -> Self {
|
||||
let (parts, body) = hyper_resp.into_parts();
|
||||
Response {
|
||||
status: parts.status,
|
||||
version: parts.version,
|
||||
headers: parts.headers,
|
||||
extensions: parts.extensions,
|
||||
body: body.into(),
|
||||
body: Body(
|
||||
body.map_frame(|frame| frame.map_data(Into::into))
|
||||
.map_err(Into::into)
|
||||
.boxed(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,7 +226,7 @@ impl Response {
|
||||
/// Inserts a value to extensions, similar to
|
||||
/// `self.extensions().insert(data)`.
|
||||
#[inline]
|
||||
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
|
||||
pub fn set_data(&mut self, data: impl Clone + Send + Sync + 'static) {
|
||||
self.extensions.insert(data);
|
||||
}
|
||||
|
||||
@@ -312,7 +321,7 @@ impl ResponseBuilder {
|
||||
#[must_use]
|
||||
pub fn extension<T>(mut self, extension: T) -> Self
|
||||
where
|
||||
T: Any + Send + Sync + 'static,
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
self.extensions.insert(extension);
|
||||
self
|
||||
|
||||
@@ -1085,6 +1085,7 @@ mod tests {
|
||||
("/*p1", 9),
|
||||
("/abc/<\\d+>/def", 10),
|
||||
("/kcd/:p1<\\d+>", 11),
|
||||
("/:package/-/:package_tgz<.*tgz$>", 12),
|
||||
];
|
||||
|
||||
for (path, id) in paths {
|
||||
@@ -1154,6 +1155,16 @@ mod tests {
|
||||
NodeData::new(11, "/kcd/:p1<\\d+>"),
|
||||
)),
|
||||
),
|
||||
(
|
||||
"/is-number/-/is-number-7.0.0.tgz",
|
||||
Some((
|
||||
create_url_params(vec![
|
||||
("package", "is-number"),
|
||||
("package_tgz", "is-number-7.0.0.tgz"),
|
||||
]),
|
||||
NodeData::new(12, "/:package/-/:package_tgz<.*tgz$>"),
|
||||
)),
|
||||
),
|
||||
];
|
||||
|
||||
for (path, mut res) in matches {
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
Endpoint, EndpointExt, IntoEndpoint, IntoResponse, Request, Response, Result,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct PathPrefix(usize);
|
||||
|
||||
/// Routing object
|
||||
|
||||
@@ -12,7 +12,8 @@ use std::{
|
||||
};
|
||||
|
||||
use http::uri::Scheme;
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::body::Incoming;
|
||||
use hyper_util::server::conn::auto;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite, ReadBuf, Result as IoResult},
|
||||
@@ -321,7 +322,7 @@ async fn serve_connection(
|
||||
let service = hyper::service::service_fn({
|
||||
let remote_addr = remote_addr.clone();
|
||||
|
||||
move |req: hyper::Request<hyper::Body>| {
|
||||
move |req: http::Request<Incoming>| {
|
||||
let ep = ep.clone();
|
||||
let local_addr = local_addr.clone();
|
||||
let remote_addr = remote_addr.clone();
|
||||
@@ -352,12 +353,13 @@ async fn serve_connection(
|
||||
None => tokio_util::either::Either::Right(socket),
|
||||
};
|
||||
|
||||
let mut conn = Http::new()
|
||||
.serve_connection(socket, service)
|
||||
.with_upgrades();
|
||||
let builder = auto::Builder::new(hyper_util::rt::TokioExecutor::new());
|
||||
let conn =
|
||||
builder.serve_connection_with_upgrades(hyper_util::rt::TokioIo::new(socket), service);
|
||||
futures_util::pin_mut!(conn);
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut conn => {
|
||||
_ = conn => {
|
||||
// Connection completed successfully.
|
||||
return;
|
||||
},
|
||||
@@ -366,10 +368,4 @@ async fn serve_connection(
|
||||
}
|
||||
_ = server_graceful_shutdown_token.cancelled() => {}
|
||||
}
|
||||
|
||||
// Init graceful shutdown for connection (`GOAWAY` for `HTTP/2` or disabling `keep-alive` for `HTTP/1`)
|
||||
Pin::new(&mut conn).graceful_shutdown();
|
||||
|
||||
// Continue awaiting after graceful-shutdown is initiated to handle existed requests.
|
||||
let _ = conn.await;
|
||||
}
|
||||
|
||||
@@ -200,7 +200,7 @@ impl<'a, E> TestRequestBuilder<'a, E> {
|
||||
#[must_use]
|
||||
pub fn data<T>(mut self, data: T) -> Self
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
self.extensions.insert(data);
|
||||
self
|
||||
|
||||
@@ -33,6 +33,7 @@ impl<'a> FromRequest<'a> for &'a CsrfToken {
|
||||
/// A verifier for CSRF Token.
|
||||
///
|
||||
/// See also [`Csrf`](crate::middleware::Csrf)
|
||||
#[derive(Clone)]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "csrf")))]
|
||||
pub struct CsrfVerifier {
|
||||
cookie: Option<UnencryptedCsrfCookie>,
|
||||
|
||||
@@ -129,7 +129,10 @@ impl StaticFileRequest {
|
||||
let mut content_length = data.len() as u64;
|
||||
let mut content_range = None;
|
||||
|
||||
let body = if let Some((start, end)) = self.range.and_then(|range| range.iter().next()) {
|
||||
let body = if let Some((start, end)) = self
|
||||
.range
|
||||
.and_then(|range| range.satisfiable_ranges(data.len() as u64).next())
|
||||
{
|
||||
let start = match start {
|
||||
Bound::Included(n) => n,
|
||||
Bound::Excluded(n) => n + 1,
|
||||
@@ -232,7 +235,10 @@ impl StaticFileRequest {
|
||||
|
||||
let mut content_range = None;
|
||||
|
||||
let body = if let Some((start, end)) = self.range.and_then(|range| range.iter().next()) {
|
||||
let body = if let Some((start, end)) = self
|
||||
.range
|
||||
.and_then(|range| range.satisfiable_ranges(metadata.len()).next())
|
||||
{
|
||||
let start = match start {
|
||||
Bound::Included(n) => n,
|
||||
Bound::Excluded(n) => n + 1,
|
||||
|
||||
Reference in New Issue
Block a user