Browse Source

implement runtime through future

biblius 1 year ago
parent
commit
721f645740
6 changed files with 195 additions and 167 deletions
  1. 1 0
      Cargo.toml
  2. 42 13
      src/lib.rs
  3. 2 2
      src/message.rs
  4. 35 32
      src/runtime.rs
  5. 102 47
      src/ws.rs
  6. 13 73
      tests/websocket.rs

+ 1 - 0
Cargo.toml

@@ -10,6 +10,7 @@ name = "test-ws"
 path = "tests/websocket.rs"
 
 [dependencies]
+flume = "0.10.14"
 futures = "0.3.28"
 pin-project = "1.1.0"
 thiserror = "1.0.40"

+ 42 - 13
src/lib.rs

@@ -1,8 +1,8 @@
-use crate::runtime::{DefaultActorRuntime, Runtime};
+use crate::runtime::{ActorRuntime, Runtime};
+use flume::{SendError, Sender};
 use message::{Envelope, Message, MessagePacker, MessageRequest};
 use std::fmt::Debug;
-use tokio::sync::{mpsc::Sender, oneshot};
-
+use tokio::sync::oneshot;
 pub mod debug;
 pub mod message;
 pub mod runtime;
@@ -14,7 +14,7 @@ pub trait Actor {
         Self: Sized + Send + 'static,
     {
         println!("Starting actor");
-        DefaultActorRuntime::start(self)
+        ActorRuntime::run(self)
     }
 }
 
@@ -26,9 +26,21 @@ where
     command_tx: Sender<ActorCommand>,
 }
 
-impl<A> ActorHandle<A>
+impl<A> Clone for ActorHandle<A>
 where
     A: Actor,
+{
+    fn clone(&self) -> Self {
+        Self {
+            message_tx: self.message_tx.clone(),
+            command_tx: self.command_tx.clone(),
+        }
+    }
+}
+
+impl<A> ActorHandle<A>
+where
+    A: Actor + 'static,
 {
     pub async fn send_sync<M>(&self, message: M) -> Result<M::Response, Error>
     where
@@ -37,7 +49,9 @@ where
     {
         let (tx, rx) = oneshot::channel();
         let packed = A::pack(message, Some(tx));
-        self.message_tx.send(packed).await.unwrap(); // TODO
+        self.message_tx
+            .send(packed)
+            .map_err(Error::send_err_boxed)?;
         MessageRequest { response_rx: rx }.await
     }
 
@@ -47,12 +61,11 @@ where
         A: Handler<M> + MessagePacker<A, M> + 'static,
     {
         let packed = A::pack(message, None);
-        self.message_tx.send(packed).await.unwrap(); // TODO
-        Ok(())
+        self.message_tx.send(packed).map_err(Error::send_err_boxed)
     }
 
     pub async fn send_cmd(&self, cmd: ActorCommand) -> Result<(), Error> {
-        self.command_tx.send(cmd).await.unwrap(); // TODO
+        self.command_tx.send(cmd).unwrap();
         Ok(())
     }
 }
