Make all test cases use poem::test.

This commit is contained in:
Sunli
2022-03-08 10:17:45 +08:00
parent 509e24d6df
commit 0e610e7ebd
48 changed files with 1195 additions and 1618 deletions

View File

@@ -34,7 +34,6 @@ serde = { version = "1.0.130", features = ["derive"] }
derive_more = "0.99.16"
num-traits = "0.2.14"
regex = "1.5.5"
typed-headers = "0.2.0"
mime = "0.3.16"
thiserror = "1.0.30"
bytes = "1.1.0"

View File

@@ -1,5 +1,7 @@
use poem::{Request, Result};
use typed_headers::{AuthScheme, Authorization, HeaderMapExt};
use poem::{
web::headers::{Authorization, HeaderMapExt},
Request, Result,
};
use crate::{auth::BasicAuthorization, error::AuthorizationError};
@@ -14,24 +16,14 @@ pub struct Basic {
impl BasicAuthorization for Basic {
fn from_request(req: &Request) -> Result<Self> {
if let Some(auth) = req.headers().typed_get::<Authorization>().ok().flatten() {
if auth.0.scheme() == &AuthScheme::BASIC {
if let Some(token68) = auth.token68() {
if let Ok(value) = base64::decode(token68.as_str()) {
if let Ok(value) = String::from_utf8(value) {
let mut s = value.split(':');
if let (Some(username), Some(password), None) =
(s.next(), s.next(), s.next())
{
return Ok(Basic {
username: username.to_string(),
password: password.to_string(),
});
}
}
}
}
}
if let Some(auth) = req
.headers()
.typed_get::<Authorization<poem::web::headers::authorization::Basic>>()
{
return Ok(Basic {
username: auth.username().to_string(),
password: auth.password().to_string(),
});
}
Err(AuthorizationError.into())

View File

@@ -1,5 +1,7 @@
use poem::{Request, Result};
use typed_headers::{AuthScheme, Authorization, HeaderMapExt};
use poem::{
web::headers::{Authorization, HeaderMapExt},
Request, Result,
};
use crate::{auth::BearerAuthorization, error::AuthorizationError};
@@ -11,14 +13,13 @@ pub struct Bearer {
impl BearerAuthorization for Bearer {
fn from_request(req: &Request) -> Result<Self> {
if let Some(auth) = req.headers().typed_get::<Authorization>().ok().flatten() {
if auth.0.scheme() == &AuthScheme::BEARER {
if let Some(token68) = auth.token68() {
return Ok(Bearer {
token: token68.as_str().to_string(),
});
}
}
if let Some(auth) = req
.headers()
.typed_get::<Authorization<poem::web::headers::authorization::Bearer>>()
{
return Ok(Bearer {
token: auth.token().to_string(),
});
}
Err(AuthorizationError.into())

View File

@@ -19,6 +19,7 @@ use crate::{
/// use poem::{
/// error::BadRequest,
/// http::{Method, StatusCode, Uri},
/// test::TestClient,
/// Body, IntoEndpoint, Request, Result,
/// };
/// use poem_openapi::{
@@ -48,33 +49,22 @@ use crate::{
/// }
/// }
///
/// let api = OpenApiService::new(MyApi::default(), "Demo", "0.1.0").into_endpoint();
/// let api = OpenApiService::new(MyApi::default(), "Demo", "0.1.0");
/// let cli = TestClient::new(api);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = api
/// .call(
/// Request::builder()
/// .method(Method::POST)
/// .content_type("text/plain")
/// .uri(Uri::from_static("/upload"))
/// .body("YWJjZGVm"),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "6");
/// let resp = cli
/// .post("/upload")
/// .content_type("text/plain")
/// .body("YWJjZGVm")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("6").await;
///
/// let resp = api
/// .call(
/// Request::builder()
/// .method(Method::GET)
/// .uri(Uri::from_static("/download"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "YWJjZGVm");
/// let resp = cli.get("/download").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("YWJjZGVm").await;
/// # });
/// ```
#[derive(Debug, Clone, Eq, PartialEq)]

View File

@@ -17,6 +17,7 @@ use crate::{
/// use poem::{
/// error::BadRequest,
/// http::{Method, StatusCode, Uri},
/// test::TestClient,
/// Body, IntoEndpoint, Request, Result,
/// };
/// use poem_openapi::{
@@ -43,34 +44,27 @@ use crate::{
/// }
/// }
///
/// let api = OpenApiService::new(MyApi, "Demo", "0.1.0").into_endpoint();
/// let api = OpenApiService::new(MyApi, "Demo", "0.1.0");
/// let cli = TestClient::new(api);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = api
/// .call(
/// Request::builder()
/// .method(Method::POST)
/// .content_type("application/octet-stream")
/// .uri(Uri::from_static("/upload"))
/// .body("abcdef"),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "6");
/// let resp = cli
/// .post("/upload")
/// .content_type("application/octet-stream")
/// .body("abcdef")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("6").await;
///
/// let resp = api
/// .call(
/// Request::builder()
/// .method(Method::POST)
/// .content_type("application/octet-stream")
/// .uri(Uri::from_static("/upload_stream"))
/// .body("abcdef"),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "6");
/// let resp = cli
/// .post("/upload_stream")
/// .content_type("application/octet-stream")
/// .body("abcdef")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("6").await;
/// # });
/// ```
#[derive(Debug, Clone, Eq, PartialEq)]

View File

@@ -11,6 +11,41 @@ use crate::{
/// A response type wrapper.
///
/// Use it to modify the status code and HTTP headers.
///
/// # Examples
///
/// ```
/// use poem::{
/// error::BadRequest,
/// http::{Method, StatusCode, Uri},
/// test::TestClient,
/// Body, IntoEndpoint, Request, Result,
/// };
/// use poem_openapi::{
/// payload::{Json, Response},
/// OpenApi, OpenApiService,
/// };
/// use tokio::io::AsyncReadExt;
///
/// struct MyApi;
///
/// #[OpenApi]
/// impl MyApi {
/// #[oai(path = "/test", method = "get")]
/// async fn test(&self) -> Response<Json<i32>> {
/// Response::new(Json(100)).header("foo", "bar")
/// }
/// }
///
/// let api = OpenApiService::new(MyApi, "Demo", "0.1.0");
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = TestClient::new(api).get("/test").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_header("foo", "bar");
/// resp.assert_text("100").await;
/// # });
/// ```
pub struct Response<T> {
inner: T,
status: Option<StatusCode>,

View File

@@ -1,7 +1,8 @@
use poem::{
http::{Method, StatusCode, Uri},
http::{Method, StatusCode},
test::TestClient,
web::Data,
Endpoint, EndpointExt, Error, IntoEndpoint,
EndpointExt, Error,
};
use poem_openapi::{
param::Query,
@@ -25,17 +26,9 @@ async fn path_and_method() {
assert_eq!(meta.paths[0].path, "/abc");
assert_eq!(meta.paths[0].operations[0].method, Method::POST);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
poem::Request::builder()
.method(Method::POST)
.uri(Uri::from_static("/abc"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ep = OpenApiService::new(Api, "test", "1.0");
let cli = TestClient::new(ep);
cli.post("/abc").send().await.assert_status_is_ok();
}
#[test]
@@ -105,17 +98,12 @@ async fn common_attributes() {
vec!["CommonOperations", "UserOperations"]
);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/hello/world"))
.finish(),
)
let ep = OpenApiService::new(Api, "test", "1.0");
TestClient::new(ep)
.get("/hello/world")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -171,42 +159,29 @@ async fn request() {
}
);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/"))
.content_type("application/json")
.body("100"),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ep = OpenApiService::new(Api, "test", "1.0");
let cli = TestClient::new(ep);
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/"))
.content_type("text/plain")
.body("abc"),
)
cli.get("/")
.content_type("application/json")
.body("100")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/"))
.content_type("application/octet-stream")
.body(vec![1, 2, 3]),
)
cli.get("/")
.content_type("text/plain")
.body("abc")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
cli.get("/")
.content_type("application/octet-stream")
.body(vec![1, 2, 3])
.send()
.await
.assert_status_is_ok();
}
#[tokio::test]
@@ -228,29 +203,22 @@ async fn payload_request() {
assert_eq!(meta_request.content[0].content_type, "application/json");
assert_eq!(meta_request.content[0].schema, i32::schema_ref());
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
poem::Request::builder()
.method(Method::POST)
.uri(Uri::from_static("/"))
.content_type("application/json")
.body("100"),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ep = OpenApiService::new(Api, "test", "1.0");
let cli = TestClient::new(ep);
let resp = ep
.get_response(
poem::Request::builder()
.method(Method::POST)
.uri(Uri::from_static("/"))
.content_type("text/plain")
.body("100"),
)
.await;
assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
cli.post("/")
.content_type("application/json")
.body("100")
.send()
.await
.assert_status_is_ok();
cli.post("/")
.content_type("text/plain")
.body("100")
.send()
.await
.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[tokio::test]
@@ -312,45 +280,22 @@ async fn response() {
String::schema_ref()
);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let ep = OpenApiService::new(Api, "test", "1.0");
let cli = TestClient::new(ep);
let mut resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/?code=200"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.take_body().into_string().await.unwrap(), "");
let resp = cli.get("/").query("code", &200).send().await;
resp.assert_status_is_ok();
resp.assert_text("").await;
let mut resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/?code=409"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::CONFLICT);
assert_eq!(resp.content_type(), Some("application/json; charset=utf-8"));
assert_eq!(resp.take_body().into_string().await.unwrap(), "409");
let resp = cli.get("/").query("code", &409).send().await;
resp.assert_status(StatusCode::CONFLICT);
resp.assert_content_type("application/json; charset=utf-8");
resp.assert_text("409").await;
let mut resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/?code=404"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
assert_eq!(resp.content_type(), Some("text/plain; charset=utf-8"));
assert_eq!(resp.take_body().into_string().await.unwrap(), "code: 404");
let resp = cli.get("/").query("code", &404).send().await;
resp.assert_status(StatusCode::NOT_FOUND);
resp.assert_content_type("text/plain; charset=utf-8");
resp.assert_text("code: 404").await;
}
#[tokio::test]
@@ -380,36 +325,21 @@ async fn bad_request_handler() {
}
}
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let ep = OpenApiService::new(Api, "test", "1.0");
let cli = TestClient::new(ep);
let mut resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/?code=200"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.content_type(), Some("text/plain; charset=utf-8"));
assert_eq!(resp.take_body().into_string().await.unwrap(), "code: 200");
let resp = cli.get("/").query("code", &200).send().await;
resp.assert_status_is_ok();
resp.assert_content_type("text/plain; charset=utf-8");
resp.assert_text("code: 200").await;
let mut resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_eq!(resp.content_type(), Some("text/plain; charset=utf-8"));
assert_eq!(
resp.take_body().into_string().await.unwrap(),
r#"!!! failed to parse parameter `code`: Type "integer(uint16)" expects an input value."#
);
let resp = cli.get("/").send().await;
resp.assert_status(StatusCode::BAD_REQUEST);
resp.assert_content_type("text/plain; charset=utf-8");
resp.assert_text(
r#"!!! failed to parse parameter `code`: Type "integer(uint16)" expects an input value."#,
)
.await;
}
#[tokio::test]
@@ -442,36 +372,18 @@ async fn bad_request_handler_for_validator() {
}
}
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let ep = OpenApiService::new(Api, "test", "1.0");
let cli = TestClient::new(ep);
let mut resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/?code=50"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.content_type(), Some("text/plain; charset=utf-8"));
assert_eq!(resp.take_body().into_string().await.unwrap(), "code: 50");
let resp = cli.get("/").query("code", &50).send().await;
resp.assert_status_is_ok();
resp.assert_content_type("text/plain; charset=utf-8");
resp.assert_text("code: 50").await;
let mut resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/?code=200"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_eq!(resp.content_type(), Some("text/plain; charset=utf-8"));
assert_eq!(
resp.take_body().into_string().await.unwrap(),
r#"!!! failed to parse parameter `code`: verification failed. maximum(100, exclusive: false)"#
);
let resp = cli.get("/").query("code", &200).send().await;
resp.assert_status(StatusCode::BAD_REQUEST);
resp.assert_content_type("text/plain; charset=utf-8");
resp.assert_text(r#"!!! failed to parse parameter `code`: verification failed. maximum(100, exclusive: false)"#).await;
}
#[tokio::test]
@@ -486,19 +398,12 @@ async fn poem_extract() {
}
}
let ep = OpenApiService::new(Api, "test", "1.0")
.data(100i32)
.into_endpoint();
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/"))
.finish(),
)
let ep = OpenApiService::new(Api, "test", "1.0").data(100i32);
TestClient::new(ep)
.get("/")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -541,56 +446,24 @@ async fn returning_borrowed_value() {
"test",
"1.0",
)
.into_endpoint()
.data(888i32);
let cli = TestClient::new(ep);
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/value1"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "999");
let resp = cli.get("/value1").send().await;
resp.assert_status_is_ok();
resp.assert_text("999").await;
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/value2"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "\"abc\"");
let resp = cli.get("/value2").send().await;
resp.assert_status_is_ok();
resp.assert_text("\"abc\"").await;
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/value3"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "888");
let resp = cli.get("/value3").send().await;
resp.assert_status_is_ok();
resp.assert_text("888").await;
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/values"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "[1,2,3,4,5]");
let resp = cli.get("/values").send().await;
resp.assert_status_is_ok();
resp.assert_text("[1,2,3,4,5]").await;
}
#[tokio::test]
@@ -644,21 +517,12 @@ async fn generic() {
}
}
let ep = OpenApiService::new(MyOpenApi { api: MyApiA }, "test", "1.0").into_endpoint();
let resp = ep
.call(
poem::Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/some_call"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.into_body().into_json::<String>().await.unwrap(),
"test"
);
let ep = OpenApiService::new(MyOpenApi { api: MyApiA }, "test", "1.0");
let cli = TestClient::new(ep);
let resp = cli.get("/some_call").send().await;
resp.assert_status_is_ok();
resp.assert_json("test").await;
}
#[tokio::test]

View File

@@ -1,7 +1,7 @@
use poem::{
http::{header, Method, StatusCode, Uri},
http::header,
test::TestClient,
web::cookie::{Cookie, CookieJar, CookieKey},
Endpoint, IntoEndpoint, Request,
};
use poem_openapi::{
param::{Cookie as ParamCookie, CookiePrivate, CookieSigned, Header, Path, Query},
@@ -30,17 +30,13 @@ async fn param_name() {
let meta: MetaApi = Api::meta().remove(0);
assert_eq!(meta.paths[0].operations[0].params[0].name, "a");
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/abc?a=10"))
.finish(),
)
let ep = OpenApiService::new(Api, "test", "1.0");
TestClient::new(ep)
.get("/abc")
.query("a", &10)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -62,12 +58,13 @@ async fn query() {
);
assert_eq!(meta.paths[0].operations[0].params[0].name, "v");
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api
.call(Request::builder().uri(Uri::from_static("/?v=10")).finish())
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.query("v", &10)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -96,16 +93,15 @@ async fn query_multiple_values() {
"array"
);
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api
.call(
Request::builder()
.uri(Uri::from_static("/?v=10&v=20&v=30"))
.finish(),
)
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.query("v", &10)
.query("v", &20)
.query("v", &30)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -135,9 +131,12 @@ async fn query_default() {
}))
);
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api.call(Request::default()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.send()
.await
.assert_status_is_ok();
}
#[tokio::test]
@@ -152,12 +151,13 @@ async fn header() {
}
}
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api
.call(Request::builder().header("v", 10).finish())
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.header("v", 10)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -172,18 +172,15 @@ async fn header_multiple_values() {
}
}
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api
.call(
Request::builder()
.header("v", 10)
.header("v", 20)
.header("v", 30)
.finish(),
)
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.header("v", 10)
.header("v", 20)
.header("v", 30)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -198,9 +195,12 @@ async fn header_default() {
}
}
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api.call(Request::default()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.send()
.await
.assert_status_is_ok();
}
#[tokio::test]
@@ -215,12 +215,12 @@ async fn path() {
}
}
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api
.call(Request::builder().uri(Uri::from_static("/k/10")).finish())
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/k/10")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -238,9 +238,7 @@ async fn cookie() {
}
let cookie_key = CookieKey::generate();
let api = OpenApiService::new(Api, "test", "1.0")
.cookie_key(cookie_key.clone())
.into_endpoint();
let api = OpenApiService::new(Api, "test", "1.0").cookie_key(cookie_key.clone());
let cookie_jar = CookieJar::default();
cookie_jar.add(Cookie::new_with_str("v1", "10"));
@@ -257,11 +255,12 @@ async fn cookie() {
cookie_jar.get("v3").unwrap()
);
let resp = api
.call(Request::builder().header(header::COOKIE, cookie).finish())
TestClient::new(api)
.get("/")
.header(header::COOKIE, cookie)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -276,9 +275,12 @@ async fn cookie_default() {
}
}
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api.call(Request::builder().finish()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.send()
.await
.assert_status_is_ok();
}
#[tokio::test]
@@ -362,12 +364,12 @@ async fn default_opt() {
Some(json!(88))
);
let api = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = api
.call(Request::builder().uri(Uri::from_static("/")).finish())
let api = OpenApiService::new(Api, "test", "1.0");
TestClient::new(api)
.get("/")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -416,17 +418,13 @@ async fn query_rename() {
MetaParamIn::Query
);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/abc?fooBar=10"))
.finish(),
)
let ep = OpenApiService::new(Api, "test", "1.0");
TestClient::new(ep)
.get("/abc")
.query("fooBar", &10)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -448,17 +446,12 @@ async fn path_rename() {
MetaParamIn::Path
);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/abc/10"))
.finish(),
)
let ep = OpenApiService::new(Api, "test", "1.0");
TestClient::new(ep)
.get("/abc/10")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -480,18 +473,13 @@ async fn header_rename() {
MetaParamIn::Header
);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/abc"))
.header("foo-bar", "10")
.finish(),
)
let ep = OpenApiService::new(Api, "test", "1.0");
TestClient::new(ep)
.get("/abc")
.header("foo-bar", "10")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -513,16 +501,11 @@ async fn cookie_rename() {
MetaParamIn::Cookie
);
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.uri(Uri::from_static("/abc"))
.header(header::COOKIE, "fooBar=10")
.finish(),
)
let ep = OpenApiService::new(Api, "test", "1.0");
TestClient::new(ep)
.get("/abc")
.header(header::COOKIE, "fooBar=10")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}

