From 660f76fdca1f66488cbc50c9daa2f5d72d9b2671 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 16 Apr 2025 15:24:28 +0200 Subject: [PATCH] channel: Add recv (#2947) --- futures-channel/src/mpsc/mod.rs | 61 ++++++++++++++++++++++++++++++++- futures-channel/tests/mpsc.rs | 45 +++++++++++++++++++++++- 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/futures-channel/src/mpsc/mod.rs b/futures-channel/src/mpsc/mod.rs index c5e6bada..69bb88ff 100644 --- a/futures-channel/src/mpsc/mod.rs +++ b/futures-channel/src/mpsc/mod.rs @@ -78,9 +78,11 @@ // happens-before semantics required for the acquire / release semantics used // by the queue structure. +use core::future::Future; use futures_core::stream::{FusedStream, Stream}; use futures_core::task::__internal::AtomicWaker; use futures_core::task::{Context, Poll, Waker}; +use futures_core::FusedFuture; use std::fmt; use std::pin::Pin; use std::sync::atomic::AtomicUsize; @@ -167,7 +169,7 @@ enum SendErrorKind { Disconnected, } -/// The error type returned from [`try_recv`](Receiver::try_recv). +/// Error returned by [`Receiver::try_recv`] or [`UnboundedReceiver::try_recv`]. #[derive(PartialEq, Eq, Clone, Copy, Debug)] pub enum TryRecvError { /// The channel is empty but not closed. @@ -177,6 +179,18 @@ pub enum TryRecvError { Closed, } +/// Error returned by the future returned by [`Receiver::recv()`] or [`UnboundedReceiver::recv()`]. +/// Received when the channel is empty and closed. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub struct RecvError; + +/// Future returned by [`Receiver::recv()`] or [`UnboundedReceiver::recv()`]. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Recv<'a, St: ?Sized> { + stream: &'a mut St, +} + impl fmt::Display for SendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.is_full() { @@ -189,6 +203,14 @@ impl fmt::Display for SendError { impl std::error::Error for SendError {} +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "receive failed because channel is empty and closed") + } +} + +impl std::error::Error for RecvError {} + impl SendError { /// Returns `true` if this error is a result of the channel being full. pub fn is_full(&self) -> bool { @@ -979,6 +1001,12 @@ impl fmt::Debug for UnboundedSender { */ impl Receiver { + /// Waits for a message from the channel. + /// If the channel is empty and closed, returns [`RecvError`]. + pub fn recv(&mut self) -> Recv<'_, Self> { + Recv::new(self) + } + /// Closes the receiving half of a channel, without dropping it. /// /// This prevents any further messages from being sent on the channel while @@ -1121,6 +1149,31 @@ impl Stream for Receiver { } } +impl Unpin for Recv<'_, St> {} +impl<'a, St: ?Sized + Stream + Unpin> Recv<'a, St> { + fn new(stream: &'a mut St) -> Self { + Self { stream } + } +} + +impl FusedFuture for Recv<'_, St> { + fn is_terminated(&self) -> bool { + self.stream.is_terminated() + } +} + +impl Future for Recv<'_, St> { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)), + Poll::Ready(None) => Poll::Ready(Err(RecvError)), + Poll::Pending => Poll::Pending, + } + } +} + impl Drop for Receiver { fn drop(&mut self) { // Drain the channel of all pending messages @@ -1164,6 +1217,12 @@ impl fmt::Debug for Receiver { } impl UnboundedReceiver { + /// Waits for a message from the channel. + /// If the channel is empty and closed, returns [`RecvError`]. + pub fn recv(&mut self) -> Recv<'_, Self> { + Recv::new(self) + } + /// Closes the receiving half of a channel, without dropping it. /// /// This prevents any further messages from being sent on the channel while diff --git a/futures-channel/tests/mpsc.rs b/futures-channel/tests/mpsc.rs index 1ef778b7..847f3a2b 100644 --- a/futures-channel/tests/mpsc.rs +++ b/futures-channel/tests/mpsc.rs @@ -4,7 +4,7 @@ use futures::future::{poll_fn, FutureExt}; use futures::sink::{Sink, SinkExt}; use futures::stream::{Stream, StreamExt}; use futures::task::{Context, Poll}; -use futures_channel::mpsc::TryRecvError; +use futures_channel::mpsc::{RecvError, TryRecvError}; use futures_test::task::{new_count_waker, noop_context}; use std::pin::pin; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -429,6 +429,49 @@ fn stress_poll_ready() { stress(16); } +#[test] +fn test_bounded_recv() { + let (dropped_tx, dropped_rx) = oneshot::channel(); + let (tx, mut rx) = mpsc::channel(1); + thread::spawn(move || { + block_on(async move { + send_one_two_three(tx).await; + dropped_tx.send(()).unwrap(); + }); + }); + + let res = block_on(async move { + let mut res = Vec::new(); + for _ in 0..3 { + res.push(rx.recv().await.unwrap()); + } + dropped_rx.await.unwrap(); + assert_eq!(rx.recv().await, Err(RecvError)); + res + }); + assert_eq!(res, [1, 2, 3]); +} + +#[test] +fn test_unbounded_recv() { + let (mut tx, mut rx) = mpsc::unbounded(); + + let res = block_on(async move { + let mut res = Vec::new(); + for i in 1..=3 { + tx.send(i).await.unwrap(); + } + drop(tx); + + for _ in 0..3 { + res.push(rx.recv().await.unwrap()); + } + assert_eq!(rx.recv().await, Err(RecvError)); + res + }); + assert_eq!(res, [1, 2, 3]); +} + #[test] fn try_send_1() { const N: usize = if cfg!(miri) { 100 } else { 3000 };