@@ -70,6 +83,16 @@ pub enum Error {
     ActorChannelClosed,
     #[error("Channel closed: {0}")]
     ChannelClosed(#[from] oneshot::error::TryRecvError),
+    #[error("Send error: {0}")]
+    Send(Box<dyn std::error::Error + Send + 'static>),
+    #[error("Warp error: {0}")]
+    Warp(#[from] warp::Error),
+}
+
+impl Error {
+    fn send_err_boxed<T: Send + 'static>(error: SendError<T>) -> Self {
+        Self::Send(Box::new(error))
+    }
 }
 
 #[derive(Debug)]
@@ -105,24 +128,30 @@ mod tests {
 
         impl Handler<Foo> for Testor {
             fn handle(&mut self, _: Foo) -> Result<usize, Error> {
+                println!("Handling Foo");
                 Ok(10)
             }
         }
 
         impl Handler<Bar> for Testor {
             fn handle(&mut self, _: Bar) -> Result<isize, Error> {
+                println!("Handling Bar");
                 Ok(10)
             }
         }
 
         let handle = Testor {}.start();
 
-        let res = handle.send_sync(Foo {}).await.unwrap();
-        let res2 = handle.send_sync(Bar {}).await.unwrap();
+        let mut res = 0;
+        let mut res2 = 0;
+        for _ in 0..100 {
+            res += handle.send_sync(Foo {}).await.unwrap();
+            res2 += handle.send_sync(Bar {}).await.unwrap();
+        }
 
         handle.send_cmd(ActorCommand::Stop).await.unwrap();
 
-        assert_eq!(res, 10);
-        assert_eq!(res2, 10);
+        assert_eq!(res, 1000);
+        assert_eq!(res2, 1000);
     }
 }

+ 2 - 2
src/message.rs

@@ -1,4 +1,4 @@
-use crate::{runtime::DefaultActorRuntime, Actor, Error, Handler};
+use crate::{runtime::ActorRuntime, Actor, Error, Handler};
 use tokio::sync::oneshot;
 
 /// Represents a message that can be sent to an actor. The response type is what the actor must return in its handler implementation.
@@ -79,7 +79,7 @@ where
     }
 }
 
-impl<A, M> MessagePacker<A, M> for DefaultActorRuntime<A>
+impl<A, M> MessagePacker<A, M> for ActorRuntime<A>
 where
     A: Actor + Handler<M>,
     M: Message + Send + 'static,

+ 35 - 32
src/runtime.rs

@@ -1,17 +1,21 @@
 use crate::{message::PackedMessage, Actor, ActorCommand, ActorHandle, Envelope, Error};
+use flume::Receiver;
 use futures::Future;
 use pin_project::pin_project;
