Browse Source

better queue handling

Josip Benko-Đaković 1 year ago
parent
commit
126b03b8cc
6 changed files with 90 additions and 34 deletions
  1. 1 0
      Cargo.toml
  2. 7 3
      src/lib.rs
  3. 12 10
      src/message.rs
  4. 11 3
      src/runtime.rs
  5. 53 14
      src/ws.rs
  6. 6 4
      tests/websocket.rs

+ 1 - 0
Cargo.toml

@@ -10,6 +10,7 @@ name = "test-ws"
 path = "tests/websocket.rs"
 
 [dependencies]
+async-trait = "0.1.68"
 flume = "0.10.14"
 futures = "0.3.28"
 pin-project = "1.1.0"

+ 7 - 3
src/lib.rs

@@ -1,4 +1,5 @@
 use crate::runtime::{ActorRuntime, Runtime};
+use async_trait::async_trait;
 use flume::{SendError, Sender};
 use message::{Envelope, Enveloper, Message, MessageRequest};
 use std::fmt::Debug;
@@ -21,11 +22,12 @@ pub trait Actor {
 }
 
 /// The main trait to implement on an [Actor] to enable it to handle messages.
+#[async_trait]
 pub trait Handler<M>: Actor
 where
     M: Message,
 {
-    fn handle(&mut self, message: M) -> Result<M::Response, Error>;
+    async fn handle(&mut self, message: M) -> Result<M::Response, Error>;
 }
 
 /// A handle to a spawned actor. Obtained when calling `start` on an [Actor] and is used to send messages
