diff --git a/poem/CHANGELOG.md b/poem/CHANGELOG.md index 3fe86b39..567b869c 100644 --- a/poem/CHANGELOG.md +++ b/poem/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# [1.2.21] 2022-1-4 + +- Add test utilities. + # [1.2.20] 2022-1-1 - `RouteMethod` returns `MethodNotAllowedError` error instead of `NotFoundError` when the corresponding method is not found. diff --git a/poem/Cargo.toml b/poem/Cargo.toml index 8f84521b..6a517f69 100644 --- a/poem/Cargo.toml +++ b/poem/Cargo.toml @@ -34,6 +34,7 @@ opentelemetry = ["libopentelemetry", "opentelemetry-http", "opentelemetry-semant prometheus = ["libopentelemetry", "opentelemetry-prometheus", "libprometheus"] tempfile = ["libtempfile"] csrf = ["cookie", "base64", "libcsrf"] +test = ["sse", "sse-codec", "tokio-util/compat"] [dependencies] poem-derive = { path = "../poem-derive", version = "1.2.21" } @@ -84,6 +85,7 @@ sha1 = { version = "0.6.0", optional = true } base64 = { version = "0.13.0", optional = true } libcsrf = { package = "csrf", version = "0.4.1", optional = true } httpdate = { version = "1.0.2", optional = true } +sse-codec = { version = "0.3.2", optional = true } # Feature optional dependencies anyhow = { version = "1.0.0", optional = true } diff --git a/poem/src/lib.rs b/poem/src/lib.rs index 05bff4eb..71cf5da5 100644 --- a/poem/src/lib.rs +++ b/poem/src/lib.rs @@ -256,6 +256,9 @@ pub mod middleware; #[cfg(feature = "session")] #[cfg_attr(docsrs, doc(cfg(feature = "session")))] pub mod session; +#[cfg(feature = "test")] +#[cfg_attr(docsrs, doc(cfg(feature = "test")))] +pub mod test; pub mod web; #[doc(inline)] diff --git a/poem/src/test/client.rs b/poem/src/test/client.rs new file mode 100644 index 00000000..282b4dd7 --- /dev/null +++ b/poem/src/test/client.rs @@ -0,0 +1,90 @@ +use http::{header, header::HeaderName, HeaderMap, HeaderValue, Method}; + +use crate::{endpoint::BoxEndpoint, test::TestRequestBuilder, Endpoint, Response}; + +macro_rules! impl_methods { + ($($(#[$docs:meta])* ($name:ident, $method:ident)),*) => { + $( + $(#[$docs])* + pub fn $name(&self, uri: impl Into) -> TestRequestBuilder<'_, E> { + TestRequestBuilder::new(self, Method::$method, uri.into()) + } + )* + }; +} + +/// A client for testing. +/// +/// # Examples +/// +/// ``` +/// use poem::{handler, test::TestClient, Route}; +/// +/// #[handler] +/// fn index() {} +/// +/// let app = Route::new().at("/", index); +/// +/// let cli = TestClient::new(index); +/// +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { +/// cli.get("/").send().await.assert_status_is_ok(); +/// # }); +/// ``` +pub struct TestClient> { + pub(crate) ep: E, + pub(crate) default_headers: HeaderMap, +} + +impl TestClient { + /// Create a new client for the specified endpoint. + pub fn new(ep: E) -> Self { + Self { + ep, + default_headers: Default::default(), + } + } + + /// Sets the default header for each requests. + #[must_use] + pub fn default_header(mut self, key: K, value: V) -> Self + where + K: TryInto, + V: TryInto, + { + let key = key.try_into().map_err(|_| ()).expect("valid header name"); + let value = value + .try_into() + .map_err(|_| ()) + .expect("valid header value"); + self.default_headers.append(key, value); + self + } + + /// Sets the default content type for each requests. + #[must_use] + pub fn default_content_type(self, content_type: impl AsRef) -> Self { + self.default_header(header::CONTENT_TYPE, content_type.as_ref()) + } + + impl_methods!( + /// Create a `GET` request. + (get, GET), + /// Create a `POST` request. + (post, POST), + /// Create a `PUT` request. + (put, PUT), + /// Create a `DELETE` request. + (delete, DELETE), + /// Create a `HEAD` request. + (head, HEAD), + /// Create a `OPTIONS` request. + (options, OPTIONS), + /// Create a `CONNECT` request. + (connect, CONNECT), + /// Create a `PATCH` request. + (patch, PATCH), + /// Create a `TRACE` request. + (trace, TRACE) + ); +} diff --git a/poem/src/test/json.rs b/poem/src/test/json.rs new file mode 100644 index 00000000..810b9ebd --- /dev/null +++ b/poem/src/test/json.rs @@ -0,0 +1,217 @@ +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::{Map, Value}; + +/// A JSON object for testing. +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct TestJson(Value); + +impl TestJson { + /// Returns a reference the value. + #[inline] + pub fn value(&self) -> TestJsonValue<'_> { + TestJsonValue(&self.0) + } +} + +macro_rules! impl_types { + ($($(#[$docs:meta])* ($ty:ty, $name:ident, $method:ident)),*) => { + $( + $(#[$docs])* + pub fn $name(&self) -> $ty { + self.0.$method().expect(stringify!($name)) + } + )* + }; +} + +macro_rules! impl_assert_types { + ($($(#[$docs:meta])* ($ty:ty, $name:ident, $method:ident)),*) => { + $( + $(#[$docs])* + pub fn $name(&self, value: $ty) { + assert_eq!(self.$method(), value); + } + )* + }; +} + +macro_rules! impl_array_types { + ($($(#[$docs:meta])* ($ty:ty, $name:ident, $method:ident)),*) => { + $( + $(#[$docs])* + pub fn $name(&self) -> Vec<$ty> { + self.array().iter().map(|value| value.$method()).collect() + } + )* + }; +} + +macro_rules! impl_assert_array_types { + ($($(#[$docs:meta])* ($ty:ty, $name:ident, $method:ident)),*) => { + $( + $(#[$docs])* + pub fn $name(&self, values: &[$ty]) { + assert_eq!(self.$method(), values); + } + )* + }; +} + +/// A JSON value. +#[derive(Debug, PartialEq, Copy, Clone)] +pub struct TestJsonValue<'a>(&'a Value); + +impl<'a> PartialEq for TestJsonValue<'a> { + fn eq(&self, other: &Value) -> bool { + self.0 == other + } +} + +impl<'a> TestJsonValue<'a> { + impl_types!( + /// Returns the `i64` value. + (i64, i64, as_i64), + /// Returns the `f64` value. + (f64, f64, as_f64), + /// Returns the `f64` value. + (bool, bool, as_bool) + ); + + impl_array_types!( + /// Returns the `i64` array. + (i64, i64_array, i64), + /// Returns the `i64` array. + (f64, f64_array, f64), + /// Returns the `i64` array. + (bool, bool_array, bool) + ); + + impl_assert_types!( + /// Asserts that value is `integer` and it equals to `value`. + (i64, assert_i64, i64), + /// Asserts that value is `float` and it equals to `value`. + (f64, assert_f64, f64), + /// Asserts that value is `boolean` and it equals to `value`. + (bool, assert_bool, bool), + /// Asserts that value is `string` and it equals to `value`. + (&str, assert_string, string) + ); + + impl_assert_array_types!( + /// Asserts that value is `integer` array and it equals to `values`. + (i64, assert_i64_array, i64_array), + /// Asserts that value is `float` array and it equals to `values`. + (f64, assert_f64_array, f64_array), + /// Asserts that value is `boolean` array and it equals to `values`. + (bool, assert_bool_array, bool_array), + /// Asserts that value is `string` array and it equals to `values`. + (&str, assert_string_array, string_array) + ); + + /// Returns the `string` value. + pub fn string(&self) -> &'a str { + self.0.as_str().expect("string") + } + + /// Returns the `string` array. + pub fn string_array(&self) -> Vec<&'a str> { + self.array().iter().map(|value| value.string()).collect() + } + + /// Asserts that the value is an array and return `TestJsonArray`. + pub fn array(&self) -> TestJsonArray<'a> { + TestJsonArray(self.0.as_array().expect("array")) + } + + /// Asserts that the value is an object and return `TestJsonArray`. + pub fn object(&self) -> TestJsonObject<'a> { + TestJsonObject(self.0.as_object().expect("object")) + } + + /// Asserts that the value is an object array and return + /// `Vec`. + pub fn object_array(&self) -> Vec> { + self.array().iter().map(|value| value.object()).collect() + } + + /// Asserts that the value is null. + pub fn assert_null(&self) { + assert!(self.0.is_null()) + } + + /// Deserialize the value to `T`. + pub fn deserialize(&self) -> T { + serde_json::from_value(self.0.clone()).expect("valid json") + } +} + +/// A JSON array. +#[derive(Debug, Copy, Clone)] +pub struct TestJsonArray<'a>(&'a [Value]); + +impl<'a, T> PartialEq for TestJsonArray<'a> +where + T: AsRef<[Value]>, +{ + fn eq(&self, other: &T) -> bool { + self.0 == other.as_ref() + } +} + +impl<'a> TestJsonArray<'a> { + /// Returns the number of elements in the array. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if the array contains no elements. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns the element at index `idx`. + pub fn get(&self, idx: usize) -> TestJsonValue<'a> { + TestJsonValue(&self.0[idx]) + } + + /// Returns an iterator over the array. + pub fn iter(&self) -> impl Iterator> { + self.0.iter().map(TestJsonValue) + } + + /// Asserts the array length is equals to `len`. + pub fn assert_len(&self, len: usize) { + assert_eq!(self.len(), len); + } +} + +/// A JSON object. +#[derive(Debug, Copy, Clone)] +pub struct TestJsonObject<'a>(&'a Map); + +impl<'a> TestJsonObject<'a> { + /// Returns the number of elements in the object. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if the object contains no elements. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns the element corresponding to the `name`. + pub fn get(&self, name: impl AsRef) -> TestJsonValue<'a> { + TestJsonValue(&self.0[name.as_ref()]) + } + + /// Returns an iterator over the object. + pub fn iter(&self) -> impl Iterator)> { + self.0.iter().map(|(k, v)| (k, TestJsonValue(v))) + } + + /// Asserts the object length is equals to `len`. + pub fn assert_len(&self, len: usize) { + assert_eq!(self.len(), len); + } +} diff --git a/poem/src/test/mod.rs b/poem/src/test/mod.rs new file mode 100644 index 00000000..4181f03c --- /dev/null +++ b/poem/src/test/mod.rs @@ -0,0 +1,11 @@ +//! Test utilities to test your endpoints. + +mod client; +mod json; +mod request_builder; +mod response; + +pub use client::TestClient; +pub use json::{TestJson, TestJsonArray, TestJsonObject, TestJsonValue}; +pub use request_builder::TestRequestBuilder; +pub use response::TestResponse; diff --git a/poem/src/test/request_builder.rs b/poem/src/test/request_builder.rs new file mode 100644 index 00000000..93dc6bbd --- /dev/null +++ b/poem/src/test/request_builder.rs @@ -0,0 +1,121 @@ +use http::{header, header::HeaderName, Extensions, HeaderMap, HeaderValue, Method}; +use serde::Serialize; + +use crate::{ + test::{TestClient, TestResponse}, + Body, Endpoint, Request, +}; + +/// A request builder for testing. +pub struct TestRequestBuilder<'a, E> { + cli: &'a TestClient, + uri: String, + method: Method, + query: String, + headers: HeaderMap, + body: Body, + extensions: Extensions, +} + +impl<'a, E> TestRequestBuilder<'a, E> +where + E: Endpoint, +{ + pub(crate) fn new(cli: &'a TestClient, method: Method, uri: String) -> Self { + Self { + cli, + uri, + method, + query: Default::default(), + headers: Default::default(), + body: Body::empty(), + extensions: Default::default(), + } + } + + /// Sets the query string for this request. + #[must_use] + pub fn query(self, params: impl Serialize) -> Self { + Self { + query: serde_urlencoded::to_string(params).expect("valid query params"), + ..self + } + } + + /// Sets the header value for this request. + #[must_use] + pub fn header(mut self, key: K, value: V) -> Self + where + K: TryInto, + V: TryInto, + { + let key = key.try_into().map_err(|_| ()).expect("valid header name"); + let value = value + .try_into() + .map_err(|_| ()) + .expect("valid header value"); + self.headers.append(key, value); + self + } + + /// Sets the content type for this request. + #[must_use] + pub fn content_type(self, content_type: impl AsRef) -> Self { + self.header(header::CONTENT_TYPE, content_type.as_ref()) + } + + /// Sets the body for this request. + #[must_use] + pub fn body(self, body: impl Into) -> Self { + Self { + body: body.into(), + ..self + } + } + + /// Sets the JSON body for this request. + #[must_use] + pub fn body_json(self, body: &impl Serialize) -> Self { + Self { + body: serde_json::to_string(&body).expect("valid json").into(), + ..self + } + } + + fn make_request(self) -> Request { + let uri = if self.query.is_empty() { + self.uri + } else { + format!("{}?{}", self.uri, self.query) + }; + + let mut req = Request::builder() + .method(self.method) + .uri(uri.parse().expect("valid uri")) + .finish(); + req.headers_mut().extend(self.cli.default_headers.clone()); + req.headers_mut().extend(self.headers); + *req.extensions_mut() = self.extensions; + req.set_body(self.body); + + req + } + + /// Sets the extension data for this request. + #[must_use] + pub fn data(mut self, data: T) -> Self + where + T: Send + Sync + 'static, + { + self.extensions.insert(data); + self + } + + /// Send this request to endpoint to get the response. + pub async fn send(self) -> TestResponse { + let ep = &self.cli.ep; + let req = self.make_request(); + let resp = ep.get_response(req).await; + TestResponse::new(resp) + } +} diff --git a/poem/src/test/response.rs b/poem/src/test/response.rs new file mode 100644 index 00000000..39430c12 --- /dev/null +++ b/poem/src/test/response.rs @@ -0,0 +1,139 @@ +use futures_util::{Stream, StreamExt}; +use http::{header, header::HeaderName, HeaderValue, StatusCode}; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json::Value; +use tokio_util::compat::TokioAsyncReadCompatExt; + +use crate::{test::json::TestJson, web::sse::Event, Body, Response}; + +/// A response object for testing. +pub struct TestResponse(Response); + +impl TestResponse { + pub(crate) fn new(resp: Response) -> Self { + Self(resp) + } + + /// Consumes this object and returns the [`Response`]. + pub fn into_inner(self) -> Response { + self.0 + } + + /// Asserts that the status code is equals to `status`. + pub fn assert_status(&self, status: StatusCode) { + assert_eq!(self.0.status(), status); + } + + /// Asserts that the status code is `200 OK`. + pub fn assert_status_is_ok(&self) { + self.assert_status(StatusCode::OK); + } + + /// Asserts that header `key` is equals to `value`. + pub fn assert_header(&self, key: K, value: V) + where + K: TryInto, + V: TryInto, + { + let key = key.try_into().map_err(|_| ()).expect("valid header name"); + let value = value + .try_into() + .map_err(|_| ()) + .expect("valid header value"); + + let value2 = self + .0 + .headers() + .get(&key) + .unwrap_or_else(|| panic!("expect header `{}`", key)); + + assert_eq!(value2, value); + } + + /// Asserts that content type is equals to `content_type`. + pub fn assert_content_type(&self, content_type: &str) { + self.assert_header(header::CONTENT_TYPE, content_type); + } + + /// Consumes this object and return the response body. + #[inline] + pub fn into_body(self) -> Body { + self.0.into_body() + } + + /// Asserts that the response body is utf8 string and it equals to `text`. + pub async fn assert_text(self, text: impl AsRef) { + assert_eq!( + self.into_body().into_string().await.expect("expect body"), + text.as_ref() + ); + } + + /// Asserts that the response body is bytes and it equals to `bytes`. + pub async fn assert_bytes(self, bytes: impl AsRef<[u8]>) { + assert_eq!( + self.into_body().into_vec().await.expect("expect body"), + bytes.as_ref() + ); + } + + /// Asserts that the response body is JSON and it equals to `json`. + pub async fn assert_json(self, json: impl Serialize) { + assert_eq!( + self.into_body() + .into_json::() + .await + .expect("expect body"), + serde_json::to_value(json).expect("valid json") + ); + } + + /// Consumes this object and return the [`TestJson`]. + pub async fn json(self) -> TestJson { + self.into_body() + .into_json::() + .await + .expect("expect body") + } + + /// Consumes this object and return the SSE events stream. + pub fn sse_stream(self) -> impl Stream + Send + Unpin + 'static { + self.assert_content_type("text/event-stream"); + sse_codec::decode_stream(self.into_body().into_async_read().compat()) + .map(|res| { + let event = res.expect("valid sse frame"); + match event { + sse_codec::Event::Message { id, event, data } => Event::Message { + id: id.unwrap_or_default(), + event, + data, + }, + sse_codec::Event::Retry { retry } => Event::Retry { retry }, + } + }) + .boxed() + } + + /// Consumes this object and return the SSE events stream which deserialize + /// the message data to `T`. + pub fn typed_sse_stream( + self, + ) -> impl Stream + Send + Unpin + 'static { + self.sse_stream() + .filter_map(|event| async move { + match event { + Event::Message { data, .. } => { + Some(serde_json::from_str::(&data).expect("valid data")) + } + Event::Retry { .. } => None, + } + }) + .boxed() + } + + /// Consumes this object and return the SSE events stream which deserialize + /// the message data to [`TestJson`]. + pub fn json_sse_stream(self) -> impl Stream + Send + Unpin + 'static { + self.typed_sse_stream::() + } +}