-use std::{collections::VecDeque, task::Poll};
-use tokio::sync::mpsc::Receiver;
+use std::{
+    collections::VecDeque,
+    pin::Pin,
+    task::{Context, Poll},
+};
 
 pub trait Runtime<A> {
-    fn start(actor: A) -> ActorHandle<A>
+    fn run(actor: A) -> ActorHandle<A>
     where
         A: Actor + Send + 'static,
     {
-        let (tx, rx) = tokio::sync::mpsc::channel(100);
-        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::channel(100);
-        let rt = DefaultActorRuntime::new(actor, cmd_rx, rx);
+        let (tx, rx) = flume::unbounded();
+        let (cmd_tx, cmd_rx) = flume::unbounded();
+        let rt = ActorRuntime::new(actor, cmd_rx, rx);
         tokio::spawn(rt);
         ActorHandle {
             message_tx: tx,
@@ -21,7 +25,7 @@ pub trait Runtime<A> {
 }
 
 #[pin_project]
-pub struct DefaultActorRuntime<A>
+pub struct ActorRuntime<A>
 where
     A: Actor,
 {
@@ -31,68 +35,67 @@ where
     message_queue: VecDeque<Envelope<A>>,
 }
 
-impl<A> Runtime<A> for DefaultActorRuntime<A> where A: Actor {}
+impl<A> Runtime<A> for ActorRuntime<A> where A: Actor {}
 
-impl<A> Future for DefaultActorRuntime<A>
+impl<A> Future for ActorRuntime<A>
 where
     A: Actor,
 {
     type Output = Result<(), Error>;
 
-    fn poll(
-        mut self: std::pin::Pin<&mut Self>,
-        cx: &mut std::task::Context<'_>,
-    ) -> std::task::Poll<Self::Output> {
+    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        let this = self.project();
         loop {
             // Poll command receiver
-            match self.command_rx.poll_recv(cx) {
-                std::task::Poll::Ready(Some(message)) => match message {
+            match Pin::new(&mut this.command_rx.recv_async()).poll(cx) {
+                Poll::Ready(Ok(message)) => match message {
                     ActorCommand::Stop => {
                         println!("Actor stopping");
                         break Poll::Ready(Ok(())); // TODO drain the queue and all that graceful stuff
                     }
                 },
-                std::task::Poll::Ready(None) => {
+                Poll::Ready(Err(_)) => {
                     println!("Command channel closed, ungracefully stopping actor");
                     break Poll::Ready(Err(Error::ActorChannelClosed));
                 }
-                std::task::Poll::Pending => {}
+                Poll::Pending => {}
             };
 
             // Process all messages
-            while let Some(mut message) = self.message_queue.pop_front() {
-                message.handle(&mut self.actor)
+            while let Some(mut message) = this.message_queue.pop_front() {
+                message.handle(this.actor)
             }
 
             // Poll message receiver and continue to process if anything comes up
-            let mut new_messages = false;
-            while let Poll::Ready(Some(message)) = self.message_rx.poll_recv(cx) {
-                self.message_queue.push_back(message);
-                new_messages = true;
+            while let Poll::Ready(Ok(message)) =
+                Pin::new(&mut this.message_rx.recv_async()).poll(cx)
+            {
+                this.message_queue.push_back(message);
             }
 
-            match self.message_rx.poll_recv(cx) {
-                std::task::Poll::Ready(Some(message)) => {
-                    self.message_queue.push_back(message);
+            // Poll again and process new messages if any
+            match Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
+                Poll::Ready(Ok(message)) => {
+                    this.message_queue.push_back(message);
                     continue;
                 }
-                std::task::Poll::Ready(None) => {
+                Poll::Ready(Err(_)) => {
                     println!("Message channel closed, ungracefully stopping actor");
                     break Poll::Ready(Err(Error::ActorChannelClosed));
                 }
-                std::task::Poll::Pending => {
-                    if new_messages {
+                Poll::Pending => {
+                    if !this.message_queue.is_empty() {
                         continue;
                     }
                 }
             };
-
-            return std::task::Poll::Pending;
+            cx.waker().wake_by_ref();
+            return Poll::Pending;
         }
     }
 }
 
-impl<A> DefaultActorRuntime<A>
+impl<A> ActorRuntime<A>
 where
     A: Actor,
 {

+ 102 - 47
src/ws.rs

@@ -1,10 +1,20 @@
+use std::{
+    collections::VecDeque,
+    pin::Pin,
+    task::{Context, Poll},
+};
+
 use crate::{
     message::{Envelope, PackedMessage},
     runtime::Runtime,
-    Actor, ActorCommand, ActorHandle, Handler,
+    Actor, ActorCommand, ActorHandle, Error, Handler,
 };
-use futures::{SinkExt, StreamExt, TryFutureExt};
-use tokio::{select, sync::mpsc::Receiver};
+use flume::Receiver;
+use futures::{
+    stream::{SplitSink, SplitStream},
+    Future, SinkExt, StreamExt,
+};
+use pin_project::pin_project;
 use warp::ws::WebSocket;
 
 pub struct WebsocketActor {
@@ -17,7 +27,7 @@ impl Actor for WebsocketActor {
         Self: Sized + Send + 'static,
     {
         println!("Starting websocket actor");
-        WebsocketRuntime::start(self)
+        WebsocketRuntime::run(self)
     }
 }
 
@@ -29,82 +39,127 @@ impl WebsocketActor {
     }
 }
 
+#[pin_project]
 pub struct WebsocketRuntime {
     actor: WebsocketActor,
+
+    ws_stream: SplitStream<WebSocket>,
+    ws_sink: SplitSink<WebSocket, warp::ws::Message>,
+
     message_rx: Receiver<Envelope<WebsocketActor>>,
     command_rx: Receiver<ActorCommand>,
+
+    message_queue: VecDeque<Envelope<WebsocketActor>>,
+    ws_queue: VecDeque<warp::ws::Message>,
 }
 
 impl WebsocketRuntime {
     pub fn new(
-        actor: WebsocketActor,
+        mut actor: WebsocketActor,
         command_rx: Receiver<ActorCommand>,
         message_rx: Receiver<Envelope<WebsocketActor>>,
     ) -> Self {
+        let (ws_sink, ws_stream) = actor
+            .websocket
+            .take()
+            .expect("Websocket runtime already started")
+            .split();
+
         Self {
             actor,
+            ws_sink,
+            ws_stream,
             message_rx,
             command_rx,
+            message_queue: VecDeque::new(),
+            ws_queue: VecDeque::new(),
         }
     }
 }
 
-impl WebsocketRuntime {
-    pub async fn run(mut self) {
-        let (mut ws_sender, mut ws_receiver) = self
-            .actor
-            .websocket
-            .take()
-            .expect("Websocket runtime already started")
-            .split();
+impl Future for WebsocketRuntime {
+    type Output = Result<(), Error>;
+
+    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        let this = self.project();
 
         loop {
-            select! {
-                // Handle any pending commands
-                Some(msg) = self.command_rx.recv() => {
-                    match msg {
-                        ActorCommand::Stop => {
-                            println!("Actor stopping");
-                            break;
-                        }
+            // Poll command receiver
+            match Pin::new(&mut this.command_rx.recv_async()).poll(cx) {
+                Poll::Ready(Ok(message)) => match message {
+                    ActorCommand::Stop => {
+                        println!("Actor stopping");
+                        break Poll::Ready(Ok(())); // TODO drain the queue and all that graceful stuff
                     }
+                },
+                Poll::Ready(Err(_)) => {
+                    println!("Command stream dropped, ungracefully stopping actor");
+                    break Poll::Ready(Err(Error::ActorChannelClosed));
                 }
-                // Handle any in-process messages
-                Some(mut message) = self.message_rx.recv() => {
-                    println!("Processing Message");
-                    message.handle(&mut self.actor)
-                }
-                // Handle any messages from the websocket
-                Some(message) = ws_receiver.next() => {
-                    match message {
-                        Ok(message) => {
-                            if let Some(res) = self.actor.handle(message).unwrap() {// TODO
-                                ws_sender.send(res)
-                                    .unwrap_or_else(|e| {
-                                        eprintln!("websocket send error: {}", e);
-                                    })
-                                    .await;
-                            }
-                        },
-                        Err(e) => {
-                            eprintln!("WS error occurred {e}")
-                        },
+                Poll::Pending => {}
+            };
+
+            // Poll the websocket stream for any messages and store them to the queue
+            while let Poll::Ready(Some(ws_message)) = Pin::new(&mut this.ws_stream.next()).poll(cx)
+            {
+                match ws_message {
+                    Ok(message) => this.ws_queue.push_back(message),
+                    Err(e) => {
+                        eprintln!("WS error occurred {e}")
                     }
                 }
-                else => {
-                    println!("No messages")
+            }
+
+            // Respond to any queued websocket messages
+            while let Some(ws_message) = this.ws_queue.pop_front() {
+                if let Some(res) = this.actor.handle(ws_message)? {
+                    match Pin::new(&mut this.ws_sink.send(res)).poll(cx) {
+                        Poll::Ready(result) => result?,
+                        Poll::Pending => todo!(),
+                    }
                 }
             }
+
+            // Process all messages
+            while let Some(mut message) = this.message_queue.pop_front() {
+                message.handle(this.actor)
+            }
+
+            // Poll message receiver and continue to process if anything comes up
+            while let Poll::Ready(Ok(message)) =
+                Pin::new(&mut this.message_rx.recv_async()).poll(cx)
+            {
+                this.message_queue.push_back(message);
+            }
+
+            // Poll again and process new messages if any
+            match Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
+                Poll::Ready(Ok(message)) => {
+                    this.message_queue.push_back(message);
+                    continue;
+                }
+                Poll::Ready(Err(_)) => {
+                    println!("Message channel closed, ungracefully stopping actor");
+                    break Poll::Ready(Err(Error::ActorChannelClosed));
+                }
+                Poll::Pending => {
+                    if !this.message_queue.is_empty() {
+                        continue;
+                    }
+                }
+            };
+            cx.waker().wake_by_ref();
+            return Poll::Pending;
         }
     }
 }
 
 impl Runtime<WebsocketActor> for WebsocketRuntime {
-    fn start(actor: WebsocketActor) -> ActorHandle<WebsocketActor> {
-        let (tx, rx) = tokio::sync::mpsc::channel(100);
-        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::channel(100);
+    fn run(actor: WebsocketActor) -> ActorHandle<WebsocketActor> {
+        let (tx, rx) = flume::unbounded();
+        let (cmd_tx, cmd_rx) = flume::unbounded();
         let rt = WebsocketRuntime::new(actor, cmd_rx, rx);
-        tokio::spawn(rt.run());
+        tokio::spawn(rt);
         ActorHandle {
             message_tx: tx,
             command_tx: cmd_tx,

+ 13 - 73
tests/websocket.rs

@@ -1,100 +1,40 @@
+use std::collections::HashMap;
+use std::sync::{Arc, RwLock};
+
 use actors::ws::WebsocketActor;
-use actors::Actor;
+use actors::{Actor, ActorHandle};
 use warp::Filter;
 
+type Arbiter = Arc<RwLock<HashMap<usize, ActorHandle<WebsocketActor>>>>;
+
 #[tokio::main]
 async fn main() {
+    let arbiter = Arc::new(RwLock::new(HashMap::new()));
+    let arbiter = warp::any().map(move || arbiter.clone());
     // GET /chat -> websocket upgrade
     let chat = warp::path("chat")
         // The `ws()` filter will prepare Websocket handshake...
         .and(warp::ws())
-        .map(|ws: warp::ws::Ws| {
+        .and(arbiter)
+        .map(|ws: warp::ws::Ws, arbiter: Arbiter| {
             // This will call our function if the handshake succeeds.
             ws.on_upgrade(move |socket| {
                 let actor = WebsocketActor::new(socket);
-                actor.start();
+                let handle = actor.start();
+                arbiter.write().unwrap().insert(0, handle);
                 futures::future::ready(())
             })
         });
 
     // GET / -> index html
+
     let index = warp::path::end().map(|| warp::reply::html(INDEX_HTML));
 
     let routes = index.or(chat);
 
     warp::serve(routes).run(([127, 0, 0, 1], 3030)).await;
 }
-/*
-async fn user_connected(ws: WebSocket, users: Users) {
-    // Use a counter to assign a new unique ID for this user.
-    let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);
-
-    eprintln!("new chat user: {}", my_id);
-
-    // Split the socket into a sender and receive of messages.
-    let (mut user_ws_tx, mut user_ws_rx) = ws.split();
-
-    // Use an unbounded channel to handle buffering and flushing of messages
-    // to the websocket...
-    let (tx, rx) = mpsc::unbounded_channel();
-    let mut rx = UnboundedReceiverStream::new(rx);
-
-    tokio::task::spawn(async move {
-        while let Some(message) = rx.next().await {
-            user_ws_tx
-                .send(message)
-                .unwrap_or_else(|e| {
-                    eprintln!("websocket send error: {}", e);
-                })
-                .await;
-        }
-    });
-
-    // Save the sender in our list of connected users.
-    users.write().await.insert(my_id, tx);
 
-    // Return a `Future` that is basically a state machine managing
-    // this specific user's connection.
-
-    // Every time the user sends a message, broadcast it to
-    // all other users...
-    while let Some(result) = user_ws_rx.next().await {
-        let msg = match result {
-            Ok(msg) => {
-                println!("MESSAGE {msg:?}");
-                msg
-            }
-            Err(e) => {
-                eprintln!("websocket error(uid={}): {}", my_id, e);
-                break;
-            }
-        };
-        // Skip any non-Text messages...
-        let msg = if let Ok(s) = msg.to_str() {
-            s
-        } else {
-            return;
-        };
-
-        let new_msg = format!("<User#{}>: {}", my_id, msg);
-
-        // New message from this user, send it to everyone else (except same uid)...
-        for (&uid, tx) in users.read().await.iter() {
-            if my_id != uid {
-                if let Err(_disconnected) = tx.send(Message::text(new_msg.clone())) {
-                    // The tx is disconnected, our `user_disconnected` code
-                    // should be happening in another task, nothing more to
-                    // do here.
-                }
-            }
-        }
-    }
-
-    // user_ws_rx stream will keep processing as long as the user stays
-    // connected. Once they disconnect, then...
-    user_disconnected(my_id, &users).await;
-}
- */
 static INDEX_HTML: &str = r#"<!DOCTYPE html>
 <html lang="en">
     <head>