feat(http1): Add support for sending HTTP/1.1 Chunked Trailer Fields (#3375)

Closes #2719
This commit is contained in:
Herman J. Radtke III
2023-12-15 13:37:48 -05:00
committed by GitHub
parent 0f2929b944
commit 31b4180752
8 changed files with 611 additions and 31 deletions

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use crate::rt::{Read, Write};
use bytes::{Buf, Bytes};
use http::header::{HeaderValue, CONNECTION};
use http::header::{HeaderValue, CONNECTION, TE};
use http::{HeaderMap, Method, Version};
use httparse::ParserConfig;
@@ -75,6 +75,7 @@ where
// We assume a modern world where the remote speaks HTTP/1.1.
// If they tell us otherwise, we'll downgrade in `read_head`.
version: Version::HTTP_11,
allow_trailer_fields: false,
},
_marker: PhantomData,
}
@@ -264,6 +265,13 @@ where
self.state.reading = Reading::Body(Decoder::new(msg.decode));
}
self.state.allow_trailer_fields = msg
.head
.headers
.get(TE)
.map(|te_header| te_header == "trailers")
.unwrap_or(false);
Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
}
@@ -640,6 +648,31 @@ where
self.state.writing = state;
}
pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) {
if T::is_server() && self.state.allow_trailer_fields == false {
debug!("trailers not allowed to be sent");
return;
}
debug_assert!(self.can_write_body() && self.can_buffer_body());
match self.state.writing {
Writing::Body(ref encoder) => {
if let Some(enc_buf) =
encoder.encode_trailers(trailers, self.state.title_case_headers)
{
self.io.buffer(enc_buf);
self.state.writing = if encoder.is_last() || encoder.is_close_delimited() {
Writing::Closed
} else {
Writing::KeepAlive
};
}
}
_ => unreachable!("write_trailers invalid state: {:?}", self.state.writing),
}
}
pub(crate) fn write_body_and_end(&mut self, chunk: B) {
debug_assert!(self.can_write_body() && self.can_buffer_body());
// empty chunks should be discarded at Dispatcher level
@@ -842,6 +875,8 @@ struct State {
upgrade: Option<crate::upgrade::Pending>,
/// Either HTTP/1.0 or 1.1 connection
version: Version,
/// Flag to track if trailer fields are allowed to be sent
allow_trailer_fields: bool,
}
#[derive(Debug)]

View File

@@ -351,27 +351,33 @@ where
*clear_body = true;
crate::Error::new_user_body(e)
})?;
let chunk = if let Ok(data) = frame.into_data() {
data
} else {
trace!("discarding non-data frame");
continue;
};
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;
if frame.is_data() {
let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;
} else {
self.conn.write_body_and_end(chunk);
}
} else {
self.conn.write_body_and_end(chunk);
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
}
} else if frame.is_trailers() {
*clear_body = true;
self.conn.write_trailers(
frame.into_trailers().unwrap_or_else(|_| unreachable!()),
);
} else {
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
trace!("discarding unknown frame");
continue;
}
} else {
*clear_body = true;

View File

@@ -1,10 +1,19 @@
use std::collections::HashMap;
use std::fmt;
use std::io::IoSlice;
use bytes::buf::{Chain, Take};
use bytes::Buf;
use bytes::{Buf, Bytes};
use http::{
header::{
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
},
HeaderMap, HeaderName, HeaderValue,
};
use super::io::WriteBuf;
use super::role::{write_headers, write_headers_title_case};
type StaticBuf = &'static [u8];
@@ -26,7 +35,7 @@ pub(crate) struct NotEof(u64);
#[derive(Debug, PartialEq, Clone)]
enum Kind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked,
Chunked(Option<Vec<HeaderValue>>),
/// An Encoder for when Content-Length is set.
///
/// Enforces that the body is not longer than the Content-Length header.
@@ -45,6 +54,7 @@ enum BufKind<B> {
Limited(Take<B>),
Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
ChunkedEnd(StaticBuf),
Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
}
impl Encoder {
@@ -55,7 +65,7 @@ impl Encoder {
}
}
pub(crate) fn chunked() -> Encoder {
Encoder::new(Kind::Chunked)
Encoder::new(Kind::Chunked(None))
}
pub(crate) fn length(len: u64) -> Encoder {
@@ -67,6 +77,16 @@ impl Encoder {
Encoder::new(Kind::CloseDelimited)
}
pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
match self.kind {
Kind::Chunked(_) => Encoder {
kind: Kind::Chunked(Some(trailers)),
is_last: self.is_last,
},
_ => self,
}
}
pub(crate) fn is_eof(&self) -> bool {
matches!(self.kind, Kind::Length(0))
}
@@ -89,10 +109,17 @@ impl Encoder {
}
}
pub(crate) fn is_chunked(&self) -> bool {
match self.kind {
Kind::Chunked(_) => true,
_ => false,
}
}
pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
match self.kind {
Kind::Length(0) => Ok(None),
Kind::Chunked => Ok(Some(EncodedBuf {
Kind::Chunked(_) => Ok(Some(EncodedBuf {
kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
})),
#[cfg(feature = "server")]
@@ -109,7 +136,7 @@ impl Encoder {
debug_assert!(len > 0, "encode() called with empty buf");
let kind = match self.kind {
Kind::Chunked => {
Kind::Chunked(_) => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
@@ -136,6 +163,53 @@ impl Encoder {
EncodedBuf { kind }
}
pub(crate) fn encode_trailers<B>(
&self,
trailers: HeaderMap,
title_case_headers: bool,
) -> Option<EncodedBuf<B>> {
match &self.kind {
Kind::Chunked(Some(ref allowed_trailer_fields)) => {
let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields);
let mut cur_name = None;
let mut allowed_trailers = HeaderMap::new();
for (opt_name, value) in trailers {
if let Some(n) = opt_name {
cur_name = Some(n);
}
let name = cur_name.as_ref().expect("current header name");
if allowed_trailer_field_map.contains_key(name.as_str())
&& valid_trailer_field(name)
{
allowed_trailers.insert(name, value);
}
}
let mut buf = Vec::new();
if title_case_headers {
write_headers_title_case(&allowed_trailers, &mut buf);
} else {
write_headers(&allowed_trailers, &mut buf);
}
if buf.is_empty() {
return None;
}
Some(EncodedBuf {
kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
})
}
_ => {
debug!("attempted to encode trailers for non-chunked response");
None
}
}
}
pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
where
B: Buf,
@@ -144,7 +218,7 @@ impl Encoder {
debug_assert!(len > 0, "encode() called with empty buf");
match self.kind {
Kind::Chunked => {
Kind::Chunked(_) => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
@@ -181,6 +255,40 @@ impl Encoder {
}
}
fn valid_trailer_field(name: &HeaderName) -> bool {
match name {
&AUTHORIZATION => false,
&CACHE_CONTROL => false,
&CONTENT_ENCODING => false,
&CONTENT_LENGTH => false,
&CONTENT_RANGE => false,
&CONTENT_TYPE => false,
&HOST => false,
&MAX_FORWARDS => false,
&SET_COOKIE => false,
&TRAILER => false,
&TRANSFER_ENCODING => false,
&TE => false,
_ => true,
}
}
fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
let mut trailer_map = HashMap::new();
for header_value in allowed_trailer_fields {
if let Ok(header_str) = header_value.to_str() {
let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();
for item in items {
trailer_map.entry(item.to_string()).or_insert(());
}
}
}
trailer_map
}
impl<B> Buf for EncodedBuf<B>
where
B: Buf,
@@ -192,6 +300,7 @@ where
BufKind::Limited(ref b) => b.remaining(),
BufKind::Chunked(ref b) => b.remaining(),
BufKind::ChunkedEnd(ref b) => b.remaining(),
BufKind::Trailers(ref b) => b.remaining(),
}
}
@@ -202,6 +311,7 @@ where
BufKind::Limited(ref b) => b.chunk(),
BufKind::Chunked(ref b) => b.chunk(),
BufKind::ChunkedEnd(ref b) => b.chunk(),
BufKind::Trailers(ref b) => b.chunk(),
}
}
@@ -212,6 +322,7 @@ where
BufKind::Limited(ref mut b) => b.advance(cnt),
BufKind::Chunked(ref mut b) => b.advance(cnt),
BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
BufKind::Trailers(ref mut b) => b.advance(cnt),
}
}
@@ -222,6 +333,7 @@ where
BufKind::Limited(ref b) => b.chunks_vectored(dst),
BufKind::Chunked(ref b) => b.chunks_vectored(dst),
BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
BufKind::Trailers(ref b) => b.chunks_vectored(dst),
}
}
}
@@ -327,7 +439,16 @@ impl std::error::Error for NotEof {}
#[cfg(test)]
mod tests {
use std::iter::FromIterator;
use bytes::BufMut;
use http::{
header::{
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
},
HeaderMap, HeaderName, HeaderValue,
};
use super::super::io::Cursor;
use super::Encoder;
@@ -402,4 +523,145 @@ mod tests {
assert!(!encoder.is_eof());
encoder.end::<()>().unwrap();
}
#[test]
fn chunked_with_valid_trailers() {
let encoder = Encoder::chunked();
let trailers = vec![HeaderValue::from_static("chunky-trailer")];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
let headers = HeaderMap::from_iter(
vec![
(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data"),
),
(
HeaderName::from_static("should-not-be-included"),
HeaderValue::from_static("oops"),
),
]
.into_iter(),
);
let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
let mut dst = Vec::new();
dst.put(buf1);
assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
}
#[test]
fn chunked_with_multiple_trailer_headers() {
let encoder = Encoder::chunked();
let trailers = vec![
HeaderValue::from_static("chunky-trailer"),
HeaderValue::from_static("chunky-trailer-2"),
];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
let headers = HeaderMap::from_iter(
vec![
(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data"),
),
(
HeaderName::from_static("chunky-trailer-2"),
HeaderValue::from_static("more header data"),
),
]
.into_iter(),
);
let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
let mut dst = Vec::new();
dst.put(buf1);
assert_eq!(
dst,
b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
);
}
#[test]
fn chunked_with_no_trailer_header() {
let encoder = Encoder::chunked();
let headers = HeaderMap::from_iter(
vec![(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data"),
)]
.into_iter(),
);
assert!(encoder
.encode_trailers::<&[u8]>(headers.clone(), false)
.is_none());
let trailers = vec![];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
}
#[test]
fn chunked_with_invalid_trailers() {
let encoder = Encoder::chunked();
let trailers = format!(
"{},{},{},{},{},{},{},{},{},{},{},{}",
AUTHORIZATION,
CACHE_CONTROL,
CONTENT_ENCODING,
CONTENT_LENGTH,
CONTENT_RANGE,
CONTENT_TYPE,
HOST,
MAX_FORWARDS,
SET_COOKIE,
TRAILER,
TRANSFER_ENCODING,
TE,
);
let trailers = vec![HeaderValue::from_str(&trailers).unwrap()];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
headers.insert(HOST, HeaderValue::from_static("header data"));
headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
headers.insert(TRAILER, HeaderValue::from_static("header data"));
headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
headers.insert(TE, HeaderValue::from_static("header data"));
assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
}
#[test]
fn chunked_with_title_case_headers() {
let encoder = Encoder::chunked();
let trailers = vec![HeaderValue::from_static("chunky-trailer")];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
let headers = HeaderMap::from_iter(
vec![(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data"),
)]
.into_iter(),
);
let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
let mut dst = Vec::new();
dst.put(buf1);
assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
}
}

