Merge branch 'hyper-1'

This commit is contained in:
Sunli
2024-01-05 21:17:48 +08:00
37 changed files with 540 additions and 422 deletions

View File

@@ -45,8 +45,8 @@ quick-xml = { version = "0.30.0", features = ["serialize"] }
base64 = "0.21.0"
serde_urlencoded = "0.7.1"
indexmap = "2.0.0"
reqwest = { version = "0.11.23", default-features = false }
# rustls, update together
hyper-rustls = { version = "0.24.0", default-features = false }
rustls = "0.21.0"
tokio-rustls = "0.24.0"
rustls = "0.22.0"
tokio-rustls = "0.25.0"

View File

@@ -11,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>
.unwrap(),
);
let request = Request::new(HelloRequest {
name: "Tonic".into(),
name: "Poem".into(),
});
let response = client.say_hello(request).await?;
println!("RESPONSE={response:?}");

View File

@@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>
.unwrap(),
);
let request = Request::new(HelloRequest {
name: "Tonic".into(),
name: "Poem".into(),
});
let response = client.say_hello(request).await?;
println!("RESPONSE={response:?}");

View File

@@ -21,9 +21,8 @@ json-codec = ["serde", "serde_json"]
poem = { workspace = true, default-features = true }
futures-util.workspace = true
hyper = { version = "0.14.20", features = ["client"] }
async-stream = "0.3.3"
tokio = { workspace = true, features = ["io-util", "rt", "sync"] }
tokio = { workspace = true, features = ["io-util", "rt", "sync", "net"] }
flate2 = "1.0.24"
itoa = "1.0.2"
percent-encoding = "2.1.0"
@@ -32,16 +31,18 @@ prost = "0.12.0"
base64 = "0.21.0"
prost-types = "0.12.0"
tokio-stream = { workspace = true, features = ["sync"] }
hyper-rustls = { workspace = true, features = [
"webpki-roots",
"http2",
"native-tokio",
] }
serde = { workspace = true, optional = true }
serde_json = { workspace = true, optional = true }
rustls.workspace = true
rustls = { workspace = true }
thiserror.workspace = true
fastrand = "2.0.0"
http = "1.0.0"
hyper = { version = "1.0.0", features = ["http1", "http2"] }
hyper-util = { version = "0.1.1", features = ["client-legacy", "tokio"] }
http-body-util = "0.1.0"
tokio-rustls.workspace = true
tower-service = "0.3.2"
webpki-roots = "0.26"
[build-dependencies]
poem-grpc-build.workspace = true

View File

