Browse Source

restructure using async await

biblius 1 year ago
parent
commit
5158afacbd
8 changed files with 277 additions and 570 deletions
  1. 1 0
      Cargo.toml
  2. 4 3
      src/debug.rs
  3. 37 51
      src/lib.rs
  4. 58 35
      src/message.rs
  5. 123 0
      src/relay.rs
  6. 26 177
      src/runtime.rs
  7. 0 260
      src/ws.rs
  8. 28 44
      tests/websocket.rs

+ 1 - 0
Cargo.toml

@@ -22,4 +22,5 @@ tokio = { version = "1.28.2", features = [
   "sync",
   "time",
 ] }
+tokio-tungstenite = "0.19.0"
 warp = "0.3.5"

+ 4 - 3
src/debug.rs

@@ -1,10 +1,11 @@
 //! So we don't polute the main lib with debug impls
 
-use crate::{Actor, Envelope};
+use crate::{Actor, Envelope, Handler};
 
-impl<A> std::fmt::Debug for Envelope<A>
+impl<M, A> std::fmt::Debug for Envelope<M, A>
 where
-    A: Actor,
+    A: Actor + Handler<M>,
+    M: Clone + Send,
 {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         f.debug_struct("Envelope")

+ 37 - 51
src/lib.rs

@@ -1,25 +1,25 @@
-use crate::runtime::{ActorRuntime, Runtime};
+use crate::runtime::ActorRuntime;
 use async_trait::async_trait;
 use flume::{SendError, Sender};
+use message::MailboxSender;
 use message::{Envelope, Enveloper, MessageRequest};
-use std::{fmt::Debug, sync::Arc};
+use std::fmt::Debug;
 use tokio::sync::oneshot;
-use tokio::sync::Mutex;
 
 pub mod debug;
 pub mod message;
+pub mod relay;
 pub mod runtime;
-pub mod ws;
 
 const DEFAULT_CHANNEL_CAPACITY: usize = 128;
 
-pub trait Actor {
-    fn start(self) -> ActorHandle<Self>
-    where
-        Self: Sized + Send + 'static,
-    {
+pub trait Actor: Sized + Send + Sync + 'static {
+    fn start(self) -> ActorHandle<Self> {
         println!("Starting actor");
-        ActorRuntime::run(self)
+        let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
+        let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
+        tokio::spawn(ActorRuntime::new(self, command_rx, message_rx).runt());
+        ActorHandle::new(message_tx, command_tx)
     }
 }
 
@@ -27,7 +27,7 @@ pub trait Actor {
 #[async_trait]
 pub trait Handler<M>: Actor {
     type Response;
-    async fn handle(this: Arc<Mutex<Self>>, message: Box<M>) -> Result<Self::Response, Error>;
+    async fn handle(&mut self, message: M) -> Result<Self::Response, Error>;
 }
 
 /// A handle to a spawned actor. Obtained when calling `start` on an [Actor] and is used to send messages
@@ -37,7 +37,7 @@ pub struct ActorHandle<A>
 where
     A: Actor,
 {
-    message_tx: Sender<Envelope<A>>,
+    message_tx: MailboxSender<A>,
     command_tx: Sender<ActorCommand>,
 }
 
@@ -57,7 +57,7 @@ impl<A> ActorHandle<A>
 where
     A: Actor,
 {
-    pub fn new(message_tx: Sender<Envelope<A>>, command_tx: Sender<ActorCommand>) -> Self {
+    pub fn new(message_tx: MailboxSender<A>, command_tx: Sender<ActorCommand>) -> Self {
         Self {
             message_tx,
             command_tx,
@@ -69,7 +69,7 @@ where
     /// actor.
     pub fn send_wait<M>(&self, message: M) -> Result<MessageRequest<A::Response>, SendError<M>>
     where
-        M: Send,
+        M: Send + Clone,
         A: Handler<M> + Enveloper<A, M>,
     {
         if self.message_tx.is_full() || self.message_tx.is_disconnected() {
@@ -84,6 +84,7 @@ where
     /// error if the channel is full or disconnected.
     pub fn send<M>(&self, message: M) -> Result<(), SendError<M>>
     where
+        M: Clone + Send,
         A: Handler<M> + Enveloper<A, M> + 'static,
     {
         if self.message_tx.is_full() || self.message_tx.is_disconnected() {
@@ -96,7 +97,7 @@ where
     /// Send a message ignoring any errors in the process. The true YOLO way to send messages.
     pub fn send_forget<M>(&self, message: M)
     where
-        M: Send + 'static,
+        M: Clone + Send + 'static,
         A: Handler<M> + Enveloper<A, M> + 'static,
     {
         let _ = self.message_tx.send(A::pack(message, None));
@@ -139,21 +140,18 @@ pub enum ActorCommand {
 #[cfg(test)]
 mod tests {
 
-    use std::{sync::atomic::AtomicUsize, time::Duration};
-
-    use tokio::task::LocalSet;
-
     use super::*;
+    use std::{sync::atomic::AtomicUsize, time::Duration};
 
     #[tokio::test]
     async fn it_works_sync() {
         #[derive(Debug)]
         struct Testor {}
 
-        #[derive(Debug)]
+        #[derive(Debug, Clone)]
         struct Foo {}
 
-        #[derive(Debug)]
+        #[derive(Debug, Clone)]
         struct Bar {}
 
         impl Actor for Testor {}
@@ -161,7 +159,7 @@ mod tests {
         #[async_trait]
         impl Handler<Foo> for Testor {
             type Response = usize;
-            async fn handle(_: Arc<Mutex<Self>>, _: Box<Foo>) -> Result<usize, Error> {
+            async fn handle(&mut self, _: Foo) -> Result<usize, Error> {
                 println!("Handling Foo");
                 Ok(10)
             }
@@ -170,7 +168,7 @@ mod tests {
         #[async_trait]
         impl Handler<Bar> for Testor {
             type Response = isize;
-            async fn handle(_: Arc<Mutex<Self>>, _: Box<Bar>) -> Result<isize, Error> {
+            async fn handle(&mut self, _: Bar) -> Result<isize, Error> {
                 for _ in 0..10_000 {
                     println!("Handling Bar");
                 }
@@ -190,22 +188,21 @@ mod tests {
 
         handle.send(Foo {}).unwrap();
         handle.send_forget(Bar {});
-
         handle.send_cmd(ActorCommand::Stop).unwrap();
 
         assert_eq!(res, 1000);
         assert_eq!(res2, 1000);
     }
 
-    #[test]
-    fn it_works_yolo() {
+    #[tokio::test]
+    async fn it_works_yolo() {
         #[derive(Debug)]
         struct Testor {}
 
-        #[derive(Debug)]
+        #[derive(Debug, Clone)]
         struct Foo {}
 
-        #[derive(Debug)]
+        #[derive(Debug, Clone)]
         struct Bar {}
 
         impl Actor for Testor {}
@@ -215,7 +212,7 @@ mod tests {
         #[async_trait]
         impl Handler<Foo> for Testor {
             type Response = usize;
-            async fn handle(_: Arc<Mutex<Testor>>, _: Box<Foo>) -> Result<usize, Error> {
+            async fn handle(&mut self, _: Foo) -> Result<usize, Error> {
                 println!("INCREMENTING COUNT FOO");
                 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                 Ok(10)
@@ -225,37 +222,26 @@ mod tests {
         #[async_trait]
         impl Handler<Bar> for Testor {
             type Response = isize;
-            async fn handle(_: Arc<Mutex<Testor>>, _: Box<Bar>) -> Result<isize, Error> {
+            async fn handle(&mut self, _: Bar) -> Result<isize, Error> {
                 println!("INCREMENTING COUNT BAR");
                 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                 Ok(10)
             }
         }
 
-        let rt = tokio::runtime::Builder::new_current_thread()
-            .enable_all()
-            .build()
-            .unwrap();
-        let local_set = LocalSet::new();
+        let handle = Testor {}.start();
 
-        let task = async {
-            let handle = Testor {}.start();
+        handle.send_wait(Bar {}).unwrap().await.unwrap();
+        handle.send(Foo {}).unwrap();
+        handle.send_forget(Bar {});
 
-            handle.send_wait(Bar {}).unwrap().await.unwrap();
-            handle.send(Foo {}).unwrap();
+        for _ in 0..100 {
+            let _ = handle.send(Foo {});
             handle.send_forget(Bar {});
+            tokio::time::sleep(Duration::from_micros(100)).await
+        }
 
-            for _ in 0..100 {
-                let _ = handle.send(Foo {});
-                handle.send_forget(Bar {});
-                tokio::time::sleep(Duration::from_micros(100)).await
-            }
-
-            assert_eq!(COUNT.load(std::sync::atomic::Ordering::Relaxed), 203);
-            handle.send_cmd(ActorCommand::Stop)
-        };
-
-        local_set.spawn_local(task);
-        rt.block_on(local_set)
+        assert_eq!(COUNT.load(std::sync::atomic::Ordering::Relaxed), 203);
+        handle.send_cmd(ActorCommand::Stop).unwrap();
     }
 }

+ 58 - 35
src/message.rs

@@ -1,94 +1,117 @@
-use std::sync::Arc;
-
 use crate::{Actor, Error, Handler};
 use async_trait::async_trait;
+use std::marker::PhantomData;
 use tokio::sync::oneshot;
-use tokio::sync::Mutex;
+
+#[async_trait]
+pub trait MessageHandler<A: Actor>: Send + Sync {
+    async fn handle(&mut self, actor: &mut A);
+}
+
+pub type BoxedMessageHandler<A> = Box<dyn MessageHandler<A>>;
+
+pub type MailboxReceiver<A> = flume::Receiver<BoxedMessageHandler<A>>;
+pub type MailboxSender<A> = flume::Sender<BoxedMessageHandler<A>>;
+
+pub struct ActorMailbox<M, A: Handler<M>> {
+    _phantom_actor: PhantomData<A>,
+    _phantom_msg: PhantomData<M>,
+}
 
 /// Represents a type erased message that ultimately gets stored in an [Envelope]. We need this indirection so we can abstract away the concrete message
 /// type when creating an actor handle, otherwise we would only be able to send a single message type to the actor.
-#[async_trait]
+/* #[async_trait]
 pub trait ActorMessage<A: Actor> {
-    async fn handle(self: Box<Self>, actor: Arc<Mutex<A>>);
-}
+    async fn handle(self, actor: &mut A);
+} */
 
 /// Used by [ActorHandle][super::ActorHandle]s to pack messages into [Envelope]s so we have a type erased message to send to the actor.
 pub trait Enveloper<A, M>
 where
     A: Handler<M>,
+    M: Clone + Send,
 {
     /// Wrap a message in an envelope with an optional response channel.
-    fn pack(message: M, tx: Option<oneshot::Sender<<A as Handler<M>>::Response>>) -> Envelope<A>;
+    fn pack(
+        message: M,
+        tx: Option<oneshot::Sender<<A as Handler<M>>::Response>>,
+    ) -> Box<dyn MessageHandler<A>>;
 }
 
 /// A type erased wrapper for messages. This wrapper essentially enables us to send any message to the actor
 /// so long as it implements the necessary handler.
-pub struct Envelope<A>
+pub struct Envelope<M, A>
 where
-    A: Actor,
+    A: Handler<M>,
+    M: Clone + Send + 'static,
 {
-    message: Box<dyn ActorMessage<A> + Send>,
+    message: M,
+    response_tx: Option<oneshot::Sender<A::Response>>,
 }
 
-impl<A> Envelope<A>
+impl<M, A> Envelope<M, A>
 where
-    A: Actor,
+    A: Handler<M> + Send + 'static,
+    A::Response: Send,
+    M: Clone + Send + Sync + 'static,
 {
-    pub fn new<M>(message: M, tx: Option<oneshot::Sender<A::Response>>) -> Self
-    where
-        A: Handler<M> + Send + 'static,
-        A::Response: Send,
-        M: Send + 'static,
-    {
+    pub fn new(message: M, tx: Option<oneshot::Sender<A::Response>>) -> Self {
         Self {
-            message: Box::new(EnvelopeInner {
-                message: Box::new(message),
-                tx,
-            }),
+            message,
+            response_tx: tx,
         }
     }
 }
 
 #[async_trait]
-impl<A> ActorMessage<A> for Envelope<A>
+impl<M, A> MessageHandler<A> for Envelope<M, A>
 where
-    A: Actor + Send,
+    A: Actor + Handler<M> + Send,
+    M: Clone + Send + Sync,
+    A::Response: Send,
 {
-    async fn handle(self: Box<Self>, actor: Arc<Mutex<A>>) {
-        ActorMessage::handle(self.message, actor).await
+    async fn handle(&mut self, actor: &mut A) {
+        let result = A::handle(actor, self.message.clone()).await;
+        if let Some(res_tx) = self.response_tx.take() {
+            let _ = res_tx.send(result.unwrap());
+        }
     }
 }
 
 /// The inner parts of the [Envelope] containing the actual message as well as an optional
 /// response channel.
-struct EnvelopeInner<M, R> {
-    message: Box<M>,
+/* struct EnvelopeInner<M, R> {
+    message: M,
     tx: Option<oneshot::Sender<R>>,
 }
 
 #[async_trait]
 impl<A, M> ActorMessage<A> for EnvelopeInner<M, <A as Handler<M>>::Response>
 where
-    A: Handler<M> + Send + 'static,
+    A: Handler<M>,
     A::Response: Send,
-    M: Send,
+    M: Clone + Send + Sync + 'static,
 {
-    async fn handle(self: Box<Self>, actor: Arc<Mutex<A>>) {
-        let result = A::handle(actor, self.message).await;
+    async fn handle(self, actor: &mut A) {
+        let result = A::handle(actor, self.message.clone()).await;
         if let Some(res_tx) = self.tx {
             let _ = res_tx.send(result.unwrap());
         }
     }
 }
+ */
 
 impl<A, M> Enveloper<A, M> for A
 where
     A: Handler<M> + Send + 'static,
     A::Response: Send,
-    M: Send + Sync + 'static,
+    M: Clone + Send + Sync + 'static,
 {
-    fn pack(message: M, tx: Option<oneshot::Sender<<A as Handler<M>>::Response>>) -> Envelope<A> {
-        Envelope::new(message, tx)
+    fn pack(
+        message: M,
+        tx: Option<oneshot::Sender<<A as Handler<M>>::Response>>,
+    ) -> Box<dyn MessageHandler<A>> {
+        Box::new(Envelope::new(message, tx))
     }
 }
 

+ 123 - 0
src/relay.rs

@@ -0,0 +1,123 @@
+use crate::{
+    message::MailboxReceiver, Actor, ActorCommand, ActorHandle, Error, DEFAULT_CHANNEL_CAPACITY,
+};
+use async_trait::async_trait;
+use flume::{r#async::RecvStream, Receiver, Sender};
+use futures::{Stream, StreamExt};
+use std::{fmt::Display, sync::atomic::AtomicUsize};
+
+/// Represents an actor that has access to a stream and a sender channel
+/// which it can respond to.
+///
+/// A websocket actor receives messages via the stream and processes them with
+/// its [Handler] implementation. The handler implementation should always return an
+/// `Option<M>` where M is the type used when implementing this trait. A handler that returns
+/// `None` will not forward any response to the sink. If the handler returns `Some(M)` it will
+/// be forwarded to the sink.
+pub trait RelayActor<M, Str>: Actor
+where
+    Self: Relay<M>,
+    Self::Error: Send,
+    Str: Stream<Item = Result<M, Self::Error>> + Unpin + Send + 'static,
+    M: Send + 'static,
+{
+    /// The error type of the underlying websocket implementation.
+    type Error: Display;
+
+    fn start_relay(self, stream: Str, sender: Sender<M>) -> ActorHandle<Self> {
+        println!("Starting actor");
+        let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
+        let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
+        tokio::spawn(RelayRuntime::new(self, command_rx, message_rx, stream, sender).runt());
+        ActorHandle::new(message_tx, command_tx)
+    }
+}
+
+#[async_trait]
+pub trait Relay<M>: Actor {
+    async fn handle(&mut self, message: M) -> Result<Option<M>, Error>;
+}
+
+pub struct RelayRuntime<A, M, Str>
+where
+    A: RelayActor<M, Str> + Relay<M> + 'static,
+    Str: Stream<Item = Result<M, A::Error>> + Send + Unpin + 'static,
+    M: Send + 'static,
+    A::Error: Send,
+{
+    actor: A,
+
+    /// The receiving end of the websocket
+    ws_stream: Str,
+
+    /// The sending end of the websocket. Hooked to a receiver that forwards any
+    /// response sent from here.
+    ws_sender: Sender<M>,
+
+    /// Actor command receiver
+    command_stream: RecvStream<'static, ActorCommand>,
+
+    mailbox: MailboxReceiver<A>,
+}
+
+static PROCESS: AtomicUsize = AtomicUsize::new(0);
+
+impl<A, M, Str> RelayRuntime<A, M, Str>
+where
+    Str: Stream<Item = Result<M, A::Error>> + Send + Unpin,
+    A: RelayActor<M, Str> + Send + 'static + Relay<M>,
+    M: Send,
+    A::Error: Send,
+{
+    pub fn new(
+        actor: A,
+        command_rx: Receiver<ActorCommand>,
+        mailbox: MailboxReceiver<A>,
+        stream: Str,
+        sender: Sender<M>,
+    ) -> Self {
+        Self {
+            actor,
+            ws_sender: sender,
+            ws_stream: stream,
+            mailbox,
+            command_stream: command_rx.into_stream(),
+        }
+    }
+
+    pub async fn runt(mut self) {
+        loop {
+            tokio::select! {
+            Some(command) = self.command_stream.next() => {
+               match command {
+                        ActorCommand::Stop => {
+                            println!("actor stopping");
+                            return
+                        },
+                    }
+            }
+            message = self.mailbox.recv_async() => {
+                if let Ok(mut message) = message {
+                    message.handle(&mut self.actor).await;
+                } else {
+                    break;
+                }
+            }
+            Some(ws_msg) = self.ws_stream.next() => {
+                    match ws_msg {
+                        Ok(msg) => {
+                            let res = self.actor.handle(msg).await.unwrap();
+                            PROCESS.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
+                            println!("PROCESSED {}", PROCESS.load(std::sync::atomic::Ordering::Relaxed));
+                            if let Some(res) = res {
+                                self.ws_sender.send_async(res).await.unwrap();
+                            }
+                        },
+                        Err(_) => todo!(),
+                    }
+            }
+            }
+        }
+        println!("actor stopping");
+    }
+}

+ 26 - 177
src/runtime.rs

@@ -1,138 +1,16 @@
-use crate::{
-    message::ActorMessage, Actor, ActorCommand, ActorHandle, Envelope, Error,
-    DEFAULT_CHANNEL_CAPACITY,
-};
+use crate::{message::MailboxReceiver, Actor, ActorCommand};
 use flume::{r#async::RecvStream, Receiver};
-use futures::{Future, StreamExt};
-use std::{
-    collections::VecDeque,
-    pin::Pin,
-    sync::Arc,
-    task::{ready, Context, Poll},
-};
-use tokio::sync::Mutex;
+use futures::StreamExt;
 
 pub const QUEUE_CAPACITY: usize = 128;
 
-pub trait Runtime<A>
-where
-    A: Actor + Send + 'static,
-{
-    fn command_stream(&mut self) -> &mut RecvStream<'static, ActorCommand>;
-
-    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<A>>;
-
-    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<A>>;
-
-    fn actor(&self) -> Arc<Mutex<A>>;
-
-    fn at_capacity(&self) -> bool;
-
-    fn run(actor: A) -> ActorHandle<A> {
-        let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
-        let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
-        tokio::spawn(ActorRuntime::new(actor, command_rx, message_rx));
-        ActorHandle::new(message_tx, command_tx)
-    }
-
-    fn process_commands(&mut self, cx: &mut Context<'_>) -> Result<Option<ActorCommand>, Error> {
-        match self.command_stream().poll_next_unpin(cx) {
-            Poll::Ready(Some(command)) => Ok(Some(command)),
-            Poll::Ready(None) => {
-                println!("Command channel closed, ungracefully stopping actor");
-                Err(Error::ActorChannelClosed)
-            }
-            Poll::Pending => Ok(None),
-        }
-    }
-
-    fn process_messages(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
-        let actor = self.actor();
-
-        self.processing_queue()
-            .retain_mut(|job| job.poll(actor.clone(), cx).is_pending());
-
-        // Poll message receiver
-        if !self.at_capacity() {
-            while let Poll::Ready(message) = self.message_stream().poll_next_unpin(cx) {
-                let Some(message) = message else { return Err(Error::ActorChannelClosed) };
-                self.processing_queue().push_back(ActorJob::new(message));
-                if self.at_capacity() {
-                    break;
-                }
-            }
-        }
-
-        // Process pending futures again after we've potentially received some
-        self.processing_queue()
-            .retain_mut(|job| job.poll(actor.clone(), cx).is_pending());
-
-        if self.at_capacity() {
-            return Ok(());
-        }
-
-        match self.message_stream().poll_next_unpin(cx) {
-            Poll::Ready(Some(message)) => {
-                self.processing_queue().push_back(ActorJob::new(message));
-                Ok(())
-            }
-            Poll::Ready(None) => {
-                println!("Message channel closed, ungracefully stopping actor");
-                Err(Error::ActorChannelClosed)
-            }
-            Poll::Pending => Ok(()),
-        }
-    }
-}
-
-impl<A> Runtime<A> for ActorRuntime<A>
-where
-    A: Actor + Send + 'static,
-{
-    fn run(actor: A) -> ActorHandle<A> {
-        let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
-        let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
-        tokio::spawn(ActorRuntime::new(actor, command_rx, message_rx));
-        ActorHandle::new(message_tx, command_tx)
-    }
-
-    #[inline]
-    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<A>> {
-        &mut self.process_queue
-    }
-
-    #[inline]
-    fn command_stream(&mut self) -> &mut RecvStream<'static, ActorCommand> {
-        &mut self.command_stream
-    }
-
-    #[inline]
-    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<A>> {
-        &mut self.message_stream
-    }
-
-    #[inline]
-    fn actor(&self) -> Arc<Mutex<A>> {
-        self.actor.clone()
-    }
-
-    #[inline]
-    fn at_capacity(&self) -> bool {
-        self.process_queue.len() >= QUEUE_CAPACITY
-    }
-}
-
-/// A future representing a message currently being handled. Created when polling an [ActorJob].
-pub type ActorFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
-
 pub struct ActorRuntime<A>
 where
     A: Actor + Send + 'static,
 {
-    actor: Arc<Mutex<A>>,
+    actor: A,
     command_stream: RecvStream<'static, ActorCommand>,
-    message_stream: RecvStream<'static, Envelope<A>>,
-    process_queue: VecDeque<ActorJob<A>>,
+    mailbox: MailboxReceiver<A>,
 }
 
 impl<A> ActorRuntime<A>
@@ -142,65 +20,36 @@ where
     pub fn new(
         actor: A,
         command_rx: Receiver<ActorCommand>,
-        message_rx: Receiver<Envelope<A>>,
+        message_rx: MailboxReceiver<A>,
     ) -> Self {
         println!("Building default runtime");
         Self {
-            actor: Arc::new(Mutex::new(actor)),
+            actor,
             command_stream: command_rx.into_stream(),
-            message_stream: message_rx.into_stream(),
-            process_queue: VecDeque::with_capacity(QUEUE_CAPACITY),
-        }
-    }
-}
-
-pub struct ActorJob<A>
-where
-    A: Actor,
-{
-    message: Option<Envelope<A>>,
-    future: Option<ActorFuture>,
-}
-
-impl<A> ActorJob<A>
-where
-    A: Actor + Send + 'static,
-{
-    fn new(message: Envelope<A>) -> Self {
-        Self {
-            message: Some(message),
-            future: None,
+            mailbox: message_rx,
         }
     }
 
-    fn poll(&mut self, actor: Arc<Mutex<A>>, cx: &mut std::task::Context<'_>) -> Poll<()> {
-        match self.message.take() {
-            Some(message) => {
-                let fut = Box::new(message).handle(actor);
-                self.future = Some(fut);
-                ready!(self.future.as_mut().unwrap().as_mut().poll(cx));
-                Poll::Ready(())
-            }
-            None => {
-                ready!(self.future.as_mut().unwrap().as_mut().poll(cx));
-                Poll::Ready(())
+    pub async fn runt(mut self) {
+        loop {
+            tokio::select! {
+                Some(command) = self.command_stream.next() => {
+                   match command {
+                            ActorCommand::Stop => {
+                                println!("actor stopping");
+                                return
+                            },
+                        }
+                }
+                message = self.mailbox.recv_async() => {
+                    if let Ok(mut message) = message {
+                        message.handle(&mut self.actor).await
+                    } else {
+                         break;
+                     }
+                }
             }
         }
-    }
-}
-
-impl<A> Future for ActorRuntime<A>
-where
-    A: Actor + Send + 'static,
-{
-    type Output = Result<(), Error>;
-
-    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        let mut this = self.as_mut();
-
-        this.process_commands(cx)?;
-        this.process_messages(cx)?;
-        cx.waker().wake_by_ref();
-        Poll::Pending
+        println!("actor stopping");
     }
 }

+ 0 - 260
src/ws.rs

@@ -1,260 +0,0 @@
-use crate::{
-    message::Envelope,
-    runtime::{ActorJob, Runtime, QUEUE_CAPACITY},
-    Actor, ActorCommand, ActorHandle, Error, Handler,
-};
-use flume::{r#async::RecvStream, Receiver};
-use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
-use std::{
-    collections::VecDeque,
-    fmt::Display,
-    marker::PhantomData,
-    pin::Pin,
-    sync::atomic::AtomicUsize,
-    task::{Context, Poll},
-};
-use std::{sync::Arc, task::ready};
-use tokio::sync::Mutex;
-
-const WS_QUEUE_SIZE: usize = 128;
-
-/// Represents an actor that can get access to a websocket stream and sink.
-///
-/// A websocket actor receives messages via the stream and processes them with
-/// its [Handler] implementation. The handler implementation should always return an
-/// `Option<M>` where M is the type used when implementing this trait. A handler that returns
-/// `None` will not forward any response to the sink. If the handler returns `Some(M)` it will
-/// be forwarded to the sink.
-pub trait WsActor<M, Str, Sin>
-where
-    Str: Stream<Item = Result<M, Self::Error>>,
-    Sin: Sink<M>,
-{
-    /// The error type of the underlying websocket implementation.
-    type Error: Display;
-    fn websocket(&mut self) -> (Sin, Str);
-}
-
-pub struct WebsocketRuntime<A, M, Str, Sin>
-where
-    A: Actor + WsActor<M, Str, Sin> + Handler<M> + 'static,
-    Str: Stream<Item = Result<M, A::Error>>,
-    Sin: Sink<M>,
-{
-    actor: Arc<Mutex<A>>,
-
-    /// The receiving end of the websocket
-    ws_stream: Str,
-
-    /// The sending end of the websocket
-    ws_sink: Sin,
-
-    /// Actor message receiver
-    message_stream: RecvStream<'static, Envelope<A>>,
-
-    /// Actor command receiver
-    command_stream: RecvStream<'static, ActorCommand>,
-
-    /// Actor messages currently being processed
-    process_queue: VecDeque<ActorJob<A>>,
-
-    /// Received, but not yet processed websocket messages
-    response_queue: VecDeque<WebsocketJob<A, M>>,
-}
-
-impl<A, M, Str, Sin> WebsocketRuntime<A, M, Str, Sin>
-where
-    Str: Stream<Item = Result<M, A::Error>>,
-    Sin: Sink<M>,
-    A: Actor + WsActor<M, Str, Sin> + Send + 'static + Handler<M>,
-{
-    pub fn new(
-        mut actor: A,
-        command_rx: Receiver<ActorCommand>,
-        message_rx: Receiver<Envelope<A>>,
-    ) -> Self {
-        let (ws_sink, ws_stream) = actor.websocket();
-
-        Self {
-            actor: Arc::new(Mutex::new(actor)),
-            ws_sink,
-            ws_stream,
-            message_stream: message_rx.into_stream(),
-            command_stream: command_rx.into_stream(),
-            response_queue: VecDeque::new(),
-            process_queue: VecDeque::new(),
-        }
-    }
-}
-
-impl<A, M, Str, Sin> Future for WebsocketRuntime<A, M, Str, Sin>
-where
-    Self: Runtime<A>,
-    Str: Stream<Item = Result<M, A::Error>> + Unpin,
-    Sin: Sink<M> + Unpin,
-    A: Actor + WsActor<M, Str, Sin> + Handler<M, Response = Option<M>> + Send + Unpin + 'static,
-{
-    type Output = Result<(), Error>;
-
-    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        let actor = self.actor();
-        let this = self.get_mut();
-
-        this.process_commands(cx)?;
-
-        // Poll the websocket stream for any messages and store them to the queue
-        if this.response_queue.len() < WS_QUEUE_SIZE {
-            while let Poll::Ready(Some(ws_message)) = this.ws_stream.poll_next_unpin(cx) {
-                match ws_message {
-                    Ok(message) => {
-                        this.response_queue.push_back(WebsocketJob::new(message));
-                        if this.response_queue.len() >= WS_QUEUE_SIZE {
-                            break;
-                        }
-                    }
-                    Err(e) => {
-                        eprintln!("WS error occurred {e}")
-                    }
-                }
-            }
-        }
-
-        let mut idx = 0;
-        while idx < this.response_queue.len() {
-            let job = &mut this.response_queue[idx];
-            match job.poll(actor.clone(), cx) {
-                Poll::Ready(result) => match result {
-                    Ok(response) => {
-                        if let Some(response) = response {
-                            let feed = &mut this.ws_sink.feed(response);
-                            let mut feed = Pin::new(feed);
-                            while feed.as_mut().poll(cx).is_pending() {
-                                // Yikes, but too dumb to figure out a better solution
-                                cx.waker().wake_by_ref();
-                            }
-                        }
-                        PROCESSED.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
-                        this.response_queue.swap_remove_front(idx);
-                    }
-                    Err(e) => {
-                        println!("Shit's fucked my dude {e}")
-                    }
-                },
-                Poll::Pending => idx += 1,
-            }
-        }
-        this.process_messages(cx)?;
-
-        println!(
-            "PROCESSED {} CURRENT IN QUEUE {}",
-            PROCESSED.load(std::sync::atomic::Ordering::Acquire),
-            this.response_queue.len(),
-        );
-
-        let _ = this.ws_sink.flush().poll_unpin(cx);
-
-        cx.waker().wake_by_ref();
-        Poll::Pending
-    }
-}
-
-impl<A, M, Str, Sin> Runtime<A> for WebsocketRuntime<A, M, Str, Sin>
-where
-    Str: Stream<Item = Result<M, A::Error>> + Unpin + Send + 'static,
-    Sin: Sink<M> + Unpin + Send + 'static,
-    A: Actor + WsActor<M, Str, Sin> + Send + 'static + Handler<M, Response = Option<M>> + Unpin,
-    M: Send + 'static,
-{
-    fn run(actor: A) -> ActorHandle<A> {
-        let (message_tx, message_rx) = flume::unbounded();
-        let (command_tx, command_rx) = flume::unbounded();
-        tokio::spawn(WebsocketRuntime::new(actor, command_rx, message_rx));
-        ActorHandle::new(message_tx, command_tx)
-    }
-
-    #[inline]
-    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<A>> {
-        &mut self.process_queue
-    }
-
-    #[inline]
-    fn command_stream(&mut self) -> &mut RecvStream<'static, ActorCommand> {
-        &mut self.command_stream
-    }
-
-    #[inline]
-    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<A>> {
-        &mut self.message_stream
-    }
-
-    #[inline]
-    fn actor(&self) -> Arc<Mutex<A>> {
-        self.actor.clone()
-    }
-
-    #[inline]
-    fn at_capacity(&self) -> bool {
-        self.process_queue.len() >= QUEUE_CAPACITY
-    }
-}
-
-struct WebsocketJob<A, M>
-where
-    A: Handler<M>,
-{
-    message: Option<Box<M>>,
-    future: Option<WsFuture<A, M>>,
-    __a: PhantomData<A>,
-}
-
-impl<A, M> WebsocketJob<A, M>
-where
-    A: Handler<M> + 'static,
-{
-    pub fn new(message: M) -> Self {
-        Self {
-            message: Some(Box::new(message)),
-            future: None,
-            __a: PhantomData,
-        }
-    }
-
-    fn poll(
-        &mut self,
-        actor: Arc<Mutex<A>>,
-        cx: &mut std::task::Context<'_>,
-    ) -> Poll<Result<A::Response, Error>> {
-        let message = self.message.take();
-
-        match message {
-            Some(message) => {
-                let fut = A::handle(actor, message);
-                self.future = Some(fut);
-                let result = ready!(self.future.as_mut().unwrap().as_mut().poll(cx));
-                match result {
-                    Ok(response) => Poll::Ready(Ok(response)),
-                    Err(e) => {
-                        println!("Shit's fucked son {e}");
-                        Poll::Ready(Err(e))
-                    }
-                }
-            }
-            None => {
-                let Some(ref mut fut) = self.future else { panic!("Impossibru") };
-                let result = ready!(fut.as_mut().poll(cx));
-                match result {
-                    Ok(response) => Poll::Ready(Ok(response)),
-                    Err(e) => {
-                        println!("Shit's fucked son {e}");
-                        Poll::Ready(Err(e))
-                    }
-                }
-            }
-        }
-    }
-}
-
-type WsFuture<A, M> =
-    Pin<Box<dyn Future<Output = Result<<A as Handler<M>>::Response, Error>> + Send>>;
-
-static PROCESSED: AtomicUsize = AtomicUsize::new(0);

+ 28 - 44
tests/websocket.rs

@@ -1,66 +1,41 @@
 use async_trait::async_trait;
-use drama::runtime::Runtime;
-use drama::ws::{WebsocketRuntime, WsActor};
+use drama::relay::{Relay, RelayActor};
 use drama::{Actor, ActorHandle, Error, Handler};
-use futures::stream::{SplitSink, SplitStream};
+use futures::stream::SplitStream;
 use futures::StreamExt;
+use parking_lot::RwLock;
 use std::collections::HashMap;
 use std::sync::atomic::AtomicUsize;
-use std::sync::{Arc, RwLock};
-use tokio::sync::Mutex;
+use std::sync::Arc;
 use warp::ws::{Message, WebSocket};
 use warp::Filter;
 
-type Arbiter = Arc<RwLock<HashMap<usize, ActorHandle<WebsocketActor>>>>;
-
-static ID: AtomicUsize = AtomicUsize::new(0);
-
 struct WebsocketActor {
-    websocket: Option<WebSocket>,
     hello: ActorHandle<Hello>,
 }
 
 impl WebsocketActor {
-    fn new(ws: WebSocket, handle: ActorHandle<Hello>) -> Self {
-        Self {
-            websocket: Some(ws),
-            hello: handle,
-        }
+    fn new(handle: ActorHandle<Hello>) -> Self {
+        Self { hello: handle }
     }
 }
 
-impl Actor for WebsocketActor {
-    fn start(self) -> ActorHandle<Self> {
-        WebsocketRuntime::run(self)
-    }
-}
+impl Actor for WebsocketActor {}
 
-impl WsActor<Message, SplitStream<WebSocket>, SplitSink<WebSocket, Message>> for WebsocketActor {
+impl RelayActor<Message, SplitStream<WebSocket>> for WebsocketActor {
     type Error = warp::Error;
-    fn websocket(&mut self) -> (SplitSink<WebSocket, Message>, SplitStream<WebSocket>) {
-        self.websocket
-            .take()
-            .expect("Websocket already split")
-            .split()
-    }
 }
 
 #[async_trait]
-impl Handler<Message> for WebsocketActor {
-    type Response = Option<Message>;
-    async fn handle(
-        this: Arc<Mutex<Self>>,
-        message: Box<Message>,
-    ) -> Result<Self::Response, Error> {
-        this.lock()
-            .await
-            .hello
+impl Relay<Message> for WebsocketActor {
+    async fn handle(&mut self, message: Message) -> Result<Option<Message>, Error> {
+        self.hello
             .send(crate::Msg {
                 _content: message.to_str().unwrap().to_owned(),
             })
             .unwrap_or_else(|e| println!("FUKEN HELL M8 {e}"));
 
-        Ok(Some(*message.clone()))
+        Ok(Some(message))
     }
 }
 
@@ -68,6 +43,7 @@ struct Hello {}
 
 impl Actor for Hello {}
 
+#[derive(Clone)]
 struct Msg {
     pub _content: String,
 }
@@ -75,11 +51,15 @@ struct Msg {
 #[async_trait]
 impl Handler<Msg> for Hello {
     type Response = usize;
-    async fn handle(_: Arc<Mutex<Self>>, _: Box<Msg>) -> Result<usize, Error> {
+    async fn handle(&mut self, _: Msg) -> Result<usize, Error> {
         Ok(10)
     }
 }
 
+type Arbiter = Arc<RwLock<HashMap<usize, ActorHandle<WebsocketActor>>>>;
+
+static ID: AtomicUsize = AtomicUsize::new(0);
+
 #[tokio::main]
 async fn main() {
     let pool = Arc::new(RwLock::new(HashMap::new()));
@@ -97,11 +77,15 @@ async fn main() {
             |ws: warp::ws::Ws, pool: Arbiter, hello: ActorHandle<Hello>| {
                 // This will call our function if the handshake succeeds.
                 ws.on_upgrade(|socket| async move {
-                    let actor = WebsocketActor::new(socket, hello);
-                    let handle = actor.start();
+                    let (si, st) = socket.split();
+                    let (tx, rx) = flume::unbounded();
+
+                    let actor = WebsocketActor::new(hello);
+                    let handle = actor.start_relay(st, tx);
+                    tokio::spawn(rx.into_stream().map(Ok).forward(si));
                     let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                     println!("Adding actor {id}");
-                    pool.write().unwrap().insert(id, handle);
+                    pool.write().insert(id, handle);
                 })
             },
         );
@@ -156,10 +140,10 @@ static INDEX_HTML: &str = r#"<!DOCTYPE html>
         send.onclick = function() {
             const msg = text.value;
             let i = 0;
-            while (i < 100000) {
+             while (i < 100000) {
                 ws.send(msg);
-                i += 1;
-            }
+                 i += 1;
+             }
             // text.value = '';
 
             message('<You>: ' + msg);