Browse Source

tidy up runtime

biblius 1 year ago
parent
commit
9a4cfb137c
3 changed files with 216 additions and 178 deletions
  1. 3 4
      src/message.rs
  2. 140 86
      src/runtime.rs
  3. 73 88
      src/ws.rs

+ 3 - 4
src/message.rs

@@ -41,7 +41,7 @@ where
     {
         Self {
             message: Box::new(EnvelopeInner {
-                message: Some(Box::new(message)),
+                message: Box::new(message),
                 tx,
             }),
         }
@@ -51,7 +51,7 @@ where
 /// The inner parts of the [Envelope] containing the actual message as well as an optional
 /// response channel.
 struct EnvelopeInner<M: Message> {
-    message: Option<Box<M>>,
+    message: Box<M>,
     tx: Option<oneshot::Sender<M::Response>>,
 }
 
@@ -73,8 +73,7 @@ where
     A: Actor + Handler<M> + Send + 'static,
 {
     async fn handle(self: Box<Self>, actor: Arc<Mutex<A>>) {
-        let Some(message) = self.message else { panic!("Handle already called") };
-        let result = A::handle(actor, message).await;
+        let result = A::handle(actor, self.message).await;
         if let Some(res_tx) = self.tx {
             let _ = res_tx.send(result.unwrap());
         }

+ 140 - 86
src/runtime.rs

@@ -3,9 +3,8 @@ use crate::{
     DEFAULT_CHANNEL_CAPACITY,
 };
 use async_trait::async_trait;
-use flume::Receiver;
-use futures::Future;
-use pin_project::pin_project;
+use flume::{r#async::RecvStream, Receiver};
+use futures::{Future, StreamExt};
 use std::{
     collections::VecDeque,
     pin::Pin,
@@ -14,37 +13,156 @@ use std::{
 };
 use tokio::sync::Mutex;
 
-const QUEUE_CAPACITY: usize = 128;
+pub const QUEUE_CAPACITY: usize = 128;
 
 #[async_trait]
-pub trait Runtime<A> {
-    async fn run(actor: Arc<Mutex<A>>) -> ActorHandle<A>
-    where
-        A: Actor + Send + 'static,
-    {
+pub trait Runtime<A>
+where
+    A: Actor + Send + 'static,
+{
+    fn command_stream(&mut self) -> &mut RecvStream<'static, ActorCommand>;
+
+    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<A>>;
+
+    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<A>>;
+
+    fn actor(&self) -> Arc<Mutex<A>>;
+
+    fn at_capacity(&self) -> bool;
+
+    async fn run(actor: Arc<Mutex<A>>) -> ActorHandle<A> {
         let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
         let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
         tokio::spawn(ActorRuntime::new(actor, command_rx, message_rx));
         ActorHandle::new(message_tx, command_tx)
     }
+
+    fn process_commands(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
+        match self.command_stream().poll_next_unpin(cx) {
+            Poll::Ready(Some(command)) => match command {
+                ActorCommand::Stop => {
+                    println!("Actor stopping");
+                    Ok(()) // TODO drain the queue and all that graceful stuff
+                }
+            },
+            Poll::Ready(None) => {
+                println!("Command channel closed, ungracefully stopping actor");
+                Err(Error::ActorChannelClosed)
+            }
+            Poll::Pending => Ok(()),
+        }
+    }
+
+    fn process_messages(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
+        let actor = self.actor();
+
+        self.processing_queue()
+            .retain_mut(|job| job.poll(actor.clone(), cx).is_pending());
+
+        // Poll message receiver
+        if !self.at_capacity() {
+            while let Poll::Ready(message) = self.message_stream().poll_next_unpin(cx) {
+                let Some(message) = message else { return Err(Error::ActorChannelClosed) };
+                self.processing_queue().push_back(ActorJob::new(message));
+                if self.at_capacity() {
+                    break;
+                }
+            }
+        }
+
+        // Process pending futures again after we've potentially received some
+        self.processing_queue()
+            .retain_mut(|job| job.poll(actor.clone(), cx).is_pending());
+
+        if self.at_capacity() {
+            return Ok(());
+        }
+
+        match self.message_stream().poll_next_unpin(cx) {
+            Poll::Ready(Some(message)) => {
+                self.processing_queue().push_back(ActorJob::new(message));
+                Ok(())
+            }
+            Poll::Ready(None) => {
+                println!("Message channel closed, ungracefully stopping actor");
+                Err(Error::ActorChannelClosed)
+            }
+            Poll::Pending => Ok(()),
+        }
+    }
+}
+
+#[async_trait]
+impl<A> Runtime<A> for ActorRuntime<A>
+where
+    A: Actor + Send + 'static,
+{
+    async fn run(actor: Arc<Mutex<A>>) -> ActorHandle<A> {
+        let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
+        let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
+        tokio::spawn(ActorRuntime::new(actor, command_rx, message_rx));
+        ActorHandle::new(message_tx, command_tx)
+    }
+
+    #[inline]
+    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<A>> {
+        &mut self.process_queue
+    }
+
+    #[inline]
+    fn command_stream(&mut self) -> &mut RecvStream<'static, ActorCommand> {
+        &mut self.command_stream
+    }
+
+    #[inline]
+    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<A>> {
+        &mut self.message_stream
+    }
+
+    #[inline]
+    fn actor(&self) -> Arc<Mutex<A>> {
+        self.actor.clone()
+    }
+
+    #[inline]
+    fn at_capacity(&self) -> bool {
+        self.process_queue.len() >= QUEUE_CAPACITY
+    }
 }
 
-#[pin_project]
+/// A future representing a message currently being handled. Created when polling an [ActorJob].
+pub type ActorFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
+
 pub struct ActorRuntime<A>
 where
     A: Actor + Send + 'static,
 {
     actor: Arc<Mutex<A>>,
-    command_rx: Receiver<ActorCommand>,
-    message_rx: Receiver<Envelope<A>>,
+    command_stream: RecvStream<'static, ActorCommand>,
+    message_stream: RecvStream<'static, Envelope<A>>,
     process_queue: VecDeque<ActorJob<A>>,
 }
 
-impl<A> Runtime<A> for ActorRuntime<A> where A: Actor + Send + 'static {}
-
-type ActorFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
+impl<A> ActorRuntime<A>
+where
+    A: Actor + 'static + Send,
+{
+    pub fn new(
+        actor: Arc<Mutex<A>>,
+        command_rx: Receiver<ActorCommand>,
+        message_rx: Receiver<Envelope<A>>,
+    ) -> Self {
+        println!("Building default runtime");
+        Self {
+            actor,
+            command_stream: command_rx.into_stream(),
+            message_stream: message_rx.into_stream(),
+            process_queue: VecDeque::with_capacity(QUEUE_CAPACITY),
+        }
+    }
+}
 
-struct ActorJob<A>
+pub struct ActorJob<A>
 where
     A: Actor,
 {
@@ -63,12 +181,8 @@ where
         }
     }
 
-    fn poll(
-        mut self: Pin<&mut Self>,
-        actor: Arc<Mutex<A>>,
-        cx: &mut std::task::Context<'_>,
-    ) -> Poll<()> {
-        match self.as_mut().message.take() {
+    fn poll(&mut self, actor: Arc<Mutex<A>>, cx: &mut std::task::Context<'_>) -> Poll<()> {
+        match self.message.take() {
             Some(message) => {
                 let fut = Box::new(message).handle(actor);
                 self.future = Some(fut);
@@ -89,72 +203,12 @@ where
 {
     type Output = Result<(), Error>;
 
-    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        let this = self.project();
-
-        // Poll command receiver
-        match Pin::new(&mut this.command_rx.recv_async()).poll(cx) {
-            Poll::Ready(Ok(message)) => match message {
-                ActorCommand::Stop => {
-                    println!("Actor stopping");
-                    return Poll::Ready(Ok(())); // TODO drain the queue and all that graceful stuff
-                }
-            },
-            Poll::Ready(Err(_)) => {
-                println!("Command channel closed, ungracefully stopping actor");
-                return Poll::Ready(Err(Error::ActorChannelClosed));
-            }
-            Poll::Pending => {}
-        };
-
-        // Process the pending futures
-        this.process_queue
-            .retain_mut(|job| Pin::new(job).poll(this.actor.clone(), cx).is_pending());
-
-        // Poll message receiver
-        while let Poll::Ready(Ok(message)) = Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
-            this.process_queue.push_back(ActorJob::new(message));
-            if this.process_queue.len() >= QUEUE_CAPACITY {
-                break;
-            }
-        }
-
-        // Process pending futures again after we've potentially received some
-        this.process_queue
-            .retain_mut(|job| Pin::new(job).poll(this.actor.clone(), cx).is_pending());
-
-        // Poll again and process new messages if any
-        match Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
-            Poll::Ready(Ok(message)) => {
-                this.process_queue.push_back(ActorJob::new(message));
-            }
-            Poll::Ready(Err(_)) => {
-                println!("Message channel closed, ungracefully stopping actor");
-                return Poll::Ready(Err(Error::ActorChannelClosed));
-            }
-            Poll::Pending => {}
-        };
+    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        let mut this = self.as_mut();
 
+        this.process_commands(cx)?;
+        this.process_messages(cx)?;
         cx.waker().wake_by_ref();
         Poll::Pending
     }
 }
-
-impl<A> ActorRuntime<A>
-where
-    A: Actor + 'static + Send,
-{
-    pub fn new(
-        actor: Arc<Mutex<A>>,
-        command_rx: Receiver<ActorCommand>,
-        message_rx: Receiver<Envelope<A>>,
-    ) -> Self {
-        println!("Building default runtime");
-        Self {
-            actor,
-            command_rx,
-            message_rx,
-            process_queue: VecDeque::with_capacity(QUEUE_CAPACITY),
-        }
-    }
-}

+ 73 - 88
src/ws.rs

@@ -1,14 +1,14 @@
 use crate::{
-    message::Envelope, runtime::Runtime, Actor, ActorCommand, ActorHandle, ActorStatus, Error,
-    Handler, Hello,
+    message::Envelope,
+    runtime::{ActorJob, Runtime, QUEUE_CAPACITY},
+    Actor, ActorCommand, ActorHandle, ActorStatus, Error, Handler, Hello,
 };
 use async_trait::async_trait;
-use flume::Receiver;
+use flume::{r#async::RecvStream, Receiver};
 use futures::{
     stream::{SplitSink, SplitStream},
-    Future, SinkExt, Stream, StreamExt,
+    Future, FutureExt, SinkExt, StreamExt,
 };
-use pin_project::pin_project;
 use std::{
     collections::VecDeque,
     pin::Pin,
@@ -42,14 +42,15 @@ impl Actor for WebsocketActor {
         WebsocketRuntime::run(Arc::new(Mutex::new(self))).await
     }
 }
+
 type WsFuture = Pin<Box<dyn Future<Output = Result<Option<Message>, Error>> + Send>>;
 
-struct ActorItem {
+struct WebsocketJob {
     message: Option<Box<Message>>,
     future: Option<WsFuture>,
 }
 
-impl ActorItem {
+impl WebsocketJob {
     pub fn new(message: Message) -> Self {
         Self {
             message: Some(Box::new(message)),
@@ -58,11 +59,11 @@ impl ActorItem {
     }
 
     fn poll(
-        mut self: Pin<&mut Self>,
+        &mut self,
         actor: Arc<Mutex<WebsocketActor>>,
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Result<Option<Message>, warp::Error>> {
-        let message = self.as_mut().message.take();
+        let message = self.message.take();
 
         match message {
             Some(message) => {
@@ -101,34 +102,28 @@ impl ActorItem {
 static PROCESSED: AtomicUsize = AtomicUsize::new(0);
 static PENDING: AtomicUsize = AtomicUsize::new(0);
 
-#[pin_project]
 pub struct WebsocketRuntime {
     actor: Arc<Mutex<WebsocketActor>>,
 
     status: ActorStatus,
 
     /// The receiving end of the websocket
-    #[pin]
     ws_stream: SplitStream<WebSocket>,
 
     /// The sending end of the websocket
-    #[pin]
     ws_sink: SplitSink<WebSocket, Message>,
 
     /// Actor message receiver
-    message_rx: Receiver<Envelope<WebsocketActor>>,
+    message_stream: RecvStream<'static, Envelope<WebsocketActor>>,
 
     /// Actor command receiver
-    command_rx: Receiver<ActorCommand>,
+    command_stream: RecvStream<'static, ActorCommand>,
 
-    /// Received, but not yet processed messages
-    message_queue: VecDeque<Envelope<WebsocketActor>>,
+    /// Actor messages currently being processed
+    process_queue: VecDeque<ActorJob<WebsocketActor>>,
 
     /// Received, but not yet processed websocket messages
-    processing_queue: VecDeque<ActorItem>,
-
-    /// Processed websocket messages being flushed to the sink
-    response_queue: VecDeque<Message>,
+    response_queue: VecDeque<WebsocketJob>,
 }
 
 impl WebsocketRuntime {
@@ -149,11 +144,10 @@ impl WebsocketRuntime {
             actor,
             ws_sink,
             ws_stream,
-            message_rx,
-            command_rx,
-            message_queue: VecDeque::new(),
-            processing_queue: VecDeque::new(),
+            message_stream: message_rx.into_stream(),
+            command_stream: command_rx.into_stream(),
             response_queue: VecDeque::new(),
+            process_queue: VecDeque::new(),
             status: ActorStatus::Starting,
         }
     }
@@ -162,34 +156,19 @@ impl WebsocketRuntime {
 impl Future for WebsocketRuntime {
     type Output = Result<(), Error>;
 
-    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        let mut this = self.project();
-
-        // Poll command receiver and immediatelly process it
-        if let Poll::Ready(result) = Pin::new(&mut this.command_rx.recv_async()).poll(cx) {
-            match result {
-                Ok(command) => {
-                    match command {
-                        ActorCommand::Stop => {
-                            println!("Actor stopping");
-                            return Poll::Ready(Ok(())); // TODO drain the queue and all that graceful stuff
-                        }
-                    }
-                }
-                Err(e) => {
-                    println!("Actor stopping - {e}"); // TODO drain the queue and all that graceful stuff
-                    return Poll::Ready(Err(Error::ActorChannelClosed));
-                }
-            }
-        };
+    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        let actor = self.actor();
+        let mut this = self.as_mut();
+
+        this.process_commands(cx)?;
 
         // Poll the websocket stream for any messages and store them to the queue
-        if this.processing_queue.is_empty() {
-            while let Poll::Ready(Some(ws_message)) = this.ws_stream.as_mut().poll_next(cx) {
+        if this.response_queue.len() < WS_QUEUE_SIZE {
+            while let Poll::Ready(Some(ws_message)) = this.ws_stream.poll_next_unpin(cx) {
                 match ws_message {
                     Ok(message) => {
-                        this.processing_queue.push_back(ActorItem::new(message));
-                        if this.processing_queue.len() >= WS_QUEUE_SIZE {
+                        this.response_queue.push_back(WebsocketJob::new(message));
+                        if this.response_queue.len() >= WS_QUEUE_SIZE {
                             break;
                         }
                     }
@@ -201,9 +180,9 @@ impl Future for WebsocketRuntime {
         }
 
         let mut idx = 0;
-        while idx < this.processing_queue.len() {
-            let job = Pin::new(&mut this.processing_queue[idx]);
-            match ActorItem::poll(job, this.actor.clone(), cx) {
+        while idx < this.response_queue.len() {
+            let job = &mut this.response_queue[idx];
+            match job.poll(actor.clone(), cx) {
                 Poll::Ready(result) => match result {
                     Ok(response) => {
                         if let Some(response) = response {
@@ -212,7 +191,7 @@ impl Future for WebsocketRuntime {
                             let _ = feed.as_mut().poll(cx);
                         }
                         PROCESSED.fetch_add(1, std::sync::atomic::Ordering::Acquire);
-                        this.processing_queue.swap_remove_front(idx);
+                        this.response_queue.swap_remove_front(idx);
                     }
                     Err(e) => {
                         println!("Shit's fucked my dude {e}")
@@ -222,34 +201,15 @@ impl Future for WebsocketRuntime {
             }
         }
 
-        // println!(
-        //     "PROCESSED {} CURRENT IN QUEUE {}",
-        //     PROCESSED.load(std::sync::atomic::Ordering::Acquire),
-        //     this.processing_queue.len(),
-        // );
-
-        let _ = Pin::new(&mut this.ws_sink.flush()).poll(cx);
+        this.process_messages(cx)?;
 
-        // Process all messages
-        /*             this.message_queue
-        .retain_mut(|message| message.handle(actor).as_mut().poll(cx).is_pending()); */
+        println!(
+            "PROCESSED {} CURRENT IN QUEUE {}",
+            PROCESSED.load(std::sync::atomic::Ordering::Acquire),
+            this.response_queue.len(),
+        );
 
-        // Poll message receiver and continue to process if anything comes up
-        while let Poll::Ready(Ok(message)) = Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
-            this.message_queue.push_back(message);
-        }
-
-        // Poll again and process new messages if any
-        match Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
-            Poll::Ready(Ok(message)) => {
-                this.message_queue.push_back(message);
-            }
-            Poll::Ready(Err(_)) => {
-                println!("Message channel closed, ungracefully stopping actor");
-                return Poll::Ready(Err(Error::ActorChannelClosed));
-            }
-            Poll::Pending => {}
-        };
+        let _ = this.ws_sink.flush().poll_unpin(cx)?;
 
         cx.waker().wake_by_ref();
         Poll::Pending
@@ -264,6 +224,31 @@ impl Runtime<WebsocketActor> for WebsocketRuntime {
         tokio::spawn(WebsocketRuntime::new(actor, command_rx, message_rx).await);
         ActorHandle::new(message_tx, command_tx)
     }
+
+    #[inline]
+    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<WebsocketActor>> {
+        &mut self.process_queue
+    }
+
+    #[inline]
+    fn command_stream(&mut self) -> &mut RecvStream<'static, ActorCommand> {
+        &mut self.command_stream
+    }
+
+    #[inline]
+    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<WebsocketActor>> {
+        &mut self.message_stream
+    }
+
+    #[inline]
+    fn actor(&self) -> Arc<Mutex<WebsocketActor>> {
+        self.actor.clone()
+    }
+
+    #[inline]
+    fn at_capacity(&self) -> bool {
+        self.process_queue.len() >= QUEUE_CAPACITY
+    }
 }
 
 impl crate::Message for Message {
@@ -277,17 +262,17 @@ impl Handler<Message> for WebsocketActor {
         message: Box<Message>,
     ) -> Result<<Message as crate::message::Message>::Response, crate::Error> {
         //let mut act = this.lock().await;
-        this.lock()
-            .await
-            .hello
-            .send(crate::Msg {
-                content: message.to_str().unwrap().to_owned(),
-            })
-            .unwrap_or_else(|e| println!("{e}"));
-        //        println!("Actor retreived lock and sent message got response {res}");
-        tokio::time::sleep(Duration::from_micros(1)).await;
-        //act.wait().await;
         if message.is_text() {
+            this.lock()
+                .await
+                .hello
+                .send(crate::Msg {
+                    content: message.to_str().unwrap().to_owned(),
+                })
+                .unwrap_or_else(|e| println!("{e}"));
+            //        println!("Actor retreived lock and sent message got response {res}");
+            tokio::time::sleep(Duration::from_micros(1)).await;
+            //act.wait().await;
             Ok(Some(*message.clone()))
         } else {
             Ok(None)