@@ -1,7 +1,9 @@
use std::sync::Arc;
use std::{io::Error as IoError, sync::Arc};
use bytes::Bytes;
use futures_util::TryStreamExt;
use hyper_rustls::HttpsConnectorBuilder;
use http_body_util::BodyExt;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use poem::{
http::{
header, header::InvalidHeaderValue, uri::InvalidUri, Extensions, HeaderValue, Method,
@@ -14,10 +16,13 @@ use rustls::ClientConfig as TlsClientConfig;
use crate::{
codec::Codec,
connector::HttpsConnector,
encoding::{create_decode_response_body, create_encode_request_body},
Code, Metadata, Request, Response, Status, Streaming,
};
pub(crate) type BoxBody = http_body_util::combinators::BoxBody<Bytes, IoError>;
/// A configuration for GRPC client
#[derive(Default)]
pub struct ClientConfig {
@@ -392,29 +397,17 @@ fn create_client_endpoint(
config: ClientConfig,
) -> Arc<dyn Endpoint<Output = HttpResponse> + 'static> {
let mut config = config;
let cli = match config.tls_config.take() {
Some(tls_config) => hyper::Client::builder().http2_only(true).build(
HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http2()
.build(),
),
None => hyper::Client::builder().http2_only(true).build(
HttpsConnectorBuilder::new()
.with_webpki_roots()
.https_or_http()
.enable_http2()
.build(),
),
};
let cli = Client::builder(TokioExecutor::new())
.http2_only(true)
.build(HttpsConnector::new(config.tls_config.take()));
let config = Arc::new(config);
Arc::new(poem::endpoint::make(move |request| {
let config = config.clone();
let cli = cli.clone();
async move {
let mut request: hyper::Request<hyper::Body> = request.into();
let mut request: hyper::Request<BoxBody> = request.into();
if config.uris.is_empty() {
return Err(poem::Error::from_string(
@@ -443,7 +436,12 @@ fn create_client_endpoint(
}
let resp = cli.request(request).await.map_err(to_boxed_error)?;
Ok::<_, poem::Error>(HttpResponse::from(resp))
let (parts, body) = resp.into_parts();
Ok::<_, poem::Error>(HttpResponse::from(hyper::Response::from_parts(
parts,
body.map_err(|err| IoError::other(err)),
)))
}
}))
}

155
poem-grpc/src/connector.rs Normal file
View File

@@ -0,0 +1,155 @@
use std::{
io::{Error as IoError, Result as IoResult},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures_util::{future::BoxFuture, FutureExt};
use http::{uri::Scheme, Uri};
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_util::{
client::legacy::connect::{Connected, Connection},
rt::TokioIo,
};
use rustls::{ClientConfig, RootCertStore};
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};
use tower_service::Service;
pub(crate) enum MaybeHttpsStream {
TcpStream(TokioIo<TcpStream>),
TlsStream {
stream: TokioIo<TlsStream<TcpStream>>,
is_http2: bool,
},
}
impl Read for MaybeHttpsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: ReadBufCursor<'_>,
) -> Poll<IoResult<()>> {
match self.get_mut() {
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_read(cx, buf),
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl Write for MaybeHttpsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
match self.get_mut() {
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_write(cx, buf),
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
match self.get_mut() {
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_flush(cx),
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
match self.get_mut() {
MaybeHttpsStream::TcpStream(stream) => Pin::new(stream).poll_shutdown(cx),
MaybeHttpsStream::TlsStream { stream, .. } => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl Connection for MaybeHttpsStream {
fn connected(&self) -> Connected {
match self {
MaybeHttpsStream::TcpStream(_) => Connected::new(),
MaybeHttpsStream::TlsStream { is_http2, .. } => {
let mut connected = Connected::new();
if *is_http2 {
connected = connected.negotiated_h2();
}
connected
}
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct HttpsConnector {
tls_config: Option<ClientConfig>,
}
impl HttpsConnector {
#[inline]
pub(crate) fn new(tls_config: Option<ClientConfig>) -> Self {
HttpsConnector { tls_config }
}
}
impl Service<Uri> for HttpsConnector {
type Response = MaybeHttpsStream;
type Error = IoError;
type Future = BoxFuture<'static, Result<MaybeHttpsStream, IoError>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, uri: Uri) -> Self::Future {
do_connect(uri, self.tls_config.clone()).boxed()
}
}
fn default_tls_config() -> ClientConfig {
let mut root_cert_store = RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth()
}
async fn do_connect(
uri: Uri,
tls_config: Option<ClientConfig>,
) -> Result<MaybeHttpsStream, IoError> {
let scheme = uri
.scheme()
.ok_or_else(|| IoError::other("missing scheme"))?
.clone();
let host = uri
.host()
.ok_or_else(|| IoError::other("missing host"))?
.to_string();
let port = uri
.port_u16()
.unwrap_or_else(|| if scheme == Scheme::HTTPS { 443 } else { 80 });
if scheme == Scheme::HTTP {
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
Ok(MaybeHttpsStream::TcpStream(TokioIo::new(stream)))
} else if scheme == Scheme::HTTPS {
let mut tls_config = tls_config.unwrap_or_else(default_tls_config);
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let connector = TlsConnector::from(Arc::new(tls_config));
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
let domain = host.try_into().map_err(IoError::other)?;
let mut is_http2 = false;
let stream = connector
.connect_with(domain, stream, |conn| {
is_http2 = conn.alpn_protocol() == Some(b"h2");
})
.await?;
Ok(MaybeHttpsStream::TlsStream {
stream: TokioIo::new(stream),
is_http2,
})
} else {
Err(IoError::other("invalid scheme"))
}
}

View File

@@ -1,12 +1,16 @@
use std::io::Result as IoResult;
use std::io::{Error as IoError, Result as IoResult};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use flate2::read::GzDecoder;
use futures_util::StreamExt;
use hyper::{body::HttpBody, HeaderMap};
use http_body_util::{BodyExt, StreamBody};
use hyper::{body::Frame, HeaderMap};
use poem::Body;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::{
client::BoxBody,
codec::{Decoder, Encoder},
Code, Status, Streaming,
};
@@ -76,18 +80,20 @@ pub(crate) fn create_decode_request_body<T: Decoder>(
mut decoder: T,
body: Body,
) -> Streaming<T::Item> {
let mut body: hyper::Body = body.into();
let mut body: BoxBody = body.into();
Streaming::new(async_stream::try_stream! {
let mut frame_decoder = DataFrameDecoder::default();
loop {
match body.data().await.transpose().map_err(Status::from_std_error)? {
Some(data) => {
frame_decoder.put_slice(data);
while let Some(data) = frame_decoder.next()? {
let message = decoder.decode(&data).map_err(Status::from_std_error)?;
yield message;
match body.frame().await.transpose().map_err(Status::from_std_error)? {
Some(frame) => {
if let Ok(data) = frame.into_data() {
frame_decoder.put_slice(data);
while let Some(data) = frame_decoder.next()? {
let message = decoder.decode(&data).map_err(Status::from_std_error)?;
yield message;
}
}
}
None => {
@@ -103,7 +109,7 @@ pub(crate) fn create_encode_response_body<T: Encoder>(
mut encoder: T,
mut stream: Streaming<T::Item>,
) -> Body {
let (mut sender, body) = hyper::Body::channel();
let (tx, rx) = mpsc::channel(16);
tokio::spawn(async move {
let mut buf = BytesMut::new();
@@ -112,45 +118,51 @@ pub(crate) fn create_encode_response_body<T: Encoder>(
match item {
Ok(message) => {
if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message) {
if sender.send_data(data).await.is_err() {
if tx.send(Frame::data(data)).await.is_err() {
return;
}
}
}
Err(status) => {
let _ = sender.send_trailers(status.to_headers()).await;
_ = tx.send(Frame::trailers(status.to_headers())).await;
return;
}
}
}
let _ = sender
.send_trailers(Status::new(Code::Ok).to_headers())
_ = tx
.send(Frame::trailers(Status::new(Code::Ok).to_headers()))
.await;
});
body.into()
BodyExt::boxed(StreamBody::new(
ReceiverStream::new(rx).map(|frame| Ok::<_, IoError>(frame)),
))
.into()
}
pub(crate) fn create_encode_request_body<T: Encoder>(
mut encoder: T,
mut stream: Streaming<T::Item>,
) -> Body {
let (mut sender, body) = hyper::Body::channel();
let (tx, rx) = mpsc::channel(16);
tokio::spawn(async move {
let mut buf = BytesMut::new();
while let Some(Ok(message)) = stream.next().await {
if let Ok(data) = encode_data_frame(&mut encoder, &mut buf, message) {
if sender.send_data(data).await.is_err() {
if tx.send(Frame::data(data)).await.is_err() {
return;
}
}
}
});
body.into()
BodyExt::boxed(StreamBody::new(
ReceiverStream::new(rx).map(|frame| Ok::<_, IoError>(frame)),
))
.into()
}
pub(crate) fn create_decode_response_body<T: Decoder>(
@@ -167,35 +179,33 @@ pub(crate) fn create_decode_response_body<T: Decoder>(
};
}
let mut body: hyper::Body = body.into();
let mut body: BoxBody = body.into();
Ok(Streaming::new(async_stream::try_stream! {
let mut frame_decoder = DataFrameDecoder::default();
let mut status = None;
loop {
if let Some(data) = body.data().await.transpose().map_err(Status::from_std_error)? {
while let Some(frame) = body.frame().await.transpose().map_err(Status::from_std_error)? {
if frame.is_data() {
let data = frame.into_data().unwrap();
frame_decoder.put_slice(data);
while let Some(data) = frame_decoder.next()? {
let message = decoder.decode(&data).map_err(Status::from_std_error)?;
yield message;
}
continue;
frame_decoder.check_incomplete()?;
} else if frame.is_trailers() {
let headers = frame.into_trailers().unwrap();
status = Some(Status::from_headers(&headers)?
.ok_or_else(|| Status::new(Code::Internal)
.with_message("missing grpc-status"))?);
break;
}
}
frame_decoder.check_incomplete()?;
match body.trailers().await.map_err(Status::from_std_error)? {
Some(trailers) => {
let status = Status::from_headers(&trailers)?
.ok_or_else(|| Status::new(Code::Internal).with_message("missing grpc-status"))?;
if !status.is_ok() {
Err(status)?;
}
}
None => Err(Status::new(Code::Internal).with_message("missing trailers"))?,
}
break;
let status = status.ok_or_else(|| Status::new(Code::Internal).with_message("missing trailers"))?;
if !status.is_ok() {
Err(status)?;
}
}))
}

View File

@@ -20,6 +20,7 @@ pub mod service;
pub mod codec;
pub mod metadata;
mod connector;
mod encoding;
mod health;
mod reflection;

View File

@@ -3,7 +3,7 @@
use std::str::FromStr;
use base64::engine::{general_purpose::STANDARD_NO_PAD, Engine};
use hyper::header::HeaderName;
use http::HeaderName;
use poem::http::{HeaderMap, HeaderValue};
/// A metadata map

View File

@@ -4,7 +4,7 @@ use std::{
};
use futures_util::Stream;
use hyper::http::Extensions;
use http::Extensions;
use crate::{Metadata, Status, Streaming};
@@ -90,7 +90,7 @@ impl<T> Request<T> {
/// Inserts a value to extensions, similar to
/// `self.extensions().insert(data)`.
#[inline]
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
pub fn set_data(&mut self, data: impl Send + Sync + Clone + 'static) {
self.extensions.insert(data);
}
}

View File

@@ -1,6 +1,6 @@
use std::fmt::Display;
use hyper::{header::HeaderValue, HeaderMap};
use http::{header::HeaderValue, HeaderMap};
use percent_encoding::{percent_decode_str, percent_encode, AsciiSet, CONTROLS};
use crate::Metadata;

View File

@@ -21,7 +21,7 @@ categories = [
[dependencies]
poem = { workspace = true, default-features = false }
lambda_http = { version = "0.8.0" }
lambda_http = { version = "0.9.0" }
[dev-dependencies]
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }

View File

@@ -30,6 +30,7 @@ use poem::{Body, Endpoint, EndpointExt, FromRequest, IntoEndpoint, Request, Requ
/// println!("request_id: {}", ctx.request_id);
/// }
/// ```
#[derive(Debug, Clone)]
pub struct Context(pub lambda_runtime::Context);
impl Deref for Context {

View File

@@ -24,7 +24,7 @@ syn = { workspace = true, features = ["full", "visit-mut"] }
thiserror.workspace = true
indexmap.workspace = true
regex.workspace = true
http = "0.2.5"
http = "1.0.0"
mime.workspace = true
[package.metadata.workspaces]

View File

@@ -31,6 +31,7 @@ pub enum ApiExtractorType {
}
#[doc(hidden)]
#[derive(Clone)]
pub struct UrlQuery(pub Vec<(String, String)>);
impl Deref for UrlQuery {

View File

@@ -21,7 +21,7 @@ categories = [
[features]
default = ["server"]
server = ["tokio/rt", "tokio/net", "hyper/server", "hyper/runtime"]
server = ["tokio/rt", "tokio/net", "hyper/server"]
websocket = ["tokio/rt", "tokio-tungstenite", "base64"]
multipart = ["multer"]
rustls = ["server", "tokio-rustls", "rustls-pemfile"]
@@ -51,14 +51,13 @@ i18n = [
"intl-memoizer",
]
acme = ["acme-native-roots"]
acme-native-roots = ["acme-base", "hyper-rustls/native-tokio"]
acme-webpki-roots = ["acme-base", "hyper-rustls/webpki-tokio"]
acme-native-roots = ["acme-base", "reqwest/rustls-tls-native-roots"]
acme-webpki-roots = ["acme-base", "reqwest/rustls-tls-webpki-roots"]
acme-base = [
"server",
"hyper/client",
"reqwest",
"rustls",
"ring",
"hyper-rustls",
"base64",
"rcgen",
"x509-parser",
@@ -74,8 +73,10 @@ poem-derive.workspace = true
async-trait = "0.1.51"
bytes.workspace = true
futures-util = { workspace = true, features = ["sink"] }
http = "0.2.5"
hyper = { version = "0.14.17", features = ["http1", "http2", "stream"] }
http = "1.0.0"
hyper = { version = "1.0.0", features = ["http1", "http2"] }
hyper-util = { version = "0.1.1", features = ["server-auto", "tokio"] }
http-body-util = "0.1.0"
tokio = { workspace = true, features = ["sync", "time", "macros", "net"] }
tokio-util = { version = "0.7.0", features = ["io"] }
serde.workspace = true
@@ -87,15 +88,16 @@ percent-encoding = "2.1.0"
regex.workspace = true
smallvec = "1.6.1"
tracing.workspace = true
headers = "0.3.7"
headers = "0.4.0"
thiserror.workspace = true
rfc7239 = "0.1.0"
mime.workspace = true
wildmatch = "2"
sync_wrapper = { version = "0.1.2", features = ["futures"] }
# Non-feature optional dependencies
multer = { version = "2.1.0", features = ["tokio"], optional = true }
tokio-tungstenite = { version = "0.20.0", optional = true }
multer = { version = "3.0.0", features = ["tokio"], optional = true }
tokio-tungstenite = { version = "0.21.0", optional = true }
tokio-rustls = { workspace = true, optional = true }
rustls-pemfile = { version = "1.0.0", optional = true }
async-compression = { version = "0.4.0", optional = true, features = [
@@ -148,12 +150,7 @@ fluent-syntax = { version = "0.11.0", optional = true }
unic-langid = { version = "0.9.0", optional = true, features = ["macros"] }
intl-memoizer = { version = "0.5.1", optional = true }
ring = { version = "0.16.20", optional = true }
hyper-rustls = { workspace = true, optional = true, features = [
"http1",
"http2",
"tls12",
"logging",
] }
reqwest = { workspace = true, features = ["json"], optional = true }
rcgen = { version = "0.11.1", optional = true }
x509-parser = { version = "0.15.0", optional = true }
tokio-metrics = { version = "0.3.0", optional = true }

View File

@@ -1,14 +1,16 @@
use std::{
fmt::{Debug, Display, Formatter},
fmt::{Debug, Formatter},
io::{Error as IoError, ErrorKind},
pin::Pin,
task::{Context, Poll},
task::Poll,
};
use bytes::{Bytes, BytesMut};
use futures_util::{Stream, TryStreamExt};
use hyper::body::HttpBody;
use http_body_util::BodyExt;
use hyper::body::{Body as _, Frame};
use serde::{de::DeserializeOwned, Serialize};
use sync_wrapper::SyncStream;
use tokio::io::{AsyncRead, AsyncReadExt};
use crate::{
@@ -16,9 +18,25 @@ use crate::{
Result,
};
pub(crate) type BoxBody = http_body_util::combinators::BoxBody<Bytes, IoError>;
/// A body object for requests and responses.
#[derive(Default)]
pub struct Body(pub(crate) hyper::Body);
pub struct Body(pub(crate) BoxBody);
impl From<Body> for BoxBody {
#[inline]
fn from(body: Body) -> Self {
body.0
}
}
impl From<BoxBody> for Body {
#[inline]
fn from(body: BoxBody) -> Self {
Body(body)
}
}
impl Debug for Body {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
@@ -26,50 +44,50 @@ impl Debug for Body {
}
}
impl From<hyper::Body> for Body {
fn from(body: hyper::Body) -> Self {
Body(body)
}
}
impl From<Body> for hyper::Body {
fn from(body: Body) -> Self {
body.0
}
}
impl From<&'static [u8]> for Body {
#[inline]
fn from(data: &'static [u8]) -> Self {
Self(data.into())
Self(BoxBody::new(
http_body_util::Full::new(data.into()).map_err::<_, IoError>(|_| unreachable!()),
))
}
}
impl From<&'static str> for Body {
#[inline]
fn from(data: &'static str) -> Self {
Self(data.into())
Self(BoxBody::new(
http_body_util::Full::new(data.into()).map_err::<_, IoError>(|_| unreachable!()),
))
}
}
impl From<Bytes> for Body {
#[inline]
fn from(data: Bytes) -> Self {
Self(data.into())
Self(
http_body_util::Full::new(data)
.map_err::<_, IoError>(|_| unreachable!())
.boxed(),
)
}
}
impl From<Vec<u8>> for Body {
#[inline]
fn from(data: Vec<u8>) -> Self {
Self(data.into())
Self(
http_body_util::Full::new(data.into())
.map_err::<_, IoError>(|_| unreachable!())
.boxed(),
)
}
}
impl From<String> for Body {
#[inline]
fn from(data: String) -> Self {
Self(data.into())
data.into_bytes().into()
}
}
@@ -102,8 +120,8 @@ impl Body {
/// Create a body object from reader.
#[inline]
pub fn from_async_read(reader: impl AsyncRead + Send + 'static) -> Self {
Self(hyper::Body::wrap_stream(tokio_util::io::ReaderStream::new(
reader,
Self(BoxBody::new(http_body_util::StreamBody::new(
SyncStream::new(tokio_util::io::ReaderStream::new(reader).map_ok(Frame::data)),
)))
}
@@ -112,9 +130,15 @@ impl Body {
where
S: Stream<Item = Result<O, E>> + Send + 'static,
O: Into<Bytes> + 'static,
E: std::error::Error + Send + Sync + 'static,
E: Into<IoError> + 'static,
{
Self(hyper::Body::wrap_stream(stream))
Self(BoxBody::new(http_body_util::StreamBody::new(
SyncStream::new(
stream
.map_ok(|data| Frame::data(data.into()))
.map_err(Into::into),
),
)))
}
/// Create a body object from JSON.
@@ -125,29 +149,33 @@ impl Body {
/// Create an empty body.
#[inline]
pub fn empty() -> Self {
Self(hyper::Body::empty())
Self(
http_body_util::Empty::new()
.map_err::<_, IoError>(|_| unreachable!())
.boxed(),
)
}
/// Returns `true` if this body is empty.
pub fn is_empty(&self) -> bool {
let size_hint = hyper::body::HttpBody::size_hint(&self.0);
let size_hint = hyper::body::Body::size_hint(&self.0);
size_hint.lower() == 0 && size_hint.upper() == Some(0)
}
/// Consumes this body object to return a [`Bytes`] that contains all data.
pub async fn into_bytes(self) -> Result<Bytes, ReadBodyError> {
hyper::body::to_bytes(self.0)
Ok(self
.0
.collect()
.await
.map_err(|err| ReadBodyError::Io(IoError::new(ErrorKind::Other, err)))
.map_err(|err| ReadBodyError::Io(IoError::new(ErrorKind::Other, err)))?
.to_bytes())
}
/// Consumes this body object to return a [`Vec<u8>`] that contains all
/// data.
pub async fn into_vec(self) -> Result<Vec<u8>, ReadBodyError> {
Ok(hyper::body::to_bytes(self.0)
.await
.map_err(|err| ReadBodyError::Io(IoError::new(ErrorKind::Other, err)))?
.to_vec())
self.into_bytes().await.map(|data| data.to_vec())
}
/// Consumes this body object to return a [`Bytes`] that contains all
@@ -223,41 +251,23 @@ impl Body {
/// Consumes this body object to return a reader.
pub fn into_async_read(self) -> impl AsyncRead + Unpin + Send + 'static {
tokio_util::io::StreamReader::new(BodyStream::new(self.0))
tokio_util::io::StreamReader::new(self.into_bytes_stream())
}
/// Consumes this body object to return a bytes stream.
pub fn into_bytes_stream(self) -> impl Stream<Item = Result<Bytes, IoError>> + Send + 'static {
TryStreamExt::map_err(self.0, |err| IoError::new(ErrorKind::Other, err))
}
}
pin_project_lite::pin_project! {
pub(crate) struct BodyStream<T> {
#[pin] inner: T,
}
}
impl<T> BodyStream<T> {
#[inline]
pub(crate) fn new(inner: T) -> Self {
Self { inner }
}
}
impl<T> Stream for BodyStream<T>
where
T: HttpBody,
T::Error: Display,
{
type Item = Result<T::Data, std::io::Error>;
#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project()
.inner
.poll_data(cx)
.map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))
let mut body = self.0;
futures_util::stream::poll_fn(move |ctx| loop {
match Pin::new(&mut body).poll_frame(ctx) {
Poll::Ready(Some(Ok(frame))) => match frame.into_data() {
Ok(data) => return Poll::Ready(Some(Ok(data))),
Err(_) => continue,
},
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
})
}
}
@@ -267,9 +277,6 @@ mod tests {
#[tokio::test]
async fn create() {
let body = Body::from(hyper::Body::from("abc"));
assert_eq!(body.into_string().await.unwrap(), "abc");
let body = Body::from(b"abc".as_ref());
assert_eq!(body.into_vec().await.unwrap(), b"abc");

View File

@@ -1,10 +1,10 @@
use std::{error::Error as StdError, future::Future};
use bytes::Bytes;
use hyper::body::{HttpBody, Sender};
use http_body_util::BodyExt;
use tower::{Service, ServiceExt};
use crate::{Endpoint, Error, Request, Response, Result};
use crate::{body::BoxBody, Endpoint, Error, Request, Response, Result};
/// Extension trait for tower service compat.
#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
@@ -12,12 +12,12 @@ pub trait TowerCompatExt {
/// Converts a tower service to a poem endpoint.
fn compat<ResBody, Err, Fut>(self) -> TowerCompatEndpoint<Self>
where
ResBody: HttpBody + Send + 'static,
ResBody: hyper::body::Body + Send + Sync + 'static,
ResBody::Data: Into<Bytes> + Send + 'static,
ResBody::Error: StdError + Send + Sync + 'static,
Err: Into<Error>,
Self: Service<
http::Request<hyper::Body>,
http::Request<BoxBody>,
Response = hyper::Response<ResBody>,
Error = Err,
Future = Fut,
@@ -41,12 +41,12 @@ pub struct TowerCompatEndpoint<Svc>(Svc);
#[async_trait::async_trait]
impl<Svc, ResBody, Err, Fut> Endpoint for TowerCompatEndpoint<Svc>
where
ResBody: HttpBody + Send + 'static,
ResBody: hyper::body::Body + Send + Sync + 'static,
ResBody::Data: Into<Bytes> + Send + 'static,
ResBody::Error: StdError + Send + Sync + 'static,
Err: Into<Error>,
Svc: Service<
http::Request<hyper::Body>,
http::Request<BoxBody>,
Response = hyper::Response<ResBody>,
Error = Err,
Future = Fut,
@@ -62,59 +62,14 @@ where
let mut svc = self.0.clone();
svc.ready().await.map_err(Into::into)?;
let hyper_req: http::Request<hyper::Body> = req.into();
let hyper_resp = svc
.call(hyper_req.map(Into::into))
.await
.map_err(Into::into)?;
if !hyper_resp.body().is_end_stream() {
Ok(hyper_resp
.map(|body| {
let (sender, new_body) = hyper::Body::channel();
tokio::spawn(copy_body(body, sender));
new_body
})
.into())
} else {
Ok(hyper_resp.map(|_| hyper::Body::empty()).into())
}
}
}
async fn copy_body<T>(body: T, mut sender: Sender)
where
T: HttpBody + Send + 'static,
T::Data: Into<Bytes> + Send + 'static,
T::Error: StdError + Send + Sync + 'static,
{
tokio::pin!(body);
loop {
match body.data().await {
Some(Ok(data)) => {
if sender.send_data(data.into()).await.is_err() {
break;
}
}
Some(Err(_)) => break,
None => {}
}
match body.trailers().await {
Ok(Some(trailers)) => {
if sender.send_trailers(trailers).await.is_err() {
break;
}
}
Ok(None) => {}
Err(_) => break,
}
if body.is_end_stream() {
break;
}
svc.call(req.into()).await.map_err(Into::into).map(|resp| {
let (parts, body) = resp.into_parts();
let body = body
.map_frame(|frame| frame.map_data(Into::into))
.map_err(std::io::Error::other)
.boxed();
hyper::Response::from_parts(parts, body).into()
})
}
}

View File

@@ -463,7 +463,7 @@ impl Error {
/// assert_eq!(resp.data::<i32>(), Some(&100));
/// ```
#[inline]
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
pub fn set_data(&mut self, data: impl Clone + Send + Sync + 'static) {
self.extensions.insert(data);
}

View File

@@ -3,15 +3,13 @@ use std::{
path::PathBuf,
};
use http::Uri;
use crate::listener::acme::{
builder::AutoCertBuilder, endpoint::Http01Endpoint, ChallengeType, Http01TokensMap,
};
/// ACME configuration
pub struct AutoCert {
pub(crate) directory_url: Uri,
pub(crate) directory_url: String,
pub(crate) domains: Vec<String>,
pub(crate) contacts: Vec<String>,
pub(crate) challenge_type: ChallengeType,

View File

@@ -4,26 +4,21 @@ use std::{
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use http::{header, Uri};
use hyper::{client::HttpConnector, Client};
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use reqwest::Client;
use crate::{
listener::acme::{
jose,
keypair::KeyPair,
protocol::{
CsrRequest, Directory, FetchAuthorizationResponse, Identifier, NewAccountRequest,
NewOrderRequest, NewOrderResponse,
},
ChallengeType,
use crate::listener::acme::{
jose,
keypair::KeyPair,
protocol::{
CsrRequest, Directory, FetchAuthorizationResponse, Identifier, NewAccountRequest,
NewOrderRequest, NewOrderResponse,
},
Body,
ChallengeType,
};
/// A client for ACME-supporting TLS certificate services.
pub struct AcmeClient {
client: Client<HttpsConnector<HttpConnector>>,
client: Client,
directory: Directory,
pub(crate) key_pair: Arc<KeyPair>,
contacts: Vec<String>,
@@ -33,14 +28,8 @@ pub struct AcmeClient {
impl AcmeClient {
/// Create a new client. `directory_url` is the url for the ACME provider. `contacts` is a list
/// of URLS (ex: `mailto:`) the ACME service can use to reach you if there's issues with your certificates.
pub async fn try_new(directory_url: &Uri, contacts: Vec<String>) -> IoResult<Self> {
let client_builder = HttpsConnectorBuilder::new();
#[cfg(feature = "acme-native-roots")]
let client_builder1 = client_builder.with_native_roots();
#[cfg(all(feature = "acme-webpki-roots", not(feature = "acme-native-roots")))]
let client_builder1 = client_builder.with_webpki_roots();
let client =
Client::builder().build(client_builder1.https_or_http().enable_http1().build());
pub(crate) async fn try_new(directory_url: &str, contacts: Vec<String>) -> IoResult<Self> {
let client = Client::new();
let directory = get_directory(&client, directory_url).await?;
Ok(Self {
client,
@@ -98,7 +87,7 @@ impl AcmeClient {
pub(crate) async fn fetch_authorization(
&self,
auth_url: &Uri,
auth_url: &str,
) -> IoResult<FetchAuthorizationResponse> {
tracing::debug!(auth_uri = %auth_url, "fetch authorization");
@@ -126,7 +115,7 @@ impl AcmeClient {
&self,
domain: &str,
challenge_type: ChallengeType,
url: &Uri,
url: &str,
) -> IoResult<()> {
tracing::debug!(
auth_uri = %url,
@@ -149,7 +138,7 @@ impl AcmeClient {
Ok(())
}
pub(crate) async fn send_csr(&self, url: &Uri, csr: &[u8]) -> IoResult<NewOrderResponse> {
pub(crate) async fn send_csr(&self, url: &str, csr: &[u8]) -> IoResult<NewOrderResponse> {
tracing::debug!(url = %url, "send certificate request");
let nonce = get_nonce(&self.client, &self.directory).await?;
@@ -166,7 +155,7 @@ impl AcmeClient {
.await
}
pub(crate) async fn obtain_certificate(&self, url: &Uri) -> IoResult<Vec<u8>> {
pub(crate) async fn obtain_certificate(&self, url: &str) -> IoResult<Vec<u8>> {
tracing::debug!(url = %url, "send certificate request");
let nonce = get_nonce(&self.client, &self.directory).await?;
@@ -180,22 +169,23 @@ impl AcmeClient {
)
.await?;
resp.into_body().into_vec().await.map_err(|err| {
IoError::new(
ErrorKind::Other,
format!("failed to download certificate: {err}"),
)
})
Ok(resp
.bytes()
.await
.map_err(|err| {
IoError::new(
ErrorKind::Other,
format!("failed to download certificate: {err}"),
)
})?
.to_vec())
}
}
async fn get_directory(
client: &Client<HttpsConnector<HttpConnector>>,
directory_url: &Uri,
) -> IoResult<Directory> {
async fn get_directory(client: &Client, directory_url: &str) -> IoResult<Directory> {
tracing::debug!("loading directory");
let resp = client.get(directory_url.clone()).await.map_err(|err| {
let resp = client.get(directory_url).send().await.map_err(|err| {
IoError::new(ErrorKind::Other, format!("failed to load directory: {err}"))
})?;
@@ -206,12 +196,9 @@ async fn get_directory(
));
}
let directory = Body(resp.into_body())
.into_json::<Directory>()
.await
.map_err(|err| {
IoError::new(ErrorKind::Other, format!("failed to load directory: {err}"))
})?;
let directory = resp.json::<Directory>().await.map_err(|err| {
IoError::new(ErrorKind::Other, format!("failed to load directory: {err}"))
})?;
tracing::debug!(
new_nonce = ?directory.new_nonce,
@@ -222,14 +209,12 @@ async fn get_directory(
Ok(directory)
}
async fn get_nonce(
client: &Client<HttpsConnector<HttpConnector>>,
directory: &Directory,
) -> IoResult<String> {
async fn get_nonce(client: &Client, directory: &Directory) -> IoResult<String> {
tracing::debug!("creating nonce");
let resp = client
.get(directory.new_nonce.clone())
.get(&directory.new_nonce)
.send()
.await
.map_err(|err| IoError::new(ErrorKind::Other, format!("failed to get nonce: {err}")))?;
@@ -252,7 +237,7 @@ async fn get_nonce(
}
async fn create_acme_account(
client: &Client<HttpsConnector<HttpConnector>>,
client: &Client,
directory: &Directory,
key_pair: &KeyPair,
contacts: Vec<String>,
@@ -274,9 +259,11 @@ async fn create_acme_account(
)
.await?;
let kid = resp
.header(header::LOCATION)
.ok_or_else(|| IoError::new(ErrorKind::Other, "unable to get account id"))?
.to_string();
.headers()
.get("location")
.and_then(|value| value.to_str().ok())
.map(ToString::to_string)
.ok_or_else(|| IoError::new(ErrorKind::Other, "unable to get account id"))?;
tracing::debug!(kid = kid.as_str(), "account created");
Ok(kid)

View File

@@ -1,13 +1,11 @@
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use http::{Method, Uri};
use hyper::{client::HttpConnector, Client};
use hyper_rustls::HttpsConnector;
use reqwest::{Client, Response};
use ring::digest::{digest, Digest, SHA256};
use serde::{de::DeserializeOwned, Serialize};
use crate::{listener::acme::keypair::KeyPair, Request, Response};
use crate::listener::acme::keypair::KeyPair;
#[derive(Serialize)]
struct Protected<'a> {
@@ -100,11 +98,11 @@ struct Body {
}
pub(crate) async fn request(
cli: &Client<HttpsConnector<HttpConnector>>,
cli: &Client,
key_pair: &KeyPair,
kid: Option<&str>,
nonce: &str,
uri: &Uri,
uri: &str,
payload: Option<impl Serialize>,
) -> IoResult<Response> {
let jwk = match kid {
@@ -121,26 +119,26 @@ pub(crate) async fn request(
let payload = URL_SAFE_NO_PAD.encode(payload);
let combined = format!("{}.{}", &protected, &payload);
let signature = URL_SAFE_NO_PAD.encode(key_pair.sign(combined.as_bytes())?);
let body = serde_json::to_vec(&Body {
protected,
payload,
signature,
})
.unwrap();
let req = Request::builder()
.method(Method::POST)
.uri(uri.clone())
.content_type("application/jose+json")
.body(body);
tracing::debug!(uri = %uri, "http request");
let resp = cli.request(req.into()).await.map_err(|err| {
IoError::new(
ErrorKind::Other,
format!("failed to send http request: {err}"),
)
})?;
let resp = cli
.post(uri)
.json(&Body {
protected,
payload,
signature,
})
.header("content-type", "application/jose+json")
.send()
.await
.map_err(|err| {
IoError::new(
ErrorKind::Other,
format!("failed to send http request: {err}"),
)
})?;
if !resp.status().is_success() {
return Err(IoError::new(
ErrorKind::Other,
@@ -151,11 +149,11 @@ pub(crate) async fn request(
}
pub(crate) async fn request_json<T, R>(
cli: &Client<HttpsConnector<HttpConnector>>,
cli: &Client,
key_pair: &KeyPair,
kid: Option<&str>,
nonce: &str,
uri: &Uri,
uri: &str,
payload: Option<T>,
) -> IoResult<R>
where
@@ -165,8 +163,7 @@ where
let resp = request(cli, key_pair, kid, nonce, uri, payload).await?;
let data = resp
.into_body()
.into_string()
.text()
.await
.map_err(|_| IoError::new(ErrorKind::Other, "failed to read response"))?;
serde_json::from_str(&data)

View File

@@ -10,8 +10,10 @@ use rcgen::{
};
use tokio_rustls::{
rustls::{
sign::{any_ecdsa_type, CertifiedKey},
PrivateKey, ServerConfig,
crypto::ring::sign::any_ecdsa_type,
pki_types::{CertificateDer, PrivateKeyDer},
sign::CertifiedKey,
ServerConfig,
},
server::TlsStream,
TlsAcceptor,
@@ -37,7 +39,6 @@ pub(crate) async fn auto_cert_acceptor<T: Listener>(
challenge_type: ChallengeType,
) -> IoResult<AutoCertAcceptor<T::Acceptor>> {
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(cert_resolver);
server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
@@ -137,7 +138,7 @@ impl<T: Listener> Listener for AutoCertListener<T> {
if let (Some(certs), Some(key)) = (cache_certs, cert_key) {
let certs = certs
.into_iter()
.map(tokio_rustls::rustls::Certificate)
.map(CertificateDer::from)
.collect::<Vec<_>>();
let expires_at = match certs
@@ -156,7 +157,7 @@ impl<T: Listener> Listener for AutoCertListener<T> {
);
*cert_resolver.cert.write() = Some(Arc::new(CertifiedKey::new(
certs,
any_ecdsa_type(&PrivateKey(key)).unwrap(),
any_ecdsa_type(&PrivateKeyDer::Pkcs8(key.into())).unwrap(),
)));
}
@@ -231,13 +232,14 @@ fn gen_acme_cert(domain: &str, acme_hash: &[u8]) -> IoResult<CertifiedKey> {
params.custom_extensions = vec![CustomExtension::new_acme_identifier(acme_hash)];
let cert = Certificate::from_params(params)
.map_err(|_| IoError::new(ErrorKind::Other, "failed to generate acme certificate"))?;
let key = any_ecdsa_type(&PrivateKey(cert.serialize_private_key_der())).unwrap();
let key = any_ecdsa_type(&PrivateKeyDer::Pkcs8(
cert.serialize_private_key_der().into(),
))
.unwrap();
Ok(CertifiedKey::new(
vec![tokio_rustls::rustls::Certificate(
cert.serialize_der().map_err(|_| {
IoError::new(ErrorKind::Other, "failed to serialize acme certificate")
})?,
)],
vec![CertificateDer::from(cert.serialize_der().map_err(
|_| IoError::new(ErrorKind::Other, "failed to serialize acme certificate"),
)?)],
key,
))
}
@@ -353,7 +355,10 @@ pub async fn issue_cert<T: AsRef<str>>(
format!("failed create certificate request: {err}"),
)
})?;
let pk = any_ecdsa_type(&PrivateKey(cert.serialize_private_key_der())).unwrap();
let pk = any_ecdsa_type(&PrivateKeyDer::Pkcs8(
cert.serialize_private_key_der().into(),
))
.unwrap();
let csr = cert.serialize_request_der().map_err(|err| {
IoError::new(
ErrorKind::Other,
@@ -400,7 +405,7 @@ pub async fn issue_cert<T: AsRef<str>>(
let cert_chain = rustls_pemfile::certs(&mut acme_cert_pem.as_slice())
.map_err(|err| IoError::new(ErrorKind::Other, format!("invalid pem: {err}")))?
.into_iter()
.map(tokio_rustls::rustls::Certificate)
.map(CertificateDer::from)
.collect();
let cert_key = CertifiedKey::new(cert_chain, pk);

View File

@@ -12,7 +12,6 @@ mod keypair;
mod listener;
mod protocol;
mod resolver;
mod serde;
pub use auto_cert::AutoCert;
pub use builder::AutoCertBuilder;

View File

@@ -5,8 +5,6 @@ use std::{
use serde::{Deserialize, Serialize};
use crate::listener::acme::serde::SerdeUri;
/// HTTP-01 challenge
const CHALLENGE_TYPE_HTTP_01: &str = "http-01";
@@ -38,9 +36,9 @@ impl Display for ChallengeType {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct Directory {
pub(crate) new_nonce: SerdeUri,
pub(crate) new_account: SerdeUri,
pub(crate) new_order: SerdeUri,
pub(crate) new_nonce: String,
pub(crate) new_account: String,
pub(crate) new_order: String,
}
#[derive(Serialize)]
@@ -75,17 +73,17 @@ pub(crate) struct Problem {
#[serde(rename_all = "camelCase")]
pub(crate) struct NewOrderResponse {
pub(crate) status: String,
pub(crate) authorizations: Vec<SerdeUri>,
pub(crate) authorizations: Vec<String>,
pub(crate) error: Option<Problem>,
pub(crate) finalize: SerdeUri,
pub(crate) certificate: Option<SerdeUri>,
pub(crate) finalize: String,
pub(crate) certificate: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Challenge {
#[serde(rename = "type")]
pub(crate) ty: String,
pub(crate) url: SerdeUri,
pub(crate) url: String,
pub(crate) token: String,
}

View File

@@ -30,7 +30,7 @@ pub fn seconds_until_expiry(cert: &CertifiedKey) -> i64 {
}
/// Shared ACME key state.
#[derive(Default)]
#[derive(Default, Debug)]
pub struct ResolveServerCert {
/// The current TLS certificate. Swap it with `Arc::write`.
pub cert: RwLock<Option<Arc<CertifiedKey>>>,

View File

@@ -1,35 +0,0 @@
use std::{
fmt::{self, Debug, Formatter},
ops::Deref,
};
use http::Uri;
use serde::{de::Error, Deserialize, Deserializer};
pub(crate) struct SerdeUri(pub(crate) Uri);
impl<'de> Deserialize<'de> for SerdeUri {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
String::deserialize(deserializer)?
.parse::<Uri>()
.map(SerdeUri)
.map_err(|err| D::Error::custom(err.to_string()))
}
}
impl Debug for SerdeUri {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl Deref for SerdeUri {
type Target = Uri;
fn deref(&self) -> &Self::Target {
&self.0
}
}

View File

@@ -9,12 +9,11 @@ use rustls_pemfile::Item;
use tokio::io::{Error as IoError, ErrorKind, Result as IoResult};
use tokio_rustls::{
rustls::{
server::{
AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ClientHello,
NoClientAuth, ResolvesServerCert,
},
sign::{self, CertifiedKey},
Certificate, PrivateKey, RootCertStore, ServerConfig,
crypto::ring::sign::any_supported_type,
pki_types::{CertificateDer, PrivateKeyDer},
server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier},
sign::CertifiedKey,
RootCertStore, ServerConfig,
},
server::TlsStream,
};
@@ -72,7 +71,7 @@ impl RustlsCertificate {
impl RustlsCertificate {
fn create_certificate_key(&self) -> IoResult<CertifiedKey> {
let cert = rustls_pemfile::certs(&mut self.cert.as_slice())
.map(|mut certs| certs.drain(..).map(Certificate).collect())
.map(|mut certs| certs.drain(..).map(CertificateDer::from).collect())
.map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls certificates"))?;
let priv_key = {
@@ -90,12 +89,12 @@ impl RustlsCertificate {
_ => continue,
};
if !key.is_empty() {
break PrivateKey(key);
break PrivateKeyDer::Pkcs8(key.into());
}
}
};
let key = sign::any_supported_type(&priv_key)
let key = any_supported_type(&priv_key)
.map_err(|_| IoError::new(ErrorKind::Other, "invalid private key"))?;
Ok(CertifiedKey {
@@ -106,7 +105,6 @@ impl RustlsCertificate {
} else {
None
},
sct_list: None,
})
}
}
@@ -239,24 +237,30 @@ impl RustlsConfig {
);
}
let client_auth = match &self.client_auth {
TlsClientAuth::Off => NoClientAuth::boxed(),
let builder = ServerConfig::builder();
let builder = match &self.client_auth {
TlsClientAuth::Off => builder.with_no_client_auth(),
TlsClientAuth::Optional(trust_anchor) => {
AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?)
.boxed()
let verifier =
WebPkiClientVerifier::builder(read_trust_anchor(trust_anchor)?.into())
.allow_unauthenticated()
.build()
.map_err(|err| IoError::other(err))?;
builder.with_client_cert_verifier(verifier)
}
TlsClientAuth::Required(trust_anchor) => {
AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?).boxed()
let verifier =
WebPkiClientVerifier::builder(read_trust_anchor(trust_anchor)?.into())
.build()
.map_err(|err| IoError::other(err))?;
builder.with_client_cert_verifier(verifier)
}
};
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(client_auth)
.with_cert_resolver(Arc::new(ResolveServerCert {
certifcate_keys,
fallback,
}));
let mut server_config = builder.with_cert_resolver(Arc::new(ResolveServerCert {
certifcate_keys,
fallback,
}));
server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
Ok(server_config)
@@ -268,7 +272,7 @@ fn read_trust_anchor(mut trust_anchor: &[u8]) -> IoResult<RootCertStore> {
let ders = rustls_pemfile::certs(&mut trust_anchor)?;
for der in ders {
store
.add(&Certificate(der))
.add(CertificateDer::from(der))
.map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
}
Ok(store)
@@ -402,6 +406,7 @@ where
}
}
#[derive(Debug)]
struct ResolveServerCert {
certifcate_keys: HashMap<String, Arc<CertifiedKey>>,
fallback: Option<Arc<CertifiedKey>>,
@@ -422,7 +427,7 @@ mod tests {
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tokio_rustls::rustls::{ClientConfig, ServerName};
use tokio_rustls::rustls::{pki_types::ServerName, ClientConfig};
use super::*;
use crate::listener::TcpListener;
@@ -441,7 +446,6 @@ mod tests {
tokio::spawn(async move {
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(
read_trust_anchor(include_bytes!("certs/chain1.pem")).unwrap(),
)

View File

@@ -2,10 +2,10 @@ use std::sync::Arc;
use libopentelemetry::{
global,
propagation::Extractor,
trace::{FutureExt, Span, SpanKind, TraceContextExt, Tracer},
Context, Key,
};
use opentelemetry_http::HeaderExtractor;
use opentelemetry_semantic_conventions::{resource, trace};
use crate::{
@@ -52,6 +52,21 @@ pub struct OpenTelemetryTracingEndpoint<T, E> {
inner: E,
}
struct HeaderExtractor<'a>(&'a http::HeaderMap);
impl<'a> Extractor for HeaderExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|value| value.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.0
.keys()
.map(|value| value.as_str())
.collect::<Vec<_>>()
}
}
#[async_trait::async_trait]
impl<T, E> Endpoint for OpenTelemetryTracingEndpoint<T, E>
where

View File

@@ -1,5 +1,4 @@
use std::{
any::Any,
fmt::{self, Debug, Formatter},
future::Future,
io::Error,
@@ -9,6 +8,8 @@ use std::{
};
use http::uri::Scheme;
use http_body_util::BodyExt;
use hyper::{body::Incoming, rt::Write as _};
use parking_lot::Mutex;
use serde::de::DeserializeOwned;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
@@ -16,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "cookie")]
use crate::web::cookie::CookieJar;
use crate::{
body::Body,
body::{Body, BoxBody},
error::{ParsePathError, ParseQueryError, UpgradeError},
http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
@@ -108,10 +109,10 @@ impl Debug for Request {
}
}
impl From<(http::Request<hyper::Body>, LocalAddr, RemoteAddr, Scheme)> for Request {
impl From<(http::Request<Incoming>, LocalAddr, RemoteAddr, Scheme)> for Request {
fn from(
(req, local_addr, remote_addr, scheme): (
http::Request<hyper::Body>,
http::Request<Incoming>,
LocalAddr,
RemoteAddr,
Scheme,
@@ -131,7 +132,7 @@ impl From<(http::Request<hyper::Body>, LocalAddr, RemoteAddr, Scheme)> for Reque
version: parts.version,
headers: parts.headers,
extensions: parts.extensions,
body: Body(body),
body: Body(body.map_err(|err| Error::other(err)).boxed()),
state: RequestState {
local_addr,
remote_addr,
@@ -146,13 +147,13 @@ impl From<(http::Request<hyper::Body>, LocalAddr, RemoteAddr, Scheme)> for Reque
}
}
impl From<Request> for hyper::Request<hyper::Body> {
impl From<Request> for hyper::Request<BoxBody> {
fn from(req: Request) -> Self {
let mut hyper_req = http::Request::builder()
.method(req.method)
.uri(req.uri)
.version(req.version)
.body(req.body.into())
.body(req.body.0)
.unwrap();
*hyper_req.headers_mut() = req.headers;
*hyper_req.extensions_mut() = req.extensions;
@@ -372,7 +373,7 @@ impl Request {
/// Inserts a value to extensions, similar to
/// `self.extensions().insert(data)`.
#[inline]
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
pub fn set_data(&mut self, data: impl Clone + Send + Sync + 'static) {
self.extensions.insert(data);
}
@@ -490,7 +491,7 @@ impl AsyncRead for Upgraded {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().stream.poll_read(cx, buf)
Pin::new(&mut hyper_util::rt::TokioIo::new(self.project().stream)).poll_read(cx, buf)
}
}
@@ -597,7 +598,7 @@ impl RequestBuilder {
#[must_use]
pub fn extension<T>(mut self, extension: T) -> Self
where
T: Any + Send + Sync + 'static,
T: Clone + Send + Sync + 'static,
{
self.extensions.insert(extension);
self

View File

@@ -1,11 +1,11 @@
use std::{
any::Any,
fmt::{self, Debug, Formatter},
};
use std::fmt::{self, Debug, Formatter};
use bytes::Bytes;
use headers::HeaderMapExt;
use http_body_util::BodyExt;
use crate::{
body::BoxBody,
http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
Extensions, StatusCode, Version,
@@ -77,9 +77,9 @@ impl<T: Into<Body>> From<(StatusCode, T)> for Response {
}
}
impl From<Response> for hyper::Response<hyper::Body> {
impl From<Response> for hyper::Response<BoxBody> {
fn from(resp: Response) -> Self {
let mut hyper_resp = hyper::Response::new(resp.body.into());
let mut hyper_resp = hyper::Response::new(resp.body.0);
*hyper_resp.status_mut() = resp.status;
*hyper_resp.version_mut() = resp.version;
*hyper_resp.headers_mut() = resp.headers;
@@ -88,15 +88,24 @@ impl From<Response> for hyper::Response<hyper::Body> {
}
}
impl From<hyper::Response<hyper::Body>> for Response {
fn from(hyper_resp: hyper::Response<hyper::Body>) -> Self {
impl<T: hyper::body::Body> From<hyper::Response<T>> for Response
where
T: hyper::body::Body + Send + Sync + 'static,
T::Data: Into<Bytes>,
T::Error: Into<std::io::Error>,
{
fn from(hyper_resp: hyper::Response<T>) -> Self {
let (parts, body) = hyper_resp.into_parts();
Response {
status: parts.status,
version: parts.version,
headers: parts.headers,
extensions: parts.extensions,
body: body.into(),
body: Body(
body.map_frame(|frame| frame.map_data(Into::into))
.map_err(Into::into)
.boxed(),
),
}
}
}
@@ -217,7 +226,7 @@ impl Response {
/// Inserts a value to extensions, similar to
/// `self.extensions().insert(data)`.
#[inline]
pub fn set_data(&mut self, data: impl Send + Sync + 'static) {
pub fn set_data(&mut self, data: impl Clone + Send + Sync + 'static) {
self.extensions.insert(data);
}
@@ -312,7 +321,7 @@ impl ResponseBuilder {
#[must_use]
pub fn extension<T>(mut self, extension: T) -> Self
where
T: Any + Send + Sync + 'static,
T: Clone + Send + Sync + 'static,
{
self.extensions.insert(extension);
self

View File

@@ -1085,6 +1085,7 @@ mod tests {
("/*p1", 9),
("/abc/<\\d+>/def", 10),
("/kcd/:p1<\\d+>", 11),
("/:package/-/:package_tgz<.*tgz$>", 12),
];
for (path, id) in paths {
@@ -1154,6 +1155,16 @@ mod tests {
NodeData::new(11, "/kcd/:p1<\\d+>"),
)),
),
(
"/is-number/-/is-number-7.0.0.tgz",
Some((
create_url_params(vec![
("package", "is-number"),
("package_tgz", "is-number-7.0.0.tgz"),
]),
NodeData::new(12, "/:package/-/:package_tgz<.*tgz$>"),
)),
),
];
for (path, mut res) in matches {

View File

@@ -10,7 +10,7 @@ use crate::{
Endpoint, EndpointExt, IntoEndpoint, IntoResponse, Request, Response, Result,
};
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
struct PathPrefix(usize);
/// Routing object

View File

@@ -12,7 +12,8 @@ use std::{
};
use http::uri::Scheme;
use hyper::server::conn::Http;
use hyper::body::Incoming;
use hyper_util::server::conn::auto;
use pin_project_lite::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf, Result as IoResult},
@@ -321,7 +322,7 @@ async fn serve_connection(
let service = hyper::service::service_fn({
let remote_addr = remote_addr.clone();
move |req: hyper::Request<hyper::Body>| {
move |req: http::Request<Incoming>| {
let ep = ep.clone();
let local_addr = local_addr.clone();
let remote_addr = remote_addr.clone();
@@ -352,12 +353,13 @@ async fn serve_connection(
None => tokio_util::either::Either::Right(socket),
};
let mut conn = Http::new()
.serve_connection(socket, service)
.with_upgrades();
let builder = auto::Builder::new(hyper_util::rt::TokioExecutor::new());
let conn =
builder.serve_connection_with_upgrades(hyper_util::rt::TokioIo::new(socket), service);
futures_util::pin_mut!(conn);
tokio::select! {
_ = &mut conn => {
_ = conn => {
// Connection completed successfully.
return;
},
@@ -366,10 +368,4 @@ async fn serve_connection(
}
_ = server_graceful_shutdown_token.cancelled() => {}
}
// Init graceful shutdown for connection (`GOAWAY` for `HTTP/2` or disabling `keep-alive` for `HTTP/1`)
Pin::new(&mut conn).graceful_shutdown();
// Continue awaiting after graceful-shutdown is initiated to handle existed requests.
let _ = conn.await;
}

View File

@@ -200,7 +200,7 @@ impl<'a, E> TestRequestBuilder<'a, E> {
#[must_use]
pub fn data<T>(mut self, data: T) -> Self
where
T: Send + Sync + 'static,
T: Clone + Send + Sync + 'static,
{
self.extensions.insert(data);
self

View File

@@ -33,6 +33,7 @@ impl<'a> FromRequest<'a> for &'a CsrfToken {
/// A verifier for CSRF Token.
///
/// See also [`Csrf`](crate::middleware::Csrf)
#[derive(Clone)]
#[cfg_attr(docsrs, doc(cfg(feature = "csrf")))]
pub struct CsrfVerifier {
cookie: Option<UnencryptedCsrfCookie>,

View File

@@ -129,7 +129,10 @@ impl StaticFileRequest {
let mut content_length = data.len() as u64;
let mut content_range = None;
let body = if let Some((start, end)) = self.range.and_then(|range| range.iter().next()) {
let body = if let Some((start, end)) = self
.range
.and_then(|range| range.satisfiable_ranges(data.len() as u64).next())
{
let start = match start {
Bound::Included(n) => n,
Bound::Excluded(n) => n + 1,
@@ -232,7 +235,10 @@ impl StaticFileRequest {
let mut content_range = None;
let body = if let Some((start, end)) = self.range.and_then(|range| range.iter().next()) {
let body = if let Some((start, end)) = self
.range
.and_then(|range| range.satisfiable_ranges(metadata.len()).next())
{
let start = match start {
Bound::Included(n) => n,
Bound::Excluded(n) => n + 1,