View File

@@ -625,6 +625,7 @@ impl Server {
};
let mut encoder = Encoder::length(0);
let mut allowed_trailer_fields: Option<Vec<HeaderValue>> = None;
let mut wrote_date = false;
let mut cur_name = None;
let mut is_name_written = false;
@@ -811,6 +812,38 @@ impl Server {
header::DATE => {
wrote_date = true;
}
header::TRAILER => {
// check that we actually can send a chunked body...
if msg.head.version == Version::HTTP_10
|| !Server::can_chunked(msg.req_method, msg.head.subject)
{
continue;
}
if !is_name_written {
is_name_written = true;
header_name_writer.write_header_name_with_colon(
dst,
"trailer: ",
header::TRAILER,
);
extend(dst, value.as_bytes());
} else {
extend(dst, b", ");
extend(dst, value.as_bytes());
}
match allowed_trailer_fields {
Some(ref mut allowed_trailer_fields) => {
allowed_trailer_fields.push(value);
}
None => {
allowed_trailer_fields = Some(vec![value]);
}
}
continue 'headers;
}
_ => (),
}
//TODO: this should perhaps instead combine them into
@@ -895,6 +928,12 @@ impl Server {
extend(dst, b"\r\n");
}
if encoder.is_chunked() {
if let Some(allowed_trailer_fields) = allowed_trailer_fields {
encoder = encoder.into_chunked_with_trailing_fields(allowed_trailer_fields);
}
}
Ok(encoder.set_last(is_last))
}
}
@@ -1302,6 +1341,19 @@ impl Client {
}
};
let encoder = encoder.map(|enc| {
if enc.is_chunked() {
let allowed_trailer_fields: Vec<HeaderValue> =
headers.get_all(header::TRAILER).iter().cloned().collect();
if !allowed_trailer_fields.is_empty() {
return enc.into_chunked_with_trailing_fields(allowed_trailer_fields);
}
}
enc
});
// This is because we need a second mutable borrow to remove
// content-length header.
if let Some(encoder) = encoder {
@@ -1464,8 +1516,7 @@ fn title_case(dst: &mut Vec<u8>, name: &[u8]) {
}
}
#[cfg(feature = "client")]
fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) {
pub(crate) fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) {
for (name, value) in headers {
title_case(dst, name.as_str().as_bytes());
extend(dst, b": ");
@@ -1474,8 +1525,7 @@ fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) {
}
}
#[cfg(feature = "client")]
fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) {
pub(crate) fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) {
for (name, value) in headers {
extend(dst, name.as_str().as_bytes());
extend(dst, b": ");

View File

@@ -5,6 +5,7 @@ use std::convert::Infallible;
use std::fmt;
use std::future::Future;
use std::io::{Read, Write};
use std::iter::FromIterator;
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::thread;
@@ -13,7 +14,7 @@ use std::time::Duration;
use http::uri::PathAndQuery;
use http_body_util::{BodyExt, StreamBody};
use hyper::body::Frame;
use hyper::header::HeaderValue;
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use hyper::{Method, Request, StatusCode, Uri, Version};
use bytes::Bytes;
@@ -409,6 +410,15 @@ macro_rules! __client_req_prop {
Frame::data,
)));
}};
($req_builder:ident, $body:ident, $addr:ident, body_stream_with_trailers: $body_e:expr) => {{
use support::trailers::StreamBodyWithTrailers;
let (body, trailers) = $body_e;
$body = BodyExt::boxed(StreamBodyWithTrailers::with_trailers(
futures_util::TryStreamExt::map_ok(body, Frame::data),
trailers,
));
}};
}
macro_rules! __client_req_header {
@@ -632,6 +642,44 @@ test! {
body: &b"hello"[..],
}
test! {
name: client_post_req_body_chunked_with_trailer,
server:
expected: "\
POST / HTTP/1.1\r\n\
trailer: chunky-trailer\r\n\
host: {addr}\r\n\
transfer-encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
chunky-trailer: header data\r\n\
\r\n\
",
reply: REPLY_OK,
client:
request: {
method: POST,
url: "http://{addr}/",
headers: {
"trailer" => "chunky-trailer",
},
body_stream_with_trailers: (
(futures_util::stream::once(async { Ok::<_, Infallible>(Bytes::from("hello"))})),
HeaderMap::from_iter(vec![(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data")
)].into_iter())),
},
response:
status: OK,
headers: {},
body: None,
}
test! {
name: client_get_req_body_sized,

View File

@@ -19,7 +19,7 @@ use futures_channel::oneshot;
use futures_util::future::{self, Either, FutureExt};
use h2::client::SendRequest;
use h2::{RecvStream, SendStream};
use http::header::{HeaderName, HeaderValue};
use http::header::{HeaderMap, HeaderName, HeaderValue};
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
use hyper::rt::Timer;
use hyper::rt::{Read as AsyncRead, Write as AsyncWrite};
@@ -2595,6 +2595,94 @@ async fn http2_keep_alive_count_server_pings() {
.expect("timed out waiting for pings");
}
#[test]
fn http1_trailer_fields() {
let body = futures_util::stream::once(async move { Ok("hello".into()) });
let mut headers = HeaderMap::new();
headers.insert("chunky-trailer", "header data".parse().unwrap());
// Invalid trailer field that should not be sent
headers.insert("Host", "www.example.com".parse().unwrap());
// Not specified in Trailer header, so should not be sent
headers.insert("foo", "bar".parse().unwrap());
let server = serve();
server
.reply()
.header("transfer-encoding", "chunked")
.header("trailer", "chunky-trailer")
.body_stream_with_trailers(body, headers);
let mut req = connect(server.addr());
req.write_all(
b"\
GET / HTTP/1.1\r\n\
Host: example.domain\r\n\
Connection: keep-alive\r\n\
TE: trailers\r\n\
\r\n\
",
)
.expect("writing");
let chunky_trailer_chunk = b"\r\nchunky-trailer: header data\r\n\r\n";
let res = read_until(&mut req, |buf| buf.ends_with(chunky_trailer_chunk)).expect("reading");
let sres = s(&res);
let expected_head =
"HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n";
assert_eq!(&sres[..expected_head.len()], expected_head);
// skip the date header
let date_fragment = "GMT\r\n\r\n";
let pos = sres.find(date_fragment).expect("find GMT");
let body = &sres[pos + date_fragment.len()..];
let expected_body = "5\r\nhello\r\n0\r\nchunky-trailer: header data\r\n\r\n";
assert_eq!(body, expected_body);
}
#[test]
fn http1_trailer_fields_not_allowed() {
let body = futures_util::stream::once(async move { Ok("hello".into()) });
let mut headers = HeaderMap::new();
headers.insert("chunky-trailer", "header data".parse().unwrap());
let server = serve();
server
.reply()
.header("transfer-encoding", "chunked")
.header("trailer", "chunky-trailer")
.body_stream_with_trailers(body, headers);
let mut req = connect(server.addr());
// TE: trailers is not specified in request headers
req.write_all(
b"\
GET / HTTP/1.1\r\n\
Host: example.domain\r\n\
Connection: keep-alive\r\n\
\r\n\
",
)
.expect("writing");
let last_chunk = b"\r\n0\r\n\r\n";
let res = read_until(&mut req, |buf| buf.ends_with(last_chunk)).expect("reading");
let sres = s(&res);
let expected_head =
"HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n";
assert_eq!(&sres[..expected_head.len()], expected_head);
// skip the date header
let date_fragment = "GMT\r\n\r\n";
let pos = sres.find(date_fragment).expect("find GMT");
let body = &sres[pos + date_fragment.len()..];
// no trailer fields should be sent because TE: trailers was not in request headers
let expected_body = "5\r\nhello\r\n0\r\n\r\n";
assert_eq!(body, expected_body);
}
// -------------------------------------------------
// the Server that is used to run all the tests with
// -------------------------------------------------
@@ -2700,6 +2788,19 @@ impl<'a> ReplyBuilder<'a> {
self.tx.lock().unwrap().send(Reply::Body(body)).unwrap();
}
fn body_stream_with_trailers<S>(self, stream: S, trailers: HeaderMap)
where
S: futures_util::Stream<Item = Result<Bytes, BoxError>> + Send + Sync + 'static,
{
use futures_util::TryStreamExt;
use hyper::body::Frame;
use support::trailers::StreamBodyWithTrailers;
let mut stream_body = StreamBodyWithTrailers::new(stream.map_ok(Frame::data));
stream_body.set_trailers(trailers);
let body = BodyExt::boxed(stream_body);
self.tx.lock().unwrap().send(Reply::Body(body)).unwrap();
}
#[allow(dead_code)]
fn error<E: Into<BoxError>>(self, err: E) {
self.tx

View File

@@ -24,6 +24,8 @@ mod tokiort;
#[allow(unused)]
pub use tokiort::{TokioExecutor, TokioIo, TokioTimer};
pub mod trailers;
#[allow(unused_macros)]
macro_rules! t {
(

76
tests/support/trailers.rs Normal file
View File

@@ -0,0 +1,76 @@
use bytes::Buf;
use futures_util::stream::Stream;
use http::header::HeaderMap;
use http_body::{Body, Frame};
use pin_project_lite::pin_project;
use std::{
pin::Pin,
task::{Context, Poll},
};
pin_project! {
/// A body created from a [`Stream`].
#[derive(Clone, Debug)]
pub struct StreamBodyWithTrailers<S> {
#[pin]
stream: S,
trailers: Option<HeaderMap>,
}
}
impl<S> StreamBodyWithTrailers<S> {
/// Create a new `StreamBodyWithTrailers`.
pub fn new(stream: S) -> Self {
Self {
stream,
trailers: None,
}
}
pub fn with_trailers(stream: S, trailers: HeaderMap) -> Self {
Self {
stream,
trailers: Some(trailers),
}
}
pub fn set_trailers(&mut self, trailers: HeaderMap) {
self.trailers = Some(trailers);
}
}
impl<S, D, E> Body for StreamBodyWithTrailers<S>
where
S: Stream<Item = Result<Frame<D>, E>>,
D: Buf,
{
type Data = D;
type Error = E;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let project = self.project();
match project.stream.poll_next(cx) {
Poll::Ready(Some(result)) => Poll::Ready(Some(result)),
Poll::Ready(None) => match project.trailers.take() {
Some(trailers) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
None => Poll::Ready(None),
},
Poll::Pending => Poll::Pending,
}
}
}
impl<S: Stream> Stream for StreamBodyWithTrailers<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().stream.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}