View File

@@ -1,7 +1,4 @@
use poem::{
http::{StatusCode, Uri},
Endpoint, Error, IntoEndpoint, Request,
};
use poem::{http::StatusCode, test::TestClient, Error};
use poem_openapi::{
param::Query,
payload::{Json, Response},
@@ -39,30 +36,18 @@ async fn response_wrapper() {
}
}
let ep = OpenApiService::new(Api, "test", "1.0").into_endpoint();
let ep = OpenApiService::new(Api, "test", "1.0");
let cli = TestClient::new(ep);
let resp = ep
.call(Request::builder().uri(Uri::from_static("/a")).finish())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.header("myheader"), Some("abc"));
let resp = cli.get("/a").send().await;
resp.assert_status_is_ok();
resp.assert_header("myheader", "abc");
let resp = ep
.call(
Request::builder()
.uri(Uri::from_static("/b?p1=qwe"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.header("myheader"), Some("qwe"));
let resp = cli.get("/b").query("p1", &"qwe").send().await;
resp.assert_status_is_ok();
resp.assert_header("myheader", "qwe");
let resp = ep
.call(Request::builder().uri(Uri::from_static("/b")).finish())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_eq!(resp.header("MY-HEADER1"), Some("def"));
let resp = cli.get("/b").send().await;
resp.assert_status(StatusCode::BAD_REQUEST);
resp.assert_header("MY-HEADER1", "def");
}

View File

@@ -1,7 +1,7 @@
use poem::{
http::{header, Uri},
web::cookie::Cookie,
Endpoint, IntoEndpoint,
http::header,
test::TestClient,
web::{cookie::Cookie, headers},
};
use poem_openapi::{
auth::{ApiKey, Basic, Bearer},
@@ -9,7 +9,8 @@ use poem_openapi::{
registry::{MetaOAuthFlow, MetaOAuthFlows, MetaOAuthScope, MetaSecurityScheme, Registry},
ApiExtractor, OAuthScopes, OpenApi, OpenApiService, SecurityScheme,
};
use typed_headers::{http::StatusCode, Token68};
use crate::headers::Authorization;
#[test]
fn rename() {
@@ -105,23 +106,14 @@ async fn basic_auth() {
}
}
let service = OpenApiService::new(MyApi, "test", "1.0").into_endpoint();
let mut resp = service
.call(
poem::Request::builder()
.uri(Uri::from_static("/test"))
.header(
header::AUTHORIZATION,
typed_headers::Credentials::basic("abc", "123456")
.unwrap()
.to_string(),
)
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.take_body().into_string().await.unwrap(), "abc/123456");
let service = OpenApiService::new(MyApi, "test", "1.0");
let resp = TestClient::new(service)
.get("/test")
.typed_header(Authorization::basic("abc", "123456"))
.send()
.await;
resp.assert_status_is_ok();
resp.assert_text("abc/123456").await;
}
#[tokio::test]
@@ -156,21 +148,14 @@ async fn bearer_auth() {
}
}
let service = OpenApiService::new(MyApi, "test", "1.0").into_endpoint();
let mut resp = service
.call(
poem::Request::builder()
.uri(Uri::from_static("/test"))
.header(
header::AUTHORIZATION,
typed_headers::Credentials::bearer(Token68::new("abcdef").unwrap()).to_string(),
)
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef");
let service = OpenApiService::new(MyApi, "test", "1.0");
let resp = TestClient::new(service)
.get("/test")
.typed_header(headers::Authorization::bearer("abcdef").unwrap())
.send()
.await;
resp.assert_status_is_ok();
resp.assert_text("abcdef").await;
}
#[tokio::test]
@@ -263,44 +248,31 @@ async fn api_key_auth() {
}
}
let service = OpenApiService::new(MyApi, "test", "1.0").into_endpoint();
let mut resp = service
.call(
poem::Request::builder()
.uri(Uri::from_static("/header"))
.header("X-API-Key", "abcdef")
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef");
let service = OpenApiService::new(MyApi, "test", "1.0");
let cli = TestClient::new(service);
let mut resp = service
.call(
poem::Request::builder()
.uri(Uri::from_static("/query?key=abcdef"))
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef");
let resp = cli
.get("/header")
.header("X-API-Key", "abcdef")
.send()
.await;
resp.assert_status_is_ok();
resp.assert_text("abcdef").await;
let mut resp = service
.call(
poem::Request::builder()
.uri(Uri::from_static("/cookie"))
.header(
header::COOKIE,
Cookie::new_with_str("key", "abcdef").to_string(),
)
.finish(),
let resp = cli.get("/query").query("key", &"abcdef").send().await;
resp.assert_status_is_ok();
resp.assert_text("abcdef").await;
let resp = cli
.get("/cookie")
.header(
header::COOKIE,
Cookie::new_with_str("key", "abcdef").to_string(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef");
.send()
.await;
resp.assert_status_is_ok();
resp.assert_text("abcdef").await;
}
#[tokio::test]

View File

@@ -4,10 +4,7 @@ use std::{
ops::Range,
};
use poem::{
http::{StatusCode, Uri},
Endpoint, IntoEndpoint, Request, Result,
};
use poem::{http::StatusCode, test::TestClient, Result};
use poem_openapi::{
param::Query,
payload::Payload,
@@ -299,15 +296,15 @@ async fn param_validator() {
}
}
let api = OpenApiService::new(Api, "test1", "1.0").into_endpoint();
let err = api
.call(Request::builder().uri(Uri::from_static("/?v=999")).finish())
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"failed to parse parameter `v`: verification failed. maximum(100, exclusive: true)"
);
let api = OpenApiService::new(Api, "test1", "1.0");
let cli = TestClient::new(api);
let resp = cli.get("/").query("v", &999).send().await;
resp.assert_status(StatusCode::BAD_REQUEST);
resp.assert_text(
"failed to parse parameter `v`: verification failed. maximum(100, exclusive: true)",
)
.await;
let meta: MetaApi = Api::meta().remove(0);
assert_eq!(
@@ -325,34 +322,24 @@ async fn param_validator() {
Some(true)
);
let resp = api
.call(Request::builder().uri(Uri::from_static("/?v=50")).finish())
cli.get("/")
.query("v", &50)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
let err = api
.call(
Request::builder()
.uri(Uri::from_static("/test2?v=101"))
.finish(),
)
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"failed to parse parameter `v`: verification failed. maximum(100, exclusive: true)"
);
let resp = cli.get("/test2").query("v", &101).send().await;
resp.assert_status(StatusCode::BAD_REQUEST);
resp.assert_text(
"failed to parse parameter `v`: verification failed. maximum(100, exclusive: true)",
)
.await;
let resp = api
.call(
Request::builder()
.uri(Uri::from_static("/test2?v=50"))
.finish(),
)
cli.get("/test2")
.query("v", &50)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[test]

View File

@@ -27,7 +27,10 @@ pub trait Endpoint: Send + Sync {
/// # Example
///
/// ```
/// use poem::{error::NotFoundError, handler, http::StatusCode, Endpoint, Request, Result};
/// use poem::{
/// error::NotFoundError, handler, http::StatusCode, test::TestClient, Endpoint, Request,
/// Result,
/// };
///
/// #[handler]
/// fn index() -> Result<()> {
@@ -35,8 +38,11 @@ pub trait Endpoint: Send + Sync {
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = index.get_response(Request::default()).await;
/// assert_eq!(resp.status(), StatusCode::NOT_FOUND);
/// TestClient::new(index)
/// .get("/")
/// .send()
/// .await
/// .assert_status(StatusCode::NOT_FOUND);
/// # });
/// ```
async fn get_response(&self, req: Request) -> Response {
@@ -115,16 +121,15 @@ where
/// # Example
///
/// ```
/// use poem::{endpoint::make_sync, http::Method, Endpoint, Request};
/// use poem::{endpoint::make_sync, http::Method, test::TestClient, Endpoint, Request};
///
/// let ep = make_sync(|req| req.method().to_string());
/// let cli = TestClient::new(ep);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = ep
/// .call(Request::builder().method(Method::GET).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp, "GET");
/// let resp = cli.get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("GET").await;
/// # });
/// ```
pub fn make_sync<F, T, R>(f: F) -> impl Endpoint<Output = T>
@@ -146,16 +151,15 @@ where
/// # Example
///
/// ```
/// use poem::{endpoint::make, http::Method, Endpoint, Request};
/// use poem::{endpoint::make, http::Method, test::TestClient, Endpoint, Request};
///
/// let ep = make(|req| async move { req.method().to_string() });
/// let app = TestClient::new(ep);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = ep
/// .call(Request::builder().method(Method::GET).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp, "GET");
/// let resp = app.get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("GET").await;
/// # });
/// ```
pub fn make<F, Fut, T, R>(f: F) -> impl Endpoint<Output = T>
@@ -218,8 +222,8 @@ pub trait EndpointExt: IntoEndpoint {
///
/// ```
/// use poem::{
/// get, handler, http::StatusCode, middleware::AddData, web::Data, Endpoint, EndpointExt,
/// Request, Route,
/// get, handler, http::StatusCode, middleware::AddData, test::TestClient, web::Data, Endpoint,
/// EndpointExt, Request, Route,
/// };
///
/// #[handler]
@@ -228,10 +232,12 @@ pub trait EndpointExt: IntoEndpoint {
/// }
///
/// let app = Route::new().at("/", get(index)).with(AddData::new(100i32));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app.call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "100");
/// let resp = cli.get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("100").await;
/// # });
/// ```
fn with<T>(self, middleware: T) -> T::Output
@@ -251,6 +257,7 @@ pub trait EndpointExt: IntoEndpoint {
/// get, handler,
/// http::{StatusCode, Uri},
/// middleware::AddData,
/// test::TestClient,
/// web::Data,
/// Endpoint, EndpointExt, Request, Route,
/// };
@@ -266,21 +273,16 @@ pub trait EndpointExt: IntoEndpoint {
/// let app = Route::new()
/// .at("/a", get(index).with_if(true, AddData::new(100i32)))
/// .at("/b", get(index).with_if(false, AddData::new(100i32)));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(Request::builder().uri(Uri::from_static("/a")).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "100");
/// let resp = cli.get("/a").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("100").await;
///
/// let resp = app
/// .call(Request::builder().uri(Uri::from_static("/b")).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "none");
/// let resp = cli.get("/b").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("none").await;
/// # });
/// ```
fn with_if<T>(self, enable: bool, middleware: T) -> EitherEndpoint<Self, T::Output>
@@ -300,7 +302,9 @@ pub trait EndpointExt: IntoEndpoint {
/// # Example
///
/// ```
/// use poem::{handler, http::StatusCode, web::Data, Endpoint, EndpointExt, Request};
/// use poem::{
/// handler, http::StatusCode, test::TestClient, web::Data, Endpoint, EndpointExt, Request,
/// };
///
/// #[handler]
/// async fn index(data: Data<&i32>) -> String {
@@ -308,9 +312,9 @@ pub trait EndpointExt: IntoEndpoint {
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let mut resp = index.data(100i32).call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.take_body().into_string().await.unwrap(), "100");
/// let resp = TestClient::new(index.data(100i32)).get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("100").await;
/// # });
/// ```
fn data<T>(self, data: T) -> AddDataEndpoint<Self::Endpoint, T>
@@ -326,7 +330,9 @@ pub trait EndpointExt: IntoEndpoint {
/// # Example
///
/// ```
/// use poem::{handler, http::StatusCode, Endpoint, EndpointExt, Error, Request, Result};
/// use poem::{
/// handler, http::StatusCode, test::TestClient, Endpoint, EndpointExt, Error, Request, Result,
/// };
///
/// #[handler]
/// async fn index(data: String) -> String {

View File

@@ -128,6 +128,7 @@ mod tests {
use futures_util::future::Ready;
use super::*;
use crate::test::TestClient;
#[tokio::test]
async fn test_tower_compat() {
@@ -149,7 +150,8 @@ mod tests {
}
let ep = MyTowerService.compat();
let resp = ep.call(Request::builder().body("abc")).await.unwrap();
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = TestClient::new(ep).get("/").body("abc").send().await;
resp.assert_status_is_ok();
resp.assert_text("abc").await;
}
}

View File

@@ -828,7 +828,10 @@ pub enum SizedLimitError {
impl ResponseError for SizedLimitError {
fn status(&self) -> StatusCode {
StatusCode::BAD_REQUEST
match self {
SizedLimitError::MissingContentLength => StatusCode::BAD_REQUEST,
SizedLimitError::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
}
}
}

View File

@@ -22,6 +22,7 @@ type LanguageArray = SmallVec<[LanguageIdentifier; 8]>;
/// handler,
/// http::header,
/// i18n::{I18NResources, Locale},
/// test::TestClient,
/// Endpoint, EndpointExt, Request, Route,
/// };
///
@@ -39,22 +40,24 @@ type LanguageArray = SmallVec<[LanguageIdentifier; 8]>;
/// }
///
/// let app = Route::new().at("/", index).data(resources);
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let req = Request::builder()
/// let resp = cli
/// .get("/")
/// .header(header::ACCEPT_LANGUAGE, "en-US")
/// .finish();
/// let resp = app.get_response(req).await;
/// assert_eq!(
/// resp.into_body().into_string().await.unwrap(),
/// "hello world!"
/// );
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("hello world!").await;
///
/// let req = Request::builder()
/// let resp = cli
/// .get("/")
/// .header(header::ACCEPT_LANGUAGE, "zh-CN")
/// .finish();
/// let resp = app.get_response(req).await;
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "你好世界!");
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("你好世界!").await;
/// # });
/// ```
pub struct Locale {

View File

@@ -39,7 +39,10 @@
//! The [`handler`] macro is used to convert a function into an endpoint.
//!
//! ```
//! use poem::{error::NotFoundError, handler, Endpoint, Request, Result};
//! use poem::{
//! error::NotFoundError, handler, http::StatusCode, test::TestClient, Endpoint, Request,
//! Result,
//! };
//!
//! #[handler]
//! fn return_str() -> &'static str {
@@ -52,11 +55,12 @@
//! }
//!
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let resp = return_str.call(Request::default()).await.unwrap();
//! assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
//! let resp = TestClient::new(return_str).get("/").send().await;
//! resp.assert_status_is_ok();
//! resp.assert_text("hello").await;
//!
//! let err = return_err.call(Request::default()).await.unwrap_err();
//! assert!(err.is::<NotFoundError>());
//! let resp = TestClient::new(return_err).get("/").send().await;
//! resp.assert_status(StatusCode::NOT_FOUND);
//! # });
//! ```
//!

View File

@@ -50,7 +50,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler, EndpointExt};
use crate::{handler, test::TestClient, EndpointExt};
#[tokio::test]
async fn test_add_data() {
@@ -59,7 +59,7 @@ mod tests {
assert_eq!(req.extensions().get::<i32>(), Some(&100));
}
let app = index.with(AddData::new(100i32));
app.call(Request::default()).await.unwrap();
let cli = TestClient::new(index.with(AddData::new(100i32)));
cli.get("/").send().await.assert_status_is_ok();
}
}

View File

@@ -82,7 +82,7 @@ mod tests {
use tokio::io::AsyncReadExt;
use super::*;
use crate::{handler, EndpointExt, Request};
use crate::{handler, test::TestClient, EndpointExt};
const DATA: &str = "abcdefghijklmnopqrstuvwxyz1234567890";
const DATA_REV: &str = "0987654321zyxwvutsrqponmlkjihgfedcba";
@@ -94,25 +94,21 @@ mod tests {
async fn test_algo(algo: CompressionAlgo) {
let ep = index.with(Compression);
let mut resp = ep
.call(
Request::builder()
.header("Content-Encoding", algo.as_str())
.header("Accept-Encoding", algo.as_str())
.body(Body::from_async_read(algo.compress(DATA.as_bytes()))),
)
.await
.unwrap();
let cli = TestClient::new(ep);
assert_eq!(
resp.headers()
.get("Content-Encoding")
.and_then(|value| value.to_str().ok()),
Some(algo.as_str())
);
let resp = cli
.post("/")
.header("Content-Encoding", algo.as_str())
.header("Accept-Encoding", algo.as_str())
.body(Body::from_async_read(algo.compress(DATA.as_bytes())))
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header("Content-Encoding", algo.as_str());
let mut data = Vec::new();
let mut reader = algo.decompress(resp.take_body().into_async_read());
let mut reader = algo.decompress(resp.0.into_body().into_async_read());
reader.read_to_end(&mut data).await.unwrap();
assert_eq!(data, DATA_REV.as_bytes());
}
@@ -127,24 +123,19 @@ mod tests {
#[tokio::test]
async fn test_negotiate() {
let ep = index.with(Compression);
let mut resp = ep
.call(
Request::builder()
.header("Accept-Encoding", "identity; q=0.5, gzip;q=1.0, br;q=0.3")
.body(DATA),
)
.await
.unwrap();
let cli = TestClient::new(ep);
assert_eq!(
resp.headers()
.get("Content-Encoding")
.and_then(|value| value.to_str().ok()),
Some("gzip")
);
let resp = cli
.post("/")
.header("Accept-Encoding", "identity; q=0.5, gzip;q=1.0, br;q=0.3")
.body(DATA)
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header("Content-Encoding", "gzip");
let mut data = Vec::new();
let mut reader = CompressionAlgo::GZIP.decompress(resp.take_body().into_async_read());
let mut reader = CompressionAlgo::GZIP.decompress(resp.0.into_body().into_async_read());
reader.read_to_end(&mut data).await.unwrap();
assert_eq!(data, DATA_REV.as_bytes());
}
@@ -152,24 +143,19 @@ mod tests {
#[tokio::test]
async fn test_star() {
let ep = index.with(Compression);
let mut resp = ep
.call(
Request::builder()
.header("Accept-Encoding", "identity; q=0.5, *;q=1.0, br;q=0.3")
.body(DATA),
)
.await
.unwrap();
let cli = TestClient::new(ep);
assert_eq!(
resp.headers()
.get("Content-Encoding")
.and_then(|value| value.to_str().ok()),
Some("gzip")
);
let resp = cli
.post("/")
.header("Accept-Encoding", "identity; q=0.5, *;q=1.0, br;q=0.3")
.body(DATA)
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header("Content-Encoding", "gzip");
let mut data = Vec::new();
let mut reader = CompressionAlgo::GZIP.decompress(resp.take_body().into_async_read());
let mut reader = CompressionAlgo::GZIP.decompress(resp.0.into_body().into_async_read());
reader.read_to_end(&mut data).await.unwrap();
assert_eq!(data, DATA_REV.as_bytes());
}

View File

@@ -70,7 +70,7 @@ impl<E: Endpoint> Endpoint for CookieJarManagerEndpoint<E> {
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler, http::StatusCode, web::cookie::Cookie, EndpointExt};
use crate::{handler, test::TestClient, web::cookie::Cookie, EndpointExt};
#[tokio::test]
async fn test_cookie_jar_manager() {
@@ -80,11 +80,12 @@ mod tests {
}
let ep = index.with(CookieJarManager::new());
let resp = ep
.call(Request::builder().header("Cookie", "value=88").finish())
let cli = TestClient::new(ep);
cli.get("/")
.header("Cookie", "value=88")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[tokio::test]
@@ -99,8 +100,8 @@ mod tests {
}
let key = CookieKey::generate();
let cli = TestClient::new(index.with(CookieJarManager::with_key(key.clone())));
let ep = index.with(CookieJarManager::with_key(key.clone()));
let cookie_jar = CookieJar::default();
cookie_jar
.private_with_key(&key)
@@ -108,21 +109,18 @@ mod tests {
cookie_jar
.signed_with_key(&key)
.add(Cookie::new_with_str("value2", "99"));
let resp = ep
.call(
Request::builder()
.header(
"Cookie",
&format!(
"value1={}; value2={}",
cookie_jar.get("value1").unwrap().value_str(),
cookie_jar.get("value2").unwrap().value_str()
),
)
.finish(),
cli.get("/")
.header(
"cookie",
format!(
"value1={}; value2={}",
cookie_jar.get("value1").unwrap().value_str(),
cookie_jar.get("value2").unwrap().value_str()
),
)
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
}

View File

@@ -400,7 +400,11 @@ mod tests {
use http::StatusCode;
use super::*;
use crate::{endpoint::make_sync, EndpointExt, Error};
use crate::{
endpoint::make_sync,
test::{TestClient, TestRequestBuilder},
EndpointExt, Error,
};
const ALLOW_ORIGIN: &str = "https://example.com";
const ALLOW_HEADER: &str = "X-Token";
@@ -415,312 +419,185 @@ mod tests {
.allow_credentials(true)
}
fn opt_request() -> Request {
Request::builder()
.method(Method::OPTIONS)
fn opt_request<T: Endpoint>(cli: &TestClient<T>) -> TestRequestBuilder<'_, T> {
cli.options("/")
.header(header::ORIGIN, ALLOW_ORIGIN)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.finish()
}
fn get_request() -> Request {
Request::builder()
.method(Method::GET)
.header(header::ORIGIN, ALLOW_ORIGIN)
.finish()
fn get_request<T: Endpoint>(cli: &TestClient<T>) -> TestRequestBuilder<'_, T> {
cli.get("/").header(header::ORIGIN, ALLOW_ORIGIN)
}
#[tokio::test]
async fn preflight_request() {
let ep = make_sync(|_| "hello").with(cors());
let resp = ep.map_to_response().call(opt_request()).await.unwrap();
let cli = TestClient::new(ep);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap(),
ALLOW_ORIGIN
);
let allow_methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.and_then(|value| value.to_str().ok())
.map(|value| value.split(',').map(|s| s.trim()).collect::<HashSet<_>>());
assert_eq!(
allow_methods,
Some(
vec!["DELETE", "GET", "OPTIONS", "POST"]
.into_iter()
.collect::<HashSet<_>>()
),
);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap(),
"x-token"
);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
.unwrap(),
"x-my-custom-header"
);
assert_eq!(
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap(),
"86400"
);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS)
.unwrap(),
"true"
let resp = opt_request(&cli).send().await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW_ORIGIN);
resp.assert_header_csv(
header::ACCESS_CONTROL_ALLOW_METHODS,
["DELETE", "GET", "OPTIONS", "POST"],
);
resp.assert_header(header::ACCESS_CONTROL_ALLOW_HEADERS, "x-token");
resp.assert_header(header::ACCESS_CONTROL_EXPOSE_HEADERS, "x-my-custom-header");
resp.assert_header(header::ACCESS_CONTROL_MAX_AGE, "86400");
resp.assert_header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
#[tokio::test]
async fn default_cors() {
let ep = make_sync(|_| "hello").with(Cors::new()).map_to_response();
let resp = ep
.call(
Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, ALLOW_ORIGIN)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.finish(),
)
.await
.unwrap();
let ep = make_sync(|_| "hello").with(Cors::new());
let cli = TestClient::new(ep);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap(),
ALLOW_ORIGIN
);
let allow_methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.and_then(|value| value.to_str().ok())
.map(|value| value.split(',').map(|s| s.trim()).collect::<HashSet<_>>());
assert_eq!(
allow_methods,
Some(
vec![
"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "CONNECT", "PATCH", "TRACE"
]
.into_iter()
.collect::<HashSet<_>>()
),
);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap(),
"X-Token"
);
assert_eq!(
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap(),
"86400"
);
let resp = cli
.options("/")
.header(header::ORIGIN, ALLOW_ORIGIN)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.send()
.await;
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.header(header::ORIGIN, ALLOW_ORIGIN)
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,),
Some(&HeaderValue::from_static(ALLOW_ORIGIN))
);
assert_eq!(
resp.headers().get(header::VARY),
Some(&HeaderValue::from_static("Origin"))
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW_ORIGIN);
resp.assert_header_csv(
header::ACCESS_CONTROL_ALLOW_METHODS,
[
"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "CONNECT", "PATCH", "TRACE",
],
);
resp.assert_header(header::ACCESS_CONTROL_ALLOW_HEADERS, "X-Token");
resp.assert_header(header::ACCESS_CONTROL_MAX_AGE, "86400");
let resp = cli
.get("/")
.header(header::ORIGIN, ALLOW_ORIGIN)
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW_ORIGIN);
resp.assert_header(header::VARY, "Origin");
}
#[tokio::test]
async fn allow_origins_fn_1() {
let ep = make_sync(|_| "hello")
.with(Cors::new().allow_origins_fn(|_| true))
.map_to_response();
let ep = make_sync(|_| "hello").with(Cors::new().allow_origins_fn(|_| true));
let cli = TestClient::new(ep);
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.header(header::ORIGIN, ALLOW_ORIGIN)
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,),
Some(&HeaderValue::from_static(ALLOW_ORIGIN))
);
assert_eq!(
resp.headers().get(header::VARY),
Some(&HeaderValue::from_static("Origin"))
);
let resp = cli
.get("/")
.header(header::ORIGIN, ALLOW_ORIGIN)
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW_ORIGIN);
resp.assert_header(header::VARY, "Origin");
}
#[tokio::test]
async fn allow_origins_fn_2() {
let ep = make_sync(|_| "hello")
.with(
Cors::new()
.allow_origin(ALLOW_ORIGIN)
.allow_origins_fn(|_| true),
)
.map_to_response();
let ep = make_sync(|_| "hello").with(
Cors::new()
.allow_origin(ALLOW_ORIGIN)
.allow_origins_fn(|_| true),
);
let cli = TestClient::new(ep);
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.header(header::ORIGIN, ALLOW_ORIGIN)
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,),
Some(&HeaderValue::from_static(ALLOW_ORIGIN))
);
assert!(resp.headers().get(header::VARY).is_none());
let resp = cli
.get("/")
.header(header::ORIGIN, ALLOW_ORIGIN)
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW_ORIGIN);
resp.assert_header_is_not_exist(header::VARY);
let resp = ep
.call(
Request::builder()
.method(Method::GET)
.header(header::ORIGIN, "https://abc.com")
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,),
Some(&HeaderValue::from_static("https://abc.com"))
);
assert_eq!(
resp.headers().get(header::VARY),
Some(&HeaderValue::from_static("Origin"))
);
let resp = cli
.get("/")
.header(header::ORIGIN, "https://abc.com")
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://abc.com");
resp.assert_header(header::VARY, "Origin");
}
#[tokio::test]
async fn allow_origins_fn_3() {
let ep = make_sync(|_| "hello")
.with(Cors::new().allow_origins_fn(|_| false))
.map_to_response();
let ep = make_sync(|_| "hello").with(Cors::new().allow_origins_fn(|_| false));
let cli = TestClient::new(ep);
assert!(ep
.call(
Request::builder()
.method(Method::GET)
.header(header::ORIGIN, ALLOW_ORIGIN)
.finish(),
)
.await
.unwrap_err()
.is::<CorsError>());
let resp = cli
.get("/")
.header(header::ORIGIN, ALLOW_ORIGIN)
.send()
.await;
resp.assert_status(StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn default_cors_middleware() {
let ep = make_sync(|_| "hello").with(Cors::new()).map_to_response();
let resp = ep.call(get_request()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap(),
"https://example.com"
);
let ep = make_sync(|_| "hello").with(Cors::new());
let cli = TestClient::new(ep);
let resp = get_request(&cli).send().await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com");
}
#[tokio::test]
async fn unauthorized_origin() {
let ep = make_sync(|_| "hello").with(cors()).map_to_response();
assert!(ep
.call(
Request::builder()
.method(Method::GET)
.header(header::ORIGIN, "https://foo.com")
.finish(),
)
.await
.unwrap_err()
.is::<CorsError>());
let ep = make_sync(|_| "hello").with(cors());
let cli = TestClient::new(ep);
let resp = cli
.get("/")
.header(header::ORIGIN, "https://foo.com")
.send()
.await;
resp.assert_status(StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn unauthorized_options() {
let ep = make_sync(|_| "hello").with(cors()).map_to_response();
let ep = make_sync(|_| "hello").with(cors());
let cli = TestClient::new(ep);
assert!(ep
.call(
Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, "https://abc.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.finish(),
)
cli.options("/")
.header(header::ORIGIN, "https://abc.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.send()
.await
.unwrap_err()
.is::<CorsError>());
.assert_status(StatusCode::UNAUTHORIZED);
assert!(ep
.call(
Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "TRACE")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.finish(),
)
cli.options("/")
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "TRACE")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.send()
.await
.unwrap_err()
.is::<CorsError>());
.assert_status(StatusCode::UNAUTHORIZED);
assert!(ep
.call(
Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-ABC")
.finish(),
)
cli.options("/")
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Abc")
.send()
.await
.unwrap_err()
.is::<CorsError>());
.assert_status(StatusCode::UNAUTHORIZED);
let resp = ep
.call(
Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.finish(),
)
cli.options("/")
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token")
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
.assert_status_is_ok();
}
#[cfg(feature = "cookie")]
@@ -731,51 +608,43 @@ mod tests {
middleware::CookieJarManager,
web::cookie::{Cookie, CookieJar},
};
#[handler(internal)]
async fn index(cookie_jar: &CookieJar) {
cookie_jar.add(Cookie::new_with_str("foo", "bar"));
}
let ep = index.with(CookieJarManager::new()).with(cors());
let resp = ep.map_to_response().call(get_request()).await.unwrap();
let cli = TestClient::new(ep);
assert_eq!(resp.headers().get(header::SET_COOKIE).unwrap(), "foo=bar");
let resp = get_request(&cli).send().await;
resp.assert_status_is_ok();
resp.assert_header(header::SET_COOKIE, "foo=bar");
}
#[tokio::test]
async fn set_cors_headers_to_error_responses() {
let ep =
make_sync(|_| Err::<(), _>(Error::from_status(StatusCode::BAD_REQUEST))).with(cors());
let resp = ep.map_to_response().get_response(get_request()).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap(),
ALLOW_ORIGIN
);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
.and_then(|value| value.to_str().ok())
.unwrap(),
EXPOSE_HEADER.to_lowercase()
let cli = TestClient::new(ep);
let resp = get_request(&cli).send().await;
resp.assert_status(StatusCode::BAD_REQUEST);
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW_ORIGIN);
resp.assert_header(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
EXPOSE_HEADER.to_lowercase(),
);
}
#[tokio::test]
async fn no_cors_requests() {
let ep = make_sync(|_| "hello").with(Cors::new().allow_origin(ALLOW_ORIGIN));
let resp = ep
.map_to_response()
.call(Request::builder().method(Method::GET).finish())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none());
let cli = TestClient::new(ep);
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_header_is_not_exist(header::ACCESS_CONTROL_ALLOW_ORIGIN);
}
#[tokio::test]
@@ -785,24 +654,16 @@ mod tests {
.allow_origin(ALLOW_ORIGIN)
.allow_method(Method::GET),
);
let resp = ep
.map_to_response()
.call(
Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, ALLOW_ORIGIN)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type")
.finish(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap(),
"content-type"
)
let cli = TestClient::new(ep);
let resp = cli
.options("/")
.header(header::ORIGIN, ALLOW_ORIGIN)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type")
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_HEADERS, "content-type");
}
}