@@ -231,15 +233,17 @@ mod tests {
 
         impl Actor for Testor {}
 
+        #[async_trait]
         impl Handler<Foo> for Testor {
-            fn handle(&mut self, _: Foo) -> Result<usize, Error> {
+            async fn handle(&mut self, _: Foo) -> Result<usize, Error> {
                 println!("Handling Foo");
                 Ok(10)
             }
         }
 
+        #[async_trait]
         impl Handler<Bar> for Testor {
-            fn handle(&mut self, _: Bar) -> Result<isize, Error> {
+            async fn handle(&mut self, _: Bar) -> Result<isize, Error> {
                 println!("Handling Bar");
                 Ok(10)
             }

+ 12 - 10
src/message.rs

@@ -1,3 +1,5 @@
+use std::task::Poll;
+
 use crate::{Actor, Error, Handler};
 use tokio::sync::oneshot;
 
@@ -9,7 +11,7 @@ pub trait Message {
 /// Represents a type erased message that ultimately gets stored in an [Envelope]. We need this indirection so we can abstract away the concrete message
 /// type when creating an actor handle, otherwise we would only be able to send a single message type to the actor.
 pub trait ActorMessage<A: Actor> {
-    fn handle(&mut self, actor: &mut A);
+    fn handle(&mut self, actor: &mut A, cx: &mut std::task::Context<'_>) -> Poll<()>;
 }
 
 /// Used by [ActorHandle][super::ActorHandle]s to pack [Message]s into [Envelope]s so we have a type erased message to send to the actor.
@@ -54,8 +56,8 @@ impl<A> ActorMessage<A> for Envelope<A>
 where
     A: Actor,
 {
-    fn handle(&mut self, actor: &mut A) {
-        self.message.handle(actor)
+    fn handle(&mut self, actor: &mut A, cx: &mut std::task::Context<'_>) -> Poll<()> {
+        self.message.handle(actor, cx)
     }
 }
 
@@ -64,16 +66,16 @@ where
     M: Message,
     A: Actor + Handler<M>,
 {
-    fn handle(&mut self, actor: &mut A) {
+    fn handle(&mut self, actor: &mut A, cx: &mut std::task::Context<'_>) -> Poll<()> {
         let Some(message) = self.message.take() else { panic!("Message already processed") };
-        match actor.handle(message) {
-            Ok(result) => {
+        match actor.handle(message).as_mut().poll(cx) {
+            Poll::Ready(result) => {
                 let Some(res_tx) = self.tx.take() else { panic!("Message already processed") };
-                // TODO
-                let _ = res_tx.send(result);
+                let _ = res_tx.send(result.unwrap());
+                Poll::Ready(())
             }
-            Err(_) => todo!(),
-        };
+            Poll::Pending => Poll::Pending,
+        }
     }
 }
 

+ 11 - 3
src/runtime.rs

@@ -60,9 +60,17 @@ where
                 Poll::Pending => {}
             };
 
-            // Process all messages
-            while let Some(mut message) = this.message_queue.pop_front() {
-                message.handle(this.actor)
+            // Process all pending messages
+            let mut idx = 0;
+            while idx < this.message_queue.len() {
+                let pending = &mut this.message_queue[idx];
+                match pending.handle(this.actor, cx) {
+                    Poll::Ready(_) => {
+                        this.message_queue.swap_remove_front(idx);
+                        continue;
+                    }
+                    Poll::Pending => idx += 1,
+                }
             }
 
             // Poll message receiver and continue to process if anything comes up

+ 53 - 14
src/ws.rs

@@ -3,6 +3,7 @@ use crate::{
     runtime::Runtime,
     Actor, ActorCommand, ActorHandle, Error, Handler,
 };
+use async_trait::async_trait;
 use flume::Receiver;
 use futures::{
     stream::{SplitSink, SplitStream},
@@ -12,6 +13,7 @@ use pin_project::pin_project;
 use std::{
     collections::VecDeque,
     pin::Pin,
+    sync::atomic::AtomicUsize,
     task::{Context, Poll},
 };
 use warp::ws::WebSocket;
@@ -38,6 +40,8 @@ impl WebsocketActor {
     }
 }
 
+static PROCESSED: AtomicUsize = AtomicUsize::new(0);
+
 #[pin_project]
 pub struct WebsocketRuntime {
     actor: WebsocketActor,
@@ -61,7 +65,10 @@ pub struct WebsocketRuntime {
     message_queue: VecDeque<Envelope<WebsocketActor>>,
 
     /// Received, but not yet processed websocket messages
-    ws_queue: VecDeque<warp::ws::Message>,
+    request_queue: VecDeque<warp::ws::Message>,
+
+    /// Processed websocket messages ready to be flushed in the sink
+    response_queue: VecDeque<warp::ws::Message>,
 }
 
 impl WebsocketRuntime {
@@ -83,7 +90,8 @@ impl WebsocketRuntime {
             message_rx,
             command_rx,
             message_queue: VecDeque::new(),
-            ws_queue: VecDeque::new(),
+            request_queue: VecDeque::new(),
+            response_queue: VecDeque::new(),
         }
     }
 }
@@ -114,26 +122,56 @@ impl Future for WebsocketRuntime {
             while let Poll::Ready(Some(ws_message)) = Pin::new(&mut this.ws_stream.next()).poll(cx)
             {
                 match ws_message {
-                    Ok(message) => this.ws_queue.push_back(message),
+                    Ok(message) => this.request_queue.push_back(message),
                     Err(e) => {
                         eprintln!("WS error occurred {e}")
                     }
                 }
             }
 
-            // Respond to any queued websocket messages
-            while let Some(ws_message) = this.ws_queue.pop_front() {
-                if let Some(res) = this.actor.handle(ws_message)? {
-                    match Pin::new(&mut this.ws_sink.send(res)).poll(cx) {
-                        Poll::Ready(result) => result?,
-                        Poll::Pending => todo!(),
-                    }
+            // Respond to any queued and processed websocket messages
+            let mut idx = 0;
+            while idx < this.request_queue.len() {
+                let ws_message = &this.request_queue[idx];
+                match this.actor.handle(ws_message.to_owned()).as_mut().poll(cx) {
+                    Poll::Ready(result) => match result {
+                        Ok(response) => {
+                            if let Some(response) = response {
+                                match Pin::new(&mut this.ws_sink.feed(response)).poll(cx) {
+                                    Poll::Ready(result) => {
+                                        result?;
+                                        this.request_queue.swap_remove_front(idx);
+                                        PROCESSED
+                                            .fetch_add(1, std::sync::atomic::Ordering::Acquire);
+                                    }
+                                    Poll::Pending => idx += 1,
+                                }
+                            }
+                        }
+                        Err(e) => return Poll::Ready(Err(e)),
+                    },
+                    Poll::Pending => idx += 1,
                 }
             }
 
+            println!(
+                "PROCESSED {}",
+                PROCESSED.load(std::sync::atomic::Ordering::Acquire)
+            );
+
+            let _ = Pin::new(&mut this.ws_sink.flush()).poll(cx);
+
             // Process all messages
-            while let Some(mut message) = this.message_queue.pop_front() {
-                message.handle(this.actor)
+            let mut idx = 0;
+            while idx < this.message_queue.len() {
+                let pending = &mut this.message_queue[idx];
+                match pending.handle(this.actor, cx) {
+                    Poll::Ready(_) => {
+                        this.message_queue.swap_remove_front(idx);
+                        continue;
+                    }
+                    Poll::Pending => idx += 1,
+                }
             }
 
             // Poll message receiver and continue to process if anything comes up
@@ -179,12 +217,13 @@ impl crate::Message for warp::ws::Message {
     type Response = Option<warp::ws::Message>;
 }
 
+#[async_trait]
 impl Handler<warp::ws::Message> for WebsocketActor {
-    fn handle(
+    async fn handle(
         &mut self,
         message: warp::ws::Message,
     ) -> Result<<warp::ws::Message as crate::message::Message>::Response, crate::Error> {
-        println!("Actor received message {message:?}");
+        // println!("Actor received message {message:?}");
         if message.is_text() {
             Ok(Some(message))
         } else {

+ 6 - 4
tests/websocket.rs

@@ -54,9 +54,7 @@ static INDEX_HTML: &str = r#"<!DOCTYPE html>
         const ws = new WebSocket(uri);
 
         function message(data) {
-            const line = document.createElement('p');
-            line.innerText = data;
-            chat.appendChild(line);
+
         }
 
         ws.onopen = function() {
@@ -73,7 +71,11 @@ static INDEX_HTML: &str = r#"<!DOCTYPE html>
 
         send.onclick = function() {
             const msg = text.value;
-            ws.send(msg);
+            let i = 0;
+            while (i < 10000) {
+                ws.send(msg);
+                i += 1;
+            }
             text.value = '';
 
             message('<You>: ' + msg);