diff --git a/.github/workflows/book.yml b/.github/workflows/book.yml deleted file mode 100644 index 8452ff03..00000000 --- a/.github/workflows/book.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Book - -on: - push: - branches: - - release - paths: - - 'docs/**' - - '.github/workflows/**' - -jobs: - deploy_en: - name: Deploy book on gh-pages - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v2 - - name: Install mdBook - uses: peaceiris/actions-mdbook@v1 - - name: Render book - run: | - mdbook build -d gh-pages docs/en - mdbook build -d gh-pages docs/zh-CN - mkdir docs/gh-pages - mv docs/en/gh-pages docs/gh-pages/en - mv docs/zh-CN/gh-pages docs/gh-pages/zh-CN - mv docs/index.html docs/gh-pages - cp -r docs/assets docs/gh-pages/en - cp -r docs/assets docs/gh-pages/zh-CN - - name: Deploy - uses: peaceiris/actions-gh-pages@v3.8.0 - with: - emptyCommits: true - keepFiles: false - deploy_key: ${{ secrets.ACTIONS_DEPLOY_KEY }} - publish_branch: gh-pages - publish_dir: docs/gh-pages - cname: poem.rs diff --git a/README.md b/README.md index 2121bd2f..4b6f598d 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,6 @@ The following are cases of community use: ### Resources -- [Book](https://poem.rs) - [Examples](https://github.com/poem-web/poem/tree/master/examples) ## Contributing diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index e9c07289..00000000 --- a/docs/.gitignore +++ /dev/null @@ -1 +0,0 @@ -book \ No newline at end of file diff --git a/docs/en/book.toml b/docs/en/book.toml deleted file mode 100644 index 718925c5..00000000 --- a/docs/en/book.toml +++ /dev/null @@ -1,9 +0,0 @@ -[book] -authors = ["sunli"] -description = "Poem Book" -src = "src" -language = "en" -title = "Poem Book" - -[rust] -edition = "2021" diff --git a/docs/en/src/SUMMARY.md b/docs/en/src/SUMMARY.md deleted file mode 100644 index 3661cc3d..00000000 --- a/docs/en/src/SUMMARY.md +++ /dev/null @@ -1,28 +0,0 @@ -# Poem Book - -## Poem - -- [Poem](poem.md) - - [Quickstart](poem/quickstart.md) - - [Endpoint](poem/endpoint.md) - - [Routing](poem/routing.md) - - [Extractors](poem/extractors.md) - - [Responses](poem/responses.md) - - [Handling errors](poem/handling_errors.md) - - [Middleware](poem/middleware.md) - - [Protocols](poem/protocols.md) - - [Websocket](poem/protocols/websocket.md) - - [Server-Sent Events (SSE)](poem/protocols/sse.md) - - [Listeners](poem/listeners.md) -- [OpenAPI](openapi.md) - - [Quickstart](openapi/quickstart.md) - - [Type System](openapi/type_system.md) - - [Basic types](openapi/type_system/basic_types.md) - - [Enum](openapi/type_system/enum.md) - - [Object](openapi/type_system/object.md) - - [API](openapi/api.md) - - [Custom Request](openapi/custom_request.md) - - [Custom Response](openapi/custom_response.md) - - [Upload files](openapi/upload_files.md) - - [Validators](openapi/validators.md) - - [Authentication](openapi/authentication.md) diff --git a/docs/en/src/openapi.md b/docs/en/src/openapi.md deleted file mode 100644 index 23d8fd56..00000000 --- a/docs/en/src/openapi.md +++ /dev/null @@ -1,13 +0,0 @@ -# OpenAPI - -The OpenAPI Specification (OAS) defines a standard, language-agnostic interface to RESTful APIs which allows both humans -and computers to discover and understand the capabilities of the service without access to source code, documentation, or -through network traffic inspection. When properly defined, a consumer can understand and interact with the remote service -with a minimal amount of implementation logic. - -`Poem-openapi` is a [OpenAPI](https://swagger.io/specification/) server-side framework based on `Poem`. - -Generally, if you want your API to support the OAS, you first need to create an [OpenAPI Definitions](https://swagger.io/specification/), -and then write the corresponding code according to the definitions, or use `Swagger CodeGen` to generate the boilerplate -server code. But `Poem-openapi` is different from these two, it allows you to only write Rust business code and use -procedural macros to automatically generate lots of boilerplate code that conform to the OpenAPI specification. diff --git a/docs/en/src/openapi/api.md b/docs/en/src/openapi/api.md deleted file mode 100644 index cf84614c..00000000 --- a/docs/en/src/openapi/api.md +++ /dev/null @@ -1,77 +0,0 @@ -# API - -The following defines some API operations to add, delete, modify and query the `pet` table. - -A method represents an API operation. Use the `path` and `method` attributes to specify the path and HTTP method of the operation. - -There can be multiple parameters for each API operation, and the following types can be used: - - - **poem_openapi::param::Query** represents this parameter is parsed from the query string - - - **poem_openapi::param::Header** represents this parameter is parsed from the request headers - - - **poem_openapi::param::Path** represents this parameter is parsed from the URI path - - - **poem_openapi::param::Cookie** represents this parameter is parsed from the cookie - - - **poem_openapi::param::CookiePrivate** represents this parameter is parsed from the private cookie - - - **poem_openapi::param::CookieSigned** represents this parameter is parsed from the signed cookie - - - **poem_openapi::payload::Binary** represents a binary request payload - - - **poem_openapi::payload::Json** represents a request payload encoded with JSON - - - **poem_openapi::payload::PlainText** represents a utf8 string request payload - - - **ApiRequest** parse the request payload generated by `ApiRequest` macro - - - **SecurityScheme** parse the security scheme generated by `SecurityScheme` macro - - - **T: FromRequest** used poem's extractors - -The return value can be any type that implements `ApiResponse`. - -```rust -use poem_api::{ - OpenApi, - poem_api::payload::Json, - param::{Path, Query}, -}; -use poem::Result; - -struct Api; - -#[OpenApi] -impl Api { - /// Add new pet - #[oai(path = "/pet", method = "post")] - async fn add_pet(&self, pet: Json) -> Result<()> { - todo!() - } - - /// Update existing pet - #[oai(path = "/pet", method = "put")] - async fn update_pet(&self, pet: Json) -> Result<()> { - todo!() - } - - /// Delete a pet - #[oai(path = "/pet/:pet_id", method = "delete")] - async fn delete_pet(&self, #[oai(name = "pet_id", in = "path")] id: Path) -> Result<()> { - todo!() - } - - /// Query pet by id - #[oai(path = "/pet", method = "get")] - async fn find_pet_by_id(&self, id: Query) -> Result> { - todo!() - } - - /// Query pets by status - #[oai(path = "/pet/findByStatus", method = "get")] - async fn find_pets_by_status(&self, status: Query) -> Result>> { - todo!() - } -} -``` diff --git a/docs/en/src/openapi/authentication.md b/docs/en/src/openapi/authentication.md deleted file mode 100644 index c23aac59..00000000 --- a/docs/en/src/openapi/authentication.md +++ /dev/null @@ -1,82 +0,0 @@ -# Authentication - -The OpenApi specification defines `apikey`, `basic`, `bearer`, `oauth2` and `openIdConnect` authentication modes, which -describe the authentication parameters required for the specified operation. - -The following example is to log in with `Github` and provide an operation to get all public repositories. - -```rust -use poem_openapi::{ - SecurityScheme, SecurityScope, OpenApi, - auth::Bearer, -}; - -#[derive(OAuthScopes)] -enum GithubScope { - /// access to public repositories. - #[oai(rename = "public_repo")] - PublicRepo, - - /// access to read a user's profile data. - #[oai(rename = "read:user")] - ReadUser, -} - -/// Github authorization -#[derive(SecurityScheme)] -#[oai( - type = "oauth2", - flows(authorization_code( - authorization_url = "https://github.com/login/oauth/authorize", - token_url = "https://github.com/login/oauth/token", - scopes = "GithubScope", - )) -)] -struct GithubAuthorization(Bearer); - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/repo", method = "get")] - async fn repo_list( - &self, - #[oai(scope("GithubScope::PublicRepo"))] auth: GithubAuthorization, - ) -> Result> { - // Use the token in GithubAuthorization to obtain all public repositories from Github. - todo!() - } -} -``` - -For the complete example, please refer to [Example](https://github.com/poem-web/poem/tree/master/examples/openapi/auth-github). - -## Check authentication information - -You can use the `checker` attribute to specify a checker function to check the original authentication information and -convert it to the return type of this function. This function must return `Option`, and return `None` if check fails. - -```rust -struct User { - username: String, -} - -/// ApiKey authorization -#[derive(SecurityScheme)] -#[oai( - type = "api_key", - key_name = "X-API-Key", - in = "header", - checker = "api_checker" -)] -struct MyApiKeyAuthorization(User); - -async fn api_checker(req: &Request, api_key: ApiKey) -> Option { - let connection = req.data::().unwrap(); - - // check in database - todo!() -} -``` - -For the complete example, please refer to [Example](https://github.com/poem-web/poem/tree/master/examples/openapi/auth-apikey). \ No newline at end of file diff --git a/docs/en/src/openapi/custom_request.md b/docs/en/src/openapi/custom_request.md deleted file mode 100644 index 5ea2f083..00000000 --- a/docs/en/src/openapi/custom_request.md +++ /dev/null @@ -1,52 +0,0 @@ -# Custom Request - -The `OpenAPI` specification allows the same operation to support processing different requests of `Content-Type`, -for example, an operation can support `application/json` and `text/plain` types of request content. - -In `Poem-openapi`, to support this type of request, you need to use the `ApiRequest` macro to customize a request object -that implements the `Payload` trait. - -In the following example, the `create_post` function accepts the `CreatePostRequest` request, and when the creation is -successful, it returns the `id`. - -```rust -use poem_open::{ - ApiRequest, Object, - payload::{PlainText, Json}, -}; -use poem::Result; - -#[derive(Object)] -struct Post { - title: String, - content: String, -} - -#[derive(ApiRequest)] -enum CreatePostRequest { - /// Create from json - Json(Json), - /// Create from plain text - Text(PlainText), -} - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/hello", method = "post")] - async fn create_post( - &self, - req: CreatePostRequest, - ) -> Result> { - match req { - CreatePostRequest::Json(Json(blog)) => { - todo!(); - } - CreatePostRequest::Text(content) => { - todo!(); - } - } - } -} -``` diff --git a/docs/en/src/openapi/custom_response.md b/docs/en/src/openapi/custom_response.md deleted file mode 100644 index 0d25f58d..00000000 --- a/docs/en/src/openapi/custom_response.md +++ /dev/null @@ -1,94 +0,0 @@ -# Custom Response - -In all the previous examples, all operations return `Result`. When an error occurs, a `poem::Error` is returned, which -contains the reason and status code of the error. However, the `OpenAPI` specification supports a more detailed definition -of the response of the operation, such as which status codes may be returned, and the reason for the status code and the -content of the response. - -In the following example, we change the return type of the `create_post` function to `CreateBlogResponse`. - -`Ok`, `Forbidden` and `InternalError` specify the response content of a specific status code. - -```rust -use poem_openapi::ApiResponse; -use poem::http::StatusCode; - -#[derive(ApiResponse)] -enum CreateBlogResponse { - /// Created successfully - #[oai(status = 200)] - Ok(Json), - - /// Permission denied - #[oai(status = 403)] - Forbidden, - - /// Internal error - #[oai(status = 500)] - InternalError, -} - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/hello", method = "get")] - async fn create_post( - &self, - req: CreatePostRequest, - ) -> CreateBlogResponse { - match req { - CreatePostRequest::Json(Json(blog)) => { - todo!(); - } - CreatePostRequest::Text(content) => { - todo!(); - } - } - } -} -``` - -When the parsing request fails, the default `400 Bad Request` error will be returned, but sometimes we want to return a -custom error content, we can use the `bad_request_handler` attribute to set an error handling function, this function is -used to convert `ParseRequestError` to specified response type. - -```rust -use poem_openapi::{ - ApiResponse, Object, ParseRequestError, payload::Json, -}; - -#[derive(Object)] -struct ErrorMessage { - code: i32, - reason: String, -} - -#[derive(ApiResponse)] -#[oai(bad_request_handler = "bad_request_handler")] -enum CreateBlogResponse { - /// Created successfully - #[oai(status = 200)] - Ok(Json), - - /// Permission denied - #[oai(status = 403)] - Forbidden, - - /// Internal error - #[oai(status = 500)] - InternalError, - - /// Bad request - #[oai(status = 400)] - BadRequest(Json), -} - -fn bad_request_handler(err: ParseRequestError) -> CreateBlogResponse { - // When the parsing request fails, a custom error content is returned, which is a JSON - CreateBlogResponse::BadRequest(Json(ErrorMessage { - code: -1, - reason: err.to_string(), - })) -} -``` diff --git a/docs/en/src/openapi/quickstart.md b/docs/en/src/openapi/quickstart.md deleted file mode 100644 index 67bfc60e..00000000 --- a/docs/en/src/openapi/quickstart.md +++ /dev/null @@ -1,63 +0,0 @@ -# Quickstart - -In the following example, we define an API with a path of `/hello`, which accepts a URL parameter named `name` and returns -a string as the response content. The type of the `name` parameter is `Option`, which means it is an optional parameter. - -Running the following code, open `http://localhost:3000` with a browser to see `Swagger UI`, you can use it to browse API -definitions and test them. - -```rust -use poem::{listener::TcpListener, Route}; -use poem_openapi::{payload::PlainText, OpenApi, OpenApiService}; - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/hello", method = "get")] - async fn index( - &self, - #[oai(name = "name", in = "query")] name: Option, // in="query" means this parameter is parsed from Url - ) -> PlainText { // PlainText is the response type, which means that the response type of the API is a string, and the Content-Type is `text/plain` - match name { - Some(name) => PlainText(format!("hello, {}!", name)), - None => PlainText("hello!".to_string()), - } - } -} - -#[tokio::main] -async fn main() -> Result<(), std::io::Error> { - // Create a TCP listener - let listener = TcpListener::bind("127.0.0.1:3000"); - - // Create API service - let api_service = OpenApiService::new(Api, "Demo", "0.1.0") - .title("Hello World") - .server("http://localhost:3000/api"); - - // Enable the Swagger UI - let ui = api_service.swagger_ui(); - - // Enable the OpenAPI specification - let spec = api_service.spec_endpoint(); - - // Start the server and specify that the root path of the API is /api, and the path of Swagger UI is / - poem::Server::new(listener) - .await? - .run( - Route::new() - .at("/openapi.json", spec) - .nest("/api", api_service) - .nest("/", ui) - ) - .await -} -``` - -This is an example of `poem-openapi`, so you can also directly execute the following command to play: - -```shell -git clone https://github.com/poem-web/poem -cargo run --bin example-openapi-hello-world -``` diff --git a/docs/en/src/openapi/type_system.md b/docs/en/src/openapi/type_system.md deleted file mode 100644 index efb02caa..00000000 --- a/docs/en/src/openapi/type_system.md +++ /dev/null @@ -1,5 +0,0 @@ -# Type System - - -Poem-openapi implements conversions from OpenAPI types to Rust types, and it's easy to use. - diff --git a/docs/en/src/openapi/type_system/basic_types.md b/docs/en/src/openapi/type_system/basic_types.md deleted file mode 100644 index ff17ec9e..00000000 --- a/docs/en/src/openapi/type_system/basic_types.md +++ /dev/null @@ -1,19 +0,0 @@ -# Basic types - -The basic type can be used as a request parameter, request content or response content. `Poem` defines a `Type` trait to -represent a basic type, which can provide some information about the type at runtime to generate OpenAPI definitions. - -`Poem` implements `Type` traits for most common types, you can use them directly, and you can also customize new types, -but you need to have a certain understanding of [Json Schema](https://json-schema.org/). - -The following table lists the Rust data types corresponding to some OpenAPI data types: - -| Open API | Rust | -|-----------------------------------------|-----------------------------------| -| `{type: "integer", format: "int32" }` | i32 | -| `{type: "integer", format: "float32" }` | f32 | -| `{type: "bool" }` | bool | -| `{type: "string" }` | String, &str | -| `{type: "string", format: "binary" }` | Binary | -| `{type: "string", format: "bytes" }` | Base64 | -| `{type: "array" }` | Vec | diff --git a/docs/en/src/openapi/type_system/enum.md b/docs/en/src/openapi/type_system/enum.md deleted file mode 100644 index c745d451..00000000 --- a/docs/en/src/openapi/type_system/enum.md +++ /dev/null @@ -1,16 +0,0 @@ -# Enum - -Use the procedural macro `Enum` to define an enumerated type. - -**Poem-openapi will automatically change the name of each item to `SCREAMING_SNAKE_CASE` convention. You can use `rename_all` attribute to rename all items.** - -```rust -use poem_api::Enum; - -#[derive(Enum)] -enum PetStatus { - Available, - Pending, - Sold, -} -``` diff --git a/docs/en/src/openapi/type_system/object.md b/docs/en/src/openapi/type_system/object.md deleted file mode 100644 index c3975d5b..00000000 --- a/docs/en/src/openapi/type_system/object.md +++ /dev/null @@ -1,29 +0,0 @@ -# Object - -Use the procedural macro `Object` to define an object. All object members must be types that implement the `Type trait` -(unless you mark it with `#[oai(skip)]`, the field will be ignored serialization and use the default value instead). - -Use the following code to define an object type, which contains four fields, one of which is an enumerated type. - -_Object type is also a kind of basic type, it also implements the `Type` trait, so it can also be a member of another object._ - -**Poem-openapi will automatically change the name of each member to `camelCase` convention. You can use `rename_all` attribute to rename all items.** - -```rust -use poem_api::{Object, Enum}; - -#[derive(Enum)] -enum PetStatus { - Available, - Pending, - Sold, -} - -#[derive(Object)] -struct Pet { - id: u64, - name: String, - photo_urls: Vec, - status: PetStatus, -} -``` diff --git a/docs/en/src/openapi/upload_files.md b/docs/en/src/openapi/upload_files.md deleted file mode 100644 index 5207bd15..00000000 --- a/docs/en/src/openapi/upload_files.md +++ /dev/null @@ -1,29 +0,0 @@ -# Upload files - -The `Multipart` macro is usually used for file upload. It can define a form to contain one or more files and some -additional fields. The following example provides an operation to create a `Pet` object, which can upload some image -files at the same time. - -```rust -use poem_openapi::{Multipart, OpenApi}; -use poem::Result; - -#[derive(Debug, Multipart)] -struct CreatePetPayload { - name: String, - status: PetStatus, - photos: Vec, // some photos -} - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/pet", method = "post")] - async fn create_pet(&self, payload: CreatePetPayload) -> Result> { - todo!() - } -} -``` - -For the complete example, please refer to [Upload Example](https://github.com/poem-web/poem/tree/master/examples/openapi/upload). diff --git a/docs/en/src/openapi/validators.md b/docs/en/src/openapi/validators.md deleted file mode 100644 index 65afe15d..00000000 --- a/docs/en/src/openapi/validators.md +++ /dev/null @@ -1,27 +0,0 @@ -# Validators - -The `OpenAPI` specification supports validation based on `Json Schema`, and `Poem-openapi` also supports them. You can -apply validators to operation parameters, object members, and `Multipart` fields. The validator can only work on specific -data types, otherwise it will fail to compile. For example, `maximum` can only be used for numeric types, and `max_items` -can only be used for array types. - -For more validators, please refer to [document](https://docs.rs/poem-openapi/*/poem_openapi/attr.OpenApi.html#operation-argument-parameters). - -```rust -use poem_openapi::{Object, OpenApi, Multipart}; - -#[derive(Object)] -struct Pet { - id: u64, - - /// The length of the name must be less than 32 - #[oai(validator(max_length = "32"))] - name: String, - - /// Array length must be less than 3 and the url length must be less than 256 - #[oai(validator(max_items = "3", max_length = "256"))] - photo_urls: Vec, - - status: PetStatus, -} -``` diff --git a/docs/en/src/poem.md b/docs/en/src/poem.md deleted file mode 100644 index 4ab1f0a8..00000000 --- a/docs/en/src/poem.md +++ /dev/null @@ -1,4 +0,0 @@ -# Poem - -`Poem` is a full-featured and easy-to-use web framework with the Rust programming language. - diff --git a/docs/en/src/poem/endpoint.md b/docs/en/src/poem/endpoint.md deleted file mode 100644 index 91f4daa3..00000000 --- a/docs/en/src/poem/endpoint.md +++ /dev/null @@ -1,88 +0,0 @@ -# Endpoint - -The endpoint can handle HTTP requests. You can implement the `Endpoint` trait to create your own endpoint. -`Poem` also provides some convenient functions to easily create a custom endpoint type. - -In the previous chapter, we learned how to use the `handler` macro to convert a function to an endpoint. - -Now let's see how to create your own endpoint by implementing the `Endpoint` trait. - -This is the definition of the `Endpoint` trait, you need to specify the type of `Output` and implement the `call` method. - -```rust -/// An HTTP request handler. -#[async_trait] -pub trait Endpoint: Send + Sync + 'static { - /// Represents the response of the endpoint. - type Output: IntoResponse; - - /// Get the response to the request. - async fn call(&self, req: Request) -> Self::Output; -} -``` - -Now we implement an `Endpoint`, which receives HTTP requests and outputs a string containing the request method and path. - -The `Output` associated type must be a type that implements the `IntoResponse` trait. Poem has been implemented by most -commonly used types. - -Since `Endpoint` contains an asynchronous method `call`, we need to decorate it with the `async_trait` macro. - -```rust -struct MyEndpoint; - -#[async_trait] -impl Endpoint for MyEndpoint { - type Output = String; - - async fn call(&self, req: Request) -> Self::Output { - format!("method={} path={}", req.method(), req.uri().path()); - } -} -``` - -## Create from functions - -You can use `poem::endpoint::make` and `poem::endpoint::make_sync` to create endpoints from asynchronous functions and -synchronous functions. - -The following endpoint does the same thing: - -```rust -let ep = poem::endpoint::make(|req| async move { - format!("method={} path={}", req.method(), req.uri().path()) -}); -``` - -## EndpointExt - -The `EndpointExt` trait provides some convenience functions for converting the input or output of the endpoint. - -- `EndpointExt::before` is used to convert the request. -- `EndpointExt::after` is used to convert the output. -- `EndpointExt::map_ok`, `EndpointExt::map_err`, `EndpointExt::and_then` are used to process the output of type `Result`. - -## Using Result type - -`Poem` also implements `IntoResponse` for the `poem::Result` type, so it can also be used as the output type of the -endpoint, so you can use `?` in the `call` method. - -```rust -struct MyEndpoint; - -#[async_trait] -impl Endpoint for MyEndpoint { - type Output = poem::Result; - - async fn call(&self, req: Request) -> Self::Output { - Ok(req.take_body().into_string().await?) - } -} -``` - -You can use the `EndpointExt::map_to_response` method to convert the output of the endpoint to the `Response` type, or -use the `EndpointExt::map_to_result` to convert the output to the `poem::Result` type. - -```rust -let ep = MyEndpoint.map_to_response() // impl Endpoint -``` diff --git a/docs/en/src/poem/extractors.md b/docs/en/src/poem/extractors.md deleted file mode 100644 index 419e6bde..00000000 --- a/docs/en/src/poem/extractors.md +++ /dev/null @@ -1,225 +0,0 @@ -# Extractors - -The extractor is used to extract something from the HTTP request. - -`Poem` provides some commonly used extractors for extracting something from HTTP requests. - -You can use one or more extractors as the parameters of the function, up to 16. - -In the following example, the `index` function uses 3 extractors to extract the remote address, HTTP method and URI. - -```rust -#[handler] -fn index(remote_addr: SocketAddr, method: Method, uri: &Uri) {} -``` - -# Built-in extractors - - - **Option<T>** - - Extracts `T` from the incoming request, returns `None` if it - fails. - - - **&Request** - - Extracts the `Request` from the incoming request. - - - **&RemoteAddr** - - Extracts the remote peer's address [`RemoteAddr`] from request. - - - **&LocalAddr** - - Extracts the local server's address [`LocalAddr`] from request. - - - **Method** - - Extracts the `Method` from the incoming request. - - - **Version** - - Extracts the `Version` from the incoming request. - - - **&Uri** - - Extracts the `Uri` from the incoming request. - - - **&HeaderMap** - - Extracts the `HeaderMap` from the incoming request. - - - **Data<&T>** - - Extracts the `Data` from the incoming request. - - - **TypedHeader<T>** - - Extracts the `TypedHeader` from the incoming request. - - - **Path<T>** - - Extracts the `Path` from the incoming request. - - - **Query<T>** - - Extracts the `Query` from the incoming request. - - - **Form<T>** - - Extracts the `Form` from the incoming request. - - - **Json<T>** - - Extracts the `Json` from the incoming request. - - _This extractor will take over the requested body, so you should avoid - using multiple extractors of this type in one handler._ - - - **TempFile** - - Extracts the `TempFile` from the incoming request. - - _This extractor will take over the requested body, so you should avoid - using multiple extractors of this type in one handler._ - - - **Multipart** - - Extracts the `Multipart` from the incoming request. - - _This extractor will take over the requested body, so you should avoid - using multiple extractors of this type in one handler._ - - - **&CookieJar** - - Extracts the `CookieJar`](cookie::CookieJar) from the incoming request. - - _Requires `CookieJarManager` middleware._ - - - **&Session** - - Extracts the [`Session`](crate::session::Session) from the incoming request. - - _Requires `CookieSession` or `RedisSession` middleware._ - - - **Body** - - Extracts the `Body` from the incoming request. - - _This extractor will take over the requested body, so you should avoid - using multiple extractors of this type in one handler._ - - - **String** - - Extracts the body from the incoming request and parse it into utf8 string. - - _This extractor will take over the requested body, so you should avoid - using multiple extractors of this type in one handler._ - - - **Vec<u8>** - - Extracts the body from the incoming request and collect it into - `Vec`. - - _This extractor will take over the requested body, so you should avoid - using multiple extractors of this type in one handler._ - - - **Bytes** - - Extracts the body from the incoming request and collect it into - `Bytes`. - - _This extractor will take over the requested body, so you should avoid - using multiple extractors of this type in one handler._ - - - **WebSocket** - - Ready to accept a websocket connection. - -## Handling of extractor errors - -By default, the extractor will return a `400 Bad Request` when an error occurs, but sometimes you may want to change -this behavior, so you can handle the error yourself. - -In the following example, when the `Query` extractor fails, it will return a `500 Internal Server Error` response and the reason for the error. - -```rust -use poem::web::Query; -use poem::error::ParseQueryError; -use poem::{IntoResponse, Response}; -use poem::http::StatusCode; - -#[derive(Debug, Deserialize)] -struct Params { - name: String, -} - -#[handler] -fn index(res: Result, ParseQueryError>) -> Response { - match res { - Ok(Query(params)) => params.name.into_response(), - Err(err) => Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(err.to_string()), - } -} -``` - -## Custom extractor - -You can also implement your own extractor. - - The following is an example of a custom token extractor, which extracts the - token from the `MyToken` header. - -```rust -use poem::{ - get, handler, http::StatusCode, listener::TcpListener, FromRequest, Request, - RequestBody, Response, Route, Server, -}; - -struct Token(String); - -// Error type for Token extractor -#[derive(Debug)] -struct MissingToken; - -/// custom-error can also be reused -impl IntoResponse for MissingToken { - fn into_response(self) -> Response { - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("missing token") - } -} - -// Implements a token extractor -#[poem::async_trait] -impl<'a> FromRequest<'a> for Token { - type Error = MissingToken; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { - let token = req - .headers() - .get("MyToken") - .and_then(|value| value.to_str().ok()) - .ok_or(MissingToken)?; - Ok(Token(token.to_string())) - } -} - -#[handler] -async fn index(token: Token) { - assert_eq!(token.0, "token123"); -} - -#[tokio::main] -async fn main() -> Result<(), std::io::Error> { - if std::env::var_os("RUST_LOG").is_none() { - std::env::set_var("RUST_LOG", "poem=debug"); - } - tracing_subscriber::fmt::init(); - - let app = Route::new().at("/", get(index)); - let listener = TcpListener::bind("127.0.0.1:3000"); - let server = Server::new(listener).await?; - server.run(app).await -} -``` \ No newline at end of file diff --git a/docs/en/src/poem/handling_errors.md b/docs/en/src/poem/handling_errors.md deleted file mode 100644 index 3f7bfcf0..00000000 --- a/docs/en/src/poem/handling_errors.md +++ /dev/null @@ -1,119 +0,0 @@ -# Handling errors - -In `Poem`, we handle errors based on the response status code. When the status code is in `400-599`, we can think that -an error occurred while processing this request. - -We can use `EndpointExt::after` to create a new endpoint type to customize the error response. - -In the following example, the `after` function is used to convert the output of the `index` function and output an error -response when an server error occurs. - -**Note that the endpoint type generated by a `handler` macro is always `Endpoint`, even if it returns -a `Result`.** - -```rust -use poem::{handler, Result, Error}; -use poem::http::StatusCode; - -#[handler] -async fn index() -> Result<()> { - Err(Error::new(StatusCode::BAD_REQUEST)) -} - -let ep = index.after(|resp| { - if resp.status().is_server_error() { - Response::builder() - .status(resp.status()) - .body("custom error") - } else { - resp - } -}); -``` - -The `EndpointExt::map_to_result` function can help us convert any type of endpoint to `Endpoint`, so -that we only need to check the status code to know whether an error has occurred. - -```rust -use poem::endpoint::make; -use poem::{Error, EndpointExt}; -use poem::http::StatusCode; - -let ep = make(|_| Ok::<(), Error>(Error::new(StatusCode::new(Status::BAD_REQUEST)))) - .map_to_response(); - -let ep = ep.after(|resp| { - if resp.status().is_server_error() { - Response::builder() - .status(resp.status()) - .body("custom error") - } else { - resp - } -}); -``` - -## poem::Error - -`poem::Error` is a general error type, which implements `From`, so you can easily use the `?` operator to -convert any error type to it. The default status code is `503 Internal Server Error`. - -```rust -use poem::Result; - -#[handler] -fn index(data: Vec) -> Result { - let value: i32 = serde_json::from_slice(&data)?; - Ok(value) -} -``` - -But sometimes we don't want to always use the `503` status code, `Poem` provides some helper functions to convert the error type. - -```rust -use poem::{Result, web::Json, error::BadRequest}; - -#[handler] -fn index(data: Vec) -> Result> { - let value: i32 = serde_json::from_slice(&data).map_err(BadRequest)?; - Ok(Json(value)) -} -``` - -## Custom error type - -Sometimes we can use custom error types to reduce boilerplate code. - -NOTE: `Poem`'s error types usually only needs to implement `IntoResponse`. - -```rust -use poem::{ - Response, - error::ReadBodyError, - http::StatusCode, -}; - -enum MyError { - InvalidValue, - ReadBodyError(ReadBodyError), -} - -impl IntoResponse for MyError { - fn into_response(self) -> Response { - match self { - MyError::InvalidValue => Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("invalid value"), - MyError::ReadBodyError(err) => err.into(), // ReadBodyError has implemented `IntoResponse`. - } - } -} - -#[handler] -fn index(data: Result) -> Result<(), MyError> { - let data = data?; - if data.len() > 10 { - return Err(MyError::InvalidValue); - } -} -``` diff --git a/docs/en/src/poem/listeners.md b/docs/en/src/poem/listeners.md deleted file mode 100644 index fe26089b..00000000 --- a/docs/en/src/poem/listeners.md +++ /dev/null @@ -1,56 +0,0 @@ -# Listeners - -`Poem` provides some commonly used listeners. - -- TcpListener - - Listens for incoming TCP connections. - -- UnixListener - - Listens for incoming Unix domain socket connections. - -## TLS - -You can call the `Listener::tls` function to wrap a listener and make it support TLS connections. - -```rust -let listener = TcpListener::bind("127.0.0.1:3000") - .tls(TlsConfig::new().key(KEY).cert(CERT)); -``` - -## TLS reload - -You can use a stream to pass the latest Tls config to `Poem`. - -The following example loads the latest TLS config from file every 1 minute: - -```rust -use async_trait::async_trait; - -fn load_tls_config() -> Result { - Ok(TlsConfig::new() - .cert(std::fs::read("cert.pem")?) - .key(std::fs::read("key.pem")?)) -} - -let listener = TcpListener::bind("127.0.0.1:3000") - .tls(async_stream::stream! { - loop { - if let Ok(tls_config) = load_tls_config() { - yield tls_config; - } - tokio::time::sleep(Duration::from_secs(60)).await; - } - }); -``` - -## Combine multiple listeners. - -Call `Listener::combine` to combine two listeners into one, or you can call this function multiple times to combine more listeners. - -```rust -let listener = TcpListener::bind("127.0.0.1:3000") - .combine(TcpListener::bind("127.0.0.1:3001")) - .combine(TcpListener::bind("127.0.0.1:3002")); -``` \ No newline at end of file diff --git a/docs/en/src/poem/middleware.md b/docs/en/src/poem/middleware.md deleted file mode 100644 index 4bf1fe86..00000000 --- a/docs/en/src/poem/middleware.md +++ /dev/null @@ -1,114 +0,0 @@ -# Middleware - -The middleware can do something before or after the request is processed. - -`Poem` provides some commonly used middleware implementations. - -- `AddData` - - Used to attach a status to the request, such as a token for authentication. - -- `SetHeader` - - Used to add some specific HTTP headers to the response. - -- `Cors` - - Used for Cross-Origin Resource Sharing. - -- `Tracing` - - Use [`tracing`](https://crates.io/crates/tracing) to record all requests and responses. - -- `Compression` - - Used for decompress request body and compress response body. - -## Custom middleware - -It is easy to implement your own middleware, you only need to implement the `Middleware` trait, which is a converter to -convert an input endpoint to another endpoint. - -The following example creates a custom middleware that reads the value of the HTTP request header named `X-Token` and -adds it as the status of the request. - -```rust -use poem::{handler, web::Data, Endpoint, EndpointExt, Middleware, Request}; - -/// A middleware that extract token from HTTP headers. -struct TokenMiddleware; - -impl Middleware for TokenMiddleware { - type Output = TokenMiddlewareImpl; - - fn transform(&self, ep: E) -> Self::Output { - TokenMiddlewareImpl { ep } - } -} - -/// The new endpoint type generated by the TokenMiddleware. -struct TokenMiddlewareImpl { - ep: E, -} - -const TOKEN_HEADER: &str = "X-Token"; - -/// Token data -struct Token(String); - -#[poem::async_trait] -impl Endpoint for TokenMiddlewareImpl { - type Output = E::Output; - - async fn call(&self, mut req: Request) -> Self::Output { - if let Some(value) = req - .headers() - .get(TOKEN_HEADER) - .and_then(|value| value.to_str().ok()) - { - // Insert token data to extensions of request. - let token = value.to_string(); - req.extensions_mut().insert(Token(token)); - } - - // call the inner endpoint. - self.ep.call(req).await - } -} - -#[handler] -async fn index(Data(token): Data<&Token>) -> String { - token.0.clone() -} - -// Use the `TokenMiddleware` middleware to convert the `index` endpoint. -let ep = index.with(TokenMiddleware); -``` - -## Custom middleware with function - -You can also use a function to implement a middleware. - -```rust -async fn extract_token(next: E, mut req: Request) -> Response { - if let Some(value) = req - .headers() - .get(TOKEN_HEADER) - .and_then(|value| value.to_str().ok()) - { - // Insert token data to extensions of request. - let token = value.to_string(); - req.extensions_mut().insert(Token(token)); - } - - // call the next endpoint. - next.call(req).await -} - -#[handler] -async fn index(Data(token): Data<&Token>) -> String { - token.0.clone() -} - -let ep = index.around(extract_token); -``` diff --git a/docs/en/src/poem/protocols.md b/docs/en/src/poem/protocols.md deleted file mode 100644 index 8d9f7317..00000000 --- a/docs/en/src/poem/protocols.md +++ /dev/null @@ -1 +0,0 @@ -# Protocols diff --git a/docs/en/src/poem/protocols/sse.md b/docs/en/src/poem/protocols/sse.md deleted file mode 100644 index dbfe534d..00000000 --- a/docs/en/src/poem/protocols/sse.md +++ /dev/null @@ -1,28 +0,0 @@ -# Server-Sent Events (SSE) - -SSE allows the server to continuously push data to the client. - -You need to create a `SSE` response with a type that implements `Stream`. - -The endpoint in the example below will send three events. - -```rust -use futures_util::stream; -use poem::{ - handler, Route, get, - http::StatusCode, - web::sse::{Event, SSE}, - Endpoint, Request, -}; - -#[handler] -fn index() -> SSE { - SSE::new(stream::iter(vec![ - Event::message("a"), - Event::message("b"), - Event::message("c"), - ])) -} - -let app = Route::new().at("/", get(index)); -``` diff --git a/docs/en/src/poem/protocols/websocket.md b/docs/en/src/poem/protocols/websocket.md deleted file mode 100644 index f288edc7..00000000 --- a/docs/en/src/poem/protocols/websocket.md +++ /dev/null @@ -1,32 +0,0 @@ -# Websocket - -Websocket allows a long connection for two-way communication between the client and the server. - -`Poem` provides a `WebSocket` extractor to create this connection. - -When the connection is successfully upgraded, a specified closure is called to send and receive data. - -The following example is an echo service, which always sends out the received data. - -**Note that the output of this endpoint must be the return value of the `WebSocket::on_upgrade` function, otherwise the -connection cannot be created correctly.** - -```rust -use futures_util::{SinkExt, StreamExt}; -use poem::{ - handler, Route, get, - web::websocket::{Message, WebSocket}, - IntoResponse, -}; - -#[handler] -async fn index(ws: WebSocket) -> impl IntoResponse { - ws.on_upgrade(|mut socket| async move { - if let Some(Ok(Message::Text(text))) = socket.next().await { - let _ = socket.send(Message::Text(text)).await; - } - }) -} - -let app = Route::new().at("/", get(index)); -``` \ No newline at end of file diff --git a/docs/en/src/poem/quickstart.md b/docs/en/src/poem/quickstart.md deleted file mode 100644 index 66740d63..00000000 --- a/docs/en/src/poem/quickstart.md +++ /dev/null @@ -1,65 +0,0 @@ -# Quickstart - -## Add dependency libraries - -```toml -[dependencies] -poem = "1.0" -serde = "1.0" -tokio = { version = "1.12.0", features = ["rt-multi-thread", "macros"] } -``` - -## Write a endpoint - -The `handler` macro converts a function into a type that implements `Endpoint`, and the `Endpoint` trait represents -a type that can handle HTTP requests. - -This function can receive one or more parameters, and each parameter is an extractor that can extract something from -the HTTP request. - -The extractor implements the `FromRequest` trait, and you can also implement this trait to create your own extractor. - -The return value of the function must be a type that implements the `IntoResponse` trait. It can convert itself into an -HTTP response through the `IntoResponse::into_response` method. - -The following function has an extractor, which extracts the `name` and `value` parameters from the query string of the -request uri and return a `String`, the string will be converted into an HTTP response. - -```rust -use serde::Deserialize; -use poem::{handler, listener::TcpListener, web::Query, Server}; - -#[derive(Deserialize)] -struct Params { - name: String, - value: i32, -} - -#[handler] -async fn index(Query(Params { name, value }): Query) -> String { - format!("{}={}", name, value) -} -``` - -## HTTP server - -Let's start a server, it listens to `127.0.0.1:3000`, please ignore these `unwrap` calls, this is just an example. - -The `Server::run` function accepts any type that implements the `Endpoint` trait. In this example we don't have a -routing object, so any request path will be handled by the `index` function. - -```rust - -#[tokio::main] -async fn main() { - let listener = TcpListener::bind("127.0.0.1:3000"); - Server::new(listener).run(index).await.unwrap(); -} -``` - -In this way, a simple example is implemented, we can run it and then use `curl` to do some tests. - -```shell -> curl http://localhost:3000?name=a&value=10 -name=10 -``` diff --git a/docs/en/src/poem/responses.md b/docs/en/src/poem/responses.md deleted file mode 100644 index c1815dc0..00000000 --- a/docs/en/src/poem/responses.md +++ /dev/null @@ -1,125 +0,0 @@ -# Responses - -All types that can be converted to HTTP response `Response` should implement `IntoResponse`, and they can be used as the -return value of the handler function. - -In the following example, the `string_response` and `status_response` functions return the `String` and `StatusCode` -types, because `Poem` has implemented the `IntoResponse` feature for them. - -The `no_response` function does not return a value. We can also think that its return type is `()`, and `Poem` also -implements `IntoResponse` for `()`, which is always converted to `200 OK`. - -```rust -use poem::handler; -use poem::http::StatusCode; - -#[handler] -fn string_response() -> String { - "hello".to_string() -} - -#[handler] -fn status_response() -> StatusCode {} - -#[handler] -fn no_response() {} - -``` - -# Built-in responses - -- **Result<T: IntoResponse, E: IntoResponse>** - - if the result is `Ok`, use the `Ok` value as the response, otherwise use the `Err` value. - -- **()** - - Sets the status to `OK` with an empty body. - -- **&'static str** - - Sets the status to `OK` and the `Content-Type` to `text/plain`. The -string is used as the body of the response. - -- **String** - - Sets the status to `OK` and the `Content-Type` to `text/plain`. The -string is used as the body of the response. - -- **&'static [u8]** - - Sets the status to `OK` and the `Content-Type` to -`application/octet-stream`. The slice is used as the body of the response. - -- **Html<T>** - - Sets the status to `OK` and the `Content-Type` to `text/html`. `T` is -used as the body of the response. - -- **Json<T>** - - Sets the status to `OK` and the `Content-Type` to `application/json`. Use -[`serde_json`](https://crates.io/crates/serde_json) to serialize `T` into a json string. - -- **Bytes** - - Sets the status to `OK` and the `Content-Type` to -`application/octet-stream`. The bytes is used as the body of the response. - -- **Vec<u8>** - - Sets the status to `OK` and the `Content-Type` to -`application/octet-stream`. The vector’s data is used as the body of the -response. - -- **Body** - - Sets the status to `OK` and use the specified body. - -- **StatusCode** - - Sets the status to the specified status code `StatusCode` with an empty -body. - -- **(StatusCode, T)** - - Convert `T` to response and set the specified status code `StatusCode`. - -- **(StatusCode, HeaderMap, T)** - - Convert `T` to response and set the specified status code `StatusCode`, -and then merge the specified `HeaderMap`. - -- **Response** - - The implementation for `Response` always returns itself. - -- **Compress<T>** - - Call `T::into_response` to get the response, then compress the response -body with the specified algorithm, and set the correct `Content-Encoding` -header. - -- **SSE** - - Sets the status to `OK` and the `Content-Type` to `text/event-stream` -with an event stream body. Use the `SSE::new` function to -create it. - -## Custom response - -In the following example, we wrap a response called `PDF`, which adds a `Content-Type: applicationn/pdf` header to the response. - -```rust -use poem::{IntoResponse, Response}; - -struct PDF(Vec); - -impl IntoResponse for PDF { - fn into_response(self) -> Response { - Response::builder() - .header("Content-Type", "application/pdf") - .body(self.0) - } -} -``` diff --git a/docs/en/src/poem/routing.md b/docs/en/src/poem/routing.md deleted file mode 100644 index f072469a..00000000 --- a/docs/en/src/poem/routing.md +++ /dev/null @@ -1,80 +0,0 @@ -# Routing - -The routing object is used to dispatch the request of the specified path and method to the specified endpoint. - -The route object is actually an endpoint, which implements the `Endpoint` trait. - -In the following example, we dispatch the requests of `/a` and `/b` to different endpoints. - -```rust -use poem::{handler, Route}; - -#[handler] -async fn a() -> &'static str { "a" } - -#[handler] -async fn b() -> &'static str { "b" } - -let ep = Route::new() - .at("/a", a) - .at("/b", b); -``` - -## Capture the variables - -Use `:NAME` to capture the value of the specified segment in the path, or use `*NAME` to capture all the values after -the specified prefix. - -In the following example, the captured values will be stored in the variable `value`, and you can use the path extractor to get them. - -```rust -#[handler] -async fn a(Path(String): Path) {} - -let ep = Route::new() - .at("/a/:value/b", handler) - .at("/prefix/*value", handler); -``` - -## Regular expressions - -You can use regular expressions to match, `` or `:NAME`, the second one can capture the matched value into a variable. - -```rust -let ep = Route::new() - .at("/a/<\\d+>", handler) - .at("/b/:value<\\d+>", handler); -``` - -## Nested - -Sometimes we want to assign a path with a specified prefix to a specified endpoint, so that some functionally independent -components can be created. - -In the following example, the request path of the `hello` endpoint is `/api/hello`. - -```rust -let api = Route::new().at("/hello", hello); -let ep = api.nest("/api", api); -``` - -Static file service is such an independent component. - -```rust -let ep = Route::new().nest("/files", Files::new("./static_files")); -``` - -## Method routing - -The routing objects introduced above can only be dispatched by some specified paths, but dispatch by paths and methods -is more common. `Poem` provides another route object `RouteMethod`, when it is combined with the `Route` object, it can -provide this ability. - -`Poem` provides some convenient functions to create `RouteMethod` objects, they are all named after HTTP standard methods. - -```rust -use poem::{Route, get, post}; - -let ep = Route::new() - .at("/users", get(get_user).post(create_user).delete(delete_user).put(update_user)); -``` diff --git a/docs/index.html b/docs/index.html deleted file mode 100644 index 4627a563..00000000 --- a/docs/index.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - Poem Book - - - - - - - - -

Poem Book

-

This book is available in multiple languages:

- - - \ No newline at end of file diff --git a/docs/zh-CN/book.toml b/docs/zh-CN/book.toml deleted file mode 100644 index a3b87084..00000000 --- a/docs/zh-CN/book.toml +++ /dev/null @@ -1,6 +0,0 @@ -[book] -authors = ["sunli"] -description = "Poem 使用手册" -src = "src" -language = "zh-CN" -title = "Poem 使用手册" diff --git a/docs/zh-CN/src/SUMMARY.md b/docs/zh-CN/src/SUMMARY.md deleted file mode 100644 index 3e563ceb..00000000 --- a/docs/zh-CN/src/SUMMARY.md +++ /dev/null @@ -1,28 +0,0 @@ -# Poem Book - -## Poem - -- [Poem](poem.md) - - [快速开始](poem/quickstart.md) - - [Endpoint](poem/endpoint.md) - - [路由](poem/routing.md) - - [提取器](poem/extractors.md) - - [响应](poem/responses.md) - - [处理错误](poem/handling_errors.md) - - [中间件](poem/middleware.md) - - [协议](poem/protocols.md) - - [Websocket](poem/protocols/websocket.md) - - [服务端事件 (SSE)](poem/protocols/sse.md) - - [监听器](poem/listeners.md) -- [OpenAPI](openapi.md) - - [快速开始](openapi/quickstart.md) - - [类型系统](openapi/type_system.md) - - [基础类型](openapi/type_system/basic_types.md) - - [枚举](openapi/type_system/enum.md) - - [对象](openapi/type_system/object.md) - - [定义API](openapi/api.md) - - [自定义请求](openapi/custom_request.md) - - [自定义响应](openapi/custom_response.md) - - [文件上传](openapi/upload_files.md) - - [参数校验](openapi/validators.md) - - [认证](openapi/authentication.md) diff --git a/docs/zh-CN/src/openapi.md b/docs/zh-CN/src/openapi.md deleted file mode 100644 index cc1e31ef..00000000 --- a/docs/zh-CN/src/openapi.md +++ /dev/null @@ -1,7 +0,0 @@ -# OpenAPI - -[OpenAPI]((https://swagger.io/specification/)) 规范为`RESTful API`定义了一个标准的并且与语言无关的接口,它允许人类和计算机在不访问源代码、文档或通过网络流量检查的情况下发现和理解服务的功能。若经良好定义,使调用者可以很容易的理解远程服务并与之交互, 并只需要很少的代码即可实现期望逻辑. - -`Poem-openapi`是基于`Poem`的 [OpenAPI](https://swagger.io/specification/) 服务端框架。 - -通常,如果你希望让你的 API 支持该规范,首先需要创建一个 [接口定义文件](https://swagger.io/specification/) ,然后再按照接口定义编写对应的代码。或者创建接口定义文件后,用 `Swagger CodeGen` 来生成服务端代码框架。但`Poem-openapi`区别于这两种方法,它让你只需要编写 Rust 的业务代码,利用过程宏来自动生成符合 OpenAPI 规范的接口和接口定义文件(这相当于接口的文档)。 diff --git a/docs/zh-CN/src/openapi/api.md b/docs/zh-CN/src/openapi/api.md deleted file mode 100644 index 159d8396..00000000 --- a/docs/zh-CN/src/openapi/api.md +++ /dev/null @@ -1,77 +0,0 @@ -# 定义API - -下面定义一组API对宠物表进行增删改查的操作。 - -一个方法代表一个API操作,必须使用`path`和`method`属性指定操作的路径和方法。 - -方法的参数可以有多个,可以使用以下类型: - - - **poem_openapi::param::Query** 表示参数来自查询字符串 - - - **poem_openapi::param::Header** 表示参数来自请求头 - - - **poem_openapi::param::Path** 表示参数来自请求路径 - - - **poem_openapi::param::Cookie** 表示参数来自Cookie - - - **poem_openapi::param::CookiePrivate** 表示参数来自加密的Cookie - - - **poem_openapi::param::CookieSigned** 表示参数来自签名后的Cookie - - - **poem_openapi::payload::Binary** 表示请求内容是二进制数据 - - - **poem_openapi::payload::Json** 表示请求内容用Json编码 - - - **poem_openapi::payload::PlainText** 表示请求内容是UTF8文本 - - - **ApiRequest** 使用`ApiRequest`宏生成的请求体 - - - **SecurityScheme** 使用`SecurityScheme`宏生成认证方法 - - - **T: FromRequest** 使用Poem的提取器 - -返回值可以是任意实现了`ApiResponse`的类型。 - -```rust -use poem_api::{ - OpenApi, - poem_api::payload::Json, - param::{Path, Query}, -}; -use poem::Result; - -struct Api; - -#[OpenApi] -impl Api { - /// 添加新Pet - #[oai(path = "/pet", method = "post")] - async fn add_pet(&self, pet: Json) -> Result<()> { - todo!() - } - - /// 更新已有的Pet - #[oai(path = "/pet", method = "put")] - async fn update_pet(&self, pet: Json) -> Result<()> { - todo!() - } - - /// 删除一个Pet - #[oai(path = "/pet/:id", method = "delete")] - async fn delete_pet(&self, id: Path) -> Result<()> { - todo!() - } - - /// 根据ID查询Pet - #[oai(path = "/pet", method = "get")] - async fn find_pet_by_id(&self, id: Query) -> Result> { - todo!() - } - - /// 根据状态查询Pet - #[oai(path = "/pet/findByStatus", method = "get")] - async fn find_pets_by_status(&self, status: Query) -> Result>> { - todo!() - } -} -``` diff --git a/docs/zh-CN/src/openapi/authentication.md b/docs/zh-CN/src/openapi/authentication.md deleted file mode 100644 index decfeacd..00000000 --- a/docs/zh-CN/src/openapi/authentication.md +++ /dev/null @@ -1,80 +0,0 @@ -# 认证 - -OpenApi规范定义了`apikey`,`basic`,`bearer`,`oauth2`,`openIdConnect`五种认证模式,它们描述了指定的`API`接口需要的认证参数。 - -下面的例子是用`Github`登录,并提供一个获取所有公共仓库信息的接口。 - -```rust -use poem_openapi::{ - SecurityScheme, SecurityScope, OpenApi, - auth::Bearer, -}; - -#[derive(OAuthScopes)] -enum GithubScope { - /// 可访问公共仓库信息。 - #[oai(rename = "public_repo")] - PublicRepo, - - /// 可访问用户的个人资料数据。 - #[oai(rename = "read:user")] - ReadUser, -} - -/// Github 认证 -#[derive(SecurityScheme)] -#[oai( - type = "oauth2", - flows(authorization_code( - authorization_url = "https://github.com/login/oauth/authorize", - token_url = "https://github.com/login/oauth/token", - scopes = "GithubScope", - )) -)] -struct GithubAuthorization(Bearer); - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/repo", method = "get")] - async fn repo_list( - &self, - #[oai(scope("GithubScope::PublicRepo"))] auth: GithubAuthorization, - ) -> Result> { - // 使用GithubAuthorization得到的token向Github获取所有公共仓库信息。 - todo!() - } -} -``` - -完整的代码请参考[例子](https://github.com/poem-web/poem/tree/master/examples/openapi/auth-github)。 - -## 检查认证信息 - -您可以使用`checker`属性指定一个检查器函数来检查原始认证信息和将其转换为该函数的返回类型。 此函数必须返回`Option`,如果检查失败则返回`None`。 - -```rust -struct User { - username: String, -} - -/// ApiKey 认证 -#[derive(SecurityScheme)] -#[oai( - type = "api_key", - key_name = "X-API-Key", - in = "header", - checker = "api_checker" -)] -struct MyApiKeyAuthorization(User); - -async fn api_checker(req: &Request, api_key: ApiKey) -> Option { - let connection = req.data::().unwrap(); - - // 在数据库中检查 - todo!() -} -``` - -完整的代码请参考[例子](https://github.com/poem-web/poem/tree/master/examples/openapi/auth-apikey). \ No newline at end of file diff --git a/docs/zh-CN/src/openapi/custom_request.md b/docs/zh-CN/src/openapi/custom_request.md deleted file mode 100644 index 8cc70b72..00000000 --- a/docs/zh-CN/src/openapi/custom_request.md +++ /dev/null @@ -1,49 +0,0 @@ -# 自定义请求 - -`OpenAPI`规范允许同一个接口支持处理不同`Content-Type`的请求,例如一个接口可以同时接受`application/json`和`text/plain`类型的Payload。 - -在`Poem-openapi`中,要支持此类型请求,需要用`ApiRequest`宏自定义一个实现了`Payload trait`的请求对象。 - -在下面的例子中,`create_post`函数接受`CreatePostRequest`请求,当创建成功后,返回`id`。 - -```rust -use poem_open::{ - ApiRequest, Object, - payload::{PlainText, Json}, -}; -use poem::Result; - -#[derive(Object)] -struct Post { - title: String, - content: String, -} - -#[derive(ApiRequest)] -enum CreatePostRequest { - /// 从JSON创建 - Json(Json), - /// 从文本创建 - Text(PlainText), -} - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/hello", method = "post")] - async fn create_post( - &self, - req: CreatePostRequest, - ) -> Result> { - match req { - CreatePostRequest::Json(Json(blog)) => { - todo!(); - } - CreatePostRequest::Text(content) => { - todo!(); - } - } - } -} -``` diff --git a/docs/zh-CN/src/openapi/custom_response.md b/docs/zh-CN/src/openapi/custom_response.md deleted file mode 100644 index 87fc7ab4..00000000 --- a/docs/zh-CN/src/openapi/custom_response.md +++ /dev/null @@ -1,89 +0,0 @@ -# 自定义响应 - -在前面的例子中,我们的所有请求处理函数都返回的`Result`类型,当发生错误时返回一个`poem::Error`,它包含错误的原因以及状态码。但`OpenAPI`规范允许更详细的描述请求的响应,例如该接口可能会返回哪些状态码,以及状态码对应的原因和响应的内容。 - -在下面的例子中,我们修改`create_post`函数的返回值为`CreateBlogResponse`类型。 - -`Ok`,`Forbidden`和`InternalError`描述了特定状态码的响应类型。 - -```rust -use poem_openapi::ApiResponse; -use poem::http::StatusCode; - -#[derive(ApiResponse)] -enum CreateBlogResponse { - /// 创建完成 - #[oai(status = 200)] - Ok(Json), - - /// 没有权限 - #[oai(status = 403)] - Forbidden, - - /// 内部错误 - #[oai(status = 500)] - InternalError, -} - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/hello", method = "get")] - async fn create_post( - &self, - req: CreatePostRequest, - ) -> CreateBlogResponse { - match req { - CreatePostRequest::Json(Json(blog)) => { - todo!(); - } - CreatePostRequest::Text(content) => { - todo!(); - } - } - } -} -``` - -当请求解析失败时,默认会返回`400 Bad Request`错误,但有时候我们想返回一个自定义的错误内容,可以使用`bad_request_handler`属性设置一个错误处理函数,这个函数用于转换`ParseRequestError`到指定的响应类型。 - -```rust -use poem_openapi::{ - ApiResponse, Object, ParseRequestError, payload::Json, -}; - -#[derive(Object)] -struct ErrorMessage { - code: i32, - reason: String, -} - -#[derive(ApiResponse)] -#[oai(bad_request_handler = "bad_request_handler")] -enum CreateBlogResponse { - /// 创建完成 - #[oai(status = 200)] - Ok(Json), - - /// 没有权限 - #[oai(status = 403)] - Forbidden, - - /// 内部错误 - #[oai(status = 500)] - InternalError, - - /// 请求无效 - #[oai(status = 400)] - BadRequest(Json), -} - -fn bad_request_handler(err: ParseRequestError) -> CreateBlogResponse { - // 当解析请求失败时,返回一个自定义的错误内容,它是一个JSON - CreateBlogResponse::BadRequest(Json(ErrorMessage { - code: -1, - reason: err.to_string(), - })) -} -``` diff --git a/docs/zh-CN/src/openapi/quickstart.md b/docs/zh-CN/src/openapi/quickstart.md deleted file mode 100644 index 20968f5e..00000000 --- a/docs/zh-CN/src/openapi/quickstart.md +++ /dev/null @@ -1,61 +0,0 @@ -# 快速开始 - -下面这个例子,我们定义了一个路径为`/hello`的API,它接受一个名为`name`的URL参数,并且返回一个字符串作为响应内容。`name`参数的类型是`Option`,意味着这是一个可选参数。 - -运行以下代码后,用浏览器打开`http://localhost:3000`就能看到`Swagger UI`,你可以用它来浏览API的定义并且测试它们。 - -```rust -use poem::{listener::TcpListener, Route}; -use poem_openapi::{payload::PlainText, OpenApi, OpenApiService}; - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/hello", method = "get")] - async fn index( - &self, - #[oai(name = "name", in = "query")] name: Option, // in="query" 说明这个参数来自Url - ) -> PlainText { // PlainText是响应类型,它表明该API的响应类型是一个字符串,Content-Type是`text/plain` - match name { - Some(name) => PlainText(format!("hello, {}!", name)), - None => PlainText("hello!".to_string()), - } - } -} - -#[tokio::main] -async fn main() -> Result<(), std::io::Error> { - // 创建一个TCP监听器 - let listener = TcpListener::bind("127.0.0.1:3000"); - - // 创建API服务 - let api_service = OpenApiService::new(Api, "Demo", "0.1.0") - .title("Hello World") - .server("http://localhost:3000/api"); - - // 创建Swagger UI端点 - let ui = api_service.swagger_ui(); - - // 创建OpenApi输出规范的端点 - let spec = api_service.spec_endpoint(); - - // 启动服务器,并指定api的根路径为 /api,Swagger UI的路径为 / - poem::Server::new(listener) - .await? - .run( - Route::new() - .at("/openapi.json", spec) - .nest("/api", api_service) - .nest("/", ui) - ) - .await -} -``` - -这是`poem-openapi`的一个例子,所以你也可以直接执行以下命令来验证: - -```shell -git clone https://github.com/poem-web/poem -cargo run --bin example-openapi-hello-world -``` diff --git a/docs/zh-CN/src/openapi/type_system.md b/docs/zh-CN/src/openapi/type_system.md deleted file mode 100644 index 7f926bdb..00000000 --- a/docs/zh-CN/src/openapi/type_system.md +++ /dev/null @@ -1,5 +0,0 @@ -# 类型系统 - - -Poem-openapi 实现了 OpenAPI 类型到 Rust 类型的转换,简单易用。 - diff --git a/docs/zh-CN/src/openapi/type_system/basic_types.md b/docs/zh-CN/src/openapi/type_system/basic_types.md deleted file mode 100644 index 6f7f2d56..00000000 --- a/docs/zh-CN/src/openapi/type_system/basic_types.md +++ /dev/null @@ -1,17 +0,0 @@ -# Basic types - -基础类型可以作为请求的参数,请求内容或者请求响应内容。`Poem`定义了一个`Type trait`,实现了该`trait`的类型都是基础类型,它们能在运行时提供一些关于该类型的信息用于生成接口定义文件。 - -`Poem`为大部分常用类型实现了`Type`trait,你可以直接使用它们,同样也可以自定义新的类型,但你需要对 [Json Schema](https://json-schema.org/) 有一定了解。 - -下表是 Open API 中的数据类型对应的Rust数据类型(只是一小部分): - -| Open API | Rust | -|-----------------------------------------|-----------------------------------| -| `{type: "integer", format: "int32" }` | i32 | -| `{type: "integer", format: "float32" }` | f32 | -| `{type: "bool" }` | bool | -| `{type: "string" }` | String, &str | -| `{type: "string", format: "binary" }` | Binary | -| `{type: "string", format: "bytes" }` | Base64 | -| `{type: "array" }` | Vec | diff --git a/docs/zh-CN/src/openapi/type_system/enum.md b/docs/zh-CN/src/openapi/type_system/enum.md deleted file mode 100644 index a5827d6a..00000000 --- a/docs/zh-CN/src/openapi/type_system/enum.md +++ /dev/null @@ -1,16 +0,0 @@ -# 枚举 - -使用过程宏 `Enum` 来定义枚举类型。 - -**Poem-openapi 会自动将每一项的名称改为`SCREAMING_SNAKE_CASE` 约定。 您可以使用 `rename_all` 属性来重命名所有项目。** - -```rust -use poem_api::Enum; - -#[derive(Enum)] -enum PetStatus { - Available, - Pending, - Sold, -} -``` diff --git a/docs/zh-CN/src/openapi/type_system/object.md b/docs/zh-CN/src/openapi/type_system/object.md deleted file mode 100644 index 0334a8df..00000000 --- a/docs/zh-CN/src/openapi/type_system/object.md +++ /dev/null @@ -1,28 +0,0 @@ -# 对象类型 - -用过程宏`Object`来定义一个对象,对象的成员必须是实现了`Type trait`的类型(除非你用`#[oai(skip)]`来标注它,那么序列化和反序列化时降忽略该字段用默认值代替)。 - -以下代码定义了一个对象类型,它包含四个字段,其中有一个字段是枚举类型。 - -_对象类型也是基础类型的一种,它同样实现了`Type trait`,所以它也可以作为另一个对象的成员。_ - -**Poem-openapi 会自动将每个成员的名称更改为 `camelCase` 约定。 你可以使用 `rename_all` 属性来重命名所有项。** - -```rust -use poem_api::{Object, Enum}; - -#[derive(Enum)] -enum PetStatus { - Available, - Pending, - Sold, -} - -#[derive(Object)] -struct Pet { - id: u64, - name: String, - photo_urls: Vec, - status: PetStatus, -} -``` diff --git a/docs/zh-CN/src/openapi/upload_files.md b/docs/zh-CN/src/openapi/upload_files.md deleted file mode 100644 index 67423095..00000000 --- a/docs/zh-CN/src/openapi/upload_files.md +++ /dev/null @@ -1,27 +0,0 @@ -# 文件上传 - -`Multipart`宏通常用于文件上传,它可以定义一个表单来包含一个或者多个文件以及一些附加字段。下面的例子提供一个创建`Pet`对象的接口,它在创建`Pet`对象的同时上传一些图片文件。 - -```rust -use poem_openapi::{Multipart, OpenApi}; -use poem::Result; - -#[derive(Debug, Multipart)] -struct CreatePetPayload { - name: String, - status: PetStatus, - photos: Vec, // 多个照片文件 -} - -struct Api; - -#[OpenApi] -impl Api { - #[oai(path = "/pet", method = "post")] - async fn create_pet(&self, payload: CreatePetPayload) -> Result> { - todo!() - } -} -``` - -完整的代码请参考[文件上传例子](https://github.com/poem-web/poem/tree/master/examples/openapi/upload`)。 diff --git a/docs/zh-CN/src/openapi/validators.md b/docs/zh-CN/src/openapi/validators.md deleted file mode 100644 index b3f2a66b..00000000 --- a/docs/zh-CN/src/openapi/validators.md +++ /dev/null @@ -1,24 +0,0 @@ -# 参数校验 - -`OpenAPI`引用了`Json Schema`的校验规范,`Poem-openapi`同样支持它们。你可以在请求的参数,对象的成员和`Multipart`的字段三个地方应用校验器。校验器是类型安全的,如果待校验的数据类型和校验器所需要的不匹配,那么将无法编译通过。例如`maximum`只能用于数值类型,`max_items`只能用于数组类型。 - -更多的校验器请参考[文档](https://docs.rs/poem-openapi/*/poem_openapi/attr.OpenApi.html#operation-argument-parameters)。 - -```rust -use poem_openapi::{Object, OpenApi, Multipart}; - -#[derive(Object)] -struct Pet { - id: u64, - - /// 名字长度不能超过32 - #[oai(validator(max_length = "32"))] - name: String, - - /// 数组长度不能超过3,并且url长度不能超过256 - #[oai(validator(max_items = "3", max_length = "256"))] - photo_urls: Vec, - - status: PetStatus, -} -``` diff --git a/docs/zh-CN/src/poem.md b/docs/zh-CN/src/poem.md deleted file mode 100644 index 7ef4b881..00000000 --- a/docs/zh-CN/src/poem.md +++ /dev/null @@ -1,3 +0,0 @@ -# Poem - -`Poem` 是一个功能齐全且易于使用的 Web 框架,采用 Rust 编程语言。 diff --git a/docs/zh-CN/src/poem/endpoint.md b/docs/zh-CN/src/poem/endpoint.md deleted file mode 100644 index 547c1a85..00000000 --- a/docs/zh-CN/src/poem/endpoint.md +++ /dev/null @@ -1,84 +0,0 @@ -# Endpoint - -Endpoint 可以处理 HTTP 请求。您可以实现`Endpoint` trait 来创建您自己的Endpoint。 -`Poem` 还提供了一些方便的功能来轻松创建自定义 Endpoint 类型。 - -在上一章中,我们学习了如何使用 `handler` 宏将函数转换为 Endpoint。 - -现在让我们看看如何通过实现 `Endpoint` trait 来创建自己的 Endpoint。 - -这是 `Endpoint` trait 的定义,你需要指定 `Output` 的类型并实现 `call` 方法。 - -```rust -/// 一个 HTTP 请求处理程序。 -#[async_trait] -pub trait Endpoint: Send + Sync + 'static { - /// 代表 endpoint 的响应。 - type Output: IntoResponse; - - /// 获取对请求的响应。 - async fn call(&self, req: Request) -> Self::Output; -} -``` - -现在我们实现一个 `Endpoint`,它接收 HTTP 请求并输出一个包含请求方法和路径的字符串。 - -`Output` 关联类型必须是实现 `IntoResponse` trait 的类型。Poem 已为大多数常用类型实现了它。 - -由于 `Endpoint` 包含一个异步方法 `call`,我们需要用 `async_trait` 宏来修饰它。 - -```rust -struct MyEndpoint; - -#[async_trait] -impl Endpoint for MyEndpoint { - type Output = String; - - async fn call(&self, req: Request) -> Self::Output { - format!("method={} path={}", req.method(), req.uri().path()); - } -} -``` - -## 从函数创建 - -你可以使用 `poem::endpoint::make` 和 `poem::endpoint::make_sync` 从异步函数和同步函数创建 Endpoint。 - -以下 Endpoint 执行相同的操作: - -```rust -let ep = poem::endpoint::make(|req| async move { - format!("method={} path={}", req.method(), req.uri().path()) -}); -``` - -## EndpointExt - -`EndpointExt` trait 提供了一些方便的函数来转换 Endpoint 的输入或输出。 - -- `EndpointExt::before` 用于转换请求。 -- `EndpointExt::after` 用于转换输出。 -- `EndpointExt::map_ok`、`EndpointExt::map_err`、`EndpointExt::and_then` 用于处理 `Result` 类型的输出。 - -## 使用 Result 类型 - -`Poem` 还为 `poem::Result` 类型实现了 `IntoResponse`,因此它也可以用作 Endpoint,因此你可以在 `call` 方法中使用 `?`。 - -```rust -struct MyEndpoint; - -#[async_trait] -impl Endpoint for MyEndpoint { - type Output = poem::Result; - - async fn call(&self, req: Request) -> Self::Output { - Ok(req.take_body().into_string().await?) - } -} -``` - -你可以使用 `EndpointExt::map_to_response` 方法将 Endpoint 的输出转换为 `Response` 类型,或者使用 `EndpointExt::map_to_result` 将输出转换为 `poem::Result` 类型。 - -```rust -let ep = MyEndpoint.map_to_response() // impl Endpoint -``` diff --git a/docs/zh-CN/src/poem/extractors.md b/docs/zh-CN/src/poem/extractors.md deleted file mode 100644 index 7e239b71..00000000 --- a/docs/zh-CN/src/poem/extractors.md +++ /dev/null @@ -1,213 +0,0 @@ -# 提取器 - -提取器用于从 HTTP 请求中提取某些内容。 - -`Poem` 提供了一些常用的提取器来从 HTTP 请求中提取一些东西。 - -你可以使用一个或多个提取器作为函数的参数,最多 16 个。 - -在下面的例子中,`index` 函数使用 3 个提取器来提取远程地址、HTTP 方法和 URI。 - -```rust -#[handler] -fn index(remote_addr: SocketAddr, method: Method, uri: &Uri) {} -``` - -# 内置提取器 - - - **Option<T>** - - 从传入的请求中提取 `T`,如果失败就返回 `None`。 - - - **&Request** - - 从传入的请求中提取 `Request`. - - - **&RemoteAddr** - - 从请求中提取远端对等地址 [`RemoteAddr`]。 - - - **&LocalAddr** - - 从请求中提取本地服务器的地址 [`LocalAddr`]。 - - - **Method** - - 从传入的请求中提取 `Method`。 - - - **Version** - - 从传入的请求中提取 `Version`。 - - - **&Uri** - - 从传入的请求中提取 `Uri`。 - - - **&HeaderMap** - - 从传入的请求中提取 `HeaderMap`。 - - - **Data<&T>** - - 从传入的请求中提取 `Data` 。 - - - **TypedHeader<T>** - - 从传入的请求中提取 `TypedHeader`。 - - - **Path<T>** - - 从传入的请求中提取 `Path`。 - - - **Query<T>** - - 从传入的请求中提取 `Query`。 - - - **Form<T>** - - 从传入的请求中提取 `Form`。 - - - **Json<T>** - - 从传入的请求中提取 `Json` 。 - - _这个提取器将接管请求的主体,所以你应该避免在一个处理程序中使用多个这种类型的提取器。_ - - - **TempFile** - - 从传入的请求中提取 `TempFile`。 - - _这个提取器将接管请求的主体,所以你应该避免在一个处理程序中使用多个这种类型的提取器。_ - - - **Multipart** - - 从传入的请求中提取 `Multipart`。 - - _这个提取器将接管请求的主体,所以你应该避免在一个处理程序中使用多个这种类型的提取器。_ - - - **&CookieJar** - - 从传入的请求中提取 `CookieJar`](cookie::CookieJar)。 - - _需要 `CookieJarManager` 中间件。_ - - - **&Session** - - 从传入的请求中提取 [`Session`](crate::session::Session)。 - - _需要 `CookieSession` 或 `RedisSession` 中间件。_ - - - **Body** - - 从传入的请求中提取 `Body`。 - - _这个提取器将接管请求的主体,所以你应该避免在一个处理程序中使用多个这种类型的提取器。_ - - - **String** - - 从传入的请求中提取 body 并将其解析为 utf8 字符串。 - - _这个提取器将接管请求的主体,所以你应该避免在一个处理程序中使用多个这种类型的提取器。_ - - - **Vec<u8>** - - 从传入的请求中提取 body 并将其收集到 `Vec`. - - _这个提取器将接管请求的主体,所以你应该避免在一个处理程序中使用多个这种类型的提取器。_ - - - **Bytes** - - 从传入的请求中提取 body 并将其收集到 `Bytes`. - - _这个提取器将接管请求的主体,所以你应该避免在一个处理程序中使用多个这种类型的提取器。_ - - - **WebSocket** - - 准备接受 websocket 连接。 - -## 处理提取器错误 - -默认情况下,当发生错误时,提取器会返回`400 Bad Request`,但有时您可能想要更改这种行为,因此您可以自己处理错误。 - -在下面的例子中,当 `Query` 提取器失败时,它将返回一个 `500 Internal Server Error` 响应以及错误原因。 - -```rust -use poem::web::Query; -use poem::error::ParseQueryError; -use poem::{IntoResponse, Response}; -use poem::http::StatusCode; - -#[derive(Debug, Deserialize)] -struct Params { - name: String, -} - -#[handler] -fn index(res: Result, ParseQueryError>) -> Response { - match res { - Ok(Query(params)) => params.name.into_response(), - Err(err) => Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(err.to_string()), - } -} -``` - -## 自定义提取器 - -您还可以实现自己的提取器。 - -以下是自定义 token 提取器的示例,它提取来自 `MyToken` 标头的 token。 - -```rust -use poem::{ - get, handler, http::StatusCode, listener::TcpListener, FromRequest, Request, - RequestBody, Response, Route, Server, -}; - -struct Token(String); - -// Token 提取器的错误类型 -#[derive(Debug)] -struct MissingToken; - -/// 自定义错误也可以重用 -impl IntoResponse for MissingToken { - fn into_response(self) -> Response { - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("missing token") - } -} - -// 实现一个 token 提取器 -#[poem::async_trait] -impl<'a> FromRequest<'a> for Token { - type Error = MissingToken; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { - let token = req - .headers() - .get("MyToken") - .and_then(|value| value.to_str().ok()) - .ok_or(MissingToken)?; - Ok(Token(token.to_string())) - } -} - -#[handler] -async fn index(token: Token) { - assert_eq!(token.0, "token123"); -} - -#[tokio::main] -async fn main() -> Result<(), std::io::Error> { - if std::env::var_os("RUST_LOG").is_none() { - std::env::set_var("RUST_LOG", "poem=debug"); - } - tracing_subscriber::fmt::init(); - - let app = Route::new().at("/", get(index)); - let listener = TcpListener::bind("127.0.0.1:3000"); - let server = Server::new(listener).await?; - server.run(app).await -} -``` \ No newline at end of file diff --git a/docs/zh-CN/src/poem/handling_errors.md b/docs/zh-CN/src/poem/handling_errors.md deleted file mode 100644 index c4a9a60e..00000000 --- a/docs/zh-CN/src/poem/handling_errors.md +++ /dev/null @@ -1,114 +0,0 @@ -# 处理错误 - -在 `Poem` 中,我们根据响应状态代码处理错误。当状态码在`400-599`时,我们可以认为处理此请求时出错。 - -我们可以使用 `EndpointExt::after` 创建一个新的 Endpoint 类型来自定义错误响应。 - -在下面的例子中,`after`函数用于转换`index`函数的输出,并在发生服务器错误时输出错误响应。 - -**注意`handler`宏生成的 Endpoint 类型总是`Endpoint`,即使它返回一个 `Result`.** - -```rust -use poem::{handler, Result, Error}; -use poem::http::StatusCode; - -#[handler] -async fn index() -> Result<()> { - Err(Error::new(StatusCode::BAD_REQUEST)) -} - -let ep = index.after(|resp| { - if resp.status().is_server_error() { - Response::builder() - .status(resp.status()) - .body("custom error") - } else { - resp - } -}); -``` - -`EndpointExt::map_to_result` 函数可以帮助我们将任何类型的 Endpoint 转换为 `Endpoint`,所以我们只需要检查状态码就知道是否发生了错误。 - -```rust -use poem::endpoint::make; -use poem::{Error, EndpointExt}; -use poem::http::StatusCode; - -let ep = make(|_| Ok::<(), Error>(Error::new(StatusCode::new(Status::BAD_REQUEST)))) - .map_to_response(); - -let ep = ep.after(|resp| { - if resp.status().is_server_error() { - Response::builder() - .status(resp.status()) - .body("custom error") - } else { - resp - } -}); -``` - -## poem::Error - -`poem::Error` 是一个通用的错误类型,它实现了 `From`,所以你可以很容易地使用 `?` 运算符来将任何错误类型转换为它。默认状态代码是`503 Internal Server Error`。 - -```rust -use poem::Result; - -#[handler] -fn index(data: Vec) -> Result { - let value: i32 = serde_json::from_slice(&data)?; - Ok(value) -} -``` - -但是有时候我们不想总是使用 `503` 状态码,`Poem` 提供了一些辅助函数来转换错误类型。 - -```rust -use poem::{Result, web::Json, error::BadRequest}; - -#[handler] -fn index(data: Vec) -> Result> { - let value: i32 = serde_json::from_slice(&data).map_err(BadRequest)?; - Ok(Json(value)) -} -``` - -## 自定义错误类型 - -有时我们可以使用自定义错误类型来减少重复的代码。 - -注意:`Poem` 的错误类型通常只需要实现 `IntoResponse`。 - -```rust -use poem::{ - Response, - error::ReadBodyError, - http::StatusCode, -}; - -enum MyError { - InvalidValue, - ReadBodyError(ReadBodyError), -} - -impl IntoResponse for MyError { - fn into_response(self) -> Response { - match self { - MyError::InvalidValue => Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("invalid value"), - MyError::ReadBodyError(err) => err.into(), // ReadBodyError 已经实现了 `IntoResponse`. - } - } -} - -#[handler] -fn index(data: Result) -> Result<(), MyError> { - let data = data?; - if data.len() > 10 { - return Err(MyError::InvalidValue); - } -} -``` diff --git a/docs/zh-CN/src/poem/listeners.md b/docs/zh-CN/src/poem/listeners.md deleted file mode 100644 index 836eb202..00000000 --- a/docs/zh-CN/src/poem/listeners.md +++ /dev/null @@ -1,56 +0,0 @@ -# 监听器 - -`Poem` 提供了一些常用的监听器。 - -- TcpListener - - 侦听传入的 TCP 连接。 - -- UnixListener - - 侦听传入的 Unix 域套接字连接。 - -## TLS - -你可以调用`Listener::tls` 函数来包装一个侦听器并使其支持TLS 连接。 - -```rust -let listener = TcpListener::bind("127.0.0.1:3000") - .tls(TlsConfig::new().key(KEY).cert(CERT)); -``` - -## TLS配置热加载 - -你可以使用异步流将最新的 Tls 配置传递给 `Poem`。 - -以下示例每 1 分钟从文件中加载最新的 TLS 配置: - -```rust -use async_trait::async_trait; - -fn load_tls_config() -> Result { - Ok(TlsConfig::new() - .cert(std::fs::read("cert.pem")?) - .key(std::fs::read("key.pem")?)) -} - -let listener = TcpListener::bind("127.0.0.1:3000") - .tls(async_stream::stream! { - loop { - if let Ok(tls_config) = load_tls_config() { - yield tls_config; - } - tokio::time::sleep(Duration::from_secs(60)).await; - } - }); -``` - -## 组合多个监听器。 - -调用`Listener::combine`将两个监听器合二为一,也可以多次调用该函数来合并更多的监听器。 - -```rust -let listener = TcpListener::bind("127.0.0.1:3000") - .combine(TcpListener::bind("127.0.0.1:3001")) - .combine(TcpListener::bind("127.0.0.1:3002")); -``` \ No newline at end of file diff --git a/docs/zh-CN/src/poem/middleware.md b/docs/zh-CN/src/poem/middleware.md deleted file mode 100644 index 7d003ed2..00000000 --- a/docs/zh-CN/src/poem/middleware.md +++ /dev/null @@ -1,113 +0,0 @@ -# 中间件 - -中间件可以在处理请求之前或之后做一些事情。 - -`Poem` 提供了一些常用的中间件实现。 - -- `AddData` - - 用于将状态附加到请求,例如用于身份验证的 token。 - -- `SetHeader` - - 用于向响应添加一些特定的 HTTP 标头。 - -- `Cors` - - 用于 CORS 跨域资源共享。 - -- `Tracing` - - 使用 [`tracing`](https://crates.io/crates/tracing) 记录所有请求和响应。 - -- `Compression` - - 用于解压请求体和压缩响应体。 - -## 自定义中间件 - -实现你自己的中间件很容易,你只需要实现 `Middleware` trait,它是一个转换器 -将输入 Endpoint 转换为另一个 Endpoint。 - -以下示例创建一个自定义中间件,该中间件读取名为“X-Token”的 HTTP 请求标头的值和将其添加为请求的状态。 - -```rust -use poem::{handler, web::Data, Endpoint, EndpointExt, Middleware, Request}; - -/// 从 HTTP 标头中提取 token 的中间件。 -struct TokenMiddleware; - -impl Middleware for TokenMiddleware { - type Output = TokenMiddlewareImpl; - - fn transform(&self, ep: E) -> Self::Output { - TokenMiddlewareImpl { ep } - } -} - -/// TokenMiddleware 生成的新 Endpoint 类型。 -struct TokenMiddlewareImpl { - ep: E, -} - -const TOKEN_HEADER: &str = "X-Token"; - -/// Token 数据 -struct Token(String); - -#[poem::async_trait] -impl Endpoint for TokenMiddlewareImpl { - type Output = E::Output; - - async fn call(&self, mut req: Request) -> Self::Output { - if let Some(value) = req - .headers() - .get(TOKEN_HEADER) - .and_then(|value| value.to_str().ok()) - { - // 将 token 数据插入到请求的扩展中。 - let token = value.to_string(); - req.extensions_mut().insert(Token(token)); - } - - // 调用内部 endpoint。 - self.ep.call(req).await - } -} - -#[handler] -async fn index(Data(token): Data<&Token>) -> String { - token.0.clone() -} - -// 使用 `TokenMiddleware` 中间件转换 `index` endpoint。 -let ep = index.with(TokenMiddleware); -``` - -## 带函数的自定义中间件 - -您还可以使用函数来实现中间件。 - -```rust -async fn extract_token(next: E, mut req: Request) -> Response { - if let Some(value) = req - .headers() - .get(TOKEN_HEADER) - .and_then(|value| value.to_str().ok()) - { - // 将 token 数据插入到请求的扩展中。 - let token = value.to_string(); - req.extensions_mut().insert(Token(token)); - } - - // 调用下一个 endpoint。 - next.call(req).await -} - -#[handler] -async fn index(Data(token): Data<&Token>) -> String { - token.0.clone() -} - -let ep = index.around(extract_token); -``` diff --git a/docs/zh-CN/src/poem/protocols.md b/docs/zh-CN/src/poem/protocols.md deleted file mode 100644 index 40e50412..00000000 --- a/docs/zh-CN/src/poem/protocols.md +++ /dev/null @@ -1 +0,0 @@ -# 协议 diff --git a/docs/zh-CN/src/poem/protocols/sse.md b/docs/zh-CN/src/poem/protocols/sse.md deleted file mode 100644 index 044d0ef1..00000000 --- a/docs/zh-CN/src/poem/protocols/sse.md +++ /dev/null @@ -1,28 +0,0 @@ -# 服务器发送的事件 (SSE) - -SSE 允许服务器不断地向客户端推送数据。 - -你需要使用实现 `Stream` 的类型创建一个 `SSE` 响应。 - -下面示例中的端点将发送三个事件。 - -```rust -use futures_util::stream; -use poem::{ - handler, Route, get, - http::StatusCode, - web::sse::{Event, SSE}, - Endpoint, Request, -}; - -#[handler] -fn index() -> SSE { - SSE::new(stream::iter(vec![ - Event::message("a"), - Event::message("b"), - Event::message("c"), - ])) -} - -let app = Route::new().at("/", get(index)); -``` diff --git a/docs/zh-CN/src/poem/protocols/websocket.md b/docs/zh-CN/src/poem/protocols/websocket.md deleted file mode 100644 index 81fb68aa..00000000 --- a/docs/zh-CN/src/poem/protocols/websocket.md +++ /dev/null @@ -1,31 +0,0 @@ -# Websocket - -Websocket 允许在客户端和服务器之间进行双向通信的长连接。 - -`Poem` 提供了一个 `WebSocket` 提取器来创建这个连接。 - -当连接升级成功时,调用指定的闭包来发送和接收数据。 - -下面的例子是一个回显服务,它总是发送接收到的数据。 - -**注意这个 Endpoint 的输出必须是`WebSocket::on_upgrade`函数的返回值,否则无法正确创建连接。** - -```rust -use futures_util::{SinkExt, StreamExt}; -use poem::{ - handler, Route, get, - web::websocket::{Message, WebSocket}, - IntoResponse, -}; - -#[handler] -async fn index(ws: WebSocket) -> impl IntoResponse { - ws.on_upgrade(|mut socket| async move { - if let Some(Ok(Message::Text(text))) = socket.next().await { - let _ = socket.send(Message::Text(text)).await; - } - }) -} - -let app = Route::new().at("/", get(index)); -``` \ No newline at end of file diff --git a/docs/zh-CN/src/poem/quickstart.md b/docs/zh-CN/src/poem/quickstart.md deleted file mode 100644 index 5c89bcff..00000000 --- a/docs/zh-CN/src/poem/quickstart.md +++ /dev/null @@ -1,60 +0,0 @@ -# 快速开始 - -## 添加依赖库 - -```toml -[dependencies] -poem = "1.0" -serde = "1.0" -tokio = { version = "1.12.0", features = ["rt-multi-thread", "macros"] } -``` - -## 实现一个Endpoint - -`handler` 宏将函数转换为实现了 `Endpoint` 的类型,`Endpoint` trait 表示一种可以处理 HTTP 请求的类型。 - -这个函数可以接收一个或多个参数,每个参数都是一个提取器,可以从 HTTP 请求中提取你想要的信息。 - -提取器实现了 `FromRequest` trait,你也可以实现这个 trait 来创建你自己的提取器。 - -函数的返回值必须是实现了 `IntoResponse` trait 的类型。它可以通过 `IntoResponse::into_response` 方法将自己转化为一个 HTTP 响应。 - -下面的函数有一个提取器,它从 uri 请求的 query 中提取 `name` 和 `value` 参数并返回一个 `String`,该字符串将被转换为 HTTP 响应。 - -```rust -use serde::Deserialize; -use poem::{handler, listener::TcpListener, web::Query, Server}; - -#[derive(Deserialize)] -struct Params { - name: String, - value: i32, -} - -#[handler] -async fn index(Query(Params { name, value }): Query) -> String { - format!("{}={}", name, value) -} -``` - -## HTTP 服务器 - -让我们启动一个服务器,它监听 `127.0.0.1:3000`,请忽略这些 `unwrap` 调用,这只是一个例子。 - -`Server::run` 函数接受任何实现了 `Endpoint` Trait 的类型。在这个例子中,我们没有路由对象,因此任何请求路径都将由 `index` 函数处理。 - -```rust - -#[tokio::main] -async fn main() { - let listener = TcpListener::bind("127.0.0.1:3000"); - Server::new(listener).run(index).await.unwrap(); -} -``` - -这样,一个简单的例子就实现了,我们可以运行它,然后使用 `curl` 做一些测试。 - -```shell -> curl http://localhost:3000?name=a&value=10 -name=10 -``` diff --git a/docs/zh-CN/src/poem/responses.md b/docs/zh-CN/src/poem/responses.md deleted file mode 100644 index 8405d780..00000000 --- a/docs/zh-CN/src/poem/responses.md +++ /dev/null @@ -1,109 +0,0 @@ -# 响应 - -所有可以转换为 HTTP 响应 `Response` 的类型都应该实现 `IntoResponse`,它们可以用作处理函数的返回值。 - -在下面的例子中,`string_response` 和 `status_response` 函数返回 `String` 和 `StatusCode`类型,因为 `Poem` 已经为它们实现了 `IntoResponse` 功能。 - -`no_response` 函数不返回值。我们也可以认为它的返回类型是`()`,`Poem`也为 `()` 实现 `IntoResponse`,它总是转换为 `200 OK`。 - -```rust -use poem::handler; -use poem::http::StatusCode; - -#[handler] -fn string_response() -> String { - "hello".to_string() -} - -#[handler] -fn status_response() -> StatusCode {} - -#[handler] -fn no_response() {} - -``` - -# 内置响应类型 - -- **Result<T: IntoResponse, E: IntoResponse>** - - 如果结果为`Ok`, 则用`Ok`的值作为响应, 否则使用`Err`的值。 - -- **()** - - 将状态设置为`OK`,body 为空。 - -- **&'static str** - - 将状态设置为`OK`,将`Content-Type`设置为`text/plain`。字符串用作 body。 - -- **String** - - 将状态设置为`OK`,将`Content-Type`设置为`text/plain`。字符串用作 body。 - -- **&'static [u8]** - - 将状态设置为 `OK`,将 `Content-Type` 设置为 `application/octet-stream`。切片用作响应的 body。 - -- **Html<T>** - - 将状态设置为 `OK`,将 `Content-Type` 设置为 `text/html`. `T` 用作响应的 body。 - -- **Json<T>** - - 将状态设置为 `OK` ,将 `Content-Type` 设置为 `application/json`. 使用 [`serde_json`](https://crates.io/crates/serde_json) 将 `T` 序列化为 json 字符串。 - -- **Bytes** - - 将状态设置为 `OK` ,将 `Content-Type` 设置为 `application/octet-stream`。字节串用作响应的 body。 - -- **Vec<u8>** - - 将状态设置为 `OK` ,将 `Content-Type` 设置为 -`application/octet-stream`. vector 的数据用作 body。 - -- **Body** - - 将状态设置为 `OK` 并使用指定的 body。 - -- **StatusCode** - - 将状态设置为指定的状态代码 `StatusCode` ,body 为空。 - -- **(StatusCode, T)** - - 将 `T` 转换为响应并设置指定的状态代码 `StatusCode`。 - -- **(StatusCode, HeaderMap, T)** - - 将 `T` 转换为响应并设置指定的状态代码 `StatusCode`,然后合并指定的`HeaderMap`。 - -- **Response** - - `Response` 的实现者总是返回自身。 - -- **Compress<T>** - - 调用 `T::into_response` 获取响应,然后使用指定的算法压缩响应 body ,并设置正确的 `Content-Encoding`标头。 - -- **SSE** - - 将状态设置为 `OK` ,将 `Content-Type` 设置为 `text/event-stream`,并带有事件流 body。使用 `SSE::new` 函数来创建它。 - -## 自定义响应 - -在下面的示例中,我们包装了一个名为 `PDF` 的响应,它向响应添加了一个 `Content-Type: applicationn/pdf` 标头。 - -```rust -use poem::{IntoResponse, Response}; - -struct PDF(Vec); - -impl IntoResponse for PDF { - fn into_response(self) -> Response { - Response::builder() - .header("Content-Type", "application/pdf") - .body(self.0) - } -} -``` diff --git a/docs/zh-CN/src/poem/routing.md b/docs/zh-CN/src/poem/routing.md deleted file mode 100644 index eff7ed2b..00000000 --- a/docs/zh-CN/src/poem/routing.md +++ /dev/null @@ -1,76 +0,0 @@ -# 路由 - -路由对象用于将指定路径和方法的请求分派到指定 Endpoint。 - -路由对象实际上是一个 Endpoint,它实现了 `Endpoint` trait。 - -在下面的例子中,我们将 `/a` 和 `/b` 的请求分派到不同的 Endpoint。 - -```rust -use poem::{handler, Route}; - -#[handler] -async fn a() -> &'static str { "a" } - -#[handler] -async fn b() -> &'static str { "b" } - -let ep = Route::new() - .at("/a", a) - .at("/b", b); -``` - -## 捕获变量 - -使用`:NAME`捕获路径中指定段的值,或者使用`*NAME`捕获路径中的所有指定前缀的值。 - -在下面的示例中,捕获的值将存储在变量 `value` 中,你可以使用路径提取器来获取它们。 - -```rust -#[handler] -async fn a(Path(String): Path) {} - -let ep = Route::new() - .at("/a/:value/b", handler) - .at("/prefix/*value", handler); -``` - -## 正则表达式 - -可以使用正则表达式进行匹配,`` 或`:NAME`,第二个可以将匹配的值捕获到一个变量中。 - -```rust -let ep = Route::new() - .at("/a/<\\d+>", handler) - .at("/b/:value<\\d+>", handler); -``` - -## 嵌套 - -有时我们想为指定的 Endpoint 分配一个带有指定前缀的路径,以便创建一些功能独立的组件。 - -在下面的例子中,`hello` Endpoint 的请求路径是 `/api/hello`。 - -```rust -let api = Route::new().at("/hello", hello); -let ep = api.nest("/api", api); -``` - -静态文件服务就是这样一个独立的组件。 - -```rust -let ep = Route::new().nest("/files", Files::new("./static_files")); -``` - -## 方法路由 - -上面介绍的路由对象只能通过一些指定的路径进行调度,但是通过路径和方法进行调度更常见。 `Poem` 提供了另一个路由对象 `RouteMethod`,当它与 `Route` 对象结合时,它可以提供这种能力。 - -`Poem` 提供了一些方便的函数来创建 `RouteMethod` 对象,它们都以 HTTP 标准方法命名。 - -```rust -use poem::{Route, get, post}; - -let ep = Route::new() - .at("/users", get(get_user).post(create_user).delete(delete_user).put(update_user)); -``` diff --git a/examples/openapi/auth-basic/src/main.rs b/examples/openapi/auth-basic/src/main.rs index bb9cca83..7b43f2ae 100644 --- a/examples/openapi/auth-basic/src/main.rs +++ b/examples/openapi/auth-basic/src/main.rs @@ -16,7 +16,7 @@ impl Api { #[oai(path = "/basic", method = "get")] async fn auth_basic(&self, auth: MyBasicAuthorization) -> Result> { if auth.0.username != "test" || auth.0.password != "123456" { - return Err(Error::new(StatusCode::UNAUTHORIZED)); + return Err(Error::new_with_status(StatusCode::UNAUTHORIZED)); } Ok(PlainText(format!("hello: {}", auth.0.username))) } diff --git a/examples/poem/basic-auth/src/main.rs b/examples/poem/basic-auth/src/main.rs index b45060ff..4f6c3375 100644 --- a/examples/poem/basic-auth/src/main.rs +++ b/examples/poem/basic-auth/src/main.rs @@ -34,15 +34,15 @@ struct BasicAuthEndpoint { #[poem::async_trait] impl Endpoint for BasicAuthEndpoint { - type Output = Result; + type Output = E::Output; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { if let Some(auth) = req.headers().typed_get::>() { if auth.0.username() == self.username && auth.0.password() == self.password { - return Ok(self.ep.call(req).await); + return self.ep.call(req).await; } } - Err(Error::new(StatusCode::UNAUTHORIZED)) + Err(Error::new_with_status(StatusCode::UNAUTHORIZED)) } } diff --git a/examples/poem/csrf/src/main.rs b/examples/poem/csrf/src/main.rs index d1adfdfc..d6cf2ab6 100644 --- a/examples/poem/csrf/src/main.rs +++ b/examples/poem/csrf/src/main.rs @@ -44,7 +44,7 @@ struct LoginRequest { #[handler] async fn login(verifier: &CsrfVerifier, Form(req): Form) -> Result { if !verifier.is_valid(&req.csrf_token) { - return Err(Error::new(StatusCode::UNAUTHORIZED).with_reason("unauthorized")); + return Err(Error::new_with_status(StatusCode::UNAUTHORIZED)); } Ok(format!("login success: {}", req.username)) diff --git a/examples/poem/custom-error/Cargo.toml b/examples/poem/custom-error/Cargo.toml index 3a1f0810..11ffabdc 100644 --- a/examples/poem/custom-error/Cargo.toml +++ b/examples/poem/custom-error/Cargo.toml @@ -7,5 +7,5 @@ publish = false [dependencies] poem = { path = "../../../poem" } tokio = { version = "1.12.0", features = ["rt-multi-thread", "macros"] } -serde = { version = "1.0.130", features = ["derive"] } tracing-subscriber = "0.2.24" +thiserror = "1.0.30" diff --git a/examples/poem/custom-error/src/main.rs b/examples/poem/custom-error/src/main.rs index 42e2bf3e..269c1366 100644 --- a/examples/poem/custom-error/src/main.rs +++ b/examples/poem/custom-error/src/main.rs @@ -1,27 +1,26 @@ use poem::{ - get, handler, http::StatusCode, listener::TcpListener, web::Json, IntoResponse, Response, - Route, Server, + error::ResponseError, get, handler, http::StatusCode, listener::TcpListener, Result, Route, + Server, }; -use serde::Serialize; -#[derive(Serialize)] +#[derive(Debug, thiserror::Error)] +#[error("{message}")] struct CustomError { message: String, } -impl IntoResponse for CustomError { - fn into_response(self) -> Response { - Json(self) - .with_status(StatusCode::BAD_REQUEST) - .into_response() +impl ResponseError for CustomError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } #[handler] -fn hello() -> Result { +fn hello() -> Result { Err(CustomError { message: "custom error".to_string(), - }) + } + .into()) } #[tokio::main] diff --git a/examples/poem/custom-extractor/src/main.rs b/examples/poem/custom-extractor/src/main.rs index 37bba23f..2278e966 100644 --- a/examples/poem/custom-extractor/src/main.rs +++ b/examples/poem/custom-extractor/src/main.rs @@ -1,34 +1,21 @@ use poem::{ - get, handler, http::StatusCode, listener::TcpListener, FromRequest, IntoResponse, Request, - RequestBody, Response, Route, Server, + get, handler, http::StatusCode, listener::TcpListener, Error, FromRequest, Request, + RequestBody, Result, Route, Server, }; struct Token(String); -// Error type for Token extractor -#[derive(Debug)] -struct MissingToken; - -/// custom-error can also be reused -impl IntoResponse for MissingToken { - fn into_response(self) -> Response { - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("missing token") - } -} - // Implements a token extractor #[poem::async_trait] impl<'a> FromRequest<'a> for Token { - type Error = MissingToken; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { let token = req .headers() .get("MyToken") .and_then(|value| value.to_str().ok()) - .ok_or(MissingToken)?; + .ok_or_else(|| { + Error::new_with_string("missing token").with_status(StatusCode::BAD_REQUEST) + })?; Ok(Token(token.to_string())) } } diff --git a/examples/poem/handling-404/src/main.rs b/examples/poem/handling-404/src/main.rs index d203ee38..2bb73593 100644 --- a/examples/poem/handling-404/src/main.rs +++ b/examples/poem/handling-404/src/main.rs @@ -1,6 +1,6 @@ use poem::{ - get, handler, http::StatusCode, listener::TcpListener, web::Path, EndpointExt, Response, Route, - Server, + error::NotFoundError, get, handler, http::StatusCode, listener::TcpListener, web::Path, + EndpointExt, Response, Route, Server, }; #[handler] @@ -15,17 +15,14 @@ async fn main() -> Result<(), std::io::Error> { } tracing_subscriber::fmt::init(); - let app = Route::new() - .at("/hello/:name", get(hello)) - .after(|resp| async move { - if resp.status() == StatusCode::NOT_FOUND { + let app = + Route::new() + .at("/hello/:name", get(hello)) + .catch_error(|_: NotFoundError| async move { Response::builder() .status(StatusCode::NOT_FOUND) .body("haha") - } else { - resp - } - }); + }); Server::new(TcpListener::bind("127.0.0.1:3000")) .run(app) diff --git a/examples/poem/middleware/src/main.rs b/examples/poem/middleware/src/main.rs index 0a3bff8c..55feb4a9 100644 --- a/examples/poem/middleware/src/main.rs +++ b/examples/poem/middleware/src/main.rs @@ -1,6 +1,6 @@ use poem::{ async_trait, get, handler, listener::TcpListener, Endpoint, EndpointExt, IntoResponse, - Middleware, Request, Response, Route, Server, + Middleware, Request, Response, Result, Route, Server, }; struct Log; @@ -19,15 +19,21 @@ struct LogImpl(E); impl Endpoint for LogImpl { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { println!("request: {}", req.uri().path()); - let resp = self.0.call(req).await.into_response(); - if resp.status().is_success() { - println!("response: {}", resp.status()); - } else { - println!("error: {}", resp.status()); + let res = self.0.call(req).await; + + match res { + Ok(resp) => { + let resp = resp.into_response(); + println!("response: {}", resp.status()); + Ok(resp) + } + Err(err) => { + println!("error: {}", err); + Err(err) + } } - resp } } diff --git a/examples/poem/middleware_fn/src/main.rs b/examples/poem/middleware_fn/src/main.rs index cbed12d0..8d01cb04 100644 --- a/examples/poem/middleware_fn/src/main.rs +++ b/examples/poem/middleware_fn/src/main.rs @@ -1,6 +1,6 @@ use poem::{ get, handler, listener::TcpListener, Endpoint, EndpointExt, IntoResponse, Request, Response, - Route, Server, + Result, Route, Server, }; #[handler] @@ -8,15 +8,21 @@ fn index() -> String { "hello".to_string() } -async fn log(next: E, req: Request) -> Response { +async fn log(next: E, req: Request) -> Result { println!("request: {}", req.uri().path()); - let resp = next.call(req).await.into_response(); - if resp.status().is_success() { - println!("response: {}", resp.status()); - } else { - println!("error: {}", resp.status()); + let res = next.call(req).await; + + match res { + Ok(resp) => { + let resp = resp.into_response(); + println!("response: {}", resp.status()); + Ok(resp) + } + Err(err) => { + println!("error: {}", err); + Err(err) + } } - resp } #[tokio::main] diff --git a/docs/assets/favicon.ico b/favicon.ico similarity index 100% rename from docs/assets/favicon.ico rename to favicon.ico diff --git a/docs/assets/logo.png b/logo.png similarity index 100% rename from docs/assets/logo.png rename to logo.png diff --git a/poem-dbsession/src/lib.rs b/poem-dbsession/src/lib.rs index d1869fa5..11e48719 100644 --- a/poem-dbsession/src/lib.rs +++ b/poem-dbsession/src/lib.rs @@ -32,8 +32,8 @@ //! let route = Route::new().at("/", index).with(ServerSession::new(CookieConfig::new(),storage)); //! ``` -#![doc(html_favicon_url = "https://poem.rs/assets/favicon.ico")] -#![doc(html_logo_url = "https://poem.rs/en/assets/logo.png")] +#![doc(html_favicon_url = "https://raw.githubusercontent.com/poem-web/poem/master/favicon.ico")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/poem-web/poem/master/logo.png")] #![forbid(unsafe_code)] #![deny(private_in_public, unreachable_pub)] #![cfg_attr(docsrs, feature(doc_cfg))] diff --git a/poem-dbsession/src/sqlx/mysql.rs b/poem-dbsession/src/sqlx/mysql.rs index d1d0d9e7..7d84d41e 100644 --- a/poem-dbsession/src/sqlx/mysql.rs +++ b/poem-dbsession/src/sqlx/mysql.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, time::Duration}; use chrono::Utc; -use poem::{session::SessionStorage, Result}; +use poem::{error::InternalServerError, session::SessionStorage, Result}; use sqlx::{mysql::MySqlStatement, types::Json, Executor, MySqlPool, Statement}; use crate::DatabaseConfig; @@ -28,6 +28,10 @@ const CLEANUP_SQL: &str = r#" /// Session storage using Mysql. /// +/// # Errors +/// +/// - [`sqlx::Error`] +/// /// # Create the table for session storage /// /// ```sql @@ -103,14 +107,15 @@ impl MysqlSessionStorage { #[poem::async_trait] impl SessionStorage for MysqlSessionStorage { async fn load_session(&self, session_id: &str) -> Result>> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; let res: Option<(Json>,)> = self .load_stmt .query_as() .bind(session_id) .bind(Utc::now()) .fetch_optional(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(res.map(|(value,)| value.0)) } @@ -120,10 +125,12 @@ impl SessionStorage for MysqlSessionStorage { entries: &BTreeMap, expires: Option, ) -> Result<()> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; let expires = match expires { - Some(expires) => Some(chrono::Duration::from_std(expires)?), + Some(expires) => { + Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?) + } None => None, }; @@ -133,17 +140,19 @@ impl SessionStorage for MysqlSessionStorage { .bind(Json(entries)) .bind(expires.map(|expires| Utc::now() + expires)) .execute(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(()) } async fn remove_session(&self, session_id: &str) -> Result<()> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; self.remove_stmt .query() .bind(session_id) .execute(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(()) } } diff --git a/poem-dbsession/src/sqlx/postgres.rs b/poem-dbsession/src/sqlx/postgres.rs index f463ba80..7665dc4d 100644 --- a/poem-dbsession/src/sqlx/postgres.rs +++ b/poem-dbsession/src/sqlx/postgres.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, time::Duration}; use chrono::Utc; -use poem::{session::SessionStorage, Result}; +use poem::{error::InternalServerError, session::SessionStorage, Result}; use sqlx::{postgres::PgStatement, types::Json, Executor, PgPool, Statement}; use crate::DatabaseConfig; @@ -28,6 +28,10 @@ const CLEANUP_SQL: &str = r#" /// Session storage using Postgres. /// +/// # Errors +/// +/// - [`sqlx::Error`] +/// /// # Create the table for session storage /// /// ```sql @@ -101,14 +105,15 @@ impl PgSessionStorage { #[poem::async_trait] impl SessionStorage for PgSessionStorage { async fn load_session(&self, session_id: &str) -> Result>> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; let res: Option<(Json>,)> = self .load_stmt .query_as() .bind(session_id) .bind(Utc::now()) .fetch_optional(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(res.map(|(value,)| value.0)) } @@ -118,10 +123,12 @@ impl SessionStorage for PgSessionStorage { entries: &BTreeMap, expires: Option, ) -> Result<()> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; let expires = match expires { - Some(expires) => Some(chrono::Duration::from_std(expires)?), + Some(expires) => { + Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?) + } None => None, }; @@ -131,17 +138,19 @@ impl SessionStorage for PgSessionStorage { .bind(Json(entries)) .bind(expires.map(|expires| Utc::now() + expires)) .execute(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(()) } async fn remove_session(&self, session_id: &str) -> Result<()> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; self.remove_stmt .query() .bind(session_id) .execute(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(()) } } diff --git a/poem-dbsession/src/sqlx/sqlite.rs b/poem-dbsession/src/sqlx/sqlite.rs index 0f662aed..876ff0bd 100644 --- a/poem-dbsession/src/sqlx/sqlite.rs +++ b/poem-dbsession/src/sqlx/sqlite.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, time::Duration}; use chrono::Utc; -use poem::{session::SessionStorage, Result}; +use poem::{error::InternalServerError, session::SessionStorage, Result}; use sqlx::{sqlite::SqliteStatement, types::Json, Executor, SqlitePool, Statement}; use crate::DatabaseConfig; @@ -28,6 +28,10 @@ const CLEANUP_SQL: &str = r#" /// Session storage using Sqlite. /// +/// # Errors +/// +/// - [`sqlx::Error`] +/// /// # Create the table for session storage /// /// ```sql @@ -99,14 +103,15 @@ impl SqliteSessionStorage { #[poem::async_trait] impl SessionStorage for SqliteSessionStorage { async fn load_session(&self, session_id: &str) -> Result>> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; let res: Option<(Json>,)> = self .load_stmt .query_as() .bind(session_id) .bind(Utc::now()) .fetch_optional(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(res.map(|(value,)| value.0)) } @@ -116,10 +121,12 @@ impl SessionStorage for SqliteSessionStorage { entries: &BTreeMap, expires: Option, ) -> Result<()> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; let expires = match expires { - Some(expires) => Some(chrono::Duration::from_std(expires)?), + Some(expires) => { + Some(chrono::Duration::from_std(expires).map_err(InternalServerError)?) + } None => None, }; @@ -129,17 +136,19 @@ impl SessionStorage for SqliteSessionStorage { .bind(Json(entries)) .bind(expires.map(|expires| Utc::now() + expires)) .execute(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(()) } async fn remove_session(&self, session_id: &str) -> Result<()> { - let mut conn = self.pool.acquire().await?; + let mut conn = self.pool.acquire().await.map_err(InternalServerError)?; self.remove_stmt .query() .bind(session_id) .execute(&mut conn) - .await?; + .await + .map_err(InternalServerError)?; Ok(()) } } diff --git a/poem-derive/src/lib.rs b/poem-derive/src/lib.rs index fe304c9e..bc0b3ce6 100644 --- a/poem-derive/src/lib.rs +++ b/poem-derive/src/lib.rs @@ -1,7 +1,7 @@ //! Macros for poem -#![doc(html_favicon_url = "https://poem.rs/assets/favicon.ico")] -#![doc(html_logo_url = "https://poem.rs/en/assets/logo.png")] +#![doc(html_favicon_url = "https://raw.githubusercontent.com/poem-web/poem/master/favicon.ico")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/poem-web/poem/master/logo.png")] #![forbid(unsafe_code)] #![deny(private_in_public, unreachable_pub)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -64,10 +64,7 @@ fn generate_handler(internal: bool, input: TokenStream) -> Result { let id = quote::format_ident!("p{}", idx); args.push(id.clone()); extractors.push(quote! { - let #id = match <#ty as #crate_name::FromRequest>::from_request(&req, &mut body).await { - Ok(value) => value, - Err(err) => return #crate_name::IntoResponse::into_response(err), - }; + let #id = <#ty as #crate_name::FromRequest>::from_request(&req, &mut body).await?; }); } } @@ -82,11 +79,13 @@ fn generate_handler(internal: bool, input: TokenStream) -> Result { type Output = #crate_name::Response; #[allow(unused_mut)] - async fn call(&self, mut req: #crate_name::Request) -> Self::Output { + async fn call(&self, mut req: #crate_name::Request) -> #crate_name::Result { let (req, mut body) = req.split(); #(#extractors)* #item_fn - #crate_name::IntoResponse::into_response(#ident(#(#args),*)#call_await) + let res = #ident(#(#args),*)#call_await; + let res = #crate_name::error::IntoResult::into_result(res); + std::result::Result::map(res, #crate_name::IntoResponse::into_response) } } }; diff --git a/poem-lambda/src/lib.rs b/poem-lambda/src/lib.rs index ced1e2d4..73312455 100644 --- a/poem-lambda/src/lib.rs +++ b/poem-lambda/src/lib.rs @@ -1,17 +1,17 @@ //! Poem for AWS Lambda. -#![doc(html_favicon_url = "https://poem.rs/assets/favicon.ico")] -#![doc(html_logo_url = "https://poem.rs/en/assets/logo.png")] +#![doc(html_favicon_url = "https://raw.githubusercontent.com/poem-web/poem/master/favicon.ico")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/poem-web/poem/master/logo.png")] #![forbid(unsafe_code)] #![deny(private_in_public, unreachable_pub)] #![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs)] -use std::{convert::Infallible, io::ErrorKind, ops::Deref, sync::Arc}; +use std::{io::ErrorKind, ops::Deref, sync::Arc}; pub use lambda_http::lambda_runtime::Error; use lambda_http::{handler, lambda_runtime, Body as LambdaBody, Request as LambdaRequest}; -use poem::{Body, Endpoint, EndpointExt, FromRequest, IntoEndpoint, Request, RequestBody}; +use poem::{Body, Endpoint, EndpointExt, FromRequest, IntoEndpoint, Request, RequestBody, Result}; /// The Lambda function execution context. /// @@ -65,7 +65,10 @@ pub async fn run(ep: impl IntoEndpoint) -> Result<(), Error> { let mut req: Request = from_lambda_request(req); req.extensions_mut().insert(Context(ctx)); - let resp = ep.call(req).await; + let resp = match ep.call(req).await { + Ok(resp) => resp, + Err(err) => err.as_response(), + }; let (parts, body) = resp.into_parts(); let data = body @@ -111,9 +114,7 @@ fn from_lambda_request(req: LambdaRequest) -> Request { #[poem::async_trait] impl<'a> FromRequest<'a> for &'a Context { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { let ctx = match req.extensions().get::() { Some(ctx) => ctx, None => panic!("Lambda runtime is required."), diff --git a/poem-openapi-derive/src/api.rs b/poem-openapi-derive/src/api.rs index d448c4f1..58780304 100644 --- a/poem-openapi-derive/src/api.rs +++ b/poem-openapi-derive/src/api.rs @@ -289,10 +289,11 @@ fn generate_operation( let #pname = match <#arg_ty as #crate_name::ApiExtractor>::from_request(&request, &mut body, param_opts).await { ::std::result::Result::Ok(value) => value, ::std::result::Result::Err(err) if <#res_ty as #crate_name::ApiResponse>::BAD_REQUEST_HANDLER => { - let resp = <#res_ty as #crate_name::ApiResponse>::from_parse_request_error(err); - return #crate_name::__private::poem::IntoResponse::into_response(resp); + let res = <#res_ty as #crate_name::ApiResponse>::from_parse_request_error(err); + let res = #crate_name::__private::poem::error::IntoResult::into_result(res); + return ::std::result::Result::map(res, #crate_name::__private::poem::IntoResponse::into_response); } - ::std::result::Result::Err(err) => return #crate_name::__private::poem::IntoResponse::into_response(err), + ::std::result::Result::Err(err) => return ::std::result::Result::Err(::std::convert::Into::into(err)), }; #param_checker }); @@ -358,8 +359,9 @@ fn generate_operation( async move { let (request, mut body) = request.split(); #(#parse_args)* - let resp = api_obj.#fn_ident(#(#use_args),*).await; - #crate_name::__private::poem::IntoResponse::into_response(resp) + let res = api_obj.#fn_ident(#(#use_args),*).await; + let res = #crate_name::__private::poem::error::IntoResult::into_result(res); + ::std::result::Result::map(res, #crate_name::__private::poem::IntoResponse::into_response) } }); #transform diff --git a/poem-openapi-derive/src/lib.rs b/poem-openapi-derive/src/lib.rs index dacf41ad..ce96ca55 100644 --- a/poem-openapi-derive/src/lib.rs +++ b/poem-openapi-derive/src/lib.rs @@ -1,7 +1,7 @@ //! Macros for poem-openapi -#![doc(html_favicon_url = "https://poem.rs/assets/favicon.ico")] -#![doc(html_logo_url = "https://poem.rs/en/assets/logo.png")] +#![doc(html_favicon_url = "https://raw.githubusercontent.com/poem-web/poem/master/favicon.ico")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/poem-web/poem/master/logo.png")] #![forbid(unsafe_code)] #![deny(private_in_public, unreachable_pub)] diff --git a/poem-openapi-derive/src/multipart.rs b/poem-openapi-derive/src/multipart.rs index eae57855..5da06ff3 100644 --- a/poem-openapi-derive/src/multipart.rs +++ b/poem-openapi-derive/src/multipart.rs @@ -95,10 +95,9 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { fields.push(field_ident); let parse_err = quote! {{ - let resp = #crate_name::__private::poem::Response::builder() - .status(#crate_name::__private::poem::http::StatusCode::BAD_REQUEST) - .body(::std::format!("failed to parse field `{}`: {}", #field_name, err.into_message())); - #crate_name::ParseRequestError::ParseRequestBody(resp) + #crate_name::error::ParseMultipartError { + reason: ::std::format!("failed to parse field `{}`: {}", #field_name, err.into_message()), + } }}; deserialize_fields.push(quote! { @@ -143,11 +142,9 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { }, ::std::option::Option::None => { <#field_ty as #crate_name::types::ParseFromMultipartField>::parse_from_multipart(::std::option::Option::None).await.map_err(|_| - #crate_name::ParseRequestError::ParseRequestBody( - #crate_name::__private::poem::Response::builder() - .status(#crate_name::__private::poem::http::StatusCode::BAD_REQUEST) - .body(::std::format!("field `{}` is required", #field_name)) - ) + #crate_name::error::ParseMultipartError { + reason: ::std::format!("field `{}` is required", #field_name), + } )? } }; @@ -211,10 +208,9 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { let deny_unknown_fields = if args.deny_unknown_fields { Some(quote! { if let ::std::option::Option::Some(name) = field.name() { - let resp = #crate_name::__private::poem::Response::builder() - .status(#crate_name::__private::poem::http::StatusCode::BAD_REQUEST) - .body(::std::format!("unknown field `{}`", name)); - return ::std::result::Result::Err(#crate_name::ParseRequestError::ParseRequestBody(resp)); + return ::std::result::Result::Err(::std::convert::Into::into(#crate_name::error::ParseMultipartError { + reason: ::std::format!("unknown field `{}`", name), + })); } }) } else { @@ -248,12 +244,11 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { impl #impl_generics #crate_name::payload::ParsePayload for #ident #ty_generics #where_clause { const IS_REQUIRED: bool = true; - async fn from_request(request: &#crate_name::__private::poem::Request, body: &mut #crate_name::__private::poem::RequestBody) -> ::std::result::Result { - let mut multipart = <#crate_name::__private::poem::web::Multipart as #crate_name::__private::poem::FromRequest>::from_request(request, body).await - .map_err(|err| #crate_name::ParseRequestError::ParseRequestBody(#crate_name::__private::poem::IntoResponse::into_response(err)))?; + async fn from_request(request: &#crate_name::__private::poem::Request, body: &mut #crate_name::__private::poem::RequestBody) -> #crate_name::__private::poem::Result { + let mut multipart = <#crate_name::__private::poem::web::Multipart as #crate_name::__private::poem::FromRequest>::from_request(request, body).await?; #(#skip_fields)* #(let mut #fields = ::std::option::Option::None;)* - while let ::std::option::Option::Some(field) = multipart.next_field().await.map_err(|err| #crate_name::ParseRequestError::ParseRequestBody(#crate_name::__private::poem::IntoResponse::into_response(err)))? { + while let ::std::option::Option::Some(field) = multipart.next_field().await? { #(#deserialize_fields)* #deny_unknown_fields } @@ -288,27 +283,27 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { request: &'__request #crate_name::__private::poem::Request, body: &mut #crate_name::__private::poem::RequestBody, _param_opts: #crate_name::ExtractParamOptions, - ) -> ::std::result::Result { + ) -> #crate_name::__private::poem::Result { match request.content_type() { ::std::option::Option::Some(content_type) => { let mime: #crate_name::__private::mime::Mime = match content_type.parse() { ::std::result::Result::Ok(mime) => mime, ::std::result::Result::Err(_) => { - return ::std::result::Result::Err(#crate_name::ParseRequestError::ContentTypeNotSupported { + return ::std::result::Result::Err(::std::convert::Into::into(#crate_name::error::ContentTypeError::NotSupported { content_type: ::std::string::ToString::to_string(&content_type), - }); + })); } }; if mime.essence_str() != ::CONTENT_TYPE { - return ::std::result::Result::Err(#crate_name::ParseRequestError::ContentTypeNotSupported { + return ::std::result::Result::Err(::std::convert::Into::into(#crate_name::error::ContentTypeError::NotSupported { content_type: ::std::string::ToString::to_string(&content_type), - }); + })); } ::from_request(request, body).await } - ::std::option::Option::None => ::std::result::Result::Err(#crate_name::ParseRequestError::ExpectContentType), + ::std::option::Option::None => ::std::result::Result::Err(::std::convert::Into::into(#crate_name::error::ContentTypeError::ExpectContentType)), } } } diff --git a/poem-openapi-derive/src/request.rs b/poem-openapi-derive/src/request.rs index f57130e7..a7c85115 100644 --- a/poem-openapi-derive/src/request.rs +++ b/poem-openapi-derive/src/request.rs @@ -117,14 +117,15 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { request: &'__request #crate_name::__private::poem::Request, body: &mut #crate_name::__private::poem::RequestBody, _param_opts: #crate_name::ExtractParamOptions, - ) -> ::std::result::Result { + ) -> #crate_name::__private::poem::Result { let content_type = request.content_type(); match content_type { #(#from_requests)* - ::std::option::Option::Some(content_type) => ::std::result::Result::Err(#crate_name::ParseRequestError::ContentTypeNotSupported { - content_type: ::std::string::ToString::to_string(content_type), - }), - ::std::option::Option::None => ::std::result::Result::Err(#crate_name::ParseRequestError::ExpectContentType), + ::std::option::Option::Some(content_type) => ::std::result::Result::Err( + ::std::convert::Into::into(#crate_name::error::ContentTypeError::NotSupported { + content_type: ::std::string::ToString::to_string(content_type), + })), + ::std::option::Option::None => ::std::result::Result::Err(::std::convert::Into::into(#crate_name::error::ContentTypeError::ExpectContentType)), } } } diff --git a/poem-openapi-derive/src/response.rs b/poem-openapi-derive/src/response.rs index c6e1feb2..6c2a1727 100644 --- a/poem-openapi-derive/src/response.rs +++ b/poem-openapi-derive/src/response.rs @@ -210,7 +210,7 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { }; let bad_request_handler = args.bad_request_handler.as_ref().map(|path| { quote! { - fn from_parse_request_error(err: #crate_name::ParseRequestError) -> Self { + fn from_parse_request_error(err: #crate_name::__private::poem::Error) -> Self { #path(err) } } diff --git a/poem-openapi-derive/src/security_scheme.rs b/poem-openapi-derive/src/security_scheme.rs index c26161de..f6297a1c 100644 --- a/poem-openapi-derive/src/security_scheme.rs +++ b/poem-openapi-derive/src/security_scheme.rs @@ -446,7 +446,7 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { args.generate_register_security_scheme(&crate_name, &oai_typename)?; let from_request = args.generate_from_request(&crate_name); let checker = args.checker.as_ref().map(|path| quote! { - let output = ::std::option::Option::ok_or(#path(&req, output).await, #crate_name::ParseRequestError::Authorization)?; + let output = ::std::option::Option::ok_or(#path(&req, output).await, #crate_name::error::AuthorizationError)?; }); let expanded = quote! { @@ -469,7 +469,7 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { req: &'a #crate_name::__private::poem::Request, body: &mut #crate_name::__private::poem::RequestBody, _param_opts: #crate_name::ExtractParamOptions, - ) -> ::std::result::Result { + ) -> #crate_name::__private::poem::Result { let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap(); let output = #from_request?; #checker diff --git a/poem-openapi-derive/src/validators.rs b/poem-openapi-derive/src/validators.rs index 4cdd492d..1e075a57 100644 --- a/poem-openapi-derive/src/validators.rs +++ b/poem-openapi-derive/src/validators.rs @@ -1,5 +1,5 @@ use darling::{util::SpannedValue, FromMeta}; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::TokenStream; use quote::quote; use regex::Regex; use syn::{Error, Type}; @@ -134,16 +134,6 @@ impl Validators { Ok((container_validators, elem_validators)) } - fn first_container_validator_span(&self) -> Option { - self.max_items - .as_ref() - .map(SpannedValue::span) - .or_else(|| self.min_items.as_ref().map(SpannedValue::span)) - .or_else(|| self.unique_items.as_ref().map(SpannedValue::span)) - .or_else(|| self.max_properties.as_ref().map(SpannedValue::span)) - .or_else(|| self.min_properties.as_ref().map(SpannedValue::span)) - } - pub(crate) fn create_obj_field_checker( &self, crate_name: &TokenStream, @@ -162,8 +152,8 @@ impl Validators { )* #( - if let ::std::option::Option::Some(value) = #crate_name::types::Type::as_raw_value(&value) { let - validator = #container_validators; + if let ::std::option::Option::Some(value) = #crate_name::types::Type::as_raw_value(&value) { + let validator = #container_validators; if !#crate_name::validation::Validator::check(&validator, value) { return Err(#crate_name::types::ParseError::::custom(format!("field `{}` verification failed. {}", #field_name, validator))); } @@ -180,33 +170,40 @@ impl Validators { ) -> GeneratorResult> { let (container_validators, elem_validators) = self.create_validators(crate_name)?; - if !container_validators.is_empty() { - return Err(Error::new( - self.first_container_validator_span().unwrap(), - "The `container` validators is not supported for parameters.", - ) - .into()); - } - - if elem_validators.is_empty() { - return Ok(None); - } - Ok(Some(quote! { #( if let ::std::option::Option::Some(value) = #crate_name::types::Type::as_raw_value(&value) { - let validator = #elem_validators; + let validator = #container_validators; if !#crate_name::validation::Validator::check(&validator, value) { - let err = #crate_name::ParseRequestError::ParseParam { + let err = #crate_name::error::ParseParamError { name: #arg_name, reason: ::std::format!("verification failed. {}", validator), }; if <#res_ty as #crate_name::ApiResponse>::BAD_REQUEST_HANDLER { - let resp = <#res_ty as #crate_name::ApiResponse>::from_parse_request_error(err); - return #crate_name::__private::poem::IntoResponse::into_response(resp); + let resp = <#res_ty as #crate_name::ApiResponse>::from_parse_request_error(std::convert::Into::into(err)); + return ::std::result::Result::Ok(#crate_name::__private::poem::IntoResponse::into_response(resp)); } else { - return #crate_name::__private::poem::IntoResponse::into_response(err); + return ::std::result::Result::Err(std::convert::Into::into(err)); + } + } + } + )* + + #( + if let ::std::option::Option::Some(value) = #crate_name::types::Type::as_raw_value(&value) { + let validator = #elem_validators; + if !#crate_name::validation::Validator::check(&validator, value) { + let err = #crate_name::error::ParseParamError { + name: #arg_name, + reason: ::std::format!("verification failed. {}", validator), + }; + + if <#res_ty as #crate_name::ApiResponse>::BAD_REQUEST_HANDLER { + let resp = <#res_ty as #crate_name::ApiResponse>::from_parse_request_error(std::convert::Into::into(err)); + return ::std::result::Result::Ok(#crate_name::__private::poem::IntoResponse::into_response(resp)); + } else { + return ::std::result::Result::Err(std::convert::Into::into(err)); } } } @@ -226,11 +223,9 @@ impl Validators { for item in #crate_name::types::Type::raw_element_iter(&value) { let validator = #elem_validators; if !#crate_name::validation::Validator::check(&validator, item) { - return Err(#crate_name::ParseRequestError::ParseRequestBody( - #crate_name::__private::poem::Response::builder() - .status(#crate_name::__private::poem::http::StatusCode::BAD_REQUEST) - .body(::std::format!("field `{}` verification failed. {}", #field_name, validator)) - )); + return Err(::std::convert::Into::into(#crate_name::error::ParseMultipartError { + reason: ::std::format!("field `{}` verification failed. {}", #field_name, validator), + })); } } )* @@ -239,13 +234,11 @@ impl Validators { if let ::std::option::Option::Some(value) = #crate_name::types::Type::as_raw_value(&value) { let validator = #container_validators; if !#crate_name::validation::Validator::check(&validator, value) { - return Err(#crate_name::ParseRequestError::ParseRequestBody( - #crate_name::__private::poem::Response::builder() - .status(#crate_name::__private::poem::http::StatusCode::BAD_REQUEST) - .body(::std::format!("field `{}` verification failed. {}", #field_name, validator)) - )); - } + return Err(::std::convert::Into::into(#crate_name::error::ParseMultipartError { + reason: ::std::format!("field `{}` verification failed. {}", #field_name, validator), + })); } + } )* }) } diff --git a/poem-openapi/src/auth/api_key.rs b/poem-openapi/src/auth/api_key.rs index f9300d9a..f2a714bd 100644 --- a/poem-openapi/src/auth/api_key.rs +++ b/poem-openapi/src/auth/api_key.rs @@ -1,6 +1,8 @@ -use poem::Request; +use poem::{Request, Result}; -use crate::{auth::ApiKeyAuthorization, base::UrlQuery, registry::MetaParamIn, ParseRequestError}; +use crate::{ + auth::ApiKeyAuthorization, base::UrlQuery, error::AuthorizationError, registry::MetaParamIn, +}; /// Used to extract the Api Key from the request. pub struct ApiKey { @@ -14,13 +16,13 @@ impl ApiKeyAuthorization for ApiKey { query: &UrlQuery, name: &str, in_type: MetaParamIn, - ) -> Result { + ) -> Result { match in_type { MetaParamIn::Query => query .get(name) .cloned() .map(|value| Self { key: value }) - .ok_or(ParseRequestError::Authorization), + .ok_or_else(|| AuthorizationError.into()), MetaParamIn::Header => req .headers() .get(name) @@ -28,7 +30,7 @@ impl ApiKeyAuthorization for ApiKey { .map(|value| Self { key: value.to_string(), }) - .ok_or(ParseRequestError::Authorization), + .ok_or_else(|| AuthorizationError.into()), MetaParamIn::Cookie => req .cookie() .get(name) @@ -36,7 +38,7 @@ impl ApiKeyAuthorization for ApiKey { .map(|cookie| Self { key: cookie.value_str().to_string(), }) - .ok_or(ParseRequestError::Authorization), + .ok_or_else(|| AuthorizationError.into()), _ => unreachable!(), } } diff --git a/poem-openapi/src/auth/basic.rs b/poem-openapi/src/auth/basic.rs index c54af554..c6fad234 100644 --- a/poem-openapi/src/auth/basic.rs +++ b/poem-openapi/src/auth/basic.rs @@ -1,7 +1,7 @@ -use poem::Request; +use poem::{Request, Result}; use typed_headers::{AuthScheme, Authorization, HeaderMapExt}; -use crate::{auth::BasicAuthorization, ParseRequestError}; +use crate::{auth::BasicAuthorization, error::AuthorizationError}; /// Used to extract the username/password from the request. pub struct Basic { @@ -13,7 +13,7 @@ pub struct Basic { } impl BasicAuthorization for Basic { - fn from_request(req: &Request) -> Result { + fn from_request(req: &Request) -> Result { if let Some(auth) = req.headers().typed_get::().ok().flatten() { if auth.0.scheme() == &AuthScheme::BASIC { if let Some(token68) = auth.token68() { @@ -34,6 +34,6 @@ impl BasicAuthorization for Basic { } } - Err(ParseRequestError::Authorization) + Err(AuthorizationError.into()) } } diff --git a/poem-openapi/src/auth/bearer.rs b/poem-openapi/src/auth/bearer.rs index 83026bdc..021f151c 100644 --- a/poem-openapi/src/auth/bearer.rs +++ b/poem-openapi/src/auth/bearer.rs @@ -1,7 +1,7 @@ -use poem::Request; +use poem::{Request, Result}; use typed_headers::{AuthScheme, Authorization, HeaderMapExt}; -use crate::{auth::BearerAuthorization, ParseRequestError}; +use crate::{auth::BearerAuthorization, error::AuthorizationError}; /// Used to extract the token68 from the request. pub struct Bearer { @@ -10,7 +10,7 @@ pub struct Bearer { } impl BearerAuthorization for Bearer { - fn from_request(req: &Request) -> Result { + fn from_request(req: &Request) -> Result { if let Some(auth) = req.headers().typed_get::().ok().flatten() { if auth.0.scheme() == &AuthScheme::BEARER { if let Some(token68) = auth.token68() { @@ -21,6 +21,6 @@ impl BearerAuthorization for Bearer { } } - Err(ParseRequestError::Authorization) + Err(AuthorizationError.into()) } } diff --git a/poem-openapi/src/auth/mod.rs b/poem-openapi/src/auth/mod.rs index 5b5b4b6d..8814123d 100644 --- a/poem-openapi/src/auth/mod.rs +++ b/poem-openapi/src/auth/mod.rs @@ -4,21 +4,21 @@ mod api_key; mod basic; mod bearer; -use poem::Request; +use poem::{Request, Result}; pub use self::{api_key::ApiKey, basic::Basic, bearer::Bearer}; -use crate::{base::UrlQuery, registry::MetaParamIn, ParseRequestError}; +use crate::{base::UrlQuery, registry::MetaParamIn}; /// Represents a basic authorization extractor. pub trait BasicAuthorization: Sized { /// Extract from the HTTP request. - fn from_request(req: &Request) -> Result; + fn from_request(req: &Request) -> Result; } /// Represents a bearer authorization extractor. pub trait BearerAuthorization: Sized { /// Extract from the HTTP request. - fn from_request(req: &Request) -> Result; + fn from_request(req: &Request) -> Result; } /// Represents an api key authorization extractor. @@ -29,5 +29,5 @@ pub trait ApiKeyAuthorization: Sized { query: &UrlQuery, name: &str, in_type: MetaParamIn, - ) -> Result; + ) -> Result; } diff --git a/poem-openapi/src/base.rs b/poem-openapi/src/base.rs index 857c7b2a..999b5419 100644 --- a/poem-openapi/src/base.rs +++ b/poem-openapi/src/base.rs @@ -1,13 +1,10 @@ use std::ops::Deref; -use poem::{FromRequest, IntoResponse, Request, RequestBody, Result, Route}; +use poem::{Error, FromRequest, Request, RequestBody, Result, Route}; -use crate::{ - registry::{ - MetaApi, MetaOAuthScope, MetaParamIn, MetaRequest, MetaResponse, MetaResponses, - MetaSchemaRef, Registry, - }, - ParseRequestError, +use crate::registry::{ + MetaApi, MetaOAuthScope, MetaParamIn, MetaRequest, MetaResponse, MetaResponses, MetaSchemaRef, + Registry, }; /// API extractor types. @@ -38,6 +35,7 @@ impl Deref for UrlQuery { } impl UrlQuery { + #[allow(missing_docs)] pub fn get_all<'a, 'b: 'a>(&'b self, name: &'a str) -> impl Iterator + 'a { self.0 .iter() @@ -45,6 +43,7 @@ impl UrlQuery { .map(|(_, value)| value) } + #[allow(missing_docs)] pub fn get(&self, name: &str) -> Option<&String> { self.get_all(name).next() } @@ -117,7 +116,7 @@ pub trait ApiExtractor<'a>: Sized { request: &'a Request, body: &mut RequestBody, param_opts: ExtractParamOptions, - ) -> Result; + ) -> Result; } #[poem::async_trait] @@ -131,18 +130,15 @@ impl<'a, T: FromRequest<'a>> ApiExtractor<'a> for T { request: &'a Request, body: &mut RequestBody, _param_opts: ExtractParamOptions, - ) -> Result { - match T::from_request(request, body).await { - Ok(value) => Ok(value), - Err(err) => Err(ParseRequestError::Extractor(err.into_response())), - } + ) -> Result { + T::from_request(request, body).await } } /// Represents a OpenAPI responses object. /// /// Reference: -pub trait ApiResponse: IntoResponse + Sized { +pub trait ApiResponse: Sized { /// If true, it means that the response object has a custom bad request /// handler. const BAD_REQUEST_HANDLER: bool = false; @@ -155,7 +151,7 @@ pub trait ApiResponse: IntoResponse + Sized { /// Convert [`ParseRequestError`] to this response object. #[allow(unused_variables)] - fn from_parse_request_error(err: ParseRequestError) -> Self { + fn from_parse_request_error(err: Error) -> Self { unreachable!() } } @@ -175,7 +171,9 @@ impl ApiResponse for () { fn register(_registry: &mut Registry) {} } -impl ApiResponse for Result { +impl ApiResponse for Result { + const BAD_REQUEST_HANDLER: bool = T::BAD_REQUEST_HANDLER; + fn meta() -> MetaResponses { T::meta() } @@ -183,6 +181,10 @@ impl ApiResponse for Result { fn register(registry: &mut Registry) { T::register(registry); } + + fn from_parse_request_error(err: Error) -> Self { + Ok(T::from_parse_request_error(err)) + } } /// Represents a OpenAPI tags. diff --git a/poem-openapi/src/docs/response.md b/poem-openapi/src/docs/response.md index 942654d4..a4419497 100644 --- a/poem-openapi/src/docs/response.md +++ b/poem-openapi/src/docs/response.md @@ -22,7 +22,8 @@ Define a OpenAPI response. # Examples ```rust -use poem_openapi::{payload::PlainText, ApiResponse, ParseRequestError}; +use poem::Error; +use poem_openapi::{payload::PlainText, ApiResponse}; #[derive(ApiResponse)] #[oai(bad_request_handler = "bad_request_handler")] @@ -39,7 +40,7 @@ enum CreateUserResponse { } // Convert error to `CreateUserResponse::BadRequest`. -fn bad_request_handler(err: ParseRequestError) -> CreateUserResponse { +fn bad_request_handler(err: Error) -> CreateUserResponse { CreateUserResponse::BadRequest(PlainText(format!("error: {}", err.to_string()))) } ``` \ No newline at end of file diff --git a/poem-openapi/src/error.rs b/poem-openapi/src/error.rs index a288eecc..c24d88b4 100644 --- a/poem-openapi/src/error.rs +++ b/poem-openapi/src/error.rs @@ -1,60 +1,81 @@ -use poem::{http::StatusCode, IntoResponse, Response}; +//! Some common error types. + +use poem::{error::ResponseError, http::StatusCode}; use thiserror::Error; -/// This type represents errors that occur when parsing the HTTP request. +/// Parameter error. #[derive(Debug, Error)] -pub enum ParseRequestError { - /// Failed to parse a parameter. - #[error("Failed to parse parameter `{name}`: {reason}")] - ParseParam { - /// The name of the parameter. - name: &'static str, +#[error("failed to parse parameter `{name}`: {reason}")] +pub struct ParseParamError { + /// The name of the parameter. + pub name: &'static str, - /// The reason for the error. - reason: String, - }, + /// The reason for the error. + pub reason: String, +} - /// Failed to parse a request body. - #[error("Failed to parse a request body")] - ParseRequestBody(Response), +impl ResponseError for ParseParamError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} - /// The `Content-Type` requested by the client is not supported. - #[error("The `Content-Type` requested by the client is not supported: {content_type}")] - ContentTypeNotSupported { +/// Parse JSON error. +#[derive(Debug, Error)] +#[error("parse JSON error: {reason}")] +pub struct ParseJsonError { + /// The reason for the error. + pub reason: String, +} + +impl ResponseError for ParseJsonError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// Parse multipart error. +#[derive(Debug, Error)] +#[error("parse multipart error: {reason}")] +pub struct ParseMultipartError { + /// The reason for the error. + pub reason: String, +} + +impl ResponseError for ParseMultipartError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// Content type error. +#[derive(Debug, Error)] +pub enum ContentTypeError { + /// Not supported. + #[error("the `Content-Type` requested by the client is not supported: {content_type}")] + NotSupported { /// The `Content-Type` header requested by the client. content_type: String, }, - /// The client request does not include the `Content-Type` header. - #[error("The client request does not include the `Content-Type` header")] + /// Expect content type header. + #[error("the client request does not include the `Content-Type` header")] ExpectContentType, - - /// Poem extractor error. - #[error("Poem extractor error")] - Extractor(Response), - - /// Authorization error. - #[error("Authorization error")] - Authorization, } -impl IntoResponse for ParseRequestError { - fn into_response(self) -> Response { - match self { - ParseRequestError::ParseParam { .. } => self - .to_string() - .with_status(StatusCode::BAD_REQUEST) - .into_response(), - ParseRequestError::ContentTypeNotSupported { .. } - | ParseRequestError::ExpectContentType => self - .to_string() - .with_status(StatusCode::METHOD_NOT_ALLOWED) - .into_response(), - ParseRequestError::ParseRequestBody(resp) | ParseRequestError::Extractor(resp) => resp, - ParseRequestError::Authorization => self - .to_string() - .with_status(StatusCode::UNAUTHORIZED) - .into_response(), - } +impl ResponseError for ContentTypeError { + fn status(&self) -> StatusCode { + StatusCode::METHOD_NOT_ALLOWED + } +} + +/// Authorization error. +#[derive(Debug, Error)] +#[error("authorization error")] +pub struct AuthorizationError; + +impl ResponseError for AuthorizationError { + fn status(&self) -> StatusCode { + StatusCode::UNAUTHORIZED } } diff --git a/poem-openapi/src/lib.rs b/poem-openapi/src/lib.rs index 3be10e45..5e132593 100644 --- a/poem-openapi/src/lib.rs +++ b/poem-openapi/src/lib.rs @@ -81,8 +81,8 @@ //! hello, sunli! //! ``` -#![doc(html_favicon_url = "https://poem.rs/assets/favicon.ico")] -#![doc(html_logo_url = "https://poem.rs/en/assets/logo.png")] +#![doc(html_favicon_url = "https://raw.githubusercontent.com/poem-web/poem/master/favicon.ico")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/poem-web/poem/master/logo.png")] #![forbid(unsafe_code)] #![deny(private_in_public, unreachable_pub)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -92,6 +92,7 @@ mod macros; pub mod auth; +pub mod error; pub mod param; pub mod payload; #[doc(hidden)] @@ -101,7 +102,6 @@ pub mod types; pub mod validation; mod base; -mod error; mod openapi; #[cfg(any(feature = "swagger-ui", feature = "rapidoc", feature = "redoc"))] mod ui; @@ -110,7 +110,6 @@ pub use base::{ ApiExtractor, ApiExtractorType, ApiResponse, CombinedAPI, ExtractParamOptions, OAuthScopes, OpenApi, Tags, }; -pub use error::ParseRequestError; pub use openapi::{LicenseObject, OpenApiService, ServerObject}; #[doc = include_str!("docs/request.md")] pub use poem_openapi_derive::ApiRequest; diff --git a/poem-openapi/src/macros.rs b/poem-openapi/src/macros.rs index ac2cd24e..5707af12 100644 --- a/poem-openapi/src/macros.rs +++ b/poem-openapi/src/macros.rs @@ -30,27 +30,27 @@ macro_rules! impl_apirequest_for_payload { request: &'a poem::Request, body: &mut poem::RequestBody, _param_opts: $crate::ExtractParamOptions, - ) -> Result { + ) -> poem::Result { match request.content_type() { Some(content_type) => { let mime: mime::Mime = match content_type.parse() { Ok(mime) => mime, Err(_) => { - return Err($crate::ParseRequestError::ContentTypeNotSupported { + return Err($crate::error::ContentTypeError::NotSupported { content_type: content_type.to_string(), - }); + }.into()); } }; if mime.essence_str() != ::CONTENT_TYPE { - return Err($crate::ParseRequestError::ContentTypeNotSupported { + return Err($crate::error::ContentTypeError::NotSupported { content_type: content_type.to_string(), - }); + }.into()); } ::from_request(request, body).await } - None => Err($crate::ParseRequestError::ExpectContentType), + None => Err($crate::error::ContentTypeError::ExpectContentType.into()), } } } diff --git a/poem-openapi/src/openapi.rs b/poem-openapi/src/openapi.rs index a418674b..aa807398 100644 --- a/poem-openapi/src/openapi.rs +++ b/poem-openapi/src/openapi.rs @@ -2,7 +2,7 @@ use poem::{ endpoint::{make_sync, BoxEndpoint}, middleware::CookieJarManager, web::cookie::CookieKey, - Endpoint, EndpointExt, IntoEndpoint, Request, Response, Route, + Endpoint, EndpointExt, IntoEndpoint, Request, Response, Result, Route, }; use crate::{ @@ -234,10 +234,10 @@ impl IntoEndpoint for OpenApiService { type Endpoint = BoxEndpoint<'static, Response>; fn into_endpoint(self) -> Self::Endpoint { - async fn extract_query(mut req: Request) -> Request { + async fn extract_query(mut req: Request) -> Result { let url_query: Vec<(String, String)> = req.params().unwrap_or_default(); req.extensions_mut().insert(UrlQuery(url_query)); - req + Ok(req) } let cookie_jar_manager = match self.cookie_key { diff --git a/poem-openapi/src/param/cookie.rs b/poem-openapi/src/param/cookie.rs index 57f3143c..02606a24 100644 --- a/poem-openapi/src/param/cookie.rs +++ b/poem-openapi/src/param/cookie.rs @@ -1,11 +1,12 @@ use std::ops::{Deref, DerefMut}; -use poem::{Request, RequestBody}; +use poem::{Request, RequestBody, Result}; use crate::{ + error::ParseParamError, registry::{MetaParamIn, MetaSchemaRef, Registry}, types::ParseFromParameter, - ApiExtractor, ApiExtractorType, ExtractParamOptions, ParseRequestError, + ApiExtractor, ApiExtractorType, ExtractParamOptions, }; /// Represents the parameters passed by the cookie. @@ -53,7 +54,7 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Cookie { request: &'a Request, _body: &mut RequestBody, param_opts: ExtractParamOptions, - ) -> Result { + ) -> Result { let value = request .cookie() .get(param_opts.name) @@ -67,9 +68,12 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Cookie { ParseFromParameter::parse_from_parameters(value.as_deref()) .map(Self) - .map_err(|err| ParseRequestError::ParseParam { - name: param_opts.name, - reason: err.into_message(), + .map_err(|err| { + ParseParamError { + name: param_opts.name, + reason: err.into_message(), + } + .into() }) } } @@ -119,7 +123,7 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for CookiePrivate { request: &'a Request, _body: &mut RequestBody, param_opts: ExtractParamOptions, - ) -> Result { + ) -> Result { let value = request .cookie() .private() @@ -134,9 +138,12 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for CookiePrivate { ParseFromParameter::parse_from_parameters(value.as_deref()) .map(Self) - .map_err(|err| ParseRequestError::ParseParam { - name: param_opts.name, - reason: err.into_message(), + .map_err(|err| { + ParseParamError { + name: param_opts.name, + reason: err.into_message(), + } + .into() }) } } @@ -186,7 +193,7 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for CookieSigned { request: &'a Request, _body: &mut RequestBody, param_opts: ExtractParamOptions, - ) -> Result { + ) -> Result { let value = request .cookie() .signed() @@ -201,9 +208,12 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for CookieSigned { ParseFromParameter::parse_from_parameters(value.as_deref()) .map(Self) - .map_err(|err| ParseRequestError::ParseParam { - name: param_opts.name, - reason: err.into_message(), + .map_err(|err| { + ParseParamError { + name: param_opts.name, + reason: err.into_message(), + } + .into() }) } } diff --git a/poem-openapi/src/param/header.rs b/poem-openapi/src/param/header.rs index 6074fa97..bb71b654 100644 --- a/poem-openapi/src/param/header.rs +++ b/poem-openapi/src/param/header.rs @@ -1,11 +1,12 @@ use std::ops::{Deref, DerefMut}; -use poem::{Request, RequestBody}; +use poem::{Request, RequestBody, Result}; use crate::{ + error::ParseParamError, registry::{MetaParamIn, MetaSchemaRef, Registry}, types::ParseFromParameter, - ApiExtractor, ApiExtractorType, ExtractParamOptions, ParseRequestError, + ApiExtractor, ApiExtractorType, ExtractParamOptions, }; /// Represents the parameters passed by the request header. @@ -53,7 +54,7 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Header { request: &'a Request, _body: &mut RequestBody, param_opts: ExtractParamOptions, - ) -> Result { + ) -> Result { let mut values = request .headers() .get_all(param_opts.name) @@ -70,9 +71,12 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Header { ParseFromParameter::parse_from_parameters(values) .map(Self) - .map_err(|err| ParseRequestError::ParseParam { - name: param_opts.name, - reason: err.into_message(), + .map_err(|err| { + ParseParamError { + name: param_opts.name, + reason: err.into_message(), + } + .into() }) } } diff --git a/poem-openapi/src/param/path.rs b/poem-openapi/src/param/path.rs index 021294a4..6fbe40ef 100644 --- a/poem-openapi/src/param/path.rs +++ b/poem-openapi/src/param/path.rs @@ -1,11 +1,12 @@ use std::ops::{Deref, DerefMut}; -use poem::{Request, RequestBody}; +use poem::{Request, RequestBody, Result}; use crate::{ + error::ParseParamError, registry::{MetaParamIn, MetaSchemaRef, Registry}, types::ParseFromParameter, - ApiExtractor, ApiExtractorType, ExtractParamOptions, ParseRequestError, + ApiExtractor, ApiExtractorType, ExtractParamOptions, }; /// Represents the parameters passed by the URI path. @@ -53,7 +54,7 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Path { request: &'a Request, _body: &mut RequestBody, param_opts: ExtractParamOptions, - ) -> Result { + ) -> Result { let value = match ( request.raw_path_param(param_opts.name), ¶m_opts.default_value, @@ -65,9 +66,12 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Path { ParseFromParameter::parse_from_parameters(value) .map(Self) - .map_err(|err| ParseRequestError::ParseParam { - name: param_opts.name, - reason: err.into_message(), + .map_err(|err| { + ParseParamError { + name: param_opts.name, + reason: err.into_message(), + } + .into() }) } } diff --git a/poem-openapi/src/param/query.rs b/poem-openapi/src/param/query.rs index 9613a5dc..24837797 100644 --- a/poem-openapi/src/param/query.rs +++ b/poem-openapi/src/param/query.rs @@ -1,12 +1,13 @@ use std::ops::{Deref, DerefMut}; -use poem::{Request, RequestBody}; +use poem::{Request, RequestBody, Result}; use crate::{ base::UrlQuery, + error::ParseParamError, registry::{MetaParamIn, MetaSchemaRef, Registry}, types::ParseFromParameter, - ApiExtractor, ApiExtractorType, ExtractParamOptions, ParseRequestError, + ApiExtractor, ApiExtractorType, ExtractParamOptions, }; /// Represents the parameters passed by the query string. @@ -54,7 +55,7 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Query { request: &'a Request, _body: &mut RequestBody, param_opts: ExtractParamOptions, - ) -> Result { + ) -> Result { let mut values = request .extensions() .get::() @@ -71,9 +72,12 @@ impl<'a, T: ParseFromParameter> ApiExtractor<'a> for Query { ParseFromParameter::parse_from_parameters(values) .map(Self) - .map_err(|err| ParseRequestError::ParseParam { - name: param_opts.name, - reason: err.into_message(), + .map_err(|err| { + ParseParamError { + name: param_opts.name, + reason: err.into_message(), + } + .into() }) } } diff --git a/poem-openapi/src/payload/binary.rs b/poem-openapi/src/payload/binary.rs index 93885203..db67f05b 100644 --- a/poem-openapi/src/payload/binary.rs +++ b/poem-openapi/src/payload/binary.rs @@ -1,12 +1,12 @@ use std::ops::{Deref, DerefMut}; use bytes::Bytes; -use poem::{Body, FromRequest, IntoResponse, Request, RequestBody, Response}; +use poem::{Body, FromRequest, IntoResponse, Request, RequestBody, Response, Result}; use crate::{ payload::{ParsePayload, Payload}, registry::{MetaMediaType, MetaResponse, MetaResponses, MetaSchema, MetaSchemaRef, Registry}, - ApiResponse, ParseRequestError, + ApiResponse, }; /// A binary payload. @@ -54,7 +54,8 @@ use crate::{ /// .uri(Uri::from_static("/upload")) /// .body("abcdef"), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "6"); /// @@ -66,7 +67,8 @@ use crate::{ /// .uri(Uri::from_static("/upload_stream")) /// .body("abcdef"), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "6"); /// # }); @@ -103,13 +105,8 @@ impl Payload for Binary { impl ParsePayload for Binary> { const IS_REQUIRED: bool = true; - async fn from_request( - request: &Request, - body: &mut RequestBody, - ) -> Result { - Ok(Self(>::from_request(request, body).await.map_err( - |err| ParseRequestError::ParseRequestBody(err.into_response()), - )?)) + async fn from_request(request: &Request, body: &mut RequestBody) -> Result { + Ok(Self(>::from_request(request, body).await?)) } } @@ -117,13 +114,8 @@ impl ParsePayload for Binary> { impl ParsePayload for Binary { const IS_REQUIRED: bool = true; - async fn from_request( - request: &Request, - body: &mut RequestBody, - ) -> Result { - Ok(Self(Bytes::from_request(request, body).await.map_err( - |err| ParseRequestError::ParseRequestBody(err.into_response()), - )?)) + async fn from_request(request: &Request, body: &mut RequestBody) -> Result { + Ok(Self(Bytes::from_request(request, body).await?)) } } @@ -131,13 +123,8 @@ impl ParsePayload for Binary { impl ParsePayload for Binary { const IS_REQUIRED: bool = true; - async fn from_request( - request: &Request, - body: &mut RequestBody, - ) -> Result { - Ok(Self(Body::from_request(request, body).await.map_err( - |err| ParseRequestError::ParseRequestBody(err.into_response()), - )?)) + async fn from_request(request: &Request, body: &mut RequestBody) -> Result { + Ok(Self(Body::from_request(request, body).await?)) } } diff --git a/poem-openapi/src/payload/json.rs b/poem-openapi/src/payload/json.rs index 15e744d7..7e9a5bb5 100644 --- a/poem-openapi/src/payload/json.rs +++ b/poem-openapi/src/payload/json.rs @@ -1,17 +1,14 @@ use std::ops::{Deref, DerefMut}; -use poem::{ - error::{ParseJsonError, ReadBodyError}, - http::StatusCode, - FromRequest, IntoResponse, Request, RequestBody, Response, -}; +use poem::{FromRequest, IntoResponse, Request, RequestBody, Response, Result}; use serde_json::Value; use crate::{ + error::ParseJsonError, payload::{ParsePayload, Payload}, registry::{MetaMediaType, MetaResponse, MetaResponses, MetaSchemaRef, Registry}, types::{ParseFromJSON, ToJSON, Type}, - ApiResponse, ParseRequestError, + ApiResponse, }; /// A JSON payload. @@ -49,30 +46,18 @@ impl Payload for Json { impl ParsePayload for Json { const IS_REQUIRED: bool = T::IS_REQUIRED; - async fn from_request( - request: &Request, - body: &mut RequestBody, - ) -> Result { - let data: Vec = - FromRequest::from_request(request, body) - .await - .map_err(|err: ReadBodyError| { - ParseRequestError::ParseRequestBody(err.into_response()) - })?; + async fn from_request(request: &Request, body: &mut RequestBody) -> Result { + let data: Vec = FromRequest::from_request(request, body).await?; let value = if data.is_empty() { Value::Null } else { - serde_json::from_slice(&data) - .map_err(ParseJsonError::Json) - .map_err(|err| ParseRequestError::ParseRequestBody(err.into_response()))? + serde_json::from_slice(&data).map_err(|err| ParseJsonError { + reason: err.to_string(), + })? }; - let value = T::parse_from_json(value).map_err(|err| { - ParseRequestError::ParseRequestBody( - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(err.into_message()), - ) + let value = T::parse_from_json(value).map_err(|err| ParseJsonError { + reason: err.into_message(), })?; Ok(Self(value)) } diff --git a/poem-openapi/src/payload/mod.rs b/poem-openapi/src/payload/mod.rs index a2c6f2d8..1451d805 100644 --- a/poem-openapi/src/payload/mod.rs +++ b/poem-openapi/src/payload/mod.rs @@ -11,10 +11,7 @@ use poem::{Request, RequestBody, Result}; pub use self::{ attachment::Attachment, binary::Binary, json::Json, plain_text::PlainText, response::Response, }; -use crate::{ - registry::{MetaSchemaRef, Registry}, - ParseRequestError, -}; +use crate::registry::{MetaSchemaRef, Registry}; /// Represents a payload type. pub trait Payload: Send { @@ -36,8 +33,5 @@ pub trait ParsePayload: Sized { const IS_REQUIRED: bool; /// Parse the payload object from the HTTP request. - async fn from_request( - request: &Request, - body: &mut RequestBody, - ) -> Result; + async fn from_request(request: &Request, body: &mut RequestBody) -> Result; } diff --git a/poem-openapi/src/payload/plain_text.rs b/poem-openapi/src/payload/plain_text.rs index 3693378f..48fe2f5e 100644 --- a/poem-openapi/src/payload/plain_text.rs +++ b/poem-openapi/src/payload/plain_text.rs @@ -1,12 +1,12 @@ use std::ops::{Deref, DerefMut}; -use poem::{FromRequest, IntoResponse, Request, RequestBody, Response}; +use poem::{FromRequest, IntoResponse, Request, RequestBody, Response, Result}; use crate::{ payload::{ParsePayload, Payload}, registry::{MetaMediaType, MetaResponse, MetaResponses, MetaSchemaRef, Registry}, types::Type, - ApiResponse, ParseRequestError, + ApiResponse, }; /// A UTF8 string payload. @@ -39,13 +39,8 @@ impl Payload for PlainText { impl ParsePayload for PlainText { const IS_REQUIRED: bool = false; - async fn from_request( - request: &Request, - body: &mut RequestBody, - ) -> Result { - Ok(Self(String::from_request(request, body).await.map_err( - |err| ParseRequestError::ParseRequestBody(err.into_response()), - )?)) + async fn from_request(request: &Request, body: &mut RequestBody) -> Result { + Ok(Self(String::from_request(request, body).await?)) } } diff --git a/poem-openapi/src/payload/response.rs b/poem-openapi/src/payload/response.rs index 4fa5c91b..cd948522 100644 --- a/poem-openapi/src/payload/response.rs +++ b/poem-openapi/src/payload/response.rs @@ -1,11 +1,11 @@ use poem::{ http::{header::HeaderName, HeaderMap, HeaderValue, StatusCode}, - IntoResponse, + Error, IntoResponse, }; use crate::{ registry::{MetaResponses, Registry}, - ApiResponse, ParseRequestError, + ApiResponse, }; /// A response type wrapper. @@ -53,7 +53,7 @@ impl Response { } } -impl IntoResponse for Response { +impl IntoResponse for Response { fn into_response(self) -> poem::Response { let mut resp = self.inner.into_response(); if let Some(status) = self.status { @@ -75,7 +75,7 @@ impl ApiResponse for Response { T::register(registry); } - fn from_parse_request_error(err: ParseRequestError) -> Self { + fn from_parse_request_error(err: Error) -> Self { Self::new(T::from_parse_request_error(err)) } } diff --git a/poem-openapi/tests/api.rs b/poem-openapi/tests/api.rs index 377b62ea..eb4b591b 100644 --- a/poem-openapi/tests/api.rs +++ b/poem-openapi/tests/api.rs @@ -1,14 +1,14 @@ use poem::{ http::{Method, StatusCode, Uri}, web::Data, - Endpoint, EndpointExt, IntoEndpoint, + Endpoint, EndpointExt, Error, IntoEndpoint, }; use poem_openapi::{ param::Query, payload::{Binary, Json, PlainText}, registry::{MetaApi, MetaSchema}, types::Type, - ApiRequest, ApiResponse, OpenApi, OpenApiService, ParseRequestError, Tags, + ApiRequest, ApiResponse, OpenApi, OpenApiService, Tags, }; #[tokio::test] @@ -33,7 +33,8 @@ async fn path_and_method() { .uri(Uri::from_static("/abc")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -112,7 +113,8 @@ async fn common_attributes() { .uri(Uri::from_static("/hello/world")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -178,7 +180,8 @@ async fn request() { .content_type("application/json") .body("100"), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); let resp = ep @@ -189,7 +192,8 @@ async fn request() { .content_type("text/plain") .body("abc"), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); let resp = ep @@ -200,7 +204,8 @@ async fn request() { .content_type("application/octet-stream") .body(vec![1, 2, 3]), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -232,10 +237,11 @@ async fn payload_request() { .content_type("application/json") .body("100"), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); - let resp = ep + let err = ep .call( poem::Request::builder() .method(Method::POST) @@ -243,8 +249,9 @@ async fn payload_request() { .content_type("text/plain") .body("100"), ) - .await; - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + .await + .unwrap_err(); + assert_eq!(err.status(), StatusCode::METHOD_NOT_ALLOWED); } #[tokio::test] @@ -275,7 +282,8 @@ async fn optional_payload_request() { .content_type("application/json") .body("100"), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "100"); @@ -288,7 +296,8 @@ async fn optional_payload_request() { .content_type("application/json") .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "999"); } @@ -361,7 +370,8 @@ async fn response() { .uri(Uri::from_static("/?code=200")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.take_body().into_string().await.unwrap(), ""); @@ -372,7 +382,8 @@ async fn response() { .uri(Uri::from_static("/?code=409")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::CONFLICT); assert_eq!(resp.content_type(), Some("application/json")); assert_eq!(resp.take_body().into_string().await.unwrap(), "409"); @@ -384,7 +395,8 @@ async fn response() { .uri(Uri::from_static("/?code=404")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.content_type(), Some("text/plain")); assert_eq!(resp.take_body().into_string().await.unwrap(), "code: 404"); @@ -403,7 +415,7 @@ async fn bad_request_handler() { BadRequest(PlainText), } - fn bad_request_handler(err: ParseRequestError) -> MyResponse { + fn bad_request_handler(err: Error) -> MyResponse { MyResponse::BadRequest(PlainText(format!("!!! {}", err.to_string()))) } @@ -426,7 +438,8 @@ async fn bad_request_handler() { .uri(Uri::from_static("/?code=200")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.content_type(), Some("text/plain")); assert_eq!(resp.take_body().into_string().await.unwrap(), "code: 200"); @@ -438,12 +451,13 @@ async fn bad_request_handler() { .uri(Uri::from_static("/")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.content_type(), Some("text/plain")); assert_eq!( resp.take_body().into_string().await.unwrap(), - r#"!!! Failed to parse parameter `code`: Type "integer(uint16)" expects an input value."# + r#"!!! failed to parse parameter `code`: Type "integer(uint16)" expects an input value."# ); } @@ -460,7 +474,7 @@ async fn bad_request_handler_for_validator() { BadRequest(PlainText), } - fn bad_request_handler(err: ParseRequestError) -> MyResponse { + fn bad_request_handler(err: Error) -> MyResponse { MyResponse::BadRequest(PlainText(format!("!!! {}", err.to_string()))) } @@ -486,7 +500,8 @@ async fn bad_request_handler_for_validator() { .uri(Uri::from_static("/?code=50")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.content_type(), Some("text/plain")); assert_eq!(resp.take_body().into_string().await.unwrap(), "code: 50"); @@ -498,12 +513,13 @@ async fn bad_request_handler_for_validator() { .uri(Uri::from_static("/?code=200")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.content_type(), Some("text/plain")); assert_eq!( resp.take_body().into_string().await.unwrap(), - r#"!!! Failed to parse parameter `code`: verification failed. maximum(100, exclusive: false)"# + r#"!!! failed to parse parameter `code`: verification failed. maximum(100, exclusive: false)"# ); } @@ -529,7 +545,8 @@ async fn poem_extract() { .uri(Uri::from_static("/")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -583,7 +600,8 @@ async fn returning_borrowed_value() { .uri(Uri::from_static("/value1")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "999"); @@ -594,7 +612,8 @@ async fn returning_borrowed_value() { .uri(Uri::from_static("/value2")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "\"abc\""); @@ -605,7 +624,8 @@ async fn returning_borrowed_value() { .uri(Uri::from_static("/value3")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "888"); @@ -616,7 +636,8 @@ async fn returning_borrowed_value() { .uri(Uri::from_static("/values")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "[1,2,3,4,5]"); } diff --git a/poem-openapi/tests/hygiene.rs b/poem-openapi/tests/hygiene.rs index 7360531d..251fa377 100644 --- a/poem-openapi/tests/hygiene.rs +++ b/poem-openapi/tests/hygiene.rs @@ -39,7 +39,7 @@ enum CreateUserResponse { BadRequest(::poem_openapi::payload::PlainText<::std::string::String>), } -fn bad_request_handler(err: ::poem_openapi::ParseRequestError) -> CreateUserResponse { +fn bad_request_handler(err: ::poem::Error) -> CreateUserResponse { CreateUserResponse::BadRequest(::poem_openapi::payload::PlainText(::std::format!( "error: {}", ::std::string::ToString::to_string(&err) diff --git a/poem-openapi/tests/multipart.rs b/poem-openapi/tests/multipart.rs index bb12111b..6d8fd4b8 100644 --- a/poem-openapi/tests/multipart.rs +++ b/poem-openapi/tests/multipart.rs @@ -1,6 +1,6 @@ use std::io::Write; -use poem::{IntoResponse, Request, RequestBody}; +use poem::{Request, RequestBody}; use poem_openapi::{ payload::{ParsePayload, Payload}, registry::{MetaSchema, MetaSchemaRef}, @@ -8,7 +8,7 @@ use poem_openapi::{ multipart::{JsonField, Upload}, Binary, }, - Enum, Multipart, Object, ParseRequestError, + Enum, Multipart, Object, }; fn create_multipart_payload(parts: &[(&str, Option<&str>, &[u8])]) -> Vec { @@ -103,15 +103,10 @@ async fn required_fields() { .await .unwrap_err(); - match err { - ParseRequestError::ParseRequestBody(resp) => { - assert_eq!( - resp.into_body().into_string().await.unwrap(), - "field `file` is required" - ); - } - _ => panic!(), - } + assert_eq!( + err.to_string(), + "parse multipart error: field `file` is required" + ); } #[tokio::test] @@ -273,15 +268,10 @@ async fn validator() { .await .unwrap_err(); - match err { - ParseRequestError::ParseRequestBody(resp) => { - assert_eq!( - resp.into_body().into_string().await.unwrap(), - r#"field `value` verification failed. maximum(32, exclusive: false)"# - ); - } - _ => panic!(), - } + assert_eq!( + err.to_string(), + "parse multipart error: field `value` verification failed. maximum(32, exclusive: false)" + ); } #[tokio::test] @@ -427,16 +417,10 @@ async fn repeated_error() { ) .await .unwrap_err(); - - match err { - ParseRequestError::ParseRequestBody(resp) => { - assert_eq!( - resp.into_body().into_string().await.unwrap(), - "failed to parse field `value`: failed to parse \"string\": repeated field" - ); - } - _ => panic!(), - } + assert_eq!( + err.to_string(), + "parse multipart error: failed to parse field `value`: failed to parse \"string\": repeated field" + ); } #[test] @@ -531,8 +515,5 @@ async fn deny_unknown_fields() { ) .await .unwrap_err(); - assert_eq!( - err.into_response().into_body().into_string().await.unwrap(), - "unknown field `c`" - ); + assert_eq!(err.to_string(), "parse multipart error: unknown field `c`"); } diff --git a/poem-openapi/tests/operation_param.rs b/poem-openapi/tests/operation_param.rs index 9f7b39e2..a7e47d71 100644 --- a/poem-openapi/tests/operation_param.rs +++ b/poem-openapi/tests/operation_param.rs @@ -38,7 +38,8 @@ async fn param_name() { .uri(Uri::from_static("/abc?a=10")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -64,7 +65,8 @@ async fn query() { let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); let resp = api .call(Request::builder().uri(Uri::from_static("/?v=10")).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -101,7 +103,8 @@ async fn query_multiple_values() { .uri(Uri::from_static("/?v=10&v=20&v=30")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -133,7 +136,7 @@ async fn query_default() { ); let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); - let resp = api.call(Request::default()).await; + let resp = api.call(Request::default()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -150,7 +153,10 @@ async fn header() { } let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); - let resp = api.call(Request::builder().header("v", 10).finish()).await; + let resp = api + .call(Request::builder().header("v", 10).finish()) + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -175,7 +181,8 @@ async fn header_multiple_values() { .header("v", 30) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -192,7 +199,7 @@ async fn header_default() { } let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); - let resp = api.call(Request::default()).await; + let resp = api.call(Request::default()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -211,7 +218,8 @@ async fn path() { let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); let resp = api .call(Request::builder().uri(Uri::from_static("/k/10")).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -251,7 +259,8 @@ async fn cookie() { let resp = api .call(Request::builder().header(header::COOKIE, cookie).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -268,7 +277,7 @@ async fn cookie_default() { } let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); - let resp = api.call(Request::builder().finish()).await; + let resp = api.call(Request::builder().finish()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -356,6 +365,7 @@ async fn default_opt() { let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); let resp = api .call(Request::builder().uri(Uri::from_static("/")).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } diff --git a/poem-openapi/tests/payload.rs b/poem-openapi/tests/payload.rs index d165c498..ef672e91 100644 --- a/poem-openapi/tests/payload.rs +++ b/poem-openapi/tests/payload.rs @@ -1,11 +1,11 @@ use poem::{ http::{StatusCode, Uri}, - Endpoint, IntoEndpoint, Request, + Endpoint, Error, IntoEndpoint, Request, }; use poem_openapi::{ param::Query, payload::{Json, Response}, - ApiResponse, OpenApi, OpenApiService, ParseRequestError, + ApiResponse, OpenApi, OpenApiService, }; #[tokio::test] @@ -20,7 +20,7 @@ async fn response_wrapper() { BadRequest(#[oai(header = "MY-HEADER1")] String), } - fn bad_request_handler(_: ParseRequestError) -> CustomApiResponse { + fn bad_request_handler(_: Error) -> CustomApiResponse { CustomApiResponse::BadRequest("def".to_string()) } @@ -43,7 +43,8 @@ async fn response_wrapper() { let resp = ep .call(Request::builder().uri(Uri::from_static("/a")).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.header("myheader"), Some("abc")); @@ -53,13 +54,15 @@ async fn response_wrapper() { .uri(Uri::from_static("/b?p1=qwe")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.header("myheader"), Some("qwe")); let resp = ep .call(Request::builder().uri(Uri::from_static("/b")).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.header("MY-HEADER1"), Some("def")); } diff --git a/poem-openapi/tests/response.rs b/poem-openapi/tests/response.rs index eec08431..db0a4b02 100644 --- a/poem-openapi/tests/response.rs +++ b/poem-openapi/tests/response.rs @@ -2,13 +2,13 @@ mod request; use poem::{ http::{HeaderValue, StatusCode}, - IntoResponse, Response, + Error, IntoResponse, }; use poem_openapi::{ payload::{Json, PlainText}, registry::{MetaMediaType, MetaResponse, MetaResponses, MetaSchema, MetaSchemaRef}, types::ToJSON, - ApiResponse, Object, ParseRequestError, + ApiResponse, Object, }; use serde_json::Value; @@ -220,13 +220,13 @@ async fn bad_request_handler() { BadRequest, } - fn bad_request_handler(_: ParseRequestError) -> CustomApiResponse { + fn bad_request_handler(_: Error) -> CustomApiResponse { CustomApiResponse::BadRequest } assert_eq!( - CustomApiResponse::from_parse_request_error(ParseRequestError::ParseRequestBody( - Response::default() + CustomApiResponse::from_parse_request_error(Error::new_with_status( + StatusCode::BAD_GATEWAY )), CustomApiResponse::BadRequest ); diff --git a/poem-openapi/tests/security_scheme.rs b/poem-openapi/tests/security_scheme.rs index e14dbfce..74e70936 100644 --- a/poem-openapi/tests/security_scheme.rs +++ b/poem-openapi/tests/security_scheme.rs @@ -118,7 +118,8 @@ async fn basic_auth() { ) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.take_body().into_string().await.unwrap(), "abc/123456"); } @@ -166,7 +167,8 @@ async fn bearer_auth() { ) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef"); } @@ -269,7 +271,8 @@ async fn api_key_auth() { .header("X-API-Key", "abcdef") .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef"); @@ -279,7 +282,8 @@ async fn api_key_auth() { .uri(Uri::from_static("/query?key=abcdef")) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef"); @@ -293,7 +297,8 @@ async fn api_key_auth() { ) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.take_body().into_string().await.unwrap(), "abcdef"); } diff --git a/poem-openapi/tests/validation.rs b/poem-openapi/tests/validation.rs index 23c9f455..4928c6d6 100644 --- a/poem-openapi/tests/validation.rs +++ b/poem-openapi/tests/validation.rs @@ -282,13 +282,14 @@ async fn param_validator() { } let api = OpenApiService::new(Api, "test", "1.0").into_endpoint(); - let mut resp = api + let err = api .call(Request::builder().uri(Uri::from_static("/?v=999")).finish()) - .await; - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + .await + .unwrap_err(); + assert_eq!(err.status(), StatusCode::BAD_REQUEST); assert_eq!( - resp.take_body().into_string().await.unwrap(), - "Failed to parse parameter `v`: verification failed. maximum(100, exclusive: true)" + err.to_string(), + "failed to parse parameter `v`: verification failed. maximum(100, exclusive: true)" ); let meta: MetaApi = Api::meta().remove(0); @@ -309,7 +310,8 @@ async fn param_validator() { let resp = api .call(Request::builder().uri(Uri::from_static("/?v=50")).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } diff --git a/poem/CHANGELOG.md b/poem/CHANGELOG.md index f1069cdb..e02d5c42 100644 --- a/poem/CHANGELOG.md +++ b/poem/CHANGELOG.md @@ -4,9 +4,20 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# [1.2.0] 2021-12-16 + +## Breaking changes + +- Refactor error handling. +- The return value type of the `Endpoint::call` function is changed from `Self::Output` to `Result`. +- Remove the associated type `Error` from `FromRequest`. +- The return value of the `FromRequest::from_request` function is changed from `Result` to `Result`. +- Add some helper methods to `EndpointExt`. + # [1.1.1] 2021-12-13 - Add `Body::from_bytes_stream` and `Body::to_bytes_stream` methods. +- Remove the `BinaryStream` type, use `poem::Body` instead. # [1.1.0] 2021-12-13 diff --git a/poem/Cargo.toml b/poem/Cargo.toml index 6467b45a..08fc3104 100644 --- a/poem/Cargo.toml +++ b/poem/Cargo.toml @@ -81,6 +81,7 @@ tokio-native-tls = { version = "0.3.0", optional = true } sha1 = { version = "0.6.0", optional = true } base64 = { version = "0.13.0", optional = true } libcsrf = { package = "csrf", version = "0.4.1", optional = true } +thiserror = "1.0.30" # Feature optional dependencies diff --git a/poem/src/body.rs b/poem/src/body.rs index 7b98c7d6..271886a0 100644 --- a/poem/src/body.rs +++ b/poem/src/body.rs @@ -11,7 +11,10 @@ use hyper::body::HttpBody; use serde::{de::DeserializeOwned, Serialize}; use tokio::io::AsyncRead; -use crate::error::{ParseJsonError, ReadBodyError}; +use crate::{ + error::{ParseJsonError, ReadBodyError}, + Result, +}; /// A body object for requests and responses. #[derive(Default)] @@ -141,8 +144,13 @@ impl Body { } /// Consumes this body object and parse it as `T`. - pub async fn into_json(self) -> Result { - Ok(serde_json::from_slice(&self.into_vec().await?)?) + /// + /// # Errors + /// + /// - [`ReadBodyError`] + /// - [`ParseJsonError`] + pub async fn into_json(self) -> Result { + Ok(serde_json::from_slice(&self.into_vec().await?).map_err(ParseJsonError)?) } /// Consumes this body object to return a reader. diff --git a/poem/src/endpoint/after.rs b/poem/src/endpoint/after.rs index 6bc0749c..5612662f 100644 --- a/poem/src/endpoint/after.rs +++ b/poem/src/endpoint/after.rs @@ -1,6 +1,6 @@ use std::future::Future; -use crate::{Endpoint, IntoResponse, Request}; +use crate::{Endpoint, IntoResponse, Request, Result}; /// Endpoint for the [`after`](super::EndpointExt::after) method. pub struct After { @@ -16,16 +16,16 @@ impl After { } #[async_trait::async_trait] -impl Endpoint for After +impl Endpoint for After where E: Endpoint, - F: Fn(E::Output) -> Fut + Send + Sync, - Fut: Future + Send, - R: IntoResponse, + F: Fn(Result) -> Fut + Send + Sync, + Fut: Future> + Send, + T: IntoResponse, { - type Output = R; + type Output = T; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { (self.f)(self.inner.call(req).await).await } } diff --git a/poem/src/endpoint/and_then.rs b/poem/src/endpoint/and_then.rs index 2a4a5047..5f70f77f 100644 --- a/poem/src/endpoint/and_then.rs +++ b/poem/src/endpoint/and_then.rs @@ -16,21 +16,18 @@ impl AndThen { } #[async_trait::async_trait] -impl Endpoint for AndThen +impl Endpoint for AndThen where - E: Endpoint>, + E: Endpoint, F: Fn(R) -> Fut + Send + Sync, - Fut: Future> + Send, - Err: IntoResponse, + Fut: Future> + Send, R: IntoResponse, R2: IntoResponse, { - type Output = Result; + type Output = R2; - async fn call(&self, req: Request) -> Self::Output { - match self.inner.call(req).await { - Ok(resp) => (self.f)(resp).await, - Err(err) => Err(err), - } + async fn call(&self, req: Request) -> Result { + let resp = self.inner.call(req).await?; + (self.f)(resp).await } } diff --git a/poem/src/endpoint/around.rs b/poem/src/endpoint/around.rs index 36787de6..21a2552a 100644 --- a/poem/src/endpoint/around.rs +++ b/poem/src/endpoint/around.rs @@ -1,6 +1,6 @@ use std::{future::Future, sync::Arc}; -use crate::{Endpoint, IntoResponse, Request}; +use crate::{Endpoint, IntoResponse, Request, Result}; /// Endpoint for the [`around`](super::EndpointExt::around) method. pub struct Around { @@ -19,16 +19,16 @@ impl Around { } #[async_trait::async_trait] -impl Endpoint for Around +impl Endpoint for Around where E: Endpoint, F: Fn(Arc, Request) -> Fut + Send + Sync + 'static, - Fut: Future + Send, - R: IntoResponse, + Fut: Future> + Send, + T: IntoResponse, { - type Output = R; + type Output = T; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { (self.f)(self.inner.clone(), req).await } } diff --git a/poem/src/endpoint/before.rs b/poem/src/endpoint/before.rs index 59e23271..3aa652a6 100644 --- a/poem/src/endpoint/before.rs +++ b/poem/src/endpoint/before.rs @@ -1,6 +1,6 @@ use std::future::Future; -use crate::{Endpoint, Request}; +use crate::{Endpoint, Request, Result}; /// Endpoint for the [`before`](super::EndpointExt::before) method. pub struct Before { @@ -20,11 +20,11 @@ impl Endpoint for Before where E: Endpoint, F: Fn(Request) -> Fut + Send + Sync, - Fut: Future + Send, + Fut: Future> + Send, { type Output = E::Output; - async fn call(&self, req: Request) -> Self::Output { - self.inner.call((self.f)(req).await).await + async fn call(&self, req: Request) -> Result { + self.inner.call((self.f)(req).await?).await } } diff --git a/poem/src/endpoint/catch_error.rs b/poem/src/endpoint/catch_error.rs new file mode 100644 index 00000000..681a0a10 --- /dev/null +++ b/poem/src/endpoint/catch_error.rs @@ -0,0 +1,42 @@ +use std::{future::Future, marker::PhantomData}; + +use crate::{Endpoint, IntoResponse, Request, Response, Result}; + +/// Endpoint for the [`catch_error`](super::EndpointExt::catch_error) method. +pub struct CatchError { + inner: E, + f: F, + _mark: PhantomData, +} + +impl CatchError { + #[inline] + pub(crate) fn new(inner: E, f: F) -> CatchError { + Self { + inner, + f, + _mark: PhantomData, + } + } +} + +#[async_trait::async_trait] +impl Endpoint for CatchError +where + E: Endpoint, + F: Fn(ErrType) -> Fut + Send + Sync, + Fut: Future + Send, + ErrType: std::error::Error + Send + Sync + 'static, +{ + type Output = Response; + + async fn call(&self, req: Request) -> Result { + match self.inner.call(req).await { + Ok(resp) => Ok(resp.into_response()), + Err(err) if err.is::() => { + Ok((self.f)(err.downcast::().unwrap()).await) + } + Err(err) => Err(err), + } + } +} diff --git a/poem/src/endpoint/endpoint.rs b/poem/src/endpoint/endpoint.rs index 438d14bd..4891a93a 100644 --- a/poem/src/endpoint/endpoint.rs +++ b/poem/src/endpoint/endpoint.rs @@ -1,10 +1,10 @@ -use std::{future::Future, sync::Arc}; +use std::{future::Future, marker::PhantomData, sync::Arc}; -use super::{After, AndThen, Before, MapErr, MapOk, MapToResponse, MapToResult}; +use super::{After, AndThen, Around, Before, CatchError, InspectError, Map, MapToResponse}; use crate::{ - endpoint::Around, + error::IntoResult, middleware::{AddData, AddDataEndpoint}, - IntoResponse, Middleware, Request, Response, Result, + Error, IntoResponse, Middleware, Request, Response, Result, }; /// An HTTP request handler. @@ -14,37 +14,45 @@ pub trait Endpoint: Send + Sync { type Output: IntoResponse; /// Get the response to the request. - async fn call(&self, req: Request) -> Self::Output; + async fn call(&self, req: Request) -> Result; } -struct SyncFnEndpoint(F); +struct SyncFnEndpoint { + _mark: PhantomData, + f: F, +} #[async_trait::async_trait] -impl Endpoint for SyncFnEndpoint +impl Endpoint for SyncFnEndpoint where F: Fn(Request) -> R + Send + Sync, - R: IntoResponse, + T: IntoResponse + Sync, + R: IntoResult, { - type Output = R; + type Output = T; - async fn call(&self, req: Request) -> Self::Output { - (self.0)(req) + async fn call(&self, req: Request) -> Result { + (self.f)(req).into_result() } } -struct AsyncFnEndpoint(F); +struct AsyncFnEndpoint { + _mark: PhantomData, + f: F, +} #[async_trait::async_trait] -impl Endpoint for AsyncFnEndpoint +impl Endpoint for AsyncFnEndpoint where F: Fn(Request) -> Fut + Sync + Send, Fut: Future + Send, - R: IntoResponse, + T: IntoResponse + Sync, + R: IntoResult, { - type Output = R; + type Output = T; - async fn call(&self, req: Request) -> Self::Output { - (self.0)(req).await + async fn call(&self, req: Request) -> Result { + (self.f)(req).await.into_result() } } @@ -62,16 +70,18 @@ where { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { match self { - EitherEndpoint::A(a) => a.call(req).await.into_response(), - EitherEndpoint::B(b) => b.call(req).await.into_response(), + EitherEndpoint::A(a) => a.call(req).await.map(IntoResponse::into_response), + EitherEndpoint::B(b) => b.call(req).await.map(IntoResponse::into_response), } } } /// Create an endpoint with a function. /// +/// The output can be any type that implements [`IntoResult`]. +/// /// # Example /// /// ``` @@ -82,20 +92,27 @@ where /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let resp = ep /// .call(Request::builder().method(Method::GET).finish()) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp, "GET"); /// # }); /// ``` -pub fn make_sync(f: F) -> impl Endpoint +pub fn make_sync(f: F) -> impl Endpoint where F: Fn(Request) -> R + Send + Sync, - R: IntoResponse, + T: IntoResponse + Sync, + R: IntoResult, { - SyncFnEndpoint(f) + SyncFnEndpoint { + _mark: PhantomData, + f, + } } /// Create an endpoint with a asyncness function. /// +/// The output can be any type that implements [`IntoResult`]. +/// /// # Example /// /// ``` @@ -106,24 +123,29 @@ where /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let resp = ep /// .call(Request::builder().method(Method::GET).finish()) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp, "GET"); /// # }); /// ``` -pub fn make(f: F) -> impl Endpoint +pub fn make(f: F) -> impl Endpoint where F: Fn(Request) -> Fut + Send + Sync, Fut: Future + Send, - R: IntoResponse, + T: IntoResponse + Sync, + R: IntoResult, { - AsyncFnEndpoint(f) + AsyncFnEndpoint { + _mark: PhantomData, + f, + } } #[async_trait::async_trait] impl Endpoint for &T { type Output = T::Output; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { T::call(self, req).await } } @@ -132,7 +154,7 @@ impl Endpoint for &T { impl Endpoint for Box { type Output = T::Output; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { self.as_ref().call(req).await } } @@ -141,7 +163,7 @@ impl Endpoint for Box { impl Endpoint for Arc { type Output = T::Output; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { self.as_ref().call(req).await } } @@ -177,7 +199,7 @@ pub trait EndpointExt: IntoEndpoint { /// /// let app = Route::new().at("/", get(index)).with(AddData::new(100i32)); /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let resp = app.call(Request::default()).await; + /// let resp = app.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "100"); /// # }); @@ -218,13 +240,15 @@ pub trait EndpointExt: IntoEndpoint { /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let resp = app /// .call(Request::builder().uri(Uri::from_static("/a")).finish()) - /// .await; + /// .await + /// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "100"); /// /// let resp = app /// .call(Request::builder().uri(Uri::from_static("/b")).finish()) - /// .await; + /// .await + /// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "none"); /// # }); @@ -254,7 +278,7 @@ pub trait EndpointExt: IntoEndpoint { /// } /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let mut resp = index.data(100i32).call(Request::default()).await; + /// let mut resp = index.data(100i32).call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.take_body().into_string().await.unwrap(), "100"); /// # }); @@ -283,17 +307,18 @@ pub trait EndpointExt: IntoEndpoint { /// let mut resp = index /// .before(|mut req| async move { /// req.set_body("abc"); - /// req + /// Ok(req) /// }) /// .call(Request::default()) - /// .await; + /// .await + /// .unwrap(); /// assert_eq!(resp.take_body().into_string().await.unwrap(), "abc"); /// # }); /// ``` fn before(self, f: F) -> Before where F: Fn(Request) -> Fut + Send + Sync, - Fut: Future + Send, + Fut: Future> + Send, Self: Sized, { Before::new(self, f) @@ -313,17 +338,23 @@ pub trait EndpointExt: IntoEndpoint { /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let mut resp = index - /// .after(|mut resp| async move { resp.take_body().into_string().await.unwrap() + "def" }) + /// .after(|res| async move { + /// match res { + /// Ok(resp) => Ok(resp.into_body().into_string().await.unwrap() + "def"), + /// Err(err) => Err(err), + /// } + /// }) /// .call(Request::default()) - /// .await; + /// .await + /// .unwrap(); /// assert_eq!(resp, "abcdef"); /// # }); /// ``` - fn after(self, f: F) -> After + fn after(self, f: F) -> After where - F: Fn(::Output) -> Fut + Send + Sync, - Fut: Future + Send, - R: IntoResponse, + F: Fn(Result<::Output>) -> Fut + Send + Sync, + Fut: Future> + Send, + T: IntoResponse, Self: Sized, { After::new(self.into_endpoint(), f) @@ -355,18 +386,19 @@ pub trait EndpointExt: IntoEndpoint { /// .around(|ep, mut req| async move { /// req.headers_mut() /// .insert("x-value", HeaderValue::from_static("hello")); - /// let mut resp = ep.call(req).await; - /// resp.take_body().into_string().await.unwrap() + "world" + /// let mut resp = ep.call(req).await?; + /// Ok(resp.take_body().into_string().await.unwrap() + "world") /// }) /// .call(Request::default()) - /// .await; + /// .await + /// .unwrap(); /// assert_eq!(resp, "hello,world"); /// # }); /// ``` fn around(self, f: F) -> Around where F: Fn(Arc, Request) -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, + Fut: Future> + Send + 'static, R: IntoResponse, Self: Sized, { @@ -383,12 +415,16 @@ pub trait EndpointExt: IntoEndpoint { /// endpoint::make, http::StatusCode, Endpoint, EndpointExt, Error, Request, Response, Result, /// }; /// - /// let ep = - /// make(|_| async { Err::<(), Error>(Error::new(StatusCode::BAD_REQUEST)) }).map_to_response(); + /// let ep1 = make(|_| async { "hello" }).map_to_response(); + /// let ep2 = make(|_| async { Err::<(), Error>(Error::new_with_status(StatusCode::BAD_REQUEST)) }) + /// .map_to_response(); /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let resp: Response = ep.call(Request::default()).await; - /// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + /// let resp = ep1.call(Request::default()).await.unwrap(); + /// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); + /// + /// let err = ep2.call(Request::default()).await.unwrap_err(); + /// assert_eq!(err.status(), StatusCode::BAD_REQUEST); /// # }); /// ``` fn map_to_response(self) -> MapToResponse @@ -398,8 +434,7 @@ pub trait EndpointExt: IntoEndpoint { MapToResponse::new(self.into_endpoint()) } - /// Convert the output of this endpoint into a result `Result`. - /// [`Response`](crate::Response). + /// Maps the response of this endpoint. /// /// # Example /// @@ -408,19 +443,23 @@ pub trait EndpointExt: IntoEndpoint { /// endpoint::make, http::StatusCode, Endpoint, EndpointExt, Error, Request, Response, Result, /// }; /// - /// let ep = make(|_| async { Response::builder().status(StatusCode::BAD_REQUEST).finish() }) - /// .map_to_result(); + /// let ep = make(|_| async { "hello" }).map(|value| async move { format!("{}, world!", value) }); /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let resp: Result = ep.call(Request::default()).await; - /// assert_eq!(resp.unwrap_err().status(), StatusCode::BAD_REQUEST); + /// let mut resp: String = ep.call(Request::default()).await.unwrap(); + /// assert_eq!(resp, "hello, world!"); /// # }); /// ``` - fn map_to_result(self) -> MapToResult + fn map(self, f: F) -> Map where + F: Fn(R) -> Fut + Send + Sync, + Fut: Future + Send, + R: IntoResponse, + R2: IntoResponse, Self: Sized, + Self::Endpoint: Endpoint + Sized, { - MapToResult::new(self.into_endpoint()) + Map::new(self.into_endpoint(), f) } /// Calls `f` if the result is `Ok`, otherwise returns the `Err` value of @@ -433,9 +472,9 @@ pub trait EndpointExt: IntoEndpoint { /// endpoint::make, http::StatusCode, Endpoint, EndpointExt, Error, Request, Response, Result, /// }; /// - /// let ep1 = make(|_| async { Ok::<_, Error>("hello") }) + /// let ep1 = make(|_| async { "hello" }) /// .and_then(|value| async move { Ok(format!("{}, world!", value)) }); - /// let ep2 = make(|_| async { Err::(Error::new(StatusCode::BAD_REQUEST)) }) + /// let ep2 = make(|_| async { Err::(Error::new_with_status(StatusCode::BAD_REQUEST)) }) /// .and_then(|value| async move { Ok(format!("{}, world!", value)) }); /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { @@ -446,87 +485,82 @@ pub trait EndpointExt: IntoEndpoint { /// assert_eq!(err.status(), StatusCode::BAD_REQUEST); /// # }); /// ``` - fn and_then(self, f: F) -> AndThen + fn and_then(self, f: F) -> AndThen where F: Fn(R) -> Fut + Send + Sync, - Fut: Future> + Send, - Err: IntoResponse, + Fut: Future> + Send, R: IntoResponse, R2: IntoResponse, Self: Sized, - Self::Endpoint: Endpoint> + Sized, + Self::Endpoint: Endpoint + Sized, { AndThen::new(self.into_endpoint(), f) } - /// Maps the response of this endpoint. + /// Catch the specified type of error and convert it into a response. /// /// # Example /// /// ``` + /// use http::Uri; /// use poem::{ - /// endpoint::make, http::StatusCode, Endpoint, EndpointExt, Error, Request, Response, Result, + /// error::NotFoundError, handler, http::StatusCode, Endpoint, EndpointExt, Request, Response, + /// Route, /// }; /// - /// let ep = - /// make(|_| async { Ok("hello") }).map_ok(|value| async move { format!("{}, world!", value) }); + /// #[handler] + /// async fn index() {} /// - /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let mut resp: String = ep.call(Request::default()).await.unwrap(); - /// assert_eq!(resp, "hello, world!"); - /// # }); - /// ``` - fn map_ok(self, f: F) -> MapOk - where - F: Fn(R) -> Fut + Send + Sync, - Fut: Future + Send, - R: IntoResponse, - R2: IntoResponse, - Self: Sized, - Self::Endpoint: Endpoint> + Sized, - { - MapOk::new(self.into_endpoint(), f) - } - - /// Maps the error of this endpoint. + /// let app = Route::new().at("/index", index).catch_error(custom_404); /// - /// # Example - /// - /// ``` - /// use poem::{ - /// endpoint::make, http::StatusCode, Endpoint, EndpointExt, Error, IntoResponse, Request, - /// Response, Result, - /// }; - /// - /// struct CustomError; - /// - /// impl IntoResponse for CustomError { - /// fn into_response(self) -> Response { - /// Response::builder() - /// .status(StatusCode::UNAUTHORIZED) - /// .finish() - /// } + /// async fn custom_404(_: NotFoundError) -> Response { + /// Response::builder() + /// .status(StatusCode::NOT_FOUND) + /// .body("custom not found") /// } /// - /// let ep = make(|_| async { Err::<(), _>(CustomError) }) - /// .map_err(|_| async move { Error::new(StatusCode::INTERNAL_SERVER_ERROR) }); - /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let err = ep.call(Request::default()).await.unwrap_err(); - /// assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR); - /// # }); + /// let resp = app + /// .call(Request::builder().uri(Uri::from_static("/abc")).finish()) + /// .await + /// .unwrap(); + /// assert_eq!(resp.status(), StatusCode::NOT_FOUND); + /// assert_eq!( + /// resp.into_body().into_string().await.unwrap(), + /// "custom not found" + /// ); + /// # }) /// ``` - fn map_err(self, f: F) -> MapErr + fn catch_error(self, f: F) -> CatchError where - F: Fn(InErr) -> Fut + Send + Sync, - Fut: Future + Send, - InErr: IntoResponse, - OutErr: IntoResponse, - R: IntoResponse, + F: Fn(ErrType) -> Fut + Send + Sync, + Fut: Future + Send, + ErrType: std::error::Error + Send + Sync + 'static, Self: Sized, - Self::Endpoint: Endpoint> + Sized, { - MapErr::new(self.into_endpoint(), f) + CatchError::new(self, f) + } + + /// Does something with each error. + /// + /// # Example + /// + /// ``` + /// use poem::{handler, EndpointExt, Route}; + /// + /// #[handler] + /// fn index() {} + /// + /// let app = Route::new().at("/", index).inspect_err(|err| { + /// println!("error: {}", err); + /// }); + /// ``` + fn inspect_err(self, f: F) -> InspectError + where + F: Fn(&Error) + Send + Sync, + Self: Sized, + { + InspectError::new(self, f) } } @@ -565,7 +599,8 @@ mod test { let ep = make(|req| async move { format!("method={}", req.method()) }).map_to_response(); let mut resp = ep .call(Request::builder().method(Method::DELETE).finish()) - .await; + .await + .unwrap(); assert_eq!( resp.take_body().into_string().await.unwrap(), "method=DELETE" @@ -578,10 +613,11 @@ mod test { make_sync(|req| req.method().to_string()) .before(|mut req| async move { req.set_method(Method::POST); - req + Ok(req) }) .call(Request::default()) - .await, + .await + .unwrap(), "POST" ); } @@ -590,18 +626,19 @@ mod test { async fn test_after() { assert_eq!( make_sync(|_| "abc") - .after(|_| async { "def" }) + .after(|_| async { Ok::<_, Error>("def") }) .call(Request::default()) - .await, + .await + .unwrap(), "def" ); } #[tokio::test] - async fn test_map_to_result() { + async fn test_map_to_response() { assert_eq!( - make_sync(|_| Response::builder().status(StatusCode::OK).body("abc")) - .map_to_result() + make_sync(|_| "abc") + .map_to_response() .call(Request::default()) .await .unwrap() @@ -611,51 +648,20 @@ mod test { .unwrap(), "abc" ); - - let err = make_sync(|_| Response::builder().status(StatusCode::BAD_REQUEST).finish()) - .map_to_result() - .call(Request::default()) - .await - .unwrap_err(); - assert_eq!(err.status(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn test_map_to_response() { - assert_eq!( - make_sync(|_| Ok::<_, Error>("abc")) - .map_to_response() - .call(Request::default()) - .await - .take_body() - .into_string() - .await - .unwrap(), - "abc" - ); - - assert_eq!( - make_sync(|_| Err::<(), Error>(Error::new(StatusCode::BAD_REQUEST))) - .map_to_response() - .call(Request::default()) - .await - .status(), - StatusCode::BAD_REQUEST - ); } #[tokio::test] async fn test_and_then() { assert_eq!( - make_sync(|_| Ok("abc")) - .and_then(|resp| async move { Ok::<_, Error>(resp.to_string() + "def") }) + make_sync(|_| "abc") + .and_then(|resp| async move { Ok(resp.to_string() + "def") }) .call(Request::default()) .await .unwrap(), "abcdef" ); - let err = make_sync(|_| Err::(Error::new(StatusCode::BAD_REQUEST))) + let err = make_sync(|_| Err::(Error::new_with_status(StatusCode::BAD_REQUEST))) .and_then(|resp| async move { Ok(resp + "def") }) .call(Request::default()) .await @@ -664,54 +670,36 @@ mod test { } #[tokio::test] - async fn test_map_ok() { + async fn test_map() { assert_eq!( - make_sync(|_| Ok("abc")) - .map_ok(|resp| async move { resp.to_string() + "def" }) + make_sync(|_| "abc") + .map(|resp| async move { resp.to_string() + "def" }) .call(Request::default()) .await .unwrap(), "abcdef" ); - let err = make_sync(|_| Err::(Error::new(StatusCode::BAD_REQUEST))) - .map_ok(|resp| async move { resp.to_string() + "def" }) + let err = make_sync(|_| Err::(Error::new_with_status(StatusCode::BAD_REQUEST))) + .map(|resp| async move { resp.to_string() + "def" }) .call(Request::default()) .await .unwrap_err(); assert_eq!(err.status(), StatusCode::BAD_REQUEST); } - #[tokio::test] - async fn test_map_err() { - assert_eq!( - make_sync(|_| Ok::<_, Error>("abc")) - .map_err(|_| async move { Error::new(StatusCode::BAD_GATEWAY) }) - .call(Request::default()) - .await - .unwrap(), - "abc" - ); - - let err = make_sync(|_| Err::(Error::new(StatusCode::BAD_REQUEST))) - .map_err(|_| async move { Error::new(StatusCode::BAD_GATEWAY) }) - .call(Request::default()) - .await - .unwrap_err(); - assert_eq!(err.status(), StatusCode::BAD_GATEWAY); - } - #[tokio::test] async fn test_around() { let ep = make(|req| async move { req.into_body().into_string().await.unwrap() + "b" }); assert_eq!( ep.around(|ep, mut req| async move { req.set_body("a"); - let resp = ep.call(req).await; - resp + "c" + let resp = ep.call(req).await?; + Ok(resp + "c") }) .call(Request::default()) - .await, + .await + .unwrap(), "abc" ); } @@ -721,7 +709,8 @@ mod test { let resp = make_sync(|_| ()) .with_if(true, SetHeader::new().appending("a", 1)) .call(Request::default()) - .await; + .await + .unwrap(); assert_eq!( resp.headers().get("a"), Some(&HeaderValue::from_static("1")) @@ -730,7 +719,8 @@ mod test { let resp = make_sync(|_| ()) .with_if(false, SetHeader::new().appending("a", 1)) .call(Request::default()) - .await; + .await + .unwrap(); assert_eq!(resp.headers().get("a"), None); } @@ -753,6 +743,7 @@ mod test { assert_eq!( app.call(Request::builder().uri(Uri::from_static("/api/a")).finish()) .await + .unwrap() .take_body() .into_string() .await @@ -763,6 +754,7 @@ mod test { assert_eq!( app.call(Request::builder().uri(Uri::from_static("/api/b")).finish()) .await + .unwrap() .take_body() .into_string() .await diff --git a/poem/src/endpoint/files.rs b/poem/src/endpoint/files.rs index 1e87d5b8..3879098e 100644 --- a/poem/src/endpoint/files.rs +++ b/poem/src/endpoint/files.rs @@ -7,8 +7,9 @@ use mime::Mime; use tokio::fs::File; use crate::{ - http::{header, HeaderValue, Method, StatusCode}, - Body, Endpoint, Request, Response, + error::StaticFileError, + http::{header, HeaderValue, Method}, + Body, Endpoint, Request, Response, Result, }; struct DirectoryTemplate<'a> { @@ -61,6 +62,10 @@ struct FileRef { } /// Static files handling service. +/// +/// # Errors +/// +/// - [`StaticFileError`] pub struct Files { path: PathBuf, show_files_listing: bool, @@ -127,13 +132,10 @@ impl Files { } } -#[async_trait::async_trait] -impl Endpoint for Files { - type Output = Response; - - async fn call(&self, req: Request) -> Self::Output { +impl Files { + async fn internal_call(&self, req: Request) -> Result { if req.method() != Method::GET { - return StatusCode::METHOD_NOT_ALLOWED.into(); + return Err(StaticFileError::MethodNotAllowed(req.method().clone())); } let path = req @@ -142,10 +144,9 @@ impl Endpoint for Files { .trim_start_matches('/') .trim_end_matches('/'); - let path = match percent_encoding::percent_decode_str(path).decode_utf8() { - Ok(path) => path, - Err(_) => return StatusCode::BAD_REQUEST.into(), - }; + let path = percent_encoding::percent_decode_str(path) + .decode_utf8() + .map_err(|_| StaticFileError::InvalidPath)?; let mut file_path = self.path.clone(); for p in Path::new(&*path) { @@ -159,11 +160,11 @@ impl Endpoint for Files { } if !file_path.starts_with(&self.path) { - return StatusCode::FORBIDDEN.into(); + return Err(StaticFileError::Forbidden(file_path.display().to_string())); } if !file_path.exists() { - return StatusCode::NOT_FOUND.into(); + return Err(StaticFileError::NotFound(file_path.display().to_string())); } if file_path.is_file() { @@ -177,22 +178,14 @@ impl Endpoint for Files { } if self.show_files_listing { - let read_dir = match file_path.read_dir() { - Ok(d) => d, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into(), - }; + let read_dir = file_path.read_dir()?; let mut template = DirectoryTemplate { path: &*path, files: Vec::new(), }; for res in read_dir { - let entry = match res { - Ok(entry) => entry, - Err(err) => { - return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into() - } - }; + let entry = res?; if let Some(filename) = entry.file_name().to_str() { let mut base_url = req.original_uri().path().to_string(); @@ -208,22 +201,28 @@ impl Endpoint for Files { } let html = template.render(); - Response::builder() + Ok(Response::builder() .header(header::CONTENT_TYPE, mime::TEXT_HTML_UTF_8.as_ref()) - .body(Body::from_string(html)) + .body(Body::from_string(html))) } else { - StatusCode::NOT_FOUND.into() + Err(StaticFileError::NotFound(file_path.display().to_string())) } } } } -async fn create_file_response(path: &Path, prefer_utf8: bool) -> Response { +#[async_trait::async_trait] +impl Endpoint for Files { + type Output = Response; + + async fn call(&self, req: Request) -> Result { + self.internal_call(req).await.map_err(Into::into) + } +} + +async fn create_file_response(path: &Path, prefer_utf8: bool) -> Result { let guess = mime_guess::from_path(path); - let file = match File::open(path).await { - Ok(file) => file, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into(), - }; + let file = File::open(path).await?; let mut resp = Response::builder().body(Body::from_async_read(file)); if let Some(mut mime) = guess.first() { if prefer_utf8 { @@ -234,7 +233,7 @@ async fn create_file_response(path: &Path, prefer_utf8: bool) -> Response { .insert(header::CONTENT_TYPE, header_value); } } - resp + Ok(resp) } fn equiv_utf8_text(ct: Mime) -> Mime { diff --git a/poem/src/endpoint/inspect_err.rs b/poem/src/endpoint/inspect_err.rs new file mode 100644 index 00000000..38a987f3 --- /dev/null +++ b/poem/src/endpoint/inspect_err.rs @@ -0,0 +1,34 @@ +use crate::{Endpoint, Error, Request, Result}; + +/// Endpoint for the [`inspect_error`](super::EndpointExt::inspect_error) +/// method. +pub struct InspectError { + inner: E, + f: F, +} + +impl InspectError { + #[inline] + pub(crate) fn new(inner: E, f: F) -> InspectError { + Self { inner, f } + } +} + +#[async_trait::async_trait] +impl Endpoint for InspectError +where + E: Endpoint, + F: Fn(&Error) + Send + Sync, +{ + type Output = E::Output; + + async fn call(&self, req: Request) -> Result { + match self.inner.call(req).await { + Ok(resp) => Ok(resp), + Err(err) => { + (self.f)(&err); + Err(err) + } + } + } +} diff --git a/poem/src/endpoint/map.rs b/poem/src/endpoint/map.rs new file mode 100644 index 00000000..55bfa320 --- /dev/null +++ b/poem/src/endpoint/map.rs @@ -0,0 +1,33 @@ +use std::future::Future; + +use crate::{Endpoint, IntoResponse, Request, Result}; + +/// Endpoint for the [`map_ok`](super::EndpointExt::map) method. +pub struct Map { + inner: E, + f: F, +} + +impl Map { + #[inline] + pub(crate) fn new(inner: E, f: F) -> Map { + Self { inner, f } + } +} + +#[async_trait::async_trait] +impl Endpoint for Map +where + E: Endpoint, + F: Fn(R) -> Fut + Send + Sync, + Fut: Future + Send, + R: IntoResponse, + R2: IntoResponse, +{ + type Output = R2; + + async fn call(&self, req: Request) -> Result { + let resp = self.inner.call(req).await?; + Ok((self.f)(resp).await) + } +} diff --git a/poem/src/endpoint/map_err.rs b/poem/src/endpoint/map_err.rs deleted file mode 100644 index ae7c5563..00000000 --- a/poem/src/endpoint/map_err.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::future::Future; - -use crate::{Endpoint, IntoResponse, Request, Result}; - -/// Endpoint for the [`map_err`](super::EndpointExt::map_err) method. -pub struct MapErr { - inner: E, - f: F, -} - -impl MapErr { - #[inline] - pub(crate) fn new(inner: E, f: F) -> MapErr { - Self { inner, f } - } -} - -#[async_trait::async_trait] -impl Endpoint for MapErr -where - E: Endpoint>, - InErr: IntoResponse, - OutErr: IntoResponse, - F: Fn(InErr) -> Fut + Send + Sync, - Fut: Future + Send, - R: IntoResponse, -{ - type Output = Result; - - async fn call(&self, req: Request) -> Self::Output { - match self.inner.call(req).await { - Ok(resp) => Ok(resp), - Err(err) => Err((self.f)(err).await), - } - } -} diff --git a/poem/src/endpoint/map_ok.rs b/poem/src/endpoint/map_ok.rs deleted file mode 100644 index 87510d70..00000000 --- a/poem/src/endpoint/map_ok.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::future::Future; - -use crate::{Endpoint, IntoResponse, Request, Result}; - -/// Endpoint for the [`map_ok`](super::EndpointExt::map_ok) method. -pub struct MapOk { - inner: E, - f: F, -} - -impl MapOk { - #[inline] - pub(crate) fn new(inner: E, f: F) -> MapOk { - Self { inner, f } - } -} - -#[async_trait::async_trait] -impl Endpoint for MapOk -where - E: Endpoint>, - F: Fn(R) -> Fut + Send + Sync, - Fut: Future + Send, - R: IntoResponse, - R2: IntoResponse, -{ - type Output = Result; - - async fn call(&self, req: Request) -> Self::Output { - match self.inner.call(req).await { - Ok(resp) => Ok((self.f)(resp).await), - Err(err) => Err(err), - } - } -} diff --git a/poem/src/endpoint/map_to_response.rs b/poem/src/endpoint/map_to_response.rs index 1e942a4d..5395aeaa 100644 --- a/poem/src/endpoint/map_to_response.rs +++ b/poem/src/endpoint/map_to_response.rs @@ -1,4 +1,4 @@ -use crate::{Endpoint, IntoResponse, Request, Response}; +use crate::{Endpoint, IntoResponse, Request, Response, Result}; /// Endpoint for the [`map_to_response`](super::EndpointExt::map_to_response) /// method. @@ -17,7 +17,7 @@ impl MapToResponse { impl Endpoint for MapToResponse { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { - self.inner.call(req).await.into_response() + async fn call(&self, req: Request) -> Result { + self.inner.call(req).await.map(IntoResponse::into_response) } } diff --git a/poem/src/endpoint/map_to_result.rs b/poem/src/endpoint/map_to_result.rs deleted file mode 100644 index 028fde84..00000000 --- a/poem/src/endpoint/map_to_result.rs +++ /dev/null @@ -1,28 +0,0 @@ -use crate::{Endpoint, Error, IntoResponse, Request, Response, Result}; - -/// Endpoint for the [`map_to_result`](super::EndpointExt::map_to_result) -/// method. -pub struct MapToResult { - inner: E, -} - -impl MapToResult { - #[inline] - pub(crate) fn new(inner: E) -> MapToResult { - Self { inner } - } -} - -#[async_trait::async_trait] -impl Endpoint for MapToResult { - type Output = Result; - - async fn call(&self, req: Request) -> Self::Output { - let resp = self.inner.call(req).await.into_response(); - if !resp.status().is_server_error() && !resp.status().is_client_error() { - Ok(resp) - } else { - Err(Error::new(resp.status())) - } - } -} diff --git a/poem/src/endpoint/mod.rs b/poem/src/endpoint/mod.rs index 1cc3dd5c..8f4aef3b 100644 --- a/poem/src/endpoint/mod.rs +++ b/poem/src/endpoint/mod.rs @@ -4,13 +4,13 @@ mod after; mod and_then; mod around; mod before; +mod catch_error; #[allow(clippy::module_inception)] mod endpoint; mod files; -mod map_err; -mod map_ok; +mod inspect_err; +mod map; mod map_to_response; -mod map_to_result; #[cfg(feature = "prometheus")] mod prometheus_exporter; #[cfg(feature = "tower-compat")] @@ -20,12 +20,12 @@ pub use after::After; pub use and_then::AndThen; pub use around::Around; pub use before::Before; +pub use catch_error::CatchError; pub use endpoint::{make, make_sync, BoxEndpoint, Endpoint, EndpointExt, IntoEndpoint}; pub use files::Files; -pub use map_err::MapErr; -pub use map_ok::MapOk; +pub use inspect_err::InspectError; +pub use map::Map; pub use map_to_response::MapToResponse; -pub use map_to_result::MapToResult; #[cfg(feature = "prometheus")] pub use prometheus_exporter::PrometheusExporter; #[cfg(feature = "tower-compat")] diff --git a/poem/src/endpoint/prometheus_exporter.rs b/poem/src/endpoint/prometheus_exporter.rs index bce49813..10087088 100644 --- a/poem/src/endpoint/prometheus_exporter.rs +++ b/poem/src/endpoint/prometheus_exporter.rs @@ -3,7 +3,7 @@ use libprometheus::{Encoder, TextEncoder}; use crate::{ http::{Method, StatusCode}, - Endpoint, IntoEndpoint, Request, Response, + Endpoint, IntoEndpoint, Request, Response, Result, }; /// An endpoint that exports metrics for Prometheus. @@ -47,17 +47,17 @@ pub struct PrometheusExporterEndpoint { impl Endpoint for PrometheusExporterEndpoint { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { if req.method() != Method::GET { - return StatusCode::METHOD_NOT_ALLOWED.into(); + return Ok(StatusCode::METHOD_NOT_ALLOWED.into()); } let encoder = TextEncoder::new(); let metric_families = self.exporter.registry().gather(); let mut result = Vec::new(); match encoder.encode(&metric_families, &mut result) { - Ok(()) => Response::builder().content_type("text/plain").body(result), - Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into(), + Ok(()) => Ok(Response::builder().content_type("text/plain").body(result)), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR.into()), } } } diff --git a/poem/src/endpoint/tower_compat.rs b/poem/src/endpoint/tower_compat.rs index a8b41efd..584880a9 100644 --- a/poem/src/endpoint/tower_compat.rs +++ b/poem/src/endpoint/tower_compat.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use hyper::body::HttpBody; use tower::{Service, ServiceExt}; -use crate::{body::BodyStream, Endpoint, Request, Response, Result}; +use crate::{body::BodyStream, error::InternalServerError, Endpoint, Request, Response, Result}; /// Extension trait for tower service compat. #[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))] @@ -56,15 +56,18 @@ where + 'static, Fut: Future, Err>> + Send + 'static, { - type Output = Result; + type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let mut svc = self.0.clone(); - svc.ready().await?; + svc.ready().await.map_err(InternalServerError)?; let hyper_req: http::Request = req.into(); - let hyper_resp = svc.call(hyper_req.map(Into::into)).await?; + let hyper_resp = svc + .call(hyper_req.map(Into::into)) + .await + .map_err(InternalServerError)?; Ok(hyper_resp .map(|body| hyper::Body::wrap_stream(BodyStream::new(body))) diff --git a/poem/src/error.rs b/poem/src/error.rs index 59f7d1d1..0a90290d 100644 --- a/poem/src/error.rs +++ b/poem/src/error.rs @@ -1,10 +1,14 @@ //! Some common error types. use std::{ - fmt::{Debug, Display}, + convert::Infallible, + error::Error as StdError, + fmt::{self, Debug, Display, Formatter}, string::FromUtf8Error, }; +use http::Method; + use crate::{http::StatusCode, IntoResponse, Response}; macro_rules! define_http_error { @@ -13,54 +17,141 @@ macro_rules! define_http_error { $(#[$docs])* #[allow(non_snake_case)] #[inline] - pub fn $name(err: impl Display) -> Error { - Error::new(StatusCode::$status).with_reason(err) + pub fn $name(err: impl StdError + Send + Sync + 'static) -> Error { + Error::new(err).with_status(StatusCode::$status) } )* }; } +/// Represents a type that can be converted to [`Error`]. +pub trait ResponseError: StdError + Send + Sync + 'static { + /// The status code of this error. + fn status(&self) -> StatusCode; +} + /// General response error. +/// +/// # Create from any error types +/// +/// ``` +/// use poem::{error::InternalServerError, handler, Result}; +/// +/// #[handler] +/// async fn index() -> Result { +/// Ok(std::fs::read_to_string("example.txt").map_err(InternalServerError)?) +/// } +/// ``` +/// +/// # Create you own error type +/// +/// ``` +/// use poem::{error::ResponseError, handler, http::StatusCode, Result}; +/// +/// #[derive(Debug, thiserror::Error)] +/// #[error("my error")] +/// struct MyError; +/// +/// impl ResponseError for MyError { +/// fn status(&self) -> StatusCode { +/// StatusCode::BAD_GATEWAY +/// } +/// } +/// +/// #[handler] +/// async fn index() -> Result { +/// Ok(std::fs::read_to_string("example.txt").map_err(|_| MyError)?) +/// } +/// ``` +/// +/// # Downcast the error to concrete error type +/// +/// ``` +/// use poem::{error::NotFoundError, Error}; +/// +/// let err: Error = NotFoundError.into(); +/// +/// assert!(err.is::()); +/// assert_eq!(err.downcast_ref::(), Some(&NotFoundError)); +/// ``` #[derive(Debug)] pub struct Error { status: StatusCode, - reason: Option, + source: Box, } -impl From for Error { - #[inline] - fn from(err: T) -> Self { - Self { - status: StatusCode::INTERNAL_SERVER_ERROR, - reason: Some(err.to_string()), - } +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.source, f) } } -impl IntoResponse for Error { - #[inline] - fn into_response(self) -> Response { - self.as_response() +impl From for Error { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +impl From for Error { + fn from(err: T) -> Self { + let status = err.status(); + Error::new(err).with_status(status) + } +} + +impl From for Error { + fn from(status: StatusCode) -> Self { + Error::new_with_status(status) } } impl Error { - /// Create a new error with status code. + /// Create a new error object from any error type with `503 + /// INTERNAL_SERVER_ERROR` status code. #[inline] - pub fn new(status: StatusCode) -> Self { + pub fn new(err: T) -> Self { Self { - status, - reason: None, + status: StatusCode::INTERNAL_SERVER_ERROR, + source: Box::new(err), } } - /// Sets the reason for this error. - #[inline] - pub fn with_reason(self, reason: impl Display) -> Self { - Self { - reason: Some(reason.to_string()), - ..self + /// create a new error object from status code. + pub fn new_with_status(status: StatusCode) -> Self { + #[derive(Debug, thiserror::Error)] + #[error("{0}")] + struct StatusError(StatusCode); + + impl ResponseError for StatusError { + fn status(&self) -> StatusCode { + self.0 + } } + + StatusError(status).into() + } + + /// Create a new error object from string with `503 INTERNAL_SERVER_ERROR` + /// status code. + pub fn new_with_string(msg: impl Into) -> Self { + #[derive(Debug, thiserror::Error)] + #[error("{0}")] + struct StringError(String); + + impl ResponseError for StringError { + fn status(&self) -> StatusCode { + StatusCode::INTERNAL_SERVER_ERROR + } + } + + StringError(msg.into()).into() + } + + /// Specifies the status code of this error. + #[inline] + #[must_use] + pub fn with_status(self, status: StatusCode) -> Self { + Self { status, ..self } } /// Returns the status code of this error. @@ -69,22 +160,37 @@ impl Error { self.status } - /// Returns the reason of this error. + /// Downcast this error object by reference. #[inline] - pub fn reason(&self) -> Option<&str> { - self.reason.as_deref() + pub fn downcast_ref(&self) -> Option<&T> { + self.source.downcast_ref() } - /// Creates full response for this error. + /// Attempts to downcast the error to a concrete error type. #[inline] - pub fn as_response(&self) -> Response { - match &self.reason { - Some(reason) => Response::builder() - .status(self.status) - .body(reason.to_string()), - None => Response::builder().status(self.status).finish(), + pub fn downcast(self) -> Result { + let status = self.status; + match self.source.downcast::() { + Ok(err) => Ok(*err), + Err(err) => Err(Error { + status, + source: err, + }), } } + + /// Returns `true` if the error type is the same as `T`. + #[inline] + pub fn is(&self) -> bool { + self.source.is::() + } + + /// Consumes this to return a response object. + pub fn as_response(&self) -> Response { + Response::builder() + .status(self.status) + .body(self.source.to_string()) + } } define_http_error!( @@ -171,25 +277,57 @@ define_http_error!( /// A specialized Result type for Poem. pub type Result = ::std::result::Result; +/// Represents a type that can be converted to `poem::Result`. +/// +/// # Example +/// +/// ``` +/// use poem::error::{IntoResult, NotFoundError}; +/// +/// let res = "abc".into_result(); +/// assert!(matches!(res, Ok("abc"))); +/// +/// let res = Err::<(), _>(NotFoundError).into_result(); +/// assert!(res.is_err()); +/// let err = res.unwrap_err(); +/// assert!(err.is::()); +/// ``` +pub trait IntoResult { + /// Consumes this value returns a `poem::Result`. + fn into_result(self) -> Result; +} + +impl IntoResult for Result +where + T: IntoResponse, + E: Into + Debug + Send + Sync + 'static, +{ + #[inline] + fn into_result(self) -> Result { + self.map_err(Into::into) + } +} + +impl IntoResult for T { + #[inline] + fn into_result(self) -> Result { + Ok(self) + } +} + macro_rules! define_simple_errors { ($($(#[$docs:meta])* ($name:ident, $status:ident, $err_msg:literal);)*) => { $( $(#[$docs])* - #[derive(Debug, Copy, Clone, Eq, PartialEq)] + #[derive(Debug, thiserror::Error, Copy, Clone, Eq, PartialEq)] + #[error($err_msg)] pub struct $name; - impl From<$name> for Error { - fn from(_: $name) -> Self { - Error::new(StatusCode::$status).with_reason($err_msg) + impl ResponseError for $name { + fn status(&self) -> StatusCode { + StatusCode::$status } } - - impl IntoResponse for $name { - fn into_response(self) -> Response { - Into::::into(self).as_response() - } - } - )* }; } @@ -197,400 +335,266 @@ macro_rules! define_simple_errors { define_simple_errors!( /// Only the endpoints under the router can get the path parameters, otherwise this error will occur. (ParsePathError, BAD_REQUEST, "invalid path params"); + + /// Error occurred in the router. + (NotFoundError, NOT_FOUND, "not found"); + + /// Error occurred in the `Cors` middleware. + (CorsError, UNAUTHORIZED, "unauthorized"); ); /// A possible error value when reading the body. -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum ReadBodyError { /// Body has been taken by other extractors. + #[error("the body has been taken")] BodyHasBeenTaken, /// Body is not a valid utf8 string. - Utf8(FromUtf8Error), + #[error("parse utf8: {0}")] + Utf8(#[from] FromUtf8Error), /// Io error. - Io(std::io::Error), + #[error("io: {0}")] + Io(#[from] std::io::Error), } -impl From for ReadBodyError { - fn from(err: FromUtf8Error) -> Self { - Self::Utf8(err) - } -} - -impl From for ReadBodyError { - fn from(err: std::io::Error) -> Self { - Self::Io(err) - } -} - -impl From for Error { - fn from(err: ReadBodyError) -> Self { - match err { - ReadBodyError::BodyHasBeenTaken => { - Error::new(StatusCode::INTERNAL_SERVER_ERROR).with_reason("the body has been taken") - } - ReadBodyError::Utf8(err) => { - Error::new(StatusCode::BAD_REQUEST).with_reason(format!("parse utf8: {}", err)) - } - ReadBodyError::Io(err) => { - Error::new(StatusCode::BAD_REQUEST).with_reason(format!("io: {}", err)) - } +impl ResponseError for ReadBodyError { + fn status(&self) -> StatusCode { + match self { + ReadBodyError::BodyHasBeenTaken => StatusCode::INTERNAL_SERVER_ERROR, + ReadBodyError::Utf8(_) => StatusCode::BAD_REQUEST, + ReadBodyError::Io(_) => StatusCode::BAD_REQUEST, } } } -impl IntoResponse for ReadBodyError { - fn into_response(self) -> Response { - Into::::into(self).as_response() - } -} - /// A possible error value when parsing cookie. #[cfg(feature = "cookie")] #[cfg_attr(docsrs, doc(cfg(feature = "cookie")))] -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum ParseCookieError { /// Cookie value is illegal. + #[error("cookie is illegal")] CookieIllegal, /// A `Cookie` header is required. + #[error("`Cookie` header is required")] CookieHeaderRequired, /// Cookie value is illegal. - ParseJsonValue(serde_json::Error), + #[error("cookie is illegal: {0}")] + ParseJsonValue(#[from] serde_json::Error), } #[cfg(feature = "cookie")] -impl From for Error { - fn from(err: ParseCookieError) -> Self { - match err { - ParseCookieError::CookieIllegal => { - Error::new(StatusCode::BAD_REQUEST).with_reason("cookie is illegal") - } - ParseCookieError::CookieHeaderRequired => { - Error::new(StatusCode::BAD_REQUEST).with_reason("`Cookie` header is required") - } - ParseCookieError::ParseJsonValue(_) => { - Error::new(StatusCode::BAD_REQUEST).with_reason("cookie is illegal") - } - } - } -} - -#[cfg(feature = "cookie")] -impl IntoResponse for ParseCookieError { - fn into_response(self) -> Response { - Into::::into(self).as_response() +impl ResponseError for ParseCookieError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } /// A possible error value when extracts data from request fails. -#[derive(Debug)] +#[derive(Debug, thiserror::Error, Eq, PartialEq)] +#[error("data of type `{0}` was not found.")] pub struct GetDataError(pub &'static str); -impl From for Error { - fn from(err: GetDataError) -> Self { - Error::new(StatusCode::INTERNAL_SERVER_ERROR) - .with_reason(format!("data of type `{}` was not found.", err.0)) - } -} - -impl IntoResponse for GetDataError { - fn into_response(self) -> Response { - Into::::into(self).as_response() +impl ResponseError for GetDataError { + fn status(&self) -> StatusCode { + StatusCode::INTERNAL_SERVER_ERROR } } /// A possible error value when parsing form. -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum ParseFormError { - /// Read body error. - ReadBody(ReadBodyError), - /// Invalid content type. + #[error("invalid content type `{0}`, expect: `application/x-www-form-urlencoded`")] InvalidContentType(String), /// `Content-Type` header is required. + #[error("expect content type `application/x-www-form-urlencoded`")] ContentTypeRequired, /// Url decode error. - UrlDecode(serde_urlencoded::de::Error), + #[error("url decode: {0}")] + UrlDecode(#[from] serde_urlencoded::de::Error), } -impl From for ParseFormError { - fn from(err: ReadBodyError) -> Self { - Self::ReadBody(err) - } -} - -impl From for ParseFormError { - fn from(err: serde_urlencoded::de::Error) -> Self { - Self::UrlDecode(err) - } -} - -impl From for Error { - fn from(err: ParseFormError) -> Self { - match err { - ParseFormError::ReadBody(err) => err.into(), - ParseFormError::InvalidContentType(content_type) => Error::new(StatusCode::BAD_REQUEST) - .with_reason(format!( - "invalid content type `{}`, expect: `application/x-www-form-urlencoded`", - content_type - )), - ParseFormError::ContentTypeRequired => Error::new(StatusCode::BAD_REQUEST) - .with_reason("expect content type `application/x-www-form-urlencoded`"), - ParseFormError::UrlDecode(err) => { - Error::new(StatusCode::BAD_REQUEST).with_reason(format!("url decode: {}", err)) - } +impl ResponseError for ParseFormError { + fn status(&self) -> StatusCode { + match self { + ParseFormError::InvalidContentType(_) => StatusCode::BAD_REQUEST, + ParseFormError::ContentTypeRequired => StatusCode::BAD_REQUEST, + ParseFormError::UrlDecode(_) => StatusCode::BAD_REQUEST, } } } -impl IntoResponse for ParseFormError { - fn into_response(self) -> Response { - Into::::into(self).as_response() - } -} - /// A possible error value when parsing JSON. -#[derive(Debug)] -pub enum ParseJsonError { - /// Read body error. - ReadBody(ReadBodyError), +#[derive(Debug, thiserror::Error)] +#[error("parse: {0}")] +pub struct ParseJsonError(#[from] pub serde_json::Error); - /// Parse error. - Json(serde_json::Error), -} - -impl From for ParseJsonError { - fn from(err: ReadBodyError) -> Self { - Self::ReadBody(err) - } -} - -impl From for ParseJsonError { - fn from(err: serde_json::Error) -> Self { - Self::Json(err) - } -} - -impl From for Error { - fn from(err: ParseJsonError) -> Self { - match err { - ParseJsonError::ReadBody(err) => err.into(), - ParseJsonError::Json(err) => { - Error::new(StatusCode::BAD_REQUEST).with_reason(format!("parse: {}", err)) - } - } - } -} - -impl IntoResponse for ParseJsonError { - fn into_response(self) -> Response { - Into::::into(self).as_response() +impl ResponseError for ParseJsonError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } /// A possible error value when parsing query. -#[derive(Debug)] -pub struct ParseQueryError(pub serde_urlencoded::de::Error); +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct ParseQueryError(#[from] pub serde_urlencoded::de::Error); -impl From for ParseQueryError { - fn from(err: serde::de::value::Error) -> Self { - ParseQueryError(err) - } -} - -impl From for Error { - fn from(err: ParseQueryError) -> Self { - Error::new(StatusCode::BAD_REQUEST).with_reason(err.0.to_string()) - } -} - -impl IntoResponse for ParseQueryError { - fn into_response(self) -> Response { - Into::::into(self).as_response() +impl ResponseError for ParseQueryError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } /// A possible error value when parsing multipart. #[cfg(feature = "multipart")] #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum ParseMultipartError { - /// Read body error. - ReadBody(ReadBodyError), - /// Invalid content type. + #[error("invalid content type `{0}`, expect: `multipart/form-data`")] InvalidContentType(String), /// `Content-Type` header is required. + #[error("expect content type `multipart/form-data`")] ContentTypeRequired, /// Parse error. - Multipart(multer::Error), + #[error("parse: {0}")] + Multipart(#[from] multer::Error), } #[cfg(feature = "multipart")] -impl From for ParseMultipartError { - fn from(err: ReadBodyError) -> Self { - Self::ReadBody(err) - } -} -#[cfg(feature = "multipart")] -impl From for ParseMultipartError { - fn from(err: multer::Error) -> Self { - Self::Multipart(err) - } -} - -#[cfg(feature = "multipart")] -impl From for Error { - fn from(err: ParseMultipartError) -> Self { - match err { - ParseMultipartError::ReadBody(err) => err.into(), - ParseMultipartError::InvalidContentType(content_type) => { - Error::new(StatusCode::BAD_REQUEST).with_reason(format!( - "invalid content type `{}`, expect: `multipart/form-data`", - content_type - )) - } - ParseMultipartError::ContentTypeRequired => Error::new(StatusCode::BAD_REQUEST) - .with_reason("expect content type `multipart/form-data`"), - ParseMultipartError::Multipart(err) => { - Error::new(StatusCode::BAD_REQUEST).with_reason(format!("parse: {}", err)) - } +impl ResponseError for ParseMultipartError { + fn status(&self) -> StatusCode { + match self { + ParseMultipartError::InvalidContentType(_) => StatusCode::BAD_REQUEST, + ParseMultipartError::ContentTypeRequired => StatusCode::BAD_REQUEST, + ParseMultipartError::Multipart(_) => StatusCode::BAD_REQUEST, } } } -#[cfg(feature = "multipart")] -impl IntoResponse for ParseMultipartError { - fn into_response(self) -> Response { - Into::::into(self).as_response() - } -} - /// A possible error value when parsing typed headers. -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum ParseTypedHeaderError { /// A specified header is required. + #[error("header `{0}` is required")] HeaderRequired(String), /// Parse error. - TypedHeader(headers::Error), + #[error("parse: {0}")] + TypedHeader(#[from] headers::Error), } -impl From for ParseTypedHeaderError { - fn from(err: headers::Error) -> Self { - Self::TypedHeader(err) - } -} - -impl From for Error { - fn from(err: ParseTypedHeaderError) -> Self { - match err { - ParseTypedHeaderError::HeaderRequired(header_name) => { - Error::new(StatusCode::BAD_REQUEST) - .with_reason(format!("header `{}` is required", header_name)) - } - ParseTypedHeaderError::TypedHeader(err) => { - Error::new(StatusCode::BAD_REQUEST).with_reason(format!("parse: {}", err)) - } - } - } -} - -impl IntoResponse for ParseTypedHeaderError { - fn into_response(self) -> Response { - Into::::into(self).as_response() +impl ResponseError for ParseTypedHeaderError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } /// A possible error value when handling websocket. #[cfg(feature = "websocket")] #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum WebSocketError { /// Invalid protocol + #[error("invalid protocol")] InvalidProtocol, /// Upgrade Error - UpgradeError(UpgradeError), + #[error(transparent)] + UpgradeError(#[from] UpgradeError), } #[cfg(feature = "websocket")] -impl From for WebSocketError { - fn from(err: UpgradeError) -> Self { - Self::UpgradeError(err) - } -} - -#[cfg(feature = "websocket")] -impl From for Error { - fn from(err: WebSocketError) -> Self { - match err { - WebSocketError::InvalidProtocol => { - Error::new(StatusCode::BAD_REQUEST).with_reason("invalid protocol") - } - WebSocketError::UpgradeError(err) => err.into(), +impl ResponseError for WebSocketError { + fn status(&self) -> StatusCode { + match self { + WebSocketError::InvalidProtocol => StatusCode::BAD_REQUEST, + WebSocketError::UpgradeError(err) => err.status(), } } } -#[cfg(feature = "websocket")] -impl IntoResponse for WebSocketError { - fn into_response(self) -> Response { - Into::::into(self).as_response() - } -} - /// A possible error value when upgrading connection. -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum UpgradeError { /// No upgrade + #[error("no upgrade")] NoUpgrade, /// Other error + #[error("{0}")] Other(String), } -impl From for Error { - fn from(err: UpgradeError) -> Self { - match err { - UpgradeError::NoUpgrade => { - Error::new(StatusCode::INTERNAL_SERVER_ERROR).with_reason("no upgrade") - } - UpgradeError::Other(err) => Error::new(StatusCode::BAD_REQUEST).with_reason(err), +impl ResponseError for UpgradeError { + fn status(&self) -> StatusCode { + match self { + UpgradeError::NoUpgrade => StatusCode::INTERNAL_SERVER_ERROR, + UpgradeError::Other(_) => StatusCode::BAD_REQUEST, } } } -impl IntoResponse for UpgradeError { - fn into_response(self) -> Response { - Into::::into(self).as_response() +/// A possible error value when processing static files. +#[derive(Debug, thiserror::Error)] +pub enum StaticFileError { + /// Method not allow + #[error("method not found")] + MethodNotAllowed(Method), + + /// Invalid path + #[error("invalid path")] + InvalidPath, + + /// Forbidden + #[error("forbidden: {0}")] + Forbidden(String), + + /// File not found + #[error("not found: {0}")] + NotFound(String), + + /// Io error + #[error("io: {0}")] + Io(#[from] std::io::Error), +} + +impl ResponseError for StaticFileError { + fn status(&self) -> StatusCode { + match self { + StaticFileError::MethodNotAllowed(_) => StatusCode::METHOD_NOT_ALLOWED, + StaticFileError::InvalidPath => StatusCode::BAD_REQUEST, + StaticFileError::Forbidden(_) => StatusCode::FORBIDDEN, + StaticFileError::NotFound(_) => StatusCode::NOT_FOUND, + StaticFileError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR, + } } } -#[cfg(test)] -mod tests { - use super::*; +/// A possible error value occurred in the `SizeLimit` middleware. +#[derive(Debug, thiserror::Error, Eq, PartialEq)] +pub enum SizedLimitError { + /// Missing `Content-Length` header. + #[error("missing `Content-Length` header")] + MissingContentLength, - #[test] - fn display_into_error() { - let err: Error = "a".into(); - assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(err.reason(), Some("a")); - } + /// Payload too large + #[error("payload too large")] + PayloadTooLarge, +} - #[test] - fn extractor_err_into_error() { - let err: Error = ReadBodyError::BodyHasBeenTaken.into(); - assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR); - - let err: Error = ParseFormError::ContentTypeRequired.into(); - assert_eq!(err.status(), StatusCode::BAD_REQUEST); +impl ResponseError for SizedLimitError { + fn status(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } diff --git a/poem/src/lib.rs b/poem/src/lib.rs index d22b86a0..d078eabb 100644 --- a/poem/src/lib.rs +++ b/poem/src/lib.rs @@ -1,15 +1,17 @@ //! Poem is a full-featured and easy-to-use web framework with the Rust //! programming language. //! -//! # Usage +//! # Table of contents //! -//! Depend on poem in Cargo.toml: +//! - [Quickstart](#quickstart) +//! - [Routing](#routing) +//! - [Extractors](#extractors) +//! - [Responses](#responses) +//! - [Handling errors](#handling-errors) +//! - [Middleware](#middleware) +//! - [Crate features](#crate-features) //! -//! ```toml -//! poem = "*" -//! ``` -//! -//! # Example +//! # Quickstart //! //! ```no_run //! use poem::{get, handler, listener::TcpListener, web::Path, IntoResponse, Route, Server}; @@ -28,10 +30,169 @@ //! } //! ``` //! +//! # Routing +//! +//! There are three available routes. +//! +//! - [`Route`] Routing for path +//! - [`RouteDomain`] Routing for domain +//! - [`RouteMethod`] Routing for HTTP method +//! +//! ``` +//! use poem::{get, handler, post, web::Path, Route}; +//! +//! #[handler] +//! async fn get_user(id: Path) {} +//! +//! #[handler] +//! async fn delete_user(id: Path) {} +//! +//! #[handler] +//! async fn create_user() {} +//! +//! let app = Route::new() +//! .at("/user/:id", get(get_user).delete(delete_user)) +//! .at("/user", post(create_user)); +//! ``` +//! +//! # Extractors +//! +//! The extractor is used to extract something from the HTTP request. +//! +//! `Poem` provides some commonly used extractors for extracting something from +//! HTTP requests. +//! +//! In the following example, the `index` function uses 3 extractors to extract +//! the remote address, HTTP method and URI. +//! +//! ```rust +//! use poem::{ +//! handler, +//! http::{Method, Uri}, +//! web::RemoteAddr, +//! }; +//! +//! #[handler] +//! fn index(remote_addr: &RemoteAddr, method: Method, uri: &Uri) {} +//! ``` +//! +//! By default, the extractor will return a `400 Bad Request` when an error +//! occurs, but sometimes you may want to change this behavior, so you can +//! handle the error yourself. +//! +//! In the following example, when the [`Query`](web::Query) extractor fails, it +//! will return a `503 Internal Server` response and the reason for the error. +//! +//! ``` +//! use poem::{ +//! error::ParseQueryError, handler, http::StatusCode, web::Query, IntoResponse, Response, +//! Result, +//! }; +//! use serde::Deserialize; +//! +//! #[derive(Debug, Deserialize)] +//! struct Params { +//! name: String, +//! } +//! +//! #[handler] +//! fn index(res: Result>) -> Result { +//! match res { +//! Ok(Query(params)) => Ok(params.name.into_response()), +//! Err(err) if err.is::() => Ok(Response::builder() +//! .status(StatusCode::INTERNAL_SERVER_ERROR) +//! .body(err.to_string())), +//! Err(err) => Err(err), +//! } +//! } +//! ``` +//! +//! You can create custom extractors, see also [`FromRequest`]. +//! +//! # Responses +//! +//! All types that can be converted to HTTP response [`Response`] should +//! implement [`IntoResponse`]. +//! +//! In the following example, the `string_response` and `status_response` +//! functions return the `String` and `StatusCode` types, because `Poem` has +//! implemented the [`IntoResponse`] trait for them. +//! +//! The `no_response` function does not return a value. We can think that +//! its return type is `()`, and `Poem` also implements [`IntoResponse`] for +//! `()`, which is always converted to `200 OK`. +//! +//! The `result_response` function returns a `Result` type, which means that an +//! error may occur. +//! ``` +//! use poem::{handler, http::StatusCode, Result}; +//! +//! #[handler] +//! fn string_response() -> String { +//! todo!() +//! } +//! +//! #[handler] +//! fn status_response() -> StatusCode { +//! todo!() +//! } +//! +//! #[handler] +//! fn no_response() {} +//! +//! #[handler] +//! fn result_response() -> Result { +//! todo!() +//! } +//! ``` +//! +//! # Handling errors +//! +//! The following example returns customized content when +//! [`NotFoundError`](error::NotFoundError) occurs. +//! +//! ``` +//! use poem::{ +//! error::NotFoundError, handler, http::StatusCode, EndpointExt, IntoResponse, Response, Route, +//! }; +//! +//! #[handler] +//! fn foo() {} +//! +//! #[handler] +//! fn bar() {} +//! +//! let app = +//! Route::new() +//! .at("/foo", foo) +//! .at("/bar", bar) +//! .catch_error(|err: NotFoundError| async move { +//! Response::builder() +//! .status(StatusCode::NOT_FOUND) +//! .body("custom not found") +//! }); +//! ``` +//! +//! # Middleware +//! +//! You can call the [`with`](EndpointExt::with) method on the [`Endpoint`] to +//! apply a middleware to an endpoint. It actually converts the original +//! endpoint to a new endpoint. +//! ``` +//! use poem::{handler, middleware::Tracing, EndpointExt, Route}; +//! +//! #[handler] +//! fn index() {} +//! +//! let app = Route::new().at("/", index).with(Tracing); +//! ``` +//! +//! You can create your own middleware, see also [`Middleware`]. +//! //! # Crate features //! -//! To avoid compiling unused dependencies, Poem gates certain features, all of -//! which are disabled by default: +//! To avoid compiling unused dependencies, `Poem` gates certain features, all +//! of which are disabled by default: //! //! |Feature |Description | //! |------------------|--------------------------------| @@ -50,8 +211,8 @@ //! |tower-compat | Adapters for `tower::Layer` and `tower::Service`. | //! |websocket | Support for WebSocket | -#![doc(html_favicon_url = "https://poem.rs/assets/favicon.ico")] -#![doc(html_logo_url = "https://poem.rs/en/assets/logo.png")] +#![doc(html_favicon_url = "https://raw.githubusercontent.com/poem-web/poem/master/favicon.ico")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/poem-web/poem/master/logo.png")] #![forbid(unsafe_code)] #![deny(private_in_public, unreachable_pub)] #![cfg_attr(docsrs, feature(doc_cfg))] diff --git a/poem/src/middleware/add_data.rs b/poem/src/middleware/add_data.rs index 25644b8c..278dee93 100644 --- a/poem/src/middleware/add_data.rs +++ b/poem/src/middleware/add_data.rs @@ -1,4 +1,4 @@ -use crate::{Endpoint, Middleware, Request}; +use crate::{Endpoint, Middleware, Request, Result}; /// Middleware for add any data to request. pub struct AddData { @@ -41,7 +41,7 @@ where { type Output = E::Output; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { req.extensions_mut().insert(self.value.clone()); self.inner.call(req).await } @@ -60,6 +60,6 @@ mod tests { } let app = index.with(AddData::new(100i32)); - app.call(Request::default()).await; + app.call(Request::default()).await.unwrap(); } } diff --git a/poem/src/middleware/compression.rs b/poem/src/middleware/compression.rs index daf3b627..04d857fb 100644 --- a/poem/src/middleware/compression.rs +++ b/poem/src/middleware/compression.rs @@ -5,7 +5,7 @@ use typed_headers::{AcceptEncoding, ContentCoding, HeaderMapExt}; use crate::{ http::header, web::{Compress, CompressionAlgo}, - Body, Endpoint, IntoResponse, Middleware, Request, Response, + Body, Endpoint, IntoResponse, Middleware, Request, Response, Result, }; /// Middleware for decompress request body and compress response body. @@ -43,7 +43,7 @@ pub struct CompressionEndpoint { impl Endpoint for CompressionEndpoint { type Output = Response; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { // decompress request body if let Some(algo) = req .headers() @@ -71,8 +71,8 @@ impl Endpoint for CompressionEndpoint { } match compress_algo { - Some(algo) => Compress::new(self.ep.call(req).await, algo).into_response(), - None => self.ep.call(req).await.into_response(), + Some(algo) => Ok(Compress::new(self.ep.call(req).await?, algo).into_response()), + None => Ok(self.ep.call(req).await?.into_response()), } } } @@ -101,7 +101,8 @@ mod tests { .header("Accept-Encoding", algo.as_str()) .body(Body::from_async_read(algo.compress(DATA.as_bytes()))), ) - .await; + .await + .unwrap(); assert_eq!( resp.headers() @@ -132,7 +133,8 @@ mod tests { .header("Accept-Encoding", "identity; q=0.5, gzip;q=1.0, br;q=0.3") .body(DATA), ) - .await; + .await + .unwrap(); assert_eq!( resp.headers() @@ -156,7 +158,8 @@ mod tests { .header("Accept-Encoding", "identity; q=0.5, *;q=1.0, br;q=0.3") .body(DATA), ) - .await; + .await + .unwrap(); assert_eq!( resp.headers() diff --git a/poem/src/middleware/cookie_jar_manager.rs b/poem/src/middleware/cookie_jar_manager.rs index 57093655..46aae454 100644 --- a/poem/src/middleware/cookie_jar_manager.rs +++ b/poem/src/middleware/cookie_jar_manager.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use crate::{ web::cookie::{CookieJar, CookieKey}, - Endpoint, IntoResponse, Middleware, Request, Response, + Endpoint, IntoResponse, Middleware, Request, Response, Result, }; /// Middleware for CookieJar support. @@ -53,16 +53,16 @@ pub struct CookieJarManagerEndpoint { impl Endpoint for CookieJarManagerEndpoint { type Output = Response; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { if req.state().cookie_jar.is_none() { let mut cookie_jar = CookieJar::extract_from_headers(req.headers()); cookie_jar.key = self.key.clone(); req.state_mut().cookie_jar = Some(cookie_jar.clone()); - let mut resp = self.inner.call(req).await.into_response(); + let mut resp = self.inner.call(req).await?.into_response(); cookie_jar.append_delta_to_headers(resp.headers_mut()); - resp + Ok(resp) } else { - self.inner.call(req).await.into_response() + self.inner.call(req).await.map(IntoResponse::into_response) } } } @@ -82,7 +82,8 @@ mod tests { let ep = index.with(CookieJarManager::new()); let resp = ep .call(Request::builder().header("Cookie", "value=88").finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -120,7 +121,8 @@ mod tests { ) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/poem/src/middleware/cors.rs b/poem/src/middleware/cors.rs index 34e163d3..06715dd4 100644 --- a/poem/src/middleware/cors.rs +++ b/poem/src/middleware/cors.rs @@ -6,19 +6,24 @@ use headers::{ use crate::{ endpoint::Endpoint, + error::CorsError, http::{ header, header::{HeaderName, HeaderValue}, - Method, StatusCode, + Method, }, middleware::Middleware, request::Request, response::Response, - Error, IntoResponse, Result, + IntoResponse, Result, }; /// Middleware for CORS /// +/// # Errors +/// +/// - [`CorsError`] +/// /// # Example /// /// ``` @@ -224,7 +229,7 @@ pub struct CorsEndpoint { max_age: i32, } -impl CorsEndpoint { +impl CorsEndpoint { fn is_valid_origin(&self, origin: &HeaderValue) -> (bool, bool) { if self.allow_origins.contains(origin) { return (true, false); @@ -322,20 +327,20 @@ impl CorsEndpoint { #[async_trait::async_trait] impl Endpoint for CorsEndpoint { - type Output = Result; + type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let origin = match req.headers().get(header::ORIGIN) { Some(origin) => origin.clone(), None => { // This is not a CORS request if there is no Origin header - return Ok(self.inner.call(req).await.into_response()); + return self.inner.call(req).await.map(IntoResponse::into_response); } }; let (origin_is_allow, vary_header) = self.is_valid_origin(&origin); if !origin_is_allow { - return Err(Error::new(StatusCode::UNAUTHORIZED)); + return Err(CorsError.into()); } if req.method() == Method::OPTIONS { @@ -352,19 +357,19 @@ impl Endpoint for CorsEndpoint { } }); if !matches!(allow_method, Some(true)) { - return Err(Error::new(StatusCode::UNAUTHORIZED)); + return Err(CorsError.into()); } let (allow_headers, request_headers) = self.check_allow_headers(&req); if !allow_headers { - return Err(Error::new(StatusCode::UNAUTHORIZED)); + return Err(CorsError.into()); } return Ok(self.build_preflight_response(&origin, request_headers)); } - let mut resp = self.inner.call(req).await.into_response(); + let mut resp = self.inner.call(req).await?.into_response(); resp.headers_mut() .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); @@ -392,6 +397,8 @@ impl Endpoint for CorsEndpoint { #[cfg(test)] mod tests { + use http::StatusCode; + use super::*; use crate::{endpoint::make_sync, EndpointExt}; @@ -427,7 +434,7 @@ mod tests { #[tokio::test] async fn preflight_request() { let ep = make_sync(|_| "hello").with(cors()); - let resp = ep.map_to_response().call(opt_request()).await; + let resp = ep.map_to_response().call(opt_request()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( @@ -485,7 +492,8 @@ mod tests { .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token") .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( @@ -527,7 +535,8 @@ mod tests { .header(header::ORIGIN, ALLOW_ORIGIN) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,), @@ -552,7 +561,8 @@ mod tests { .header(header::ORIGIN, ALLOW_ORIGIN) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,), @@ -581,7 +591,8 @@ mod tests { .header(header::ORIGIN, ALLOW_ORIGIN) .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,), @@ -596,7 +607,8 @@ mod tests { .header(header::ORIGIN, "https://abc.com") .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN,), @@ -614,21 +626,22 @@ mod tests { .with(Cors::new().allow_origins_fn(|_| false)) .map_to_response(); - let resp = ep + assert!(ep .call( Request::builder() .method(Method::GET) .header(header::ORIGIN, ALLOW_ORIGIN) .finish(), ) - .await; - assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + .await + .unwrap_err() + .is::()); } #[tokio::test] async fn default_cors_middleware() { let ep = make_sync(|_| "hello").with(Cors::new()).map_to_response(); - let resp = ep.call(get_request()).await; + let resp = ep.call(get_request()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers() @@ -641,22 +654,23 @@ mod tests { #[tokio::test] async fn unauthorized_origin() { let ep = make_sync(|_| "hello").with(cors()).map_to_response(); - let resp = ep + assert!(ep .call( Request::builder() .method(Method::GET) .header(header::ORIGIN, "https://foo.com") .finish(), ) - .await; - assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + .await + .unwrap_err() + .is::()); } #[tokio::test] async fn unauthorized_options() { let ep = make_sync(|_| "hello").with(cors()).map_to_response(); - let resp = ep + assert!(ep .call( Request::builder() .method(Method::OPTIONS) @@ -665,10 +679,11 @@ mod tests { .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token") .finish(), ) - .await; - assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + .await + .unwrap_err() + .is::()); - let resp = ep + assert!(ep .call( Request::builder() .method(Method::OPTIONS) @@ -677,10 +692,11 @@ mod tests { .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token") .finish(), ) - .await; - assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + .await + .unwrap_err() + .is::()); - let resp = ep + assert!(ep .call( Request::builder() .method(Method::OPTIONS) @@ -689,8 +705,9 @@ mod tests { .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-ABC") .finish(), ) - .await; - assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + .await + .unwrap_err() + .is::()); let resp = ep .call( @@ -701,7 +718,8 @@ mod tests { .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Token") .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -719,15 +737,15 @@ mod tests { } let ep = index.with(CookieJarManager::new()).with(cors()); - let resp = ep.map_to_response().call(get_request()).await; + let resp = ep.map_to_response().call(get_request()).await.unwrap(); assert_eq!(resp.headers().get(header::SET_COOKIE).unwrap(), "foo=bar"); } #[tokio::test] async fn set_cors_headers_to_error_responses() { - let ep = make_sync(|_| Err::<(), Error>(Error::new(StatusCode::BAD_REQUEST))).with(cors()); - let resp = ep.map_to_response().call(get_request()).await; + let ep = make_sync(|_| StatusCode::BAD_REQUEST).with(cors()); + let resp = ep.map_to_response().call(get_request()).await.unwrap(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!( resp.headers() @@ -750,7 +768,8 @@ mod tests { let resp = ep .map_to_response() .call(Request::builder().method(Method::GET).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert!(resp .headers() @@ -775,7 +794,8 @@ mod tests { .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type") .finish(), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers() diff --git a/poem/src/middleware/csrf.rs b/poem/src/middleware/csrf.rs index cb96e3e3..01ad8bdb 100644 --- a/poem/src/middleware/csrf.rs +++ b/poem/src/middleware/csrf.rs @@ -11,7 +11,7 @@ use crate::{ cookie::{Cookie, SameSite}, CsrfToken, CsrfVerifier, }, - Endpoint, Middleware, Request, + Endpoint, Middleware, Request, Result, }; /// Middleware for Cross-Site Request Forgery (CSRF) protection. @@ -38,10 +38,10 @@ use crate::{ /// async fn login(verifier: &CsrfVerifier, req: &Request) -> Result { /// let csrf_token = req /// .header("X-CSRF-Token") -/// .ok_or_else(|| Error::new(StatusCode::UNAUTHORIZED))?; +/// .ok_or_else(|| Error::new_with_status(StatusCode::UNAUTHORIZED))?; /// /// if !verifier.is_valid(&csrf_token) { -/// return Err(Error::new(StatusCode::UNAUTHORIZED)); +/// return Err(Error::new_with_status(StatusCode::UNAUTHORIZED)); /// } /// /// Ok(format!("login success")) @@ -52,7 +52,7 @@ use crate::{ /// .at("/", get(login_ui).post(login)) /// .with(Csrf::new()); /// -/// let resp = app.call(Request::default()).await; +/// let resp = app.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// let cookie = resp.headers().get(header::SET_COOKIE).unwrap(); /// let cookie = Cookie::parse(cookie.to_str().unwrap()).unwrap(); @@ -69,7 +69,8 @@ use crate::{ /// ) /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!( /// resp.into_body().into_string().await.unwrap(), @@ -200,7 +201,7 @@ impl CsrfEndpoint { impl Endpoint for CsrfEndpoint { type Output = E::Output; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { let existing_cookie = req .cookie() .get(&self.cookie_name) @@ -246,19 +247,22 @@ mod tests { #[handler(internal)] fn login(verifier: &CsrfVerifier, req: &Request) -> Result { - let token = req - .header(CSRF_TOKEN_NAME) - .ok_or_else(|| Error::new(StatusCode::BAD_REQUEST).with_reason("missing token"))?; + let token = req.header(CSRF_TOKEN_NAME).ok_or_else(|| { + Error::new_with_string("missing token").with_status(StatusCode::BAD_REQUEST) + })?; match verifier.is_valid(token) { true => Ok("ok"), - false => Err(Error::new(StatusCode::BAD_REQUEST).with_reason("invalid token")), + false => { + Err(Error::new_with_string("invalid token") + .with_status(StatusCode::BAD_REQUEST)) + } } } let app = get(login_ui).post(login).with(Csrf::new()); for _ in 0..5 { - let resp = app.call(Request::default()).await; + let resp = app.call(Request::default()).await.unwrap(); let cookie = resp .header(header::SET_COOKIE) .map(|cookie| cookie.to_string()) @@ -274,6 +278,7 @@ mod tests { .finish(), ) .await + .unwrap() .into_body() .into_string() .await @@ -281,7 +286,7 @@ mod tests { assert_eq!(resp, "ok"); } - let resp = app.call(Request::default()).await; + let resp = app.call(Request::default()).await.unwrap(); let cookie = resp .header(header::SET_COOKIE) .map(|cookie| cookie.to_string()) @@ -291,8 +296,8 @@ mod tests { let mut token = base64::decode(token).unwrap(); token[0] = token[0].wrapping_add(1); - let resp = app - .call( + assert_eq!( + app.call( Request::builder() .method(Method::POST) .header(CSRF_TOKEN_NAME, base64::encode(token)) @@ -300,10 +305,9 @@ mod tests { .finish(), ) .await - .into_body() - .into_string() - .await - .unwrap(); - assert_eq!(resp, "invalid token"); + .unwrap_err() + .to_string(), + "invalid token" + ); } } diff --git a/poem/src/middleware/mod.rs b/poem/src/middleware/mod.rs index a87acff7..70b70676 100644 --- a/poem/src/middleware/mod.rs +++ b/poem/src/middleware/mod.rs @@ -40,16 +40,14 @@ pub use size_limit::{SizeLimit, SizeLimitEndpoint}; pub use tower_compat::TowerLayerCompatExt; pub use tracing_mw::{Tracing, TracingEndpoint}; -#[cfg(feature = "tracing")] -pub use self::tracing_mw::{Tracing, TracingEndpoint}; use crate::endpoint::Endpoint; /// Represents a middleware trait. /// -/// # Example +/// # Create you own middleware /// /// ``` -/// use poem::{handler, web::Data, Endpoint, EndpointExt, Middleware, Request}; +/// use poem::{handler, web::Data, Endpoint, EndpointExt, Middleware, Request, Result}; /// /// /// A middleware that extract token from HTTP headers. /// struct TokenMiddleware; @@ -76,7 +74,7 @@ use crate::endpoint::Endpoint; /// impl Endpoint for TokenMiddlewareImpl { /// type Output = E::Output; /// -/// async fn call(&self, mut req: Request) -> Self::Output { +/// async fn call(&self, mut req: Request) -> Result { /// if let Some(value) = req /// .headers() /// .get(TOKEN_HEADER) @@ -103,7 +101,8 @@ use crate::endpoint::Endpoint; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let mut resp = ep /// .call(Request::builder().header(TOKEN_HEADER, "abc").finish()) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.take_body().into_string().await.unwrap(), "abc"); /// # }); /// ``` @@ -149,7 +148,7 @@ mod tests { handler, http::{header::HeaderName, HeaderValue, StatusCode}, web::Data, - EndpointExt, IntoResponse, Request, Response, + EndpointExt, IntoResponse, Request, Response, Result, }; #[tokio::test] @@ -169,11 +168,11 @@ mod tests { impl Endpoint for AddHeader { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { - let mut resp = self.ep.call(req).await.into_response(); + async fn call(&self, req: Request) -> Result { + let mut resp = self.ep.call(req).await?.into_response(); resp.headers_mut() .insert(self.header.clone(), self.value.clone()); - resp + Ok(resp) } } @@ -182,7 +181,7 @@ mod tests { header: HeaderName::from_static("hello"), value: HeaderValue::from_static("world"), })); - let mut resp = ep.call(Request::default()).await; + let mut resp = ep.call(Request::default()).await.unwrap(); assert_eq!( resp.headers() .get(HeaderName::from_static("hello")) @@ -205,7 +204,7 @@ mod tests { SetHeader::new().appending("myheader-2", "b"), )); - let mut resp = ep.call(Request::default()).await; + let mut resp = ep.call(Request::default()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get("myheader-1"), diff --git a/poem/src/middleware/normalize_path.rs b/poem/src/middleware/normalize_path.rs index a30149b8..c9d1b3f5 100644 --- a/poem/src/middleware/normalize_path.rs +++ b/poem/src/middleware/normalize_path.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use http::{uri::PathAndQuery, Uri}; use regex::Regex; -use crate::{Endpoint, Middleware, Request}; +use crate::{Endpoint, Middleware, Request, Result}; /// Determines the behavior of the [`NormalizePath`] middleware. #[derive(Debug, Clone, Copy)] @@ -53,7 +53,8 @@ impl Default for TrailingSlash { /// .uri(Uri::from_static("/foo/bar/")) /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); /// # }); @@ -91,7 +92,7 @@ pub struct NormalizePathEndpoint { impl Endpoint for NormalizePathEndpoint { type Output = E::Output; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { let original_path = req .uri() .path_and_query() @@ -132,7 +133,7 @@ impl Endpoint for NormalizePathEndpoint { #[cfg(test)] mod tests { use super::*; - use crate::{endpoint::make_sync, EndpointExt, Route}; + use crate::{endpoint::make_sync, error::NotFoundError, http::StatusCode, EndpointExt, Route}; #[tokio::test] async fn trim_trailing_slashes() { @@ -167,7 +168,7 @@ mod tests { for uri in test_uris { let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish(); - let res = ep.call(req).await; + let res = ep.call(req).await.unwrap(); assert!(res.status().is_success(), "Failed uri: {}", uri); } } @@ -190,7 +191,7 @@ mod tests { for uri in test_uris { let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish(); - let res = ep.call(req).await; + let res = ep.call(req).await.unwrap(); assert!(res.status().is_success(), "Failed uri: {}", uri); } } @@ -228,7 +229,7 @@ mod tests { for uri in test_uris { let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish(); - let res = ep.call(req).await; + let res = ep.call(req).await.unwrap(); assert!(res.status().is_success(), "Failed uri: {}", uri); } } @@ -251,7 +252,7 @@ mod tests { for uri in test_uris { let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish(); - let res = ep.call(req).await; + let res = ep.call(req).await.unwrap(); assert!(res.status().is_success(), "Failed uri: {}", uri); } } @@ -294,7 +295,16 @@ mod tests { for (uri, success) in test_uris { let req = Request::builder().uri(Uri::from_str(uri).unwrap()).finish(); let res = ep.call(req).await; - assert_eq!(res.status().is_success(), success, "Failed uri: {}", uri); + + if success { + assert_eq!(res.unwrap().status(), StatusCode::OK, "Failed uri: {}", uri); + } else { + assert!( + res.unwrap_err().is::(), + "Failed uri: {}", + uri + ); + } } } } diff --git a/poem/src/middleware/opentelemetry_metrics.rs b/poem/src/middleware/opentelemetry_metrics.rs index fdf72eec..cb53b810 100644 --- a/poem/src/middleware/opentelemetry_metrics.rs +++ b/poem/src/middleware/opentelemetry_metrics.rs @@ -6,7 +6,7 @@ use libopentelemetry::{ Key, }; -use crate::{Endpoint, IntoResponse, Middleware, Request, Response}; +use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result}; const METHOD_KEY: Key = Key::from_static_str("request_method"); const PATH_KEY: Key = Key::from_static_str("request_path"); @@ -76,13 +76,13 @@ pub struct OpenTelemetryMetricsEndpoint { impl Endpoint for OpenTelemetryMetricsEndpoint { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let mut labels = Vec::with_capacity(3); labels.push(METHOD_KEY.string(req.method().to_string())); labels.push(PATH_KEY.string(req.uri().path().to_string())); let s = Instant::now(); - let resp = self.inner.call(req).await.into_response(); + let resp = self.inner.call(req).await?.into_response(); let elapsed = s.elapsed(); labels.push(STATUS_KEY.i64(resp.status().as_u16() as i64)); @@ -94,6 +94,6 @@ impl Endpoint for OpenTelemetryMetricsEndpoint { self.duration .record(elapsed.as_secs_f64() / 1000.0, &labels); - resp + Ok(resp) } } diff --git a/poem/src/middleware/opentelemetry_tracing.rs b/poem/src/middleware/opentelemetry_tracing.rs index 9dc8bd7a..c12a6d35 100644 --- a/poem/src/middleware/opentelemetry_tracing.rs +++ b/poem/src/middleware/opentelemetry_tracing.rs @@ -8,7 +8,9 @@ use libopentelemetry::{ use opentelemetry_http::HeaderExtractor; use opentelemetry_semantic_conventions::{resource, trace}; -use crate::{web::headers::HeaderMapExt, Endpoint, IntoResponse, Middleware, Request, Response}; +use crate::{ + web::headers::HeaderMapExt, Endpoint, IntoResponse, Middleware, Request, Response, Result, +}; /// Middleware for tracing with OpenTelemetry. #[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))] @@ -55,7 +57,7 @@ where { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let parent_cx = global::get_text_map_propagator(|propagator| { propagator.extract(&HeaderExtractor(req.headers())) }); @@ -80,7 +82,7 @@ where span.add_event("request.started".to_string(), vec![]); async move { - let resp = self.inner.call(req).await.into_response(); + let resp = self.inner.call(req).await?.into_response(); let cx = Context::current(); let span = cx.span(); @@ -91,7 +93,7 @@ where trace::HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length.0 as i64), ); } - resp + Ok(resp) } .with_context(Context::current_with_span(span)) .await diff --git a/poem/src/middleware/propagate_header.rs b/poem/src/middleware/propagate_header.rs index 21fbf175..69a3b77d 100644 --- a/poem/src/middleware/propagate_header.rs +++ b/poem/src/middleware/propagate_header.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use http::{header::HeaderName, HeaderMap}; -use crate::{Endpoint, IntoResponse, Middleware, Request, Response}; +use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result}; /// Middleware for propagate a header from the request to the response. #[derive(Default)] @@ -51,7 +51,7 @@ pub struct PropagateHeaderEndpoint { impl Endpoint for PropagateHeaderEndpoint { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let mut headers = HeaderMap::new(); for header in &self.headers { @@ -60,10 +60,9 @@ impl Endpoint for PropagateHeaderEndpoint { } } - let mut resp = self.inner.call(req).await.into_response(); + let mut resp = self.inner.call(req).await?.into_response(); resp.headers_mut().extend(headers); - - resp + Ok(resp) } } @@ -80,7 +79,8 @@ mod tests { let resp = index .with(PropagateHeader::new().header("x-request-id")) .call(Request::builder().header("x-request-id", "100").finish()) - .await; + .await + .unwrap(); assert_eq!( resp.headers() diff --git a/poem/src/middleware/set_header.rs b/poem/src/middleware/set_header.rs index f630aa26..bbe09c02 100644 --- a/poem/src/middleware/set_header.rs +++ b/poem/src/middleware/set_header.rs @@ -2,7 +2,7 @@ use std::convert::TryInto; use crate::{ http::{header::HeaderName, HeaderValue}, - Endpoint, IntoResponse, Middleware, Request, Response, + Endpoint, IntoResponse, Middleware, Request, Response, Result, }; #[derive(Debug, Clone)] @@ -37,7 +37,7 @@ enum Action { /// ); /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { -/// let resp = app.call(Request::default()).await; +/// let resp = app.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!( /// resp.headers() @@ -124,8 +124,8 @@ pub struct SetHeaderEndpoint { impl Endpoint for SetHeaderEndpoint { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { - let mut resp = self.inner.call(req).await.into_response(); + async fn call(&self, req: Request) -> Result { + let mut resp = self.inner.call(req).await?.into_response(); let headers = resp.headers_mut(); for action in &self.actions { @@ -139,7 +139,7 @@ impl Endpoint for SetHeaderEndpoint { } } - resp + Ok(resp) } } @@ -162,7 +162,8 @@ mod tests { .appending("custom-b", "b"), ) .call(Request::default()) - .await; + .await + .unwrap(); assert_eq!( resp.headers() diff --git a/poem/src/middleware/size_limit.rs b/poem/src/middleware/size_limit.rs index 2b65159c..61bfad55 100644 --- a/poem/src/middleware/size_limit.rs +++ b/poem/src/middleware/size_limit.rs @@ -1,11 +1,15 @@ use crate::{ - http::StatusCode, web::headers::HeaderMapExt, Endpoint, Error, Middleware, Request, Result, + error::SizedLimitError, web::headers::HeaderMapExt, Endpoint, Middleware, Request, Result, }; /// Middleware for limit the request payload size. /// /// If the incoming request does not contain the `Content-Length` header, it /// will return `BAD_REQUEST` status code. +/// +/// # Errors +/// +/// - [`SizedLimitError`] pub struct SizeLimit { max_size: usize, } @@ -36,24 +40,26 @@ pub struct SizeLimitEndpoint { #[async_trait::async_trait] impl Endpoint for SizeLimitEndpoint { - type Output = Result; + type Output = E::Output; - async fn call(&self, req: Request) -> Self::Output { - let content_length = match req.headers().typed_get::() { - Some(content_length) => content_length.0 as usize, - None => return Err(Error::new(StatusCode::BAD_REQUEST)), - }; + async fn call(&self, req: Request) -> Result { + let content_length = req + .headers() + .typed_get::() + .ok_or(SizedLimitError::MissingContentLength)?; - if content_length > self.max_size { - return Err(Error::new(StatusCode::PAYLOAD_TOO_LARGE)); + if content_length.0 as usize > self.max_size { + return Err(SizedLimitError::PayloadTooLarge.into()); } - Ok(self.inner.call(req).await) + self.inner.call(req).await } } #[cfg(test)] mod tests { + use http::StatusCode; + use super::*; use crate::{ endpoint::{make_sync, EndpointExt}, @@ -64,6 +70,14 @@ mod tests { async fn size_limit() { let ep = make_sync(|_| ()).with(SizeLimit::new(5)); + assert_eq!( + ep.call(Request::builder().body(&b"123456"[..])) + .await + .unwrap_err() + .downcast_ref::(), + Some(&SizedLimitError::MissingContentLength) + ); + assert_eq!( ep.call( Request::builder() @@ -71,9 +85,9 @@ mod tests { .body(&b"123456"[..]) ) .await - .into_response() - .status(), - StatusCode::PAYLOAD_TOO_LARGE + .unwrap_err() + .downcast_ref::(), + Some(&SizedLimitError::PayloadTooLarge) ); assert_eq!( @@ -83,6 +97,7 @@ mod tests { .body(&b"1234"[..]) ) .await + .unwrap() .into_response() .status(), StatusCode::OK @@ -95,6 +110,7 @@ mod tests { .body(&b"12345"[..]) ) .await + .unwrap() .into_response() .status(), StatusCode::OK diff --git a/poem/src/middleware/tower_compat.rs b/poem/src/middleware/tower_compat.rs index 85d795bb..709e214f 100644 --- a/poem/src/middleware/tower_compat.rs +++ b/poem/src/middleware/tower_compat.rs @@ -1,13 +1,24 @@ use std::{ - convert::Infallible, sync::Arc, task::{Context, Poll}, }; use futures_util::{future::BoxFuture, FutureExt}; -use tower::{buffer::Buffer, Layer, Service, ServiceExt}; +use tower::{buffer::Buffer, BoxError, Layer, Service, ServiceExt}; -use crate::{Endpoint, IntoResponse, Middleware, Request, Result}; +use crate::{Endpoint, Error, IntoResponse, Middleware, Request, Result}; + +#[doc(hidden)] +#[derive(Debug, thiserror::Error)] +#[error("wrapper error")] +pub struct WrapperError(Error); + +fn boxed_err_to_poem_err(err: BoxError) -> Error { + match err.downcast::() { + Ok(err) => (*err).0, + Err(err) => Error::new_with_string(err.to_string()), + } +} /// Extension trait for tower layer compat. #[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))] @@ -34,15 +45,14 @@ where L::Service: Service + Send + 'static, >::Future: Send, >::Response: IntoResponse + Send + 'static, - >::Error: Into + Send + Sync, + >::Error: Into + Send + Sync, { type Output = TowerServiceToEndpoint; fn transform(&self, ep: E) -> Self::Output { - TowerServiceToEndpoint(Buffer::new( - self.0.layer(EndpointToTowerService(Arc::new(ep))), - 32, - )) + let new_svc = self.0.layer(EndpointToTowerService(Arc::new(ep))); + let buffer = Buffer::new(new_svc, 32); + TowerServiceToEndpoint(buffer) } } @@ -54,7 +64,7 @@ where E: Endpoint + 'static, { type Response = E::Output; - type Error = Infallible; + type Error = WrapperError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -63,7 +73,7 @@ where fn call(&mut self, req: Request) -> Self::Future { let ep = self.0.clone(); - async move { Ok(ep.call(req).await) }.boxed() + async move { ep.call(req).await.map_err(WrapperError) }.boxed() } } @@ -75,15 +85,15 @@ impl Endpoint for TowerServiceToEndpoint where Svc: Service + Send + 'static, Svc::Future: Send, - Svc::Response: IntoResponse + Send + 'static, - Svc::Error: Into + Send + Sync, + Svc::Response: IntoResponse + 'static, + Svc::Error: Into + Send + Sync, { - type Output = Result; + type Output = Svc::Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let mut svc = self.0.clone(); - svc.ready().await?; - let res = svc.call(req).await?; + svc.ready().await.map_err(boxed_err_to_poem_err)?; + let res = svc.call(req).await.map_err(boxed_err_to_poem_err)?; Ok(res) } } @@ -129,7 +139,7 @@ mod tests { } let ep = make_sync(|_| ()).with(MyServiceLayer.compat()); - let resp = ep.call(Request::default()).await.into_response(); + let resp = ep.call(Request::default()).await.unwrap().into_response(); assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/poem/src/middleware/tracing_mw.rs b/poem/src/middleware/tracing_mw.rs index 1876c9e4..3df5c08f 100644 --- a/poem/src/middleware/tracing_mw.rs +++ b/poem/src/middleware/tracing_mw.rs @@ -2,7 +2,7 @@ use std::time::SystemTime; use tracing::{Instrument, Level}; -use crate::{Endpoint, IntoResponse, Middleware, Request, Response}; +use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result}; /// Middleware for [`tracing`](https://crates.io/crates/tracing). #[derive(Default)] @@ -25,7 +25,7 @@ pub struct TracingEndpoint { impl Endpoint for TracingEndpoint { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let span = tracing::span!( target: module_path!(), Level::INFO, @@ -38,7 +38,7 @@ impl Endpoint for TracingEndpoint { async move { let now = SystemTime::now(); - let resp = self.inner.call(req).await.into_response(); + let resp = self.inner.call(req).await?.into_response(); match now.elapsed() { Ok(duration) => tracing::info!( status = %resp.status(), @@ -50,7 +50,7 @@ impl Endpoint for TracingEndpoint { "response" ), } - resp + Ok(resp) } .instrument(span) .await diff --git a/poem/src/request.rs b/poem/src/request.rs index 28710ab5..6ad2b95e 100644 --- a/poem/src/request.rs +++ b/poem/src/request.rs @@ -262,7 +262,8 @@ impl Request { /// .uri(Uri::from_static("/100/abc")) /// .finish(), /// ) - /// .await; + /// .await + /// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "100:abc"); /// # }); @@ -307,7 +308,8 @@ impl Request { /// .uri(Uri::from_static("/?a=100&b=abc")) /// .finish(), /// ) - /// .await; + /// .await + /// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "100:abc"); /// # }); diff --git a/poem/src/response.rs b/poem/src/response.rs index 29e71340..b36d2239 100644 --- a/poem/src/response.rs +++ b/poem/src/response.rs @@ -11,7 +11,7 @@ use crate::{ Extensions, StatusCode, Version, }, web::headers::Header, - Body, Error, + Body, }; /// Component parts of an HTTP Response. @@ -71,12 +71,6 @@ impl From for Response { } } -impl From for Response { - fn from(err: Error) -> Self { - err.as_response() - } -} - impl> From<(StatusCode, T)> for Response { fn from((status, body): (StatusCode, T)) -> Self { Response::builder().status(status).body(body.into()) @@ -342,10 +336,6 @@ mod tests { assert_eq!(resp.status(), StatusCode::BAD_GATEWAY); assert!(resp.body.into_string().await.unwrap().is_empty()); - let resp = Response::from(Error::new(StatusCode::BAD_GATEWAY).with_reason("bad gateway")); - assert_eq!(resp.status(), StatusCode::BAD_GATEWAY); - assert_eq!(resp.body.into_string().await.unwrap(), "bad gateway"); - let resp = Response::from((StatusCode::BAD_GATEWAY, Body::from("abc"))); assert_eq!(resp.status(), StatusCode::BAD_GATEWAY); assert_eq!(resp.body.into_string().await.unwrap(), "abc"); diff --git a/poem/src/route/router.rs b/poem/src/route/router.rs index 0d006cff..e2fbf7dd 100644 --- a/poem/src/route/router.rs +++ b/poem/src/route/router.rs @@ -1,16 +1,157 @@ use std::{str::FromStr, sync::Arc}; -use http::StatusCode; use regex::Regex; use crate::{ endpoint::BoxEndpoint, + error::NotFoundError, http::{uri::PathAndQuery, Uri}, route::internal::radix_tree::RadixTree, - Endpoint, EndpointExt, IntoEndpoint, IntoResponse, Request, Response, + Endpoint, EndpointExt, IntoEndpoint, IntoResponse, Request, Response, Result, }; /// Routing object +/// +/// You can match the full path or wildcard path, and use the +/// [`Path`](crate::web::Path) extractor to get the path parameters. +/// +/// # Example +/// +/// ``` +/// use poem::{ +/// get, handler, +/// http::{StatusCode, Uri}, +/// web::Path, +/// Endpoint, Request, Route, +/// }; +/// +/// #[handler] +/// async fn a() {} +/// +/// #[handler] +/// async fn b(Path((group, name)): Path<(String, String)>) { +/// assert_eq!(group, "foo"); +/// assert_eq!(name, "bar"); +/// } +/// +/// #[handler] +/// async fn c(Path(path): Path) { +/// assert_eq!(path, "d/e"); +/// } +/// +/// let app = Route::new() +/// // full path +/// .at("/a/b", get(a)) +/// // capture parameters +/// .at("/b/:group/:name", get(b)) +/// // capture tail path +/// .at("/c/*path", get(c)) +/// // match regex +/// .at("/d/<\\d+>", get(a)) +/// // capture with regex +/// .at("/e/:name<\\d+>", get(a)); +/// +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { +/// // /a/b +/// let resp = app +/// .call(Request::builder().uri(Uri::from_static("/a/b")).finish()) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// +/// // /b/:group/:name +/// let resp = app +/// .call( +/// Request::builder() +/// .uri(Uri::from_static("/b/foo/bar")) +/// .finish(), +/// ) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// +/// // /c/*path +/// let resp = app +/// .call(Request::builder().uri(Uri::from_static("/c/d/e")).finish()) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// +/// // /d/<\\d> +/// let resp = app +/// .call(Request::builder().uri(Uri::from_static("/d/123")).finish()) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// +/// // /e/:name<\\d> +/// let resp = app +/// .call(Request::builder().uri(Uri::from_static("/e/123")).finish()) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// # }); +/// ``` +/// +/// # Nested +/// +/// ``` +/// use poem::{ +/// handler, +/// http::{StatusCode, Uri}, +/// Endpoint, Request, Route, +/// }; +/// +/// #[handler] +/// fn index() -> &'static str { +/// "hello" +/// } +/// +/// let app = Route::new().nest("/foo", Route::new().at("/bar", index)); +/// +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { +/// let resp = app +/// .call( +/// Request::builder() +/// .uri(Uri::from_static("/foo/bar")) +/// .finish(), +/// ) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); +/// # }); +/// ``` +/// +/// # Nested no strip +/// +/// ``` +/// use poem::{ +/// handler, +/// http::{StatusCode, Uri}, +/// Endpoint, Request, Route, +/// }; +/// +/// #[handler] +/// fn index() -> &'static str { +/// "hello" +/// } +/// +/// let app = Route::new().nest_no_strip("/foo", Route::new().at("/foo/bar", index)); +/// +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { +/// let resp = app +/// .call( +/// Request::builder() +/// .uri(Uri::from_static("/foo/bar")) +/// .finish(), +/// ) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); +/// # }); +/// ``` #[derive(Default)] pub struct Route { tree: RadixTree>, @@ -23,125 +164,18 @@ impl Route { } /// Add an [Endpoint] to the specified path. - /// - /// You can match the full path or wildcard path, and use the - /// [`Path`](crate::web::Path) extractor to get the path parameters. - /// - /// # Example - /// - /// ``` - /// use poem::{ - /// get, handler, - /// http::{StatusCode, Uri}, - /// web::Path, - /// Endpoint, Request, Route, - /// }; - /// - /// #[handler] - /// async fn a() {} - /// - /// #[handler] - /// async fn b(Path((group, name)): Path<(String, String)>) { - /// assert_eq!(group, "foo"); - /// assert_eq!(name, "bar"); - /// } - /// - /// #[handler] - /// async fn c(Path(path): Path) { - /// assert_eq!(path, "d/e"); - /// } - /// - /// let app = Route::new() - /// // full path - /// .at("/a/b", get(a)) - /// // capture parameters - /// .at("/b/:group/:name", get(b)) - /// // capture tail path - /// .at("/c/*path", get(c)) - /// // match regex - /// .at("/d/<\\d+>", get(a)) - /// // capture with regex - /// .at("/e/:name<\\d+>", get(a)); - /// - /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// // /a/b - /// let resp = app - /// .call(Request::builder().uri(Uri::from_static("/a/b")).finish()) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// - /// // /b/:group/:name - /// let resp = app - /// .call( - /// Request::builder() - /// .uri(Uri::from_static("/b/foo/bar")) - /// .finish(), - /// ) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// - /// // /c/*path - /// let resp = app - /// .call(Request::builder().uri(Uri::from_static("/c/d/e")).finish()) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// - /// // /d/<\\d> - /// let resp = app - /// .call(Request::builder().uri(Uri::from_static("/d/123")).finish()) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// - /// // /e/:name<\\d> - /// let resp = app - /// .call(Request::builder().uri(Uri::from_static("/e/123")).finish()) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// # }); - /// ``` #[must_use] pub fn at(mut self, path: impl AsRef, ep: E) -> Self where E: IntoEndpoint, E::Endpoint: 'static, { - self.tree.add( - &normalize_path(path.as_ref()), - Box::new(ep.into_endpoint().map_to_response()), - ); + self.tree + .add(&normalize_path(path.as_ref()), ep.map_to_response().boxed()); self } /// Nest a `Endpoint` to the specified path and strip the prefix. - /// - /// # Example - /// - /// ``` - /// use poem::{ - /// handler, - /// http::{StatusCode, Uri}, - /// Endpoint, Request, Route, - /// }; - /// - /// #[handler] - /// fn index() -> &'static str { - /// "hello" - /// } - /// - /// let app = Route::new().nest("/foo", Route::new().at("/bar", index)); - /// - /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let resp = app - /// .call( - /// Request::builder() - /// .uri(Uri::from_static("/foo/bar")) - /// .finish(), - /// ) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); - /// # }); - /// ``` #[must_use] pub fn nest(self, path: impl AsRef, ep: E) -> Self where @@ -152,35 +186,6 @@ impl Route { } /// Nest a `Endpoint` to the specified path, but do not strip the prefix. - /// - /// # Example - /// - /// ``` - /// use poem::{ - /// handler, - /// http::{StatusCode, Uri}, - /// Endpoint, Request, Route, - /// }; - /// - /// #[handler] - /// fn index() -> &'static str { - /// "hello" - /// } - /// - /// let app = Route::new().nest_no_strip("/foo", Route::new().at("/foo/bar", index)); - /// - /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let resp = app - /// .call( - /// Request::builder() - /// .uri(Uri::from_static("/foo/bar")) - /// .finish(), - /// ) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); - /// # }); - /// ``` #[must_use] pub fn nest_no_strip(self, path: impl AsRef, ep: E) -> Self where @@ -211,7 +216,7 @@ impl Route { impl Endpoint for Nest { type Output = Response; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { if !self.root { let idx = req.state().match_params.len() - 1; let (name, _) = req.state_mut().match_params.remove(idx); @@ -231,7 +236,8 @@ impl Route { Uri::from_parts(uri_parts).unwrap() }; *req.uri_mut() = new_uri; - self.inner.call(req).await.into_response() + + Ok(self.inner.call(req).await?.into_response()) } } @@ -269,13 +275,13 @@ impl Route { impl Endpoint for Route { type Output = Response; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { match self.tree.matches(req.uri().path()) { Some(matches) => { req.state_mut().match_params.extend(matches.params); matches.data.call(req).await } - None => StatusCode::NOT_FOUND.into(), + None => Err(NotFoundError.into()), } } } @@ -313,6 +319,7 @@ mod tests { route .call(Request::builder().uri(Uri::from_static(path)).finish()) .await + .unwrap() .take_body() .into_string() .await diff --git a/poem/src/route/router_domain.rs b/poem/src/route/router_domain.rs index 7a4b3939..274de4be 100644 --- a/poem/src/route/router_domain.rs +++ b/poem/src/route/router_domain.rs @@ -1,11 +1,43 @@ use crate::{ - endpoint::BoxEndpoint, - http::{header, StatusCode}, - route::internal::trie::Trie, - Endpoint, EndpointExt, IntoEndpoint, Request, Response, + endpoint::BoxEndpoint, error::NotFoundError, http::header, route::internal::trie::Trie, + Endpoint, EndpointExt, IntoEndpoint, Request, Response, Result, }; /// Routing object for `HOST` header +/// +/// # Example +/// +/// ``` +/// use poem::{endpoint::make_sync, handler, http::header, Endpoint, Request, RouteDomain}; +/// +/// let app = RouteDomain::new() +/// .add("example.com", make_sync(|_| "1")) +/// .add("www.+.com", make_sync(|_| "2")) +/// .add("*.example.com", make_sync(|_| "3")) +/// .add("*", make_sync(|_| "4")); +/// +/// fn make_request(host: &str) -> Request { +/// Request::builder().header(header::HOST, host).finish() +/// } +/// +/// async fn do_request(app: &RouteDomain, req: Request) -> String { +/// app.call(req) +/// .await +/// .unwrap() +/// .into_body() +/// .into_string() +/// .await +/// .unwrap() +/// } +/// +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { +/// assert_eq!(do_request(&app, make_request("example.com")).await, "1"); +/// assert_eq!(do_request(&app, make_request("www.abc.com")).await, "2"); +/// assert_eq!(do_request(&app, make_request("a.b.example.com")).await, "3"); +/// assert_eq!(do_request(&app, make_request("rust-lang.org")).await, "4"); +/// assert_eq!(do_request(&app, Request::default()).await, "4"); +/// # }); +/// ``` #[derive(Default)] pub struct RouteDomain { tree: Trie>, @@ -18,34 +50,6 @@ impl RouteDomain { } /// Add an [Endpoint] to the specified domain pattern. - /// - /// # Example - /// - /// ``` - /// use poem::{endpoint::make_sync, handler, http::header, Endpoint, Request, RouteDomain}; - /// - /// let app = RouteDomain::new() - /// .add("example.com", make_sync(|_| "1")) - /// .add("www.+.com", make_sync(|_| "2")) - /// .add("*.example.com", make_sync(|_| "3")) - /// .add("*", make_sync(|_| "4")); - /// - /// fn make_request(host: &str) -> Request { - /// Request::builder().header(header::HOST, host).finish() - /// } - /// - /// async fn do_request(app: &RouteDomain, req: Request) -> String { - /// app.call(req).await.into_body().into_string().await.unwrap() - /// } - /// - /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// assert_eq!(do_request(&app, make_request("example.com")).await, "1"); - /// assert_eq!(do_request(&app, make_request("www.abc.com")).await, "2"); - /// assert_eq!(do_request(&app, make_request("a.b.example.com")).await, "3"); - /// assert_eq!(do_request(&app, make_request("rust-lang.org")).await, "4"); - /// assert_eq!(do_request(&app, Request::default()).await, "4"); - /// # }); - /// ``` pub fn add(mut self, pattern: impl AsRef, ep: E) -> Self where E: IntoEndpoint, @@ -53,7 +57,7 @@ impl RouteDomain { { self.tree.add( pattern.as_ref(), - Box::new(ep.into_endpoint().map_to_response()), + ep.into_endpoint().map_to_response().boxed(), ); self } @@ -63,7 +67,7 @@ impl RouteDomain { impl Endpoint for RouteDomain { type Output = Response; - async fn call(&self, req: Request) -> Self::Output { + async fn call(&self, req: Request) -> Result { let host = req .headers() .get(header::HOST) @@ -71,7 +75,7 @@ impl Endpoint for RouteDomain { .unwrap_or_default(); match self.tree.matches(host) { Some(ep) => ep.call(req).await, - None => StatusCode::NOT_FOUND.into(), + None => Err(NotFoundError.into()), } } } @@ -89,6 +93,7 @@ mod tests { assert_eq!( r.call(req.finish()) .await + .unwrap() .into_body() .into_string() .await @@ -131,19 +136,20 @@ mod tests { .add("www.+.com", make_sync(|_| "3")) .add("*.com", make_sync(|_| "4")); - assert_eq!( - r.call( + assert!(r + .call( Request::builder() .header(header::HOST, "rust-lang.org") .finish() ) .await - .status(), - StatusCode::NOT_FOUND, - ); - assert_eq!( - r.call(Request::default()).await.status(), - StatusCode::NOT_FOUND, - ); + .unwrap_err() + .is::()); + + assert!(r + .call(Request::default()) + .await + .unwrap_err() + .is::()); } } diff --git a/poem/src/route/router_method.rs b/poem/src/route/router_method.rs index a665d3bb..e948da23 100644 --- a/poem/src/route/router_method.rs +++ b/poem/src/route/router_method.rs @@ -1,10 +1,47 @@ use crate::{ - endpoint::BoxEndpoint, - http::{Method, StatusCode}, - Endpoint, EndpointExt, IntoEndpoint, Request, Response, + endpoint::BoxEndpoint, error::NotFoundError, http::Method, Endpoint, EndpointExt, IntoEndpoint, + Request, Response, Result, }; /// Routing object for HTTP methods +/// +/// # Example +/// +/// ``` +/// use poem::{ +/// handler, +/// http::{Method, StatusCode}, +/// Endpoint, Request, RouteMethod, +/// }; +/// +/// #[handler] +/// fn handle_get() -> &'static str { +/// "get" +/// } +/// +/// #[handler] +/// fn handle_post() -> &'static str { +/// "post" +/// } +/// +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { +/// let route_method = RouteMethod::new().get(handle_get).post(handle_post); +/// +/// let resp = route_method +/// .call(Request::builder().method(Method::GET).finish()) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// assert_eq!(resp.into_body().into_string().await.unwrap(), "get"); +/// +/// let resp = route_method +/// .call(Request::builder().method(Method::POST).finish()) +/// .await +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// assert_eq!(resp.into_body().into_string().await.unwrap(), "post"); +/// # }); +/// ``` #[derive(Default)] pub struct RouteMethod { methods: Vec<(Method, BoxEndpoint<'static, Response>)>, @@ -12,42 +49,6 @@ pub struct RouteMethod { impl RouteMethod { /// Create a `RouteMethod` object. - /// - /// # Example - /// - /// ``` - /// use poem::{ - /// handler, - /// http::{Method, StatusCode}, - /// Endpoint, Request, RouteMethod, - /// }; - /// - /// #[handler] - /// fn handle_get() -> &'static str { - /// "get" - /// } - /// - /// #[handler] - /// fn handle_post() -> &'static str { - /// "post" - /// } - /// - /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// let route_method = RouteMethod::new().get(handle_get).post(handle_post); - /// - /// let resp = route_method - /// .call(Request::builder().method(Method::GET).finish()) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// assert_eq!(resp.into_body().into_string().await.unwrap(), "get"); - /// - /// let resp = route_method - /// .call(Request::builder().method(Method::POST).finish()) - /// .await; - /// assert_eq!(resp.status(), StatusCode::OK); - /// assert_eq!(resp.into_body().into_string().await.unwrap(), "post"); - /// # }); - /// ``` pub fn new() -> Self { Default::default() } @@ -59,7 +60,7 @@ impl RouteMethod { E::Endpoint: 'static, { self.methods - .push((method, Box::new(ep.into_endpoint().map_to_response()))); + .push((method, ep.into_endpoint().map_to_response().boxed())); self } @@ -149,7 +150,7 @@ impl RouteMethod { impl Endpoint for RouteMethod { type Output = Response; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { match self .methods .iter() @@ -160,11 +161,11 @@ impl Endpoint for RouteMethod { None => { if req.method() == Method::HEAD { req.set_method(Method::GET); - let mut resp = self.call(req).await; + let mut resp = self.call(req).await?; resp.set_body(()); - return resp; + return Ok(resp); } - StatusCode::NOT_FOUND.into() + Err(NotFoundError.into()) } } } @@ -281,7 +282,8 @@ mod tests { let route = RouteMethod::new().method(method.clone(), index).post(index); let resp = route .call(Request::builder().method(method.clone()).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); } @@ -292,7 +294,8 @@ mod tests { let route = RouteMethod::new().$id(index).post(index); let resp = route .call(Request::builder().method(Method::$method).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.into_body().into_string().await.unwrap(), "hello"); )* @@ -322,7 +325,8 @@ mod tests { let route = RouteMethod::new().get(index); let resp = route .call(Request::builder().method(Method::HEAD).finish()) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert!(resp.into_body().into_vec().await.unwrap().is_empty()); } diff --git a/poem/src/server.rs b/poem/src/server.rs index 25763e18..ffc8540e 100644 --- a/poem/src/server.rs +++ b/poem/src/server.rs @@ -176,8 +176,10 @@ async fn serve_connection( let local_addr = local_addr.clone(); let remote_addr = remote_addr.clone(); async move { - let resp = ep.call((req, local_addr, remote_addr).into()).await.into(); - Ok::<_, Infallible>(resp) + match ep.call((req, local_addr, remote_addr).into()).await { + Ok(resp) => Ok::, Infallible>(resp.into()), + Err(err) => Ok(err.as_response().into()), + } } } }); diff --git a/poem/src/session/cookie_session.rs b/poem/src/session/cookie_session.rs index e574fd78..e6e53794 100644 --- a/poem/src/session/cookie_session.rs +++ b/poem/src/session/cookie_session.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeMap, sync::Arc}; use crate::{ middleware::{CookieJarManager, CookieJarManagerEndpoint}, session::{CookieConfig, Session, SessionStatus}, - Endpoint, Middleware, Request, + Endpoint, Middleware, Request, Result, }; /// Middleware for client-side(cookie) session. @@ -44,7 +44,7 @@ pub struct CookieSessionEndpoint { impl Endpoint for CookieSessionEndpoint { type Output = E::Output; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { let cookie_jar = req.cookie().clone(); let session = self .config @@ -54,7 +54,7 @@ impl Endpoint for CookieSessionEndpoint { .unwrap_or_else(Session::default); req.extensions_mut().insert(session.clone()); - let resp = self.inner.call(req).await; + let resp = self.inner.call(req).await?; match session.status() { SessionStatus::Changed | SessionStatus::Renewed => { @@ -69,7 +69,7 @@ impl Endpoint for CookieSessionEndpoint { SessionStatus::Unchanged => {} }; - resp + Ok(resp) } } diff --git a/poem/src/session/redis_storage.rs b/poem/src/session/redis_storage.rs index e46b7994..811dd8bc 100644 --- a/poem/src/session/redis_storage.rs +++ b/poem/src/session/redis_storage.rs @@ -5,6 +5,10 @@ use redis::{aio::ConnectionLike, AsyncCommands, Cmd}; use crate::{error::InternalServerError, session::session_storage::SessionStorage, Result}; /// A session storage using redis. +/// +/// # Errors +/// +/// - [`redis::RedisError`] #[cfg_attr(docsrs, doc(cfg(feature = "redis-session")))] pub struct RedisStorage { connection: T, diff --git a/poem/src/session/server_session.rs b/poem/src/session/server_session.rs index 3ebb5745..43e445fe 100644 --- a/poem/src/session/server_session.rs +++ b/poem/src/session/server_session.rs @@ -57,9 +57,9 @@ where T: SessionStorage, E: Endpoint, { - type Output = Result; + type Output = E::Output; - async fn call(&self, mut req: Request) -> Self::Output { + async fn call(&self, mut req: Request) -> Result { let cookie_jar = req.cookie().clone(); let mut session_id = self.config.get_cookie_value(&cookie_jar); let session = match &session_id { @@ -74,7 +74,7 @@ where }; req.extensions_mut().insert(session.clone()); - let resp = self.inner.call(req).await; + let resp = self.inner.call(req).await?; match session.status() { SessionStatus::Changed => match session_id { diff --git a/poem/src/session/session.rs b/poem/src/session/session.rs index d280e2a1..464a34e6 100644 --- a/poem/src/session/session.rs +++ b/poem/src/session/session.rs @@ -1,6 +1,5 @@ use std::{ collections::BTreeMap, - convert::Infallible, fmt::{self, Debug, Formatter}, sync::Arc, }; @@ -8,7 +7,7 @@ use std::{ use parking_lot::RwLock; use serde::{de::DeserializeOwned, Serialize}; -use crate::{FromRequest, Request, RequestBody}; +use crate::{FromRequest, Request, RequestBody, Result}; /// Status of the Session. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -142,9 +141,7 @@ impl Session { #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a Session { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req .extensions() .get::() diff --git a/poem/src/session/test_harness.rs b/poem/src/session/test_harness.rs index 3f58216d..9a7c0c1b 100644 --- a/poem/src/session/test_harness.rs +++ b/poem/src/session/test_harness.rs @@ -28,7 +28,7 @@ impl TestClient { .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap()); } - let resp = ep.call(req).await.into_response(); + let resp = ep.call(req).await.unwrap().into_response(); for s in resp.headers().get_all(header::SET_COOKIE) { if let Ok(s) = s.to_str() { let cookie = Cookie::parse(&s).unwrap(); diff --git a/poem/src/web/compress.rs b/poem/src/web/compress.rs index 9a8b51d7..25dc94b6 100644 --- a/poem/src/web/compress.rs +++ b/poem/src/web/compress.rs @@ -158,9 +158,10 @@ mod tests { } let mut resp = index - .after(move |resp| async move { Compress::new(resp, algo) }) + .and_then(move |resp| async move { Ok(Compress::new(resp, algo)) }) .call(Request::default()) .await + .unwrap() .into_response(); assert_eq!( resp.headers().get(header::CONTENT_ENCODING), diff --git a/poem/src/web/cookie.rs b/poem/src/web/cookie.rs index 5a2f2b6c..d62a9fce 100644 --- a/poem/src/web/cookie.rs +++ b/poem/src/web/cookie.rs @@ -22,7 +22,11 @@ use crate::{ /// The `SameSite` cookie attribute. pub type SameSite = libcookie::SameSite; -/// Representation of an HTTP cookie. +/// HTTP cookie extractor. +/// +/// # Errors +/// +/// - [`ParseCookieError`] #[derive(Clone, Debug, PartialEq)] pub struct Cookie(libcookie::Cookie<'static>); @@ -289,9 +293,7 @@ impl Cookie { #[async_trait::async_trait] impl<'a> FromRequest<'a> for Cookie { - type Error = ParseCookieError; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { let value = req .headers() .get(header::COOKIE) @@ -333,14 +335,15 @@ impl<'a> FromRequest<'a> for Cookie { /// .at("/", get(index)) /// .with(CookieJarManager::new()); /// -/// let resp = app.call(Request::default()).await; +/// let resp = app.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// let cookie = resp.headers().get(header::SET_COOKIE).cloned().unwrap(); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "count: 1"); /// /// let resp = app /// .call(Request::builder().header(header::COOKIE, cookie).finish()) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "count: 2"); /// # }); /// ``` @@ -489,9 +492,7 @@ impl FromStr for CookieJar { #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a CookieJar { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req.cookie()) } } diff --git a/poem/src/web/csrf.rs b/poem/src/web/csrf.rs index 61c6f5ff..a19b739b 100644 --- a/poem/src/web/csrf.rs +++ b/poem/src/web/csrf.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, ops::Deref, sync::Arc}; +use std::{ops::Deref, sync::Arc}; use libcsrf::{AesGcmCsrfProtection, CsrfProtection, UnencryptedCsrfCookie}; @@ -21,9 +21,7 @@ impl Deref for CsrfToken { #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a CsrfToken { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req .extensions() .get::() @@ -51,9 +49,7 @@ impl CsrfVerifier { #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a CsrfVerifier { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req .extensions() .get::() diff --git a/poem/src/web/data.rs b/poem/src/web/data.rs index ea4dd281..db556381 100644 --- a/poem/src/web/data.rs +++ b/poem/src/web/data.rs @@ -4,6 +4,10 @@ use crate::{error::GetDataError, FromRequest, Request, RequestBody, Result}; /// An extractor that can extract data from the request extension. /// +/// # Errors +/// +/// - [`GetDataError`] +/// /// # Example /// /// ``` @@ -19,7 +23,7 @@ use crate::{error::GetDataError, FromRequest, Request, RequestBody, Result}; /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let app = Route::new().at("/", get(index)).with(AddData::new(10)); -/// let resp = app.call(Request::default()).await; +/// let resp = app.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// # }); /// ``` @@ -35,20 +39,19 @@ impl Deref for Data { #[async_trait::async_trait] impl<'a, T: Send + Sync + 'static> FromRequest<'a> for Data<&'a T> { - type Error = GetDataError; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { - req.extensions() - .get::() - .ok_or_else(|| GetDataError(std::any::type_name::())) - .map(Data) + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + Ok(Data( + req.extensions() + .get::() + .ok_or_else(|| GetDataError(std::any::type_name::()))?, + )) } } #[cfg(test)] mod tests { use super::*; - use crate::{handler, http::StatusCode, middleware::AddData, Endpoint, EndpointExt}; + use crate::{handler, middleware::AddData, Endpoint, EndpointExt}; #[tokio::test] async fn test_data_extractor() { @@ -58,7 +61,7 @@ mod tests { } let app = index.with(AddData::new(100i32)); - app.call(Request::default()).await; + app.call(Request::default()).await.unwrap(); } #[tokio::test] @@ -69,11 +72,12 @@ mod tests { } let app = index; - let mut resp = app.call(Request::default()).await; - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!( - resp.take_body().into_string().await.unwrap(), - "data of type `i32` was not found." + app.call(Request::default()) + .await + .unwrap_err() + .downcast_ref::(), + Some(&GetDataError("i32")) ); } @@ -83,7 +87,8 @@ mod tests { async fn index(value: Data<&String>) { assert_eq!(value.to_uppercase(), "ABC"); } + let app = index.with(AddData::new("abc".to_string())); - app.call(Request::default()).await; + app.call(Request::default()).await.unwrap(); } } diff --git a/poem/src/web/form.rs b/poem/src/web/form.rs index 90014dca..aeae4ab7 100644 --- a/poem/src/web/form.rs +++ b/poem/src/web/form.rs @@ -20,6 +20,11 @@ use crate::{ /// If the `Content-Type` is not `application/x-www-form-urlencoded`, then a /// `Bad Request` response will be returned. /// +/// # Errors +/// +/// - [`ReadBodyError`] +/// - [`ParseFormError`] +/// /// # Example /// /// ``` @@ -51,7 +56,8 @@ use crate::{ /// .uri(Uri::from_static("/?title=foo&content=bar")) /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:bar"); /// @@ -63,7 +69,8 @@ use crate::{ /// .content_type("application/x-www-form-urlencoded") /// .body("title=foo&content=bar"), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:bar"); /// # }); @@ -86,11 +93,13 @@ impl DerefMut for Form { #[async_trait::async_trait] impl<'a, T: DeserializeOwned> FromRequest<'a> for Form { - type Error = ParseFormError; - - async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { if req.method() == Method::GET { - Ok(serde_urlencoded::from_str(req.uri().query().unwrap_or_default()).map(Self)?) + Ok( + serde_urlencoded::from_str(req.uri().query().unwrap_or_default()) + .map_err(ParseFormError::UrlDecode) + .map(Self)?, + ) } else { let content_type = req.headers().get(header::CONTENT_TYPE); if content_type @@ -99,14 +108,15 @@ impl<'a, T: DeserializeOwned> FromRequest<'a> for Form { )) { return match content_type.and_then(|value| value.to_str().ok()) { - Some(ty) => Err(ParseFormError::InvalidContentType(ty.to_string())), - None => Err(ParseFormError::ContentTypeRequired), + Some(ty) => Err(ParseFormError::InvalidContentType(ty.to_string()).into()), + None => Err(ParseFormError::ContentTypeRequired.into()), }; } - Ok(Self(serde_urlencoded::from_bytes( - &body.take()?.into_bytes().await?, - )?)) + Ok(Self( + serde_urlencoded::from_bytes(&body.take()?.into_vec().await?) + .map_err(ParseFormError::UrlDecode)?, + )) } } } @@ -116,11 +126,7 @@ mod tests { use serde::Deserialize; use super::*; - use crate::{ - handler, - http::{StatusCode, Uri}, - Endpoint, - }; + use crate::{handler, http::Uri, Endpoint}; #[tokio::test] async fn test_form_extractor() { @@ -142,7 +148,8 @@ mod tests { .uri(Uri::from_static("/?name=abc&value=100")) .finish(), ) - .await; + .await + .unwrap(); index .call( @@ -151,16 +158,18 @@ mod tests { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body("name=abc&value=100"), ) - .await; + .await + .unwrap(); - let resp = index + assert!(index .call( Request::builder() .method(Method::POST) .header(header::CONTENT_TYPE, "application/json") .body("name=abc&value=100"), ) - .await; - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + .await + .unwrap_err() + .is::()); } } diff --git a/poem/src/web/json.rs b/poem/src/web/json.rs index 588a0083..9f46506a 100644 --- a/poem/src/web/json.rs +++ b/poem/src/web/json.rs @@ -1,21 +1,23 @@ use std::ops::{Deref, DerefMut}; +use http::StatusCode; use serde::{de::DeserializeOwned, Serialize}; use crate::{ - error::{InternalServerError, ParseJsonError}, - http::header, - web::RequestBody, - FromRequest, IntoResponse, Request, Response, Result, + error::ParseJsonError, http::header, web::RequestBody, FromRequest, IntoResponse, Request, + Response, Result, }; /// JSON extractor and response. /// -/// # Extractor -/// /// To extract the specified type of JSON from the body, `T` must implement /// [`serde::Deserialize`]. /// +/// # Errors +/// +/// - [`ReadBodyError`] +/// - [`ParseJsonError`] +/// /// ``` /// use poem::{ /// handler, @@ -44,7 +46,8 @@ use crate::{ /// .method(Method::POST) /// .body(r#"{"name": "foo"}"#), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!( /// resp.into_body().into_string().await.unwrap(), @@ -76,7 +79,7 @@ use crate::{ /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let app = Route::new().at("/", get(index)); -/// let resp = app.call(Request::default()).await; +/// let resp = app.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!( /// resp.into_body().into_string().await.unwrap(), @@ -103,11 +106,9 @@ impl DerefMut for Json { #[async_trait::async_trait] impl<'a, T: DeserializeOwned> FromRequest<'a> for Json { - type Error = ParseJsonError; - - async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { + async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { let data = body.take()?.into_bytes().await?; - Ok(Self(serde_json::from_slice(&data)?)) + Ok(Self(serde_json::from_slice(&data).map_err(ParseJsonError)?)) } } @@ -115,7 +116,11 @@ impl IntoResponse for Json { fn into_response(self) -> Response { let data = match serde_json::to_vec(&self.0) { Ok(data) => data, - Err(err) => return InternalServerError(err).as_response(), + Err(err) => { + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(err.to_string()) + } }; Response::builder() .header(header::CONTENT_TYPE, "application/json") @@ -162,7 +167,8 @@ mod tests { "#, ), ) - .await; + .await + .unwrap(); } #[tokio::test] @@ -175,7 +181,7 @@ mod tests { }) } - let mut resp = index.call(Request::default()).await; + let mut resp = index.call(Request::default()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( serde_json::from_str::(&resp.take_body().into_string().await.unwrap()) diff --git a/poem/src/web/mod.rs b/poem/src/web/mod.rs index 1356942b..9da23fb7 100644 --- a/poem/src/web/mod.rs +++ b/poem/src/web/mod.rs @@ -28,7 +28,7 @@ mod typed_header; #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] pub mod websocket; -use std::convert::Infallible; +use std::{convert::Infallible, fmt::Debug}; pub use addr::{LocalAddr, RemoteAddr}; use bytes::Bytes; @@ -218,7 +218,7 @@ impl RequestBody { /// Ready to accept a websocket [`WebSocket`](websocket::WebSocket) /// connection. /// -/// # Custom extractor +/// # Create you own extractor /// /// The following is an example of a custom token extractor, which extracts the /// token from the `MyToken` header. @@ -235,14 +235,14 @@ impl RequestBody { /// /// #[poem::async_trait] /// impl<'a> FromRequest<'a> for Token { -/// type Error = Error; -/// /// async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { /// let token = req /// .headers() /// .get("MyToken") /// .and_then(|value| value.to_str().ok()) -/// .ok_or_else(|| Error::new(StatusCode::BAD_REQUEST).with_reason("missing token"))?; +/// .ok_or_else(|| { +/// Error::new_with_string("missing token").with_status(StatusCode::BAD_REQUEST) +/// })?; /// Ok(Token(token.to_string())) /// } /// } @@ -259,28 +259,16 @@ impl RequestBody { /// .await; /// # }); /// ``` - #[async_trait::async_trait] pub trait FromRequest<'a>: Sized { - /// The error type of this extractor. - /// - /// If you don't know what type you should use, you can use - /// [`Error`](crate::Error). - type Error: IntoResponse; - /// Perform the extraction. - async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result; + async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result; } /// Represents a type that can convert into response. /// /// # Provided Implementations /// -/// - **Result<T: IntoResponse, E: IntoResponse>** -/// -/// if the result is `Ok`, use the `Ok` value as the response, otherwise use -/// the `Err` value. -/// /// - **()** /// /// Sets the status to `OK` with an empty body. @@ -355,7 +343,7 @@ pub trait FromRequest<'a>: Sized { /// with an event stream body. Use the [`SSE::new`](sse::SSE::new) function to /// create it. /// -/// # Custom response +/// # Create you own response /// /// ``` /// use poem::{handler, http::Uri, web::Query, Endpoint, IntoResponse, Request, Response}; @@ -392,6 +380,7 @@ pub trait FromRequest<'a>: Sized { /// .finish() /// ) /// .await +/// .unwrap() /// .take_body() /// .into_string() /// .await @@ -403,6 +392,7 @@ pub trait FromRequest<'a>: Sized { /// index /// .call(Request::builder().uri(Uri::from_static("/")).finish()) /// .await +/// .unwrap() /// .take_body() /// .into_string() /// .await @@ -411,7 +401,6 @@ pub trait FromRequest<'a>: Sized { /// ); /// # }); /// ``` - pub trait IntoResponse: Send { /// Consume itself and return [`Response`]. fn into_response(self) -> Response; @@ -630,19 +619,6 @@ impl IntoResponse for (HeaderMap, T) { } } -impl IntoResponse for std::result::Result -where - T: IntoResponse, - E: IntoResponse, -{ - fn into_response(self) -> Response { - match self { - Ok(resp) => resp.into_response(), - Err(err) => err.into_response(), - } - } -} - /// An HTML response. #[derive(Debug, Clone, Eq, PartialEq, Default)] pub struct Html(pub T); @@ -657,118 +633,92 @@ impl + Send> IntoResponse for Html { #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a Request { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a Uri { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req.uri()) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for Method { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req.method().clone()) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for Version { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req.version()) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a HeaderMap { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(req.headers()) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for Body { - type Error = ReadBodyError; - - async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { - body.take() + async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { + Ok(body.take()?) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for String { - type Error = ReadBodyError; - - async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { + async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { let data = body.take()?.into_bytes().await?; - String::from_utf8(data.to_vec()).map_err(ReadBodyError::Utf8) + Ok(String::from_utf8(data.to_vec()).map_err(ReadBodyError::Utf8)?) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for Bytes { - type Error = ReadBodyError; - - async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { + async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { Ok(body.take()?.into_bytes().await?) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for Vec { - type Error = ReadBodyError; - - async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { + async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { Ok(body.take()?.into_vec().await?) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a RemoteAddr { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(&req.state().remote_addr) } } #[async_trait::async_trait] impl<'a> FromRequest<'a> for &'a LocalAddr { - type Error = Infallible; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { Ok(&req.state().local_addr) } } #[async_trait::async_trait] impl<'a, T: FromRequest<'a>> FromRequest<'a> for Option { - type Error = T::Error; - - async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { Ok(T::from_request(req, body).await.ok()) } } #[async_trait::async_trait] -impl<'a, T: FromRequest<'a>> FromRequest<'a> for Result { - type Error = Infallible; - - async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { +impl<'a, T: FromRequest<'a>> FromRequest<'a> for Result { + async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { Ok(T::from_request(req, body).await) } } @@ -776,7 +726,7 @@ impl<'a, T: FromRequest<'a>> FromRequest<'a> for Result { #[cfg(test)] mod tests { use super::*; - use crate::{Addr, Error}; + use crate::Addr; #[tokio::test] async fn into_response() { @@ -853,38 +803,6 @@ mod tests { ); assert_eq!(resp.into_body().into_string().await.unwrap(), "abc"); - // Result - let resp = Ok::<_, Error>("abc").into_response(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.into_body().into_string().await.unwrap(), "abc"); - - let resp = Err::<(), _>(Error::new(StatusCode::BAD_GATEWAY).with_reason("bad gateway")) - .into_response(); - assert_eq!(resp.status(), StatusCode::BAD_GATEWAY); - assert_eq!(resp.into_body().into_string().await.unwrap(), "bad gateway"); - - struct CustomError; - - impl IntoResponse for CustomError { - fn into_response(self) -> Response { - Response::builder() - .status(StatusCode::CONFLICT) - .header("Value1", "123") - .body("custom error") - } - } - - let resp = Err::<(), _>(CustomError).into_response(); - assert_eq!(resp.status(), StatusCode::CONFLICT); - assert_eq!( - resp.headers().get("Value1"), - Some(&HeaderValue::from_static("123")) - ); - assert_eq!( - resp.into_body().into_string().await.unwrap(), - "custom error" - ); - // StatusCode let resp = StatusCode::CREATED.into_response(); assert_eq!(resp.status(), StatusCode::CREATED); diff --git a/poem/src/web/multipart.rs b/poem/src/web/multipart.rs index fbc505ea..b6c3b3fe 100644 --- a/poem/src/web/multipart.rs +++ b/poem/src/web/multipart.rs @@ -82,6 +82,11 @@ impl Field { /// An extractor that parses `multipart/form-data` requests commonly used with /// file uploads. /// +/// # Errors +/// +/// - [`ReadBodyError`] +/// - [`ParseMultipartError`] +/// /// # Example /// /// ``` @@ -106,9 +111,7 @@ pub struct Multipart { #[async_trait::async_trait] impl<'a> FromRequest<'a> for Multipart { - type Error = ParseMultipartError; - - async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { + async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result { let content_type = req .headers() .get(header::CONTENT_TYPE) @@ -119,10 +122,12 @@ impl<'a> FromRequest<'a> for Multipart { if content_type.essence_str() != mime::MULTIPART_FORM_DATA { return Err(ParseMultipartError::InvalidContentType( content_type.essence_str().to_string(), - )); + ) + .into()); } - let boundary = multer::parse_boundary(content_type.as_ref())?; + let boundary = multer::parse_boundary(content_type.as_ref()) + .map_err(ParseMultipartError::Multipart)?; Ok(Self { inner: multer::Multipart::new( tokio_util::io::ReaderStream::new(body.take()?.into_async_read()), @@ -154,18 +159,18 @@ mod tests { todo!() } - let mut resp = index + let err = index .call( Request::builder() .header("content-type", "multipart/json; boundary=X-BOUNDARY") .body(()), ) - .await; - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - assert_eq!( - resp.take_body().into_string().await.unwrap(), - "invalid content type `multipart/json`, expect: `multipart/form-data`" - ); + .await + .unwrap_err(); + match err.downcast_ref::().unwrap() { + ParseMultipartError::InvalidContentType(ct) if ct == "multipart/json" => {} + _ => panic!(), + } } #[tokio::test] @@ -193,7 +198,8 @@ mod tests { .header("content-type", "multipart/form-data; boundary=X-BOUNDARY") .body(data), ) - .await; + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/poem/src/web/path/mod.rs b/poem/src/web/path/mod.rs index 423e53a0..d4a6f7da 100644 --- a/poem/src/web/path/mod.rs +++ b/poem/src/web/path/mod.rs @@ -10,6 +10,10 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result}; /// An extractor that will get captures from the URL and parse them using /// `serde`. /// +/// # Errors +/// +/// - [`ParsePathError`] +/// /// # Example /// /// ``` @@ -33,7 +37,8 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result}; /// .uri(Uri::from_static("/users/100/team/300")) /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "100:300"); /// # }); @@ -62,7 +67,8 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result}; /// .uri(Uri::from_static("/users/100")) /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "100"); /// # }); @@ -99,7 +105,8 @@ use crate::{error::ParsePathError, FromRequest, Request, RequestBody, Result}; /// .uri(Uri::from_static("/users/foo/team/100")) /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:100"); /// # }); @@ -121,13 +128,18 @@ impl DerefMut for Path { } } -#[async_trait::async_trait] -impl<'a, T: DeserializeOwned> FromRequest<'a> for Path { - type Error = ParsePathError; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { - T::deserialize(de::PathDeserializer::new(&req.state().match_params)) - .map_err(|_| ParsePathError) - .map(Path) +impl Path { + async fn internal_from_request(req: &Request) -> Result { + Ok(Path( + T::deserialize(de::PathDeserializer::new(&req.state().match_params)) + .map_err(|_| ParsePathError)?, + )) + } +} + +#[async_trait::async_trait] +impl<'a, T: DeserializeOwned> FromRequest<'a> for Path { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + Self::internal_from_request(req).await.map_err(Into::into) } } diff --git a/poem/src/web/query.rs b/poem/src/web/query.rs index 8ef7f779..e44a9435 100644 --- a/poem/src/web/query.rs +++ b/poem/src/web/query.rs @@ -6,6 +6,10 @@ use crate::{error::ParseQueryError, FromRequest, Request, RequestBody, Result}; /// An extractor that can deserialize some type from query string. /// +/// # Errors +/// +/// - [`ParseQueryError`] +/// /// # Example /// /// ``` @@ -37,7 +41,8 @@ use crate::{error::ParseQueryError, FromRequest, Request, RequestBody, Result}; /// .uri(Uri::from_static("/?title=foo&content=bar")) /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "foo:bar"); /// # }); @@ -59,12 +64,16 @@ impl DerefMut for Query { } } +impl Query { + async fn internal_from_request(req: &Request) -> Result { + Ok(serde_urlencoded::from_str(req.uri().query().unwrap_or_default()).map(Self)?) + } +} + #[async_trait::async_trait] impl<'a, T: DeserializeOwned> FromRequest<'a> for Query { - type Error = ParseQueryError; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { - Ok(serde_urlencoded::from_str(req.uri().query().unwrap_or_default()).map(Self)?) + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + Self::internal_from_request(req).await.map_err(Into::into) } } @@ -95,6 +104,7 @@ mod tests { .uri(Uri::from_static("/?name=abc&value=100")) .finish(), ) - .await; + .await + .unwrap(); } } diff --git a/poem/src/web/redirect.rs b/poem/src/web/redirect.rs index 2341f391..3bfa64f8 100644 --- a/poem/src/web/redirect.rs +++ b/poem/src/web/redirect.rs @@ -22,7 +22,7 @@ use crate::{ /// /// let app = Route::new().at("/", get(index)); /// # tokio::runtime::Runtime::new().unwrap().block_on(async { -/// let resp = app.call(Request::default()).await; +/// let resp = app.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::MOVED_PERMANENTLY); /// assert_eq!( /// resp.headers().get(header::LOCATION), diff --git a/poem/src/web/sse/response.rs b/poem/src/web/sse/response.rs index 7e52023b..6d367664 100644 --- a/poem/src/web/sse/response.rs +++ b/poem/src/web/sse/response.rs @@ -28,7 +28,7 @@ use crate::{Body, IntoResponse, Response}; /// } /// /// # tokio::runtime::Runtime::new().unwrap().block_on(async { -/// let mut resp = index.call(Request::default()).await; +/// let mut resp = index.call(Request::default()).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!( /// resp.take_body().into_string().await.unwrap(), diff --git a/poem/src/web/tempfile.rs b/poem/src/web/tempfile.rs index b24939fe..52d443a8 100644 --- a/poem/src/web/tempfile.rs +++ b/poem/src/web/tempfile.rs @@ -8,18 +8,19 @@ use tokio::{ io::{AsyncRead, AsyncSeekExt, ReadBuf, SeekFrom}, }; -use crate::{error::ReadBodyError, FromRequest, Request, RequestBody}; +use crate::{error::ReadBodyError, FromRequest, Request, RequestBody, Result}; /// An extractor that extracts the body and writes the contents to a temporary /// file. +/// +/// # Errors +/// +/// - [`ReadBodyError`] #[cfg_attr(docsrs, doc(cfg(feature = "tempfile")))] pub struct TempFile(File); -#[async_trait::async_trait] -impl<'a> FromRequest<'a> for TempFile { - type Error = ReadBodyError; - - async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { +impl TempFile { + async fn internal_from_request(body: &mut RequestBody) -> Result { let body = body.take()?; let mut reader = body.into_async_read(); let mut file = tokio::fs::File::from_std(::libtempfile::tempfile()?); @@ -29,6 +30,13 @@ impl<'a> FromRequest<'a> for TempFile { } } +#[async_trait::async_trait] +impl<'a> FromRequest<'a> for TempFile { + async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result { + Self::internal_from_request(body).await.map_err(Into::into) + } +} + impl AsyncRead for TempFile { fn poll_read( mut self: Pin<&mut Self>, @@ -55,6 +63,9 @@ mod tests { assert_eq!(s, "abcdef"); } - index123.call(Request::builder().body("abcdef")).await; + index123 + .call(Request::builder().body("abcdef")) + .await + .unwrap(); } } diff --git a/poem/src/web/typed_header.rs b/poem/src/web/typed_header.rs index d58dfcf4..1a47b8ea 100644 --- a/poem/src/web/typed_header.rs +++ b/poem/src/web/typed_header.rs @@ -6,6 +6,10 @@ use crate::{error::ParseTypedHeaderError, FromRequest, Request, RequestBody, Res /// An extractor that extracts a typed header value. /// +/// # Errors +/// +/// - [`ParseTypedHeaderError`] +/// /// # Example /// /// ``` @@ -30,11 +34,13 @@ use crate::{error::ParseTypedHeaderError, FromRequest, Request, RequestBody, Res /// .header(header::HOST, "example.com") /// .finish(), /// ) -/// .await; +/// .await +/// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// assert_eq!(resp.into_body().into_string().await.unwrap(), "example.com"); /// # }); /// ``` +#[derive(Debug)] pub struct TypedHeader(pub T); impl Deref for TypedHeader { @@ -51,11 +57,8 @@ impl DerefMut for TypedHeader { } } -#[async_trait::async_trait] -impl<'a, T: Header> FromRequest<'a> for TypedHeader { - type Error = ParseTypedHeaderError; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { +impl TypedHeader { + async fn internal_from_request(req: &Request) -> Result { let value = req.headers().typed_try_get::()?; Ok(Self(value.ok_or_else(|| { ParseTypedHeaderError::HeaderRequired(T::name().to_string()) @@ -63,6 +66,13 @@ impl<'a, T: Header> FromRequest<'a> for TypedHeader { } } +#[async_trait::async_trait] +impl<'a, T: Header> FromRequest<'a> for TypedHeader { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + Self::internal_from_request(req).await.map_err(Into::into) + } +} + #[cfg(test)] mod tests { use super::*; @@ -81,7 +91,8 @@ mod tests { index .call(Request::builder().header("content-length", 3).body("abc")) - .await; + .await + .unwrap(); } #[tokio::test] @@ -89,8 +100,8 @@ mod tests { let (req, mut body) = Request::builder().body("abc").split(); let res = TypedHeader::::from_request(&req, &mut body).await; - match res { - Err(ParseTypedHeaderError::HeaderRequired(name)) if name == "host" => {} + match res.unwrap_err().downcast_ref::() { + Some(ParseTypedHeaderError::HeaderRequired(name)) if name == "host" => {} _ => panic!(), } } diff --git a/poem/src/web/websocket/extractor.rs b/poem/src/web/websocket/extractor.rs index 2cf28b4a..b3d14185 100644 --- a/poem/src/web/websocket/extractor.rs +++ b/poem/src/web/websocket/extractor.rs @@ -14,6 +14,10 @@ use crate::{ }; /// An extractor that can accept websocket connections. +/// +/// # Errors +/// +/// - [`WebSocketError`] pub struct WebSocket { key: HeaderValue, on_upgrade: OnUpgrade, @@ -21,11 +25,8 @@ pub struct WebSocket { sec_websocket_protocol: Option, } -#[async_trait::async_trait] -impl<'a> FromRequest<'a> for WebSocket { - type Error = WebSocketError; - - async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { +impl WebSocket { + async fn internal_from_request(req: &Request) -> Result { if req.method() != Method::GET || req.headers().get(header::UPGRADE) != Some(&HeaderValue::from_static("websocket")) || req.headers().get(header::SEC_WEBSOCKET_VERSION) @@ -60,6 +61,13 @@ impl<'a> FromRequest<'a> for WebSocket { } } +#[async_trait::async_trait] +impl<'a> FromRequest<'a> for WebSocket { + async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { + Self::internal_from_request(req).await.map_err(Into::into) + } +} + impl WebSocket { /// Set the known protocols. ///