View File

@@ -24,6 +24,7 @@ use crate::{
/// http::{header, Method, StatusCode},
/// middleware::Csrf,
/// post,
/// test::TestClient,
/// web::{cookie::Cookie, CsrfToken, CsrfVerifier},
/// Endpoint, EndpointExt, Error, Request, Result, Route,
/// };
@@ -51,31 +52,26 @@ use crate::{
/// let app = Route::new()
/// .at("/", get(login_ui).post(login))
/// .with(Csrf::new());
/// let cli = TestClient::new(app);
///
/// let resp = app.call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// let cookie = resp.headers().get(header::SET_COOKIE).unwrap();
/// let resp = cli.get("/").send().await;
/// resp.assert_status_is_ok();
///
/// let cookie = resp.0.headers().get(header::SET_COOKIE).unwrap();
/// let cookie = Cookie::parse(cookie.to_str().unwrap()).unwrap();
/// let csrf_token = resp.into_body().into_string().await.unwrap();
/// let csrf_token = resp.0.into_body().into_string().await.unwrap();
///
/// let resp = app
/// .call(
/// Request::builder()
/// .method(Method::POST)
/// .header("X-CSRF-Token", csrf_token)
/// .header(
/// header::COOKIE,
/// format!("{}={}", cookie.name(), cookie.value_str()),
/// )
/// .finish(),
/// let resp = cli
/// .post("/")
/// .header("X-CSRF-Token", csrf_token)
/// .header(
/// header::COOKIE,
/// format!("{}={}", cookie.name(), cookie.value_str()),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(
/// resp.into_body().into_string().await.unwrap(),
/// "login success"
/// );
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("login success").await;
/// # });
/// ```
#[cfg_attr(docsrs, doc(cfg(feature = "csrf")))]

View File

@@ -52,7 +52,9 @@ use crate::endpoint::Endpoint;
/// # Create you own middleware
///
/// ```
/// use poem::{handler, web::Data, Endpoint, EndpointExt, Middleware, Request, Result};
/// use poem::{
/// handler, test::TestClient, web::Data, Endpoint, EndpointExt, Middleware, Request, Result,
/// };
///
/// /// A middleware that extract token from HTTP headers.
/// struct TokenMiddleware;
@@ -104,11 +106,13 @@ use crate::endpoint::Endpoint;
/// let ep = index.with(TokenMiddleware);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let mut resp = ep
/// .call(Request::builder().header(TOKEN_HEADER, "abc").finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.take_body().into_string().await.unwrap(), "abc");
/// let mut resp = TestClient::new(ep)
/// .get("/")
/// .header(TOKEN_HEADER, "abc")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("abc").await;
/// # });
/// ```
///
@@ -117,7 +121,9 @@ use crate::endpoint::Endpoint;
/// ```rust
/// use std::sync::Arc;
///
/// use poem::{handler, web::Data, Endpoint, EndpointExt, IntoResponse, Request, Result};
/// use poem::{
/// handler, test::TestClient, web::Data, Endpoint, EndpointExt, IntoResponse, Request, Result,
/// };
/// const TOKEN_HEADER: &str = "X-Token";
///
/// #[handler]
@@ -144,13 +150,12 @@ use crate::endpoint::Endpoint;
/// }
///
/// let ep = index.around(token_middleware);
/// let cli = TestClient::new(ep);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let mut resp = ep
/// .call(Request::builder().header(TOKEN_HEADER, "abc").finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.take_body().into_string().await.unwrap(), "abc");
/// let resp = cli.get("/").header(TOKEN_HEADER, "abc").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("abc").await;
/// # });
/// ```
pub trait Middleware<E: Endpoint> {
@@ -193,7 +198,8 @@ mod tests {
use super::*;
use crate::{
handler,
http::{header::HeaderName, HeaderValue, StatusCode},
http::{header::HeaderName, HeaderValue},
test::TestClient,
web::Data,
EndpointExt, IntoResponse, Request, Response, Result,
};
@@ -228,14 +234,11 @@ mod tests {
header: HeaderName::from_static("hello"),
value: HeaderValue::from_static("world"),
}));
let mut resp = ep.call(Request::default()).await.unwrap();
assert_eq!(
resp.headers()
.get(HeaderName::from_static("hello"))
.cloned(),
Some(HeaderValue::from_static("world"))
);
assert_eq!(resp.take_body().into_string().await.unwrap(), "abc");
let cli = TestClient::new(ep);
let resp = cli.get("/").send().await;
resp.assert_header("hello", "world");
resp.assert_text("abc").await;
}
#[tokio::test]
@@ -250,17 +253,12 @@ mod tests {
SetHeader::new().appending("myheader-1", "a"),
SetHeader::new().appending("myheader-2", "b"),
));
let cli = TestClient::new(ep);
let mut resp = ep.call(Request::default()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get("myheader-1"),
Some(&HeaderValue::from_static("a"))
);
assert_eq!(
resp.headers().get("myheader-2"),
Some(&HeaderValue::from_static("b"))
);
assert_eq!(resp.take_body().into_string().await.unwrap(), "10");
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_header("myheader-1", "a");
resp.assert_header("myheader-2", "b");
resp.assert_text("10").await;
}
}

