diff --git a/tests/server.rs b/tests/server.rs index 2ba6f92c..485a10e0 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -23,7 +23,7 @@ use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use hyper::rt::Timer; use hyper::rt::{Read as AsyncRead, Write as AsyncWrite}; use support::{TokioExecutor, TokioIo, TokioTimer}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio::net::{TcpListener as TkTcpListener, TcpListener, TcpStream as TkTcpStream}; use hyper::body::{Body, Incoming as IncomingBody}; @@ -1004,6 +1004,21 @@ fn setup_tcp_listener() -> (TcpListener, SocketAddr) { (listener, addr) } +fn setup_duplex_test_server() -> (DuplexStream, DuplexStream, SocketAddr) { + use std::net::{IpAddr, Ipv6Addr}; + let _ = pretty_env_logger::try_init(); + + const BUF_SIZE: usize = 1024; + let (ioa, iob) = tokio::io::duplex(BUF_SIZE); + + /// A test address inside the 'documentation' address range. + /// See: + const TEST_ADDR: IpAddr = IpAddr::V6(Ipv6Addr::new(0x3fff, 0, 0, 0, 0, 0, 0, 1)); + const TEST_SOCKET: SocketAddr = SocketAddr::new(TEST_ADDR, 8080); + + (ioa, iob, TEST_SOCKET) +} + #[tokio::test] async fn expect_continue_waits_for_body_poll() { let (listener, addr) = setup_tcp_listener(); @@ -1189,19 +1204,26 @@ fn http_11_uri_too_long() { #[tokio::test] async fn disable_keep_alive_mid_request() { - let (listener, addr) = setup_tcp_listener(); + let (client_io, server_io, _) = setup_duplex_test_server(); let (tx1, rx1) = oneshot::channel(); - let (tx2, rx2) = mpsc::channel(); + let (tx2, rx2) = oneshot::channel(); - let child = thread::spawn(move || { - let mut req = connect(&addr); - req.write_all(b"GET / HTTP/1.1\r\n").unwrap(); - thread::sleep(Duration::from_millis(10)); + let client_task = tokio::spawn(async move { + let mut client_io = client_io; + // Send partial request + client_io.write_all(b"GET / HTTP/1.1\r\n").await.unwrap(); + // Signal server that partial request sent tx1.send(()).unwrap(); - rx2.recv().unwrap(); - req.write_all(b"Host: localhost\r\n\r\n").unwrap(); + // Wait for server to be ready for rest of request + rx2.await.unwrap(); + // Send rest of request + client_io + .write_all(b"Host: localhost\r\n\r\n") + .await + .unwrap(); + // Read response let mut buf = vec![]; - req.read_to_end(&mut buf).unwrap(); + client_io.read_to_end(&mut buf).await.unwrap(); assert!( buf.starts_with(b"HTTP/1.1 200 OK\r\n"), "should receive OK response, but buf: {:?}", @@ -1215,9 +1237,8 @@ async fn disable_keep_alive_mid_request() { ); }); - let (socket, _) = listener.accept().await.unwrap(); - let socket = TokioIo::new(socket); - let srv = http1::Builder::new().serve_connection(socket, HelloWorld); + let server_io = TokioIo::new(server_io); + let srv = http1::Builder::new().serve_connection(server_io, HelloWorld); future::try_select(srv, rx1) .then(|r| match r { Ok(Either::Left(_)) => panic!("expected rx first"), @@ -1232,7 +1253,7 @@ async fn disable_keep_alive_mid_request() { .await .unwrap(); - child.join().unwrap(); + client_task.await.unwrap(); } #[tokio::test]