This commit is contained in:
Sunli
2022-02-24 12:53:08 +08:00
parent 5bb9c247d2
commit bc5b3e9d0e
3 changed files with 65 additions and 46 deletions

View File

@@ -7,8 +7,8 @@ publish = false
[dependencies]
poem = { path = "../../../poem", features = ["tower-compat"] }
tokio = { version = "1.17.0", features = ["rt-multi-thread", "macros"] }
prost = "0.8.0"
tonic = "0.5.2"
prost = "0.9.0"
tonic = "0.6.2"
tracing-subscriber = "0.3.9"
[build-dependencies]

View File

@@ -34,9 +34,19 @@ async fn main() -> Result<(), std::io::Error> {
tracing_subscriber::fmt::init();
let app = Route::new().nest_no_strip(
format!("/{}", GreeterServer::<MyGreeter>::NAME),
GreeterServer::new(MyGreeter).compat(),
"/",
tonic::transport::Server::builder()
.add_service(GreeterServer::new(MyGreeter))
.into_service()
.compat(),
);
tokio::spawn(
tonic::transport::Server::builder()
.add_service(GreeterServer::new(MyGreeter))
.serve("127.0.0.1:3001".parse().unwrap()),
);
Server::new(TcpListener::bind("127.0.0.1:3000"))
.run(app)
.await

View File

@@ -1,34 +1,32 @@
use std::{error::Error as StdError, future::Future};
use std::{error::Error as StdError, future::Future, marker::PhantomData};
use bytes::Bytes;
use hyper::body::HttpBody;
use tower::{Service, ServiceExt};
use crate::{body::BodyStream, error::InternalServerError, Endpoint, Request, Response, Result};
use crate::{body::BodyStream, Endpoint, Error, Request, Response, Result};
/// Extension trait for tower service compat.
#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
pub trait TowerCompatExt {
/// Converts a tower service to a poem endpoint.
fn compat<ResBody, Err, Fut>(self) -> TowerCompatEndpoint<Self>
fn compat<Req, Resp, Err, Fut>(self) -> TowerCompatEndpoint<Req, Self>
where
ResBody: HttpBody + Send + 'static,
ResBody::Data: Into<Bytes> + Send + 'static,
ResBody::Error: StdError + Send + Sync + 'static,
Err: StdError + Send + Sync + 'static,
Self: Service<
http::Request<hyper::Body>,
Response = hyper::Response<ResBody>,
Error = Err,
Future = Fut,
> + Clone
Req: From<Request> + Send + Sync,
Resp: Into<Response>,
Err: Into<Error>,
Fut: Future<Output = Result<Resp, Err>> + Send + 'static,
Self: Service<Req, Response = Resp, Error = Err, Future = Fut>
+ Clone
+ Send
+ Sync
+ Sized
+ 'static,
Fut: Future<Output = Result<hyper::Response<ResBody>, Err>> + Send + 'static,
{
TowerCompatEndpoint(self)
TowerCompatEndpoint {
marker: PhantomData,
svc: self,
}
}
}
@@ -36,42 +34,29 @@ impl<T> TowerCompatExt for T {}
/// A tower service adapter.
#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
pub struct TowerCompatEndpoint<Svc>(Svc);
pub struct TowerCompatEndpoint<Req, Svc> {
marker: PhantomData<Req>,
svc: Svc,
}
#[async_trait::async_trait]
impl<Svc, ResBody, Err, Fut> Endpoint for TowerCompatEndpoint<Svc>
impl<Svc, Req, Resp, Err, Fut> Endpoint for TowerCompatEndpoint<Req, Svc>
where
ResBody: HttpBody + Send + 'static,
ResBody::Data: Into<Bytes> + Send + 'static,
ResBody::Error: StdError + Send + Sync + 'static,
Err: StdError + Send + Sync + 'static,
Svc: Service<
http::Request<hyper::Body>,
Response = hyper::Response<ResBody>,
Error = Err,
Future = Fut,
> + Clone
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<hyper::Response<ResBody>, Err>> + Send + 'static,
Req: From<Request> + Send + Sync,
Resp: Into<Response>,
Err: Into<Error>,
Fut: Future<Output = Result<Resp, Err>> + Send + 'static,
Svc: Service<Req, Response = Resp, Error = Err, Future = Fut> + Clone + Send + Sync + 'static,
{
type Output = Response;
async fn call(&self, req: Request) -> Result<Self::Output> {
let mut svc = self.0.clone();
let mut svc = self.svc.clone();
svc.ready().await.map_err(InternalServerError)?;
svc.ready().await.map_err(Into::into)?;
let hyper_req: http::Request<hyper::Body> = req.into();
let hyper_resp = svc
.call(hyper_req.map(Into::into))
.await
.map_err(InternalServerError)?;
Ok(hyper_resp
.map(|body| hyper::Body::wrap_stream(BodyStream::new(body)))
.into())
let req: Req = req.into();
svc.call(req).await.map(Into::into).map_err(Into::into)
}
}
@@ -79,6 +64,7 @@ where
mod tests {
use std::{
convert::Infallible,
num::ParseIntError,
task::{Context, Poll},
};
@@ -109,4 +95,27 @@ mod tests {
let resp = ep.call(Request::builder().body("abc")).await.unwrap();
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
}
#[tokio::test]
async fn test_map() {
#[derive(Clone)]
struct MyTowerService;
impl Service<&str> for MyTowerService {
type Response = i32;
type Error = ParseIntError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: &str) -> Self::Future {
futures_util::future::ready(req.parse())
}
}
let ep =
ServiceExt::map_request(MyTowerService, |req| Request::builder().body(req)).compat();
}
}