View File

@@ -34,6 +34,7 @@ impl Default for TrailingSlash {
/// get, handler,
/// http::{StatusCode, Uri},
/// middleware::{NormalizePath, TrailingSlash},
/// test::TestClient,
/// Endpoint, EndpointExt, Request, Route,
/// };
///
@@ -45,18 +46,12 @@ impl Default for TrailingSlash {
/// let app = Route::new()
/// .at("/foo/bar", get(index))
/// .with(NormalizePath::new(TrailingSlash::Trim));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/foo/bar/"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
/// let resp = cli.get("/foo/bar/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("hello").await;
/// # });
/// ```
pub struct NormalizePath(TrailingSlash);
@@ -133,7 +128,7 @@ impl<E: Endpoint> Endpoint for NormalizePathEndpoint<E> {
#[cfg(test)]
mod tests {
use super::*;
use crate::{endpoint::make_sync, error::NotFoundError, http::StatusCode, EndpointExt, Route};
use crate::{endpoint::make_sync, http::StatusCode, test::TestClient, EndpointExt, Route};
#[tokio::test]
async fn trim_trailing_slashes() {
@@ -150,6 +145,7 @@ mod tests {
}),
)
.with(NormalizePath::new(TrailingSlash::Trim));
let cli = TestClient::new(ep);
let test_uris = [
"/",
@@ -167,9 +163,8 @@ mod tests {
];
for uri in test_uris {
let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish();
let res = ep.call(req).await.unwrap();
assert!(res.status().is_success(), "Failed uri: {}", uri);
let resp = cli.get(uri).send().await;
assert!(resp.0.status().is_success(), "Failed uri: {}", uri);
}
}
@@ -186,13 +181,12 @@ mod tests {
}),
)
.with(NormalizePath::new(TrailingSlash::Trim));
let cli = TestClient::new(ep);
let test_uris = ["/?query=test", "//?query=test", "///?query=test"];
for uri in test_uris {
let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish();
let res = ep.call(req).await.unwrap();
assert!(res.status().is_success(), "Failed uri: {}", uri);
let resp = cli.get(uri).send().await;
assert!(resp.0.status().is_success(), "Failed uri: {}", uri);
}
}
@@ -211,6 +205,7 @@ mod tests {
}),
)
.with(NormalizePath::new(TrailingSlash::Always));
let cli = TestClient::new(ep);
let test_uris = [
"/",
@@ -228,9 +223,8 @@ mod tests {
];
for uri in test_uris {
let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish();
let res = ep.call(req).await.unwrap();
assert!(res.status().is_success(), "Failed uri: {}", uri);
let resp = cli.get(uri).send().await;
assert!(resp.0.status().is_success(), "Failed uri: {}", uri);
}
}
@@ -247,13 +241,13 @@ mod tests {
}),
)
.with(NormalizePath::new(TrailingSlash::Always));
let cli = TestClient::new(ep);
let test_uris = ["/?query=test", "//?query=test", "///?query=test"];
for uri in test_uris {
let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish();
let res = ep.call(req).await.unwrap();
assert!(res.status().is_success(), "Failed uri: {}", uri);
let resp = cli.get(uri).send().await;
assert!(resp.0.status().is_success(), "Failed uri: {}", uri);
}
}
@@ -273,6 +267,7 @@ mod tests {
}),
)
.with(NormalizePath::new(TrailingSlash::MergeOnly));
let cli = TestClient::new(ep);
let test_uris = [
("/", true), // root paths should still work
@@ -293,14 +288,14 @@ mod tests {
];
for (uri, success) in test_uris {
let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish();
let res = ep.call(req).await;
let resp = cli.get(uri).send().await;
if success {
assert_eq!(res.unwrap().status(), StatusCode::OK, "Failed uri: {}", uri);
assert_eq!(resp.0.status(), StatusCode::OK, "Failed uri: {}", uri);
} else {
assert!(
res.unwrap_err().is::<NotFoundError>(),
assert_eq!(
resp.0.status(),
StatusCode::NOT_FOUND,
"Failed uri: {}",
uri
);

View File

@@ -69,24 +69,16 @@ impl<E: Endpoint> Endpoint for PropagateHeaderEndpoint<E> {
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler, EndpointExt};
use crate::{handler, test::TestClient, EndpointExt};
#[tokio::test]
async fn test_propagate_header() {
#[handler(internal)]
fn index() {}
let resp = index
.with(PropagateHeader::new().header("x-request-id"))
.call(Request::builder().header("x-request-id", "100").finish())
.await
.unwrap();
assert_eq!(
resp.headers()
.get("x-request-id")
.and_then(|value| value.to_str().ok()),
Some("100")
);
let cli = TestClient::new(index.with(PropagateHeader::new().header("x-request-id")));
let resp = cli.get("/").header("x-request-id", "100").send().await;
resp.assert_status_is_ok();
resp.assert_header("x-request-id", "100");
}
}

View File

@@ -127,7 +127,11 @@ fn set_sensitive(headers: &mut HeaderMap, names: &HashSet<HeaderName>) {
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler, EndpointExt};
use crate::{
handler,
test::{TestClient, TestRequestBuilder},
EndpointExt,
};
fn create_middleware() -> SensitiveHeader {
SensitiveHeader::new()
@@ -137,11 +141,10 @@ mod tests {
.header("x-api-key4")
}
fn create_request() -> Request {
Request::builder()
fn create_request<T: Endpoint>(cli: &TestClient<T>) -> TestRequestBuilder<'_, T> {
cli.get("/")
.header("x-api-key1", "a")
.header("x-api-key2", "b")
.finish()
}
#[tokio::test]
@@ -155,13 +158,11 @@ mod tests {
.with_header("x-api-key4", "c")
}
let resp = index
.with(create_middleware().request_only())
.get_response(create_request())
.await;
let cli = TestClient::new(index.with(create_middleware().request_only()));
assert!(!resp.headers().get("x-api-key3").unwrap().is_sensitive());
assert!(!resp.headers().get("x-api-key4").unwrap().is_sensitive());
let resp = create_request(&cli).send().await;
assert!(!resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
assert!(!resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
}
#[tokio::test]
@@ -175,13 +176,11 @@ mod tests {
.with_header("x-api-key4", "c")
}
let resp = index
.with(create_middleware().response_only())
.get_response(create_request())
.await;
let cli = TestClient::new(index.with(create_middleware().response_only()));
assert!(resp.headers().get("x-api-key3").unwrap().is_sensitive());
assert!(resp.headers().get("x-api-key4").unwrap().is_sensitive());
let resp = create_request(&cli).send().await;
assert!(resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
assert!(resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
}
#[tokio::test]
@@ -195,12 +194,10 @@ mod tests {
.with_header("x-api-key4", "c")
}
let resp = index
.with(create_middleware())
.get_response(create_request())
.await;
let cli = TestClient::new(index.with(create_middleware()));
let resp = create_request(&cli).send().await;
assert!(resp.headers().get("x-api-key3").unwrap().is_sensitive());
assert!(resp.headers().get("x-api-key4").unwrap().is_sensitive());
assert!(resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
assert!(resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
}
}

View File

@@ -20,6 +20,7 @@ enum Action {
/// get, handler,
/// http::{HeaderValue, StatusCode},
/// middleware::SetHeader,
/// test::TestClient,
/// Endpoint, EndpointExt, Request, Route,
/// };
///
@@ -37,22 +38,10 @@ enum Action {
/// );
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app.call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(
/// resp.headers()
/// .get_all("MyHeader1")
/// .iter()
/// .collect::<Vec<_>>(),
/// vec![HeaderValue::from_static("a"), HeaderValue::from_static("b")]
/// );
/// assert_eq!(
/// resp.headers()
/// .get_all("MyHeader2")
/// .iter()
/// .collect::<Vec<_>>(),
/// vec![HeaderValue::from_static("b")]
/// );
/// let resp = TestClient::new(app).get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_header_all("MyHeader1", ["a", "b"]);
/// resp.assert_header_all("MyHeader2", ["b"]);
/// # });
/// ```
#[derive(Default)]
@@ -146,41 +135,27 @@ impl<E: Endpoint> Endpoint for SetHeaderEndpoint<E> {
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler, EndpointExt};
use crate::{handler, test::TestClient, EndpointExt};
#[tokio::test]
async fn test_set_header() {
#[handler(internal)]
fn index() {}
let resp = index
.with(
let cli = TestClient::new(
index.with(
SetHeader::new()
.overriding("custom-a", "a")
.overriding("custom-a", "b")
.appending("custom-b", "a")
.appending("custom-b", "b"),
)
.call(Request::default())
.await
.unwrap();
assert_eq!(
resp.headers()
.get_all("custom-a")
.into_iter()
.filter_map(|value| value.to_str().ok())
.collect::<Vec<_>>(),
vec!["b"]
),
);
assert_eq!(
resp.headers()
.get_all("custom-b")
.into_iter()
.filter_map(|value| value.to_str().ok())
.collect::<Vec<_>>(),
vec!["a", "b"]
);
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_header_all("custom-a", ["b"]);
resp.assert_header_all("custom-b", ["a", "b"]);
}
}

View File

@@ -63,57 +63,38 @@ mod tests {
use super::*;
use crate::{
endpoint::{make_sync, EndpointExt},
IntoResponse,
test::TestClient,
};
#[tokio::test]
async fn size_limit() {
let ep = make_sync(|_| ()).with(SizeLimit::new(5));
let cli = TestClient::new(ep);
assert_eq!(
ep.call(Request::builder().body(&b"123456"[..]))
.await
.unwrap_err()
.downcast_ref::<SizedLimitError>(),
Some(&SizedLimitError::MissingContentLength)
);
assert_eq!(
ep.call(
Request::builder()
.header("content-length", 6)
.body(&b"123456"[..])
)
cli.post("/")
.send()
.await
.unwrap_err()
.downcast_ref::<SizedLimitError>(),
Some(&SizedLimitError::PayloadTooLarge)
);
.assert_status(StatusCode::BAD_REQUEST);
assert_eq!(
ep.call(
Request::builder()
.header("content-length", 4)
.body(&b"1234"[..])
)
cli.post("/")
.header("content-length", 6)
.body(&b"123456"[..])
.send()
.await
.unwrap()
.into_response()
.status(),
StatusCode::OK
);
.assert_status(StatusCode::PAYLOAD_TOO_LARGE);
assert_eq!(
ep.call(
Request::builder()
.header("content-length", 5)
.body(&b"12345"[..])
)
cli.post("/")
.header("content-length", 4)
.body(&b"1234"[..])
.send()
.await
.unwrap()
.into_response()
.status(),
StatusCode::OK
);
.assert_status_is_ok();
cli.post("/")
.header("content-length", 5)
.body(&b"12345"[..])
.send()
.await
.assert_status_is_ok();
}
}

View File

@@ -101,10 +101,9 @@ where
#[cfg(test)]
mod tests {
use http::StatusCode;
use super::*;
use crate::{endpoint::make_sync, EndpointExt};
use crate::{endpoint::make_sync, test::TestClient, EndpointExt};
#[tokio::test]
async fn test_tower_layer() {
@@ -140,7 +139,7 @@ mod tests {
}
let ep = make_sync(|_| ()).with(MyServiceLayer.compat());
let resp = ep.call(Request::default()).await.unwrap().into_response();
assert_eq!(resp.status(), StatusCode::OK);
let cli = TestClient::new(ep);
cli.get("/").send().await.assert_status_is_ok();
}
}

View File

@@ -273,6 +273,7 @@ impl Request {
/// use poem::{
/// handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// Endpoint, Request, Result, Route,
/// };
///
@@ -283,18 +284,12 @@ impl Request {
/// }
///
/// let app = Route::new().at("/:a/:b", index);
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/100/abc"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "100:abc");
/// let resp = cli.get("/100/abc").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("100:abc").await;
/// # });
/// ```
pub fn path_params<T: DeserializeOwned>(&self) -> Result<T, ParsePathError> {
@@ -312,6 +307,7 @@ impl Request {
/// use poem::{
/// handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// Endpoint, Request, Result, Route,
/// };
/// use serde::Deserialize;
@@ -329,18 +325,17 @@ impl Request {
/// }
///
/// let app = Route::new().at("/", index);
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/?a=100&b=abc"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "100:abc");
/// let resp = cli
/// .get("/")
/// .query("a", &100)
/// .query("b", &"abc")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("100:abc").await;
/// # });
/// ```
pub fn params<T: DeserializeOwned>(&self) -> Result<T, ParseQueryError> {

View File

@@ -25,6 +25,7 @@ use crate::{
/// use poem::{
/// get, handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// web::Path,
/// Endpoint, Request, Route,
/// };
@@ -56,44 +57,22 @@ use crate::{
/// .at("/e/:name<\\d+>", get(a));
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let cli = TestClient::new(app);
///
/// // /a/b
/// let resp = app
/// .call(Request::builder().uri(Uri::from_static("/a/b")).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// cli.get("/a/b").send().await.assert_status_is_ok();
///
/// // /b/:group/:name
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/b/foo/bar"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// cli.get("/b/foo/bar").send().await.assert_status_is_ok();
///
/// // /c/*path
/// let resp = app
/// .call(Request::builder().uri(Uri::from_static("/c/d/e")).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// cli.get("/c/d/e").send().await.assert_status_is_ok();
///
/// // /d/<\\d>
/// let resp = app
/// .call(Request::builder().uri(Uri::from_static("/d/123")).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// cli.get("/d/123").send().await.assert_status_is_ok();
///
/// // /e/:name<\\d>
/// let resp = app
/// .call(Request::builder().uri(Uri::from_static("/e/123")).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// cli.get("/e/123").send().await.assert_status_is_ok();
/// # });
/// ```
///
@@ -103,6 +82,7 @@ use crate::{
/// use poem::{
/// handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// Endpoint, Request, Route,
/// };
///
@@ -112,18 +92,12 @@ use crate::{
/// }
///
/// let app = Route::new().nest("/foo", Route::new().at("/bar", index));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/foo/bar"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
/// let resp = cli.get("/foo/bar").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("hello").await;
/// # });
/// ```
///
@@ -133,6 +107,7 @@ use crate::{
/// use poem::{
/// handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// Endpoint, Request, Route,
/// };
///
@@ -142,18 +117,12 @@ use crate::{
/// }
///
/// let app = Route::new().nest_no_strip("/foo", Route::new().at("/foo/bar", index));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/foo/bar"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
/// let resp = cli.get("/foo/bar").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("hello").await;
/// # });
/// ```
#[derive(Default)]

View File

@@ -15,7 +15,13 @@ use crate::{
/// # Example
///
/// ```
/// use poem::{endpoint::make_sync, handler, http::header, Endpoint, Request, RouteDomain};
/// use poem::{
/// endpoint::make_sync,
/// handler,
/// http::header,
/// test::{TestClient, TestRequestBuilder},
/// Endpoint, Request, RouteDomain,
/// };
///
/// let app = RouteDomain::new()
/// .at("example.com", make_sync(|_| "1"))
@@ -23,26 +29,23 @@ use crate::{
/// .at("*.example.com", make_sync(|_| "3"))
/// .at("*", make_sync(|_| "4"));
///
/// fn make_request(host: &str) -> Request {
/// Request::builder().header(header::HOST, host).finish()
/// }
///
/// async fn do_request(app: &RouteDomain, req: Request) -> String {
/// app.call(req)
/// .await
/// .unwrap()
/// .into_body()
/// .into_string()
/// .await
/// .unwrap()
/// async fn check(app: impl Endpoint, domain: Option<&str>, res: &str) {
/// let cli = TestClient::new(app);
/// let mut req = cli.get("/");
/// if let Some(domain) = domain {
/// req = req.header(header::HOST, domain);
/// }
/// let resp = req.send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text(res).await;
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// assert_eq!(do_request(&app, make_request("example.com")).await, "1");
/// assert_eq!(do_request(&app, make_request("www.abc.com")).await, "2");
/// assert_eq!(do_request(&app, make_request("a.b.example.com")).await, "3");
/// assert_eq!(do_request(&app, make_request("rust-lang.org")).await, "4");
/// assert_eq!(do_request(&app, Request::default()).await, "4");
/// check(&app, Some("example.com"), "1").await;
/// check(&app, Some("www.abc.com"), "2").await;
/// check(&app, Some("a.b.example.com"), "3").await;
/// check(&app, Some("rust-lang.org"), "4").await;
/// check(&app, None, "4").await;
/// # });
/// ```
#[derive(Default)]
@@ -103,24 +106,20 @@ impl Endpoint for RouteDomain {
#[cfg(test)]
mod tests {
use http::StatusCode;
use super::*;
use crate::{endpoint::make_sync, handler, http::HeaderMap};
use crate::{endpoint::make_sync, handler, http::HeaderMap, test::TestClient};
async fn check(r: &RouteDomain, host: &str, value: &str) {
let mut req = Request::builder();
let cli = TestClient::new(r);
let mut req = cli.get("/");
if !host.is_empty() {
req = req.header(header::HOST, host);
}
assert_eq!(
r.call(req.finish())
.await
.unwrap()
.into_body()
.into_string()
.await
.unwrap(),
value
);
let resp = req.send().await;
resp.assert_status_is_ok();
resp.assert_text(value).await;
}
#[tokio::test]
@@ -156,22 +155,18 @@ mod tests {
.at("www.example.com", make_sync(|_| "2"))
.at("www.+.com", make_sync(|_| "3"))
.at("*.com", make_sync(|_| "4"));
let cli = TestClient::new(r);
assert!(r
.call(
Request::builder()
.header(header::HOST, "rust-lang.org")
.finish()
)
cli.get("/")
.header(header::HOST, "rust-lang.org")
.send()
.await
.unwrap_err()
.is::<NotFoundError>());
.assert_status(StatusCode::NOT_FOUND);
assert!(r
.call(Request::default())
cli.get("/")
.send()
.await
.unwrap_err()
.is::<NotFoundError>());
.assert_status(StatusCode::NOT_FOUND);
}
#[handler(internal)]

View File

@@ -275,13 +275,13 @@ mod tests {
use crate::{
handler,
http::{Method, StatusCode},
Request,
test::TestClient,
};
#[tokio::test]
async fn method_not_allowed() {
let resp = RouteMethod::new().get_response(Request::default()).await;
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let resp = TestClient::new(RouteMethod::new()).get("/").send().await;
resp.assert_status(StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
@@ -303,22 +303,21 @@ mod tests {
Method::TRACE,
] {
let route = RouteMethod::new().method(method.clone(), index).post(index);
let resp = route
.get_response(Request::builder().method(method.clone()).finish())
let resp = TestClient::new(route)
.request(method.clone(), "/")
.send()
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
resp.assert_status_is_ok();
resp.assert_text("hello").await;
}
macro_rules! test_method {
($(($id:ident, $method:ident)),*) => {
$(
let route = RouteMethod::new().$id(index).post(index);
let resp = route
.get_response(Request::builder().method(Method::$method).finish())
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
let resp = TestClient::new(route).request(Method::$method, "/").send().await;
resp.assert_status_is_ok();
resp.assert_text("hello").await;
)*
};
}
@@ -344,10 +343,8 @@ mod tests {
}
let route = RouteMethod::new().get(index);
let resp = route
.get_response(Request::builder().method(Method::HEAD).finish())
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.into_body().into_vec().await.unwrap().is_empty());
let resp = TestClient::new(route).head("/").send().await;
resp.assert_status_is_ok();
resp.assert_text("").await;
}
}

View File

@@ -1,6 +1,6 @@
use http::{header, header::HeaderName, HeaderMap, HeaderValue, Method};
use crate::{test::TestRequestBuilder, Endpoint};
use crate::{test::TestRequestBuilder, Endpoint, IntoEndpoint};
macro_rules! impl_methods {
($($(#[$docs:meta])* ($name:ident, $method:ident)),*) => {
@@ -21,9 +21,12 @@ pub struct TestClient<E> {
impl<E: Endpoint> TestClient<E> {
/// Create a new client for the specified endpoint.
pub fn new(ep: E) -> Self {
Self {
ep,
pub fn new<T>(ep: T) -> TestClient<T::Endpoint>
where
T: IntoEndpoint<Endpoint = E>,
{
TestClient {
ep: ep.into_endpoint(),
default_headers: Default::default(),
}
}
@@ -74,6 +77,11 @@ impl<E: Endpoint> TestClient<E> {
self.default_header(header::CONTENT_TYPE, content_type.as_ref())
}
/// Create a [`TestRequestBuilder`].
pub fn request(&self, method: Method, uri: impl Into<String>) -> TestRequestBuilder<'_, E> {
TestRequestBuilder::new(self, method, uri.into())
}
impl_methods!(
/// Create a [`TestRequestBuilder`] with `GET` method.
(get, GET),

View File

@@ -1,5 +1,3 @@
use std::collections::HashMap;
use headers::{Header, HeaderMapExt};
use http::{header, header::HeaderName, Extensions, HeaderMap, HeaderValue, Method};
use serde::Serialize;
@@ -15,16 +13,13 @@ pub struct TestRequestBuilder<'a, E> {
cli: &'a TestClient<E>,
uri: String,
method: Method,
query: HashMap<String, Value>,
query: Vec<(String, Value)>,
headers: HeaderMap,
body: Body,
extensions: Extensions,
}
impl<'a, E> TestRequestBuilder<'a, E>
where
E: Endpoint,
{
impl<'a, E> TestRequestBuilder<'a, E> {
pub(crate) fn new(cli: &'a TestClient<E>, method: Method, uri: String) -> Self {
Self {
cli,
@@ -73,7 +68,7 @@ where
#[must_use]
pub fn query(mut self, name: impl Into<String>, value: &impl Serialize) -> Self {
if let Ok(value) = serde_json::to_value(value) {
self.query.insert(name.into(), value);
self.query.push((name.into(), value));
}
self
}
@@ -194,7 +189,10 @@ where
}
/// Send this request to endpoint to get the response.
pub async fn send(self) -> TestResponse {
pub async fn send(self) -> TestResponse
where
E: Endpoint,
{
let ep = &self.cli.ep;
let req = self.make_request();
let resp = ep.get_response(req).await;

View File

@@ -1,24 +1,21 @@
use std::collections::HashSet;
use futures_util::{Stream, StreamExt};
use http::{header, header::HeaderName, HeaderValue, StatusCode};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use tokio_util::compat::TokioAsyncReadCompatExt;
use crate::{test::json::TestJson, web::sse::Event, Body, Response};
use crate::{test::json::TestJson, web::sse::Event, Response};
/// A response object for testing.
pub struct TestResponse(Response);
pub struct TestResponse(pub Response);
impl TestResponse {
pub(crate) fn new(resp: Response) -> Self {
Self(resp)
}
/// Consumes this object and returns the [`Response`].
pub fn into_inner(self) -> Response {
self.0
}
/// Asserts that the status code is equals to `status`.
pub fn assert_status(&self, status: StatusCode) {
assert_eq!(self.0.status(), status);
@@ -29,6 +26,24 @@ impl TestResponse {
self.assert_status(StatusCode::OK);
}
/// Asserts that header `key` is not exist.
pub fn assert_header_is_not_exist<K>(&self, key: K)
where
K: TryInto<HeaderName>,
{
let key = key.try_into().map_err(|_| ()).expect("valid header name");
assert!(!self.0.headers().contains_key(key));
}
/// Asserts that header `key` exist.
pub fn assert_header_exist<K>(&self, key: K)
where
K: TryInto<HeaderName>,
{
let key = key.try_into().map_err(|_| ()).expect("valid header name");
assert!(self.0.headers().contains_key(key));
}
/// Asserts that header `key` is equals to `value`.
pub fn assert_header<K, V>(&self, key: K, value: V)
where
@@ -50,21 +65,75 @@ impl TestResponse {
assert_eq!(value2, value);
}
/// Asserts that the header `key` is equal to `values` separated by commas.
pub fn assert_header_csv<K, V, I>(&self, key: K, values: I)
where
K: TryInto<HeaderName>,
V: AsRef<str>,
I: IntoIterator<Item = V>,
{
let expect_values = values.into_iter().collect::<Vec<_>>();
let expect_values = expect_values
.iter()
.map(|value| value.as_ref())
.collect::<HashSet<_>>();
let key = key.try_into().map_err(|_| ()).expect("valid header name");
let value = self
.0
.headers()
.get(&key)
.unwrap_or_else(|| panic!("expect header `{}`", key));
let values = value
.to_str()
.expect("valid header value")
.split(',')
.map(|s| s.trim())
.collect::<HashSet<_>>();
assert_eq!(values, expect_values);
}
/// Asserts that header `key` is equals to `values`.
pub fn assert_header_all<K, V, I>(&self, key: K, values: I)
where
K: TryInto<HeaderName>,
V: TryInto<HeaderValue>,
I: IntoIterator<Item = V>,
{
let key = key.try_into().map_err(|_| ()).expect("valid header name");
let mut values = values
.into_iter()
.map(|value| {
value
.try_into()
.map_err(|_| ())
.expect("valid header value")
})
.collect::<Vec<_>>();
let mut values2 = self
.0
.headers()
.get_all(&key)
.iter()
.cloned()
.collect::<Vec<_>>();
values.sort();
values2.sort();
assert_eq!(values, values2);
}
/// Asserts that content type is equals to `content_type`.
pub fn assert_content_type(&self, content_type: &str) {
self.assert_header(header::CONTENT_TYPE, content_type);
}
/// Consumes this object and return the response body.
#[inline]
pub fn into_body(self) -> Body {
self.0.into_body()
}
/// Asserts that the response body is utf8 string and it equals to `text`.
pub async fn assert_text(self, text: impl AsRef<str>) {
assert_eq!(
self.into_body().into_string().await.expect("expect body"),
self.0.into_body().into_string().await.expect("expect body"),
text.as_ref()
);
}
@@ -72,7 +141,7 @@ impl TestResponse {
/// Asserts that the response body is bytes and it equals to `bytes`.
pub async fn assert_bytes(self, bytes: impl AsRef<[u8]>) {
assert_eq!(
self.into_body().into_vec().await.expect("expect body"),
self.0.into_body().into_vec().await.expect("expect body"),
bytes.as_ref()
);
}
@@ -80,7 +149,8 @@ impl TestResponse {
/// Asserts that the response body is JSON and it equals to `json`.
pub async fn assert_json(self, json: impl Serialize) {
assert_eq!(
self.into_body()
self.0
.into_body()
.into_json::<Value>()
.await
.expect("expect body"),
@@ -90,7 +160,8 @@ impl TestResponse {
/// Consumes this object and return the [`TestJson`].
pub async fn json(self) -> TestJson {
self.into_body()
self.0
.into_body()
.into_json::<TestJson>()
.await
.expect("expect body")
@@ -99,7 +170,7 @@ impl TestResponse {
/// Consumes this object and return the SSE events stream.
pub fn sse_stream(self) -> impl Stream<Item = Event> + Send + Unpin + 'static {
self.assert_content_type("text/event-stream");
sse_codec::decode_stream(self.into_body().into_async_read().compat())
sse_codec::decode_stream(self.0.into_body().into_async_read().compat())
.map(|res| {
let event = res.expect("valid sse frame");
match event {

View File

@@ -139,7 +139,7 @@ mod tests {
use tokio::io::AsyncReadExt;
use super::*;
use crate::{handler, Endpoint, EndpointExt, Request};
use crate::{handler, test::TestClient, EndpointExt};
async fn decompress_data(algo: CompressionAlgo, data: &[u8]) -> String {
let mut output = Vec::new();
@@ -157,18 +157,16 @@ mod tests {
DATA
}
let mut resp = index
.and_then(move |resp| async move { Ok(Compress::new(resp, algo)) })
.call(Request::default())
.await
.unwrap()
.into_response();
let resp = TestClient::new(
index.and_then(move |resp| async move { Ok(Compress::new(resp, algo)) }),
)
.get("/")
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(header::CONTENT_ENCODING, algo.as_str());
assert_eq!(
resp.headers().get(header::CONTENT_ENCODING),
Some(&HeaderValue::from_static(algo.as_str()))
);
assert_eq!(
decompress_data(algo, &resp.take_body().into_bytes().await.unwrap()).await,
decompress_data(algo, &resp.0.into_body().into_bytes().await.unwrap()).await,
DATA
);
}

View File

@@ -316,6 +316,7 @@ impl<'a> FromRequest<'a> for Cookie {
/// get, handler,
/// http::{header, StatusCode},
/// middleware::CookieJarManager,
/// test::TestClient,
/// web::cookie::{Cookie, CookieJar},
/// Endpoint, EndpointExt, Request, Route,
/// };
@@ -334,17 +335,16 @@ impl<'a> FromRequest<'a> for Cookie {
/// let app = Route::new()
/// .at("/", get(index))
/// .with(CookieJarManager::new());
/// let cli = TestClient::new(app);
///
/// let resp = app.call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// let cookie = resp.headers().get(header::SET_COOKIE).cloned().unwrap();
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "count: 1");
/// let resp = cli.get("/").send().await;
/// resp.assert_status_is_ok();
/// let cookie = resp.0.headers().get(header::SET_COOKIE).cloned().unwrap();
/// resp.assert_text("count: 1").await;
///
/// let resp = app
/// .call(Request::builder().header(header::COOKIE, cookie).finish())
/// .await
/// .unwrap();
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "count: 2");
/// let resp = cli.get("/").header(header::COOKIE, cookie).send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("count: 2").await;
/// # });
/// ```
#[derive(Default, Clone)]

View File

@@ -50,8 +50,10 @@ impl<'a, T: Send + Sync + 'static> FromRequest<'a> for Data<&'a T> {
#[cfg(test)]
mod tests {
use http::StatusCode;
use super::*;
use crate::{handler, middleware::AddData, Endpoint, EndpointExt};
use crate::{handler, middleware::AddData, test::TestClient, EndpointExt};
#[tokio::test]
async fn test_data_extractor() {
@@ -61,7 +63,11 @@ mod tests {
}
let app = index.with(AddData::new(100i32));
app.call(Request::default()).await.unwrap();
TestClient::new(app)
.get("/")
.send()
.await
.assert_status_is_ok();
}
#[tokio::test]
@@ -71,14 +77,11 @@ mod tests {
todo!()
}
let app = index;
assert_eq!(
app.call(Request::default())
.await
.unwrap_err()
.downcast_ref::<GetDataError>(),
Some(&GetDataError("i32"))
);
TestClient::new(index)
.get("/")
.send()
.await
.assert_status(StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
@@ -88,7 +91,10 @@ mod tests {
assert_eq!(value.to_uppercase(), "ABC");
}
let app = index.with(AddData::new("abc".to_string()));
app.call(Request::default()).await.unwrap();
TestClient::new(index.with(AddData::new("abc".to_string())))
.get("/")
.send()
.await
.assert_status_is_ok();
}
}

View File

@@ -31,6 +31,7 @@ use crate::{
/// use poem::{
/// get, handler,
/// http::{Method, StatusCode, Uri},
/// test::TestClient,
/// web::Form,
/// Endpoint, Request, Route,
/// };
@@ -47,32 +48,26 @@ use crate::{
/// format!("{}:{}", title, content)
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let app = Route::new().at("/", get(index).post(index));
/// let cli = TestClient::new(app);
///
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/?title=foo&content=bar"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:bar");
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = cli
/// .get("/")
/// .query("title", &"foo")
/// .query("content", &"bar")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("foo:bar").await;
///
/// let resp = app
/// .call(
/// Request::builder()
/// .method(Method::POST)
/// .uri(Uri::from_static("/"))
/// .content_type("application/x-www-form-urlencoded")
/// .body("title=foo&content=bar"),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:bar");
/// let resp = cli
/// .post("/")
/// .form(&[("title", "foo"), ("content", "bar")])
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("foo:bar").await;
/// # });
/// ```
pub struct Form<T>(pub T);
@@ -123,10 +118,11 @@ impl<'a, T: DeserializeOwned> FromRequest<'a> for Form<T> {
#[cfg(test)]
mod tests {
use http::StatusCode;
use serde::Deserialize;
use super::*;
use crate::{handler, http::Uri, Endpoint};
use crate::{handler, test::TestClient};
#[tokio::test]
async fn test_form_extractor() {
@@ -142,34 +138,26 @@ mod tests {
assert_eq!(form.value, 100);
}
index
.call(
Request::builder()
.uri(Uri::from_static("/?name=abc&value=100"))
.finish(),
)
.await
.unwrap();
let cli = TestClient::new(index);
index
.call(
Request::builder()
.method(Method::POST)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("name=abc&value=100"),
)
cli.get("/")
.query("name", &"abc")
.query("value", &"100")
.send()
.await
.unwrap();
.assert_status_is_ok();
assert!(index
.call(
Request::builder()
.method(Method::POST)
.header(header::CONTENT_TYPE, "application/json")
.body("name=abc&value=100"),
)
cli.post("/")
.form(&[("name", "abc"), ("value", "100")])
.send()
.await
.unwrap_err()
.is::<ParseFormError>());
.assert_status_is_ok();
cli.post("/")
.content_type("application/json")
.body("name=abc&value=100")
.send()
.await
.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
}

View File

@@ -23,6 +23,7 @@ use crate::{
/// handler,
/// http::{Method, StatusCode},
/// post,
/// test::TestClient,
/// web::Json,
/// Endpoint, Request, Route,
/// };
@@ -38,21 +39,13 @@ use crate::{
/// format!("welcome {}!", user.name)
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let app = Route::new().at("/", post(index));
/// let resp = app
/// .call(
/// Request::builder()
/// .method(Method::POST)
/// .body(r#"{"name": "foo"}"#),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(
/// resp.into_body().into_string().await.unwrap(),
/// "welcome foo!"
/// );
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = cli.post("/").body(r#"{"name": "foo"}"#).send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("welcome foo!").await;
/// # });
/// ```
///
@@ -62,7 +55,9 @@ use crate::{
/// [`serde::Serialize`].
///
/// ```
/// use poem::{get, handler, http::StatusCode, web::Json, Endpoint, Request, Route};
/// use poem::{
/// get, handler, http::StatusCode, test::TestClient, web::Json, Endpoint, Request, Route,
/// };
/// use serde::Serialize;
///
/// #[derive(Serialize)]
@@ -77,14 +72,13 @@ use crate::{
/// })
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let app = Route::new().at("/", get(index));
/// let resp = app.call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(
/// resp.into_body().into_string().await.unwrap(),
/// r#"{"name":"foo"}"#
/// )
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = cli.get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text(r#"{"name":"foo"}"#).await;
/// # });
/// ```
#[derive(Debug, Clone, Eq, PartialEq, Default)]
@@ -131,13 +125,10 @@ impl<T: Serialize + Send> IntoResponse for Json<T> {
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use serde_json::json;
use super::*;
use crate::{
handler,
http::{Method, StatusCode},
Endpoint,
};
use crate::{handler, test::TestClient};
#[derive(Deserialize, Serialize, Debug, Eq, PartialEq)]
struct CreateResource {
@@ -153,22 +144,13 @@ mod tests {
assert_eq!(query.value, 100);
}
index
.call(
Request::builder()
.method(Method::POST)
.header(header::CONTENT_TYPE, "application/json")
.body(
r#"
{
"name": "abc",
"value": 100
}
"#,
),
)
let cli = TestClient::new(index);
cli.post("/")
.header(header::CONTENT_TYPE, "application/json")
.body_json(&json!({"name": "abc", "value": 100}))
.send()
.await
.unwrap();
.assert_status_is_ok();
}
#[tokio::test]
@@ -181,15 +163,13 @@ mod tests {
})
}
let mut resp = index.call(Request::default()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
serde_json::from_str::<CreateResource>(&resp.take_body().into_string().await.unwrap())
.unwrap(),
CreateResource {
name: "abc".to_string(),
value: 100,
}
);
let cli = TestClient::new(index);
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_json(&CreateResource {
name: "abc".to_string(),
value: 100,
})
.await;
}
}

View File

@@ -242,8 +242,8 @@ impl RequestBody {
/// use std::fmt::{self, Display, Formatter};
///
/// use poem::{
/// get, handler, http::StatusCode, Endpoint, Error, FromRequest, Request, RequestBody, Result,
/// Route,
/// get, handler, http::StatusCode, test::TestClient, Endpoint, Error, FromRequest, Request,
/// RequestBody, Result, Route,
/// };
///
/// struct Token(String);
@@ -266,10 +266,14 @@ impl RequestBody {
/// }
///
/// let app = Route::new().at("/", get(index));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let _ = index
/// .call(Request::builder().header("MyToken", "token123").finish())
/// .await;
/// cli.get("/")
/// .header("MyToken", "token123")
/// .send()
/// .await
/// .assert_status_is_ok();
/// # });
/// ```
#[async_trait::async_trait]
@@ -372,7 +376,9 @@ pub trait FromRequest<'a>: Sized {
/// # Create you own response
///
/// ```
/// use poem::{handler, http::Uri, web::Query, Endpoint, IntoResponse, Request, Response};
/// use poem::{
/// handler, http::Uri, test::TestClient, web::Query, Endpoint, IntoResponse, Request, Response,
/// };
/// use serde::Deserialize;
///
/// struct Hello(Option<String>);
@@ -397,34 +403,16 @@ pub trait FromRequest<'a>: Sized {
/// Hello(params.0.name)
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// assert_eq!(
/// index
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/?name=sunli"))
/// .finish()
/// )
/// .await
/// .unwrap()
/// .take_body()
/// .into_string()
/// .await
/// .unwrap(),
/// "hello sunli"
/// );
/// let cli = TestClient::new(index);
///
/// assert_eq!(
/// index
/// .call(Request::builder().uri(Uri::from_static("/")).finish())
/// .await
/// .unwrap()
/// .take_body()
/// .into_string()
/// .await
/// .unwrap(),
/// "hello"
/// );
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = cli.get("/").query("name", &"sunli").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("hello sunli").await;
///
/// let resp = cli.get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("hello").await;
/// # });
/// ```
pub trait IntoResponse: Send {

View File

@@ -170,7 +170,7 @@ impl Multipart {
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler, http::StatusCode, Endpoint};
use crate::{handler, http::StatusCode, test::TestClient};
#[tokio::test]
async fn test_multipart_extractor_content_type() {
@@ -179,18 +179,14 @@ mod tests {
todo!()
}
let err = index
.call(
Request::builder()
.header("content-type", "multipart/json; boundary=X-BOUNDARY")
.body(()),
)
.await
.unwrap_err();
match err.downcast_ref::<ParseMultipartError>().unwrap() {
ParseMultipartError::InvalidContentType(ct) if ct == "multipart/json" => {}
_ => panic!(),
}
let cli = TestClient::new(index);
let resp = cli
.post("/")
.header("content-type", "multipart/json; boundary=X-BOUNDARY")
.body(())
.send()
.await;
resp.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[tokio::test]
@@ -212,14 +208,14 @@ mod tests {
}
let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_file_field\"; filename=\"a-text-file.txt\"\r\nContent-Type: text/plain\r\n\r\nHello world\nHello\r\nWorld\rAgain\r\n--X-BOUNDARY--\r\n";
let resp = index
.call(
Request::builder()
.header("content-type", "multipart/form-data; boundary=X-BOUNDARY")
.body(data),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let cli = TestClient::new(index);
let resp = cli
.post("/")
.header("content-type", "multipart/form-data; boundary=X-BOUNDARY")
.body(data)
.send()
.await;
resp.assert_status_is_ok();
}
}

View File

@@ -20,6 +20,7 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result};
/// use poem::{
/// get, handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// web::Path,
/// Endpoint, Request, Route,
/// };
@@ -30,17 +31,12 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result};
/// }
///
/// let app = Route::new().at("/users/:user_id/team/:team_id", get(users_teams_show));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/users/100/team/300"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "100:300");
/// let resp = cli.get("/users/100/team/300").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("100:300").await;
/// # });
/// ```
///
@@ -50,6 +46,7 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result};
/// use poem::{
/// get, handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// web::Path,
/// Endpoint, Request, Route,
/// };
@@ -60,17 +57,12 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result};
/// }
///
/// let app = Route::new().at("/users/:user_id", get(user_info));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/users/100"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "100");
/// let resp = cli.get("/users/100").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("100").await;
/// # });
/// ```
///
@@ -81,6 +73,7 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result};
/// use poem::{
/// get, handler,
/// http::{StatusCode, Uri},
/// test::TestClient,
/// web::Path,
/// Endpoint, Request, Route,
/// };
@@ -98,17 +91,12 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result};
/// }
///
/// let app = Route::new().at("/users/:user_id/team/:team_id", get(users_teams_show));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/users/foo/team/100"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:100");
/// let resp = cli.get("/users/foo/team/100").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("foo:100").await;
/// # });
/// ```
#[derive(Debug, Eq, PartialEq, Clone)]

View File

@@ -16,6 +16,7 @@ use crate::{error::ParseQueryError, FromRequest, Request, RequestBody, Result};
/// use poem::{
/// get, handler,
/// http::{Method, StatusCode, Uri},
/// test::TestClient,
/// web::Query,
/// Endpoint, Request, Route,
/// };
@@ -34,17 +35,16 @@ use crate::{error::ParseQueryError, FromRequest, Request, RequestBody, Result};
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let app = Route::new().at("/", get(index).post(index));
/// let cli = TestClient::new(app);
///
/// let resp = app
/// .call(
/// Request::builder()
/// .uri(Uri::from_static("/?title=foo&content=bar"))
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:bar");
/// let resp = cli
/// .get("/")
/// .query("title", &"foo")
/// .query("content", &"bar")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("foo:bar").await;
/// # });
/// ```
#[derive(Debug, Clone, Eq, PartialEq, Default)]
@@ -82,7 +82,7 @@ mod tests {
use serde::Deserialize;
use super::*;
use crate::{handler, http::Uri, Endpoint};
use crate::{handler, test::TestClient};
#[tokio::test]
async fn test_query_extractor() {
@@ -98,13 +98,12 @@ mod tests {
assert_eq!(query.value, 100);
}
index
.call(
Request::builder()
.uri(Uri::from_static("/?name=abc&value=100"))
.finish(),
)
let cli = TestClient::new(index);
cli.get("/")
.query("name", &"abc")
.query("value", &100)
.send()
.await
.unwrap();
.assert_status_is_ok();
}
}

View File

@@ -13,6 +13,7 @@ use crate::{
/// use poem::{
/// get, handler,
/// http::{header, HeaderValue, StatusCode, Uri},
/// test::TestClient,
/// web::Redirect,
/// Endpoint, Request, Route,
/// };
@@ -24,12 +25,9 @@ use crate::{
///
/// let app = Route::new().at("/", get(index));
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app.call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::MOVED_PERMANENTLY);
/// assert_eq!(
/// resp.headers().get(header::LOCATION),
/// Some(&HeaderValue::from_static("https://www.google.com"))
/// );
/// let resp = TestClient::new(app).get("/").send().await;
/// resp.assert_status(StatusCode::MOVED_PERMANENTLY);
/// resp.assert_header(header::LOCATION, "https://www.google.com");
/// # });
/// ```
#[derive(Debug, Clone, Eq, PartialEq)]

View File

@@ -14,6 +14,7 @@ use crate::{Body, IntoResponse, Response};
/// use poem::{
/// handler,
/// http::StatusCode,
/// test::TestClient,
/// web::sse::{Event, SSE},
/// Endpoint, Request,
/// };
@@ -27,13 +28,12 @@ use crate::{Body, IntoResponse, Response};
/// ]))
/// }
///
/// let cli = TestClient::new(index);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let mut resp = index.call(Request::default()).await.unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(
/// resp.take_body().into_string().await.unwrap(),
/// "data: a\n\ndata: b\n\ndata: c\n\n"
/// );
/// let resp = cli.get("/").send().await;
/// resp.assert_status_is_ok();
/// resp.assert_text("data: a\n\ndata: b\n\ndata: c\n\n").await;
/// # });
/// ```
pub struct SSE {

View File

@@ -52,20 +52,22 @@ mod tests {
use tokio::io::AsyncReadExt;
use super::*;
use crate::{handler, Endpoint};
use crate::{handler, test::TestClient};
#[tokio::test]
async fn test_tempfile_extractor() {
#[handler(internal)]
async fn index123(mut file: TempFile) {
async fn index(mut file: TempFile) {
let mut s = String::new();
file.read_to_string(&mut s).await.unwrap();
assert_eq!(s, "abcdef");
}
index123
.call(Request::builder().body("abcdef"))
let cli = TestClient::new(index);
cli.get("/")
.body("abcdef")
.send()
.await
.unwrap();
.assert_status_is_ok();
}
}

View File

@@ -16,6 +16,7 @@ use crate::{error::ParseTypedHeaderError, FromRequest, Request, RequestBody, Res
/// use poem::{
/// get, handler,
/// http::{header, StatusCode},
/// test::TestClient,
/// web::{headers::Host, TypedHeader},
/// Endpoint, Request, Route,
/// };
@@ -26,18 +27,16 @@ use crate::{error::ParseTypedHeaderError, FromRequest, Request, RequestBody, Res
/// }
///
/// let app = Route::new().at("/", get(index));
/// let cli = TestClient::new(app);
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let resp = app
/// .call(
/// Request::builder()
/// .header(header::HOST, "example.com")
/// .finish(),
/// )
/// .await
/// .unwrap();
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(resp.into_body().into_string().await.unwrap(), "example.com");
/// let resp = cli
/// .get("/")
/// .header(header::HOST, "example.com")
/// .send()
/// .await;
/// resp.assert_status_is_ok();
/// resp.assert_text("example.com").await;
/// # });
/// ```
#[derive(Debug)]
@@ -78,8 +77,8 @@ mod tests {
use super::*;
use crate::{
handler,
test::TestClient,
web::headers::{ContentLength, Host},
Endpoint,
};
#[tokio::test]
@@ -89,10 +88,14 @@ mod tests {
assert_eq!(content_length.0 .0, 3);
}
index
.call(Request::builder().header("content-length", 3).body("abc"))
.await
.unwrap();
let cli = TestClient::new(index);
let resp = cli
.get("/")
.header("content-length", 3)
.body("abc")
.send()
.await;
resp.assert_status_is_ok();
}
#[tokio::test]