Browse Source

make websocket actor generic

biblius 1 year ago
parent
commit
963c68a9e9
5 changed files with 254 additions and 338 deletions
  1. 21 149
      src/lib.rs
  2. 29 28
      src/message.rs
  3. 7 15
      src/runtime.rs
  4. 120 142
      src/ws.rs
  5. 77 4
      tests/websocket.rs

+ 21 - 149
src/lib.rs

@@ -1,7 +1,7 @@
 use crate::runtime::{ActorRuntime, Runtime};
 use async_trait::async_trait;
 use flume::{SendError, Sender};
-use message::{Envelope, Enveloper, Message, MessageRequest};
+use message::{Envelope, Enveloper, MessageRequest};
 use std::{fmt::Debug, sync::Arc};
 use tokio::sync::oneshot;
 use tokio::sync::Mutex;
@@ -11,45 +11,23 @@ pub mod message;
 pub mod runtime;
 pub mod ws;
 
-pub struct Hello {}
-
-impl Actor for Hello {}
-
-pub struct Msg {
-    pub content: String,
-}
-impl Message for Msg {
-    type Response = usize;
-}
-
-#[async_trait]
-impl Handler<Msg> for Hello {
-    async fn handle(_: Arc<Mutex<Self>>, _: Box<Msg>) -> Result<usize, Error> {
-        println!("Handling message Hello");
-        Ok(10)
-    }
-}
-
 const DEFAULT_CHANNEL_CAPACITY: usize = 128;
 
-#[async_trait]
 pub trait Actor {
-    async fn start(self) -> ActorHandle<Self>
+    fn start(self) -> ActorHandle<Self>
     where
         Self: Sized + Send + 'static,
     {
         println!("Starting actor");
-        ActorRuntime::run(Arc::new(Mutex::new(self))).await
+        ActorRuntime::run(self)
     }
 }
 
 /// The main trait to implement on an [Actor] to enable it to handle messages.
 #[async_trait]
-pub trait Handler<M>: Actor
-where
-    M: Message,
-{
-    async fn handle(this: Arc<Mutex<Self>>, message: Box<M>) -> Result<M::Response, Error>;
+pub trait Handler<M>: Actor {
+    type Response;
+    async fn handle(this: Arc<Mutex<Self>>, message: Box<M>) -> Result<Self::Response, Error>;
 }
 
 /// A handle to a spawned actor. Obtained when calling `start` on an [Actor] and is used to send messages
@@ -89,9 +67,9 @@ where
     /// Sends a message to the actor and returns a [MessageRequest] that can
     /// be awaited. This method should be used when one needs a response from the
     /// actor.
-    pub fn send_wait<M>(&self, message: M) -> Result<MessageRequest<M::Response>, SendError<M>>
+    pub fn send_wait<M>(&self, message: M) -> Result<MessageRequest<A::Response>, SendError<M>>
     where
-        M: Message + Send,
+        M: Send,
         A: Handler<M> + Enveloper<A, M>,
     {
         if self.message_tx.is_full() || self.message_tx.is_disconnected() {
@@ -106,7 +84,6 @@ where
     /// error if the channel is full or disconnected.
     pub fn send<M>(&self, message: M) -> Result<(), SendError<M>>
     where
-        M: Message + Send + 'static,
         A: Handler<M> + Enveloper<A, M> + 'static,
     {
         if self.message_tx.is_full() || self.message_tx.is_disconnected() {
@@ -119,7 +96,7 @@ where
     /// Send a message ignoring any errors in the process. The true YOLO way to send messages.
     pub fn send_forget<M>(&self, message: M)
     where
-        M: Message + Send + 'static,
+        M: Send + 'static,
         A: Handler<M> + Enveloper<A, M> + 'static,
     {
         let _ = self.message_tx.send(A::pack(message, None));
@@ -128,102 +105,6 @@ where
     pub fn send_cmd(&self, cmd: ActorCommand) -> Result<(), SendError<ActorCommand>> {
         self.command_tx.send(cmd)
     }
-
-    pub fn recipient<M>(&self) -> Recipient<M>
-    where
-        M: Message + Send + 'static + Sync,
-        M::Response: Send,
-        A: Handler<M> + Send + 'static,
-    {
-        Recipient {
-            message_tx: Box::new(self.message_tx.clone()),
-            command_tx: self.command_tx.clone(),
-        }
-    }
-}
-
-/// The same as an [ActorHandle], but instead of being tied to a specific actor, it is only
-/// tied to the message type. Can be obtained from an [ActorHandle].
-///
-/// Useful for grouping different types of actors that can handle the same message.
-pub struct Recipient<M>
-where
-    M: Message,
-{
-    message_tx: Box<dyn MessageSender<M>>,
-    command_tx: Sender<ActorCommand>,
-}
-
-impl<M> Recipient<M>
-where
-    M: Message + Send,
-{
-    pub fn send_wait(&self, message: M) -> Result<MessageRequest<M::Response>, SendError<M>> {
-        self.message_tx.send_sync(message)
-    }
-
-    pub fn send(&self, message: M) -> Result<(), SendError<M>> {
-        self.message_tx.send(message)
-    }
-
-    pub fn send_forget(&self, message: M) {
-        let _ = self.message_tx.send(message);
-    }
-
-    pub fn send_cmd(&self, cmd: ActorCommand) -> Result<(), SendError<ActorCommand>> {
-        self.command_tx.send(cmd)
-    }
-}
-
-/// A helper trait used solely by [Recipient]'s message channel to erase the actor type.
-/// This is achieved by implementing it on [Sender<Envelope<A>].
-trait MessageSender<M>
-where
-    M: Message + Send,
-{
-    fn send_sync(&self, message: M) -> Result<MessageRequest<M::Response>, SendError<M>>;
-
-    fn send(&self, message: M) -> Result<(), SendError<M>>;
-}
-
-impl<A, M> MessageSender<M> for Sender<Envelope<A>>
-where
-    M: Message + Send + 'static,
-    M::Response: Send,
-    A: Actor + Handler<M> + Enveloper<A, M>,
-{
-    fn send(&self, message: M) -> Result<(), SendError<M>> {
-        if self.is_full() {
-            return Err(SendError(message));
-        }
-        let _ = self.send(A::pack(message, None));
-        Ok(())
-    }
-
-    fn send_sync(
-        &self,
-        message: M,
-    ) -> Result<MessageRequest<<M as Message>::Response>, SendError<M>> {
-        if self.is_full() {
-            return Err(SendError(message));
-        }
-        let (tx, rx) = oneshot::channel();
-        let _ = self.send(A::pack(message, Some(tx)));
-        Ok(MessageRequest { response_rx: rx })
-    }
-}
-
-impl<A, M> From<ActorHandle<A>> for Recipient<M>
-where
-    M: Message + Send + 'static + Sync,
-    M::Response: Send,
-    A: Actor + Handler<M> + Enveloper<A, M> + Send + 'static,
-{
-    /// Just calls `ActorHandler::recipient`, i.e. clones the underlying channels
-    /// into the recipient and boxes the message one.
-    fn from(handle: ActorHandle<A>) -> Self {
-        handle.recipient()
-    }
 }
 
 #[derive(Debug, PartialEq, Eq)]
@@ -245,6 +126,11 @@ pub enum Error {
     Warp(#[from] warp::Error),
 }
 
+pub enum SendErr<M> {
+    Full(M),
+    Closed(M),
+}
+
 #[derive(Debug)]
 pub enum ActorCommand {
     Stop,
@@ -270,18 +156,11 @@ mod tests {
         #[derive(Debug)]
         struct Bar {}
 
-        impl Message for Foo {
-            type Response = usize;
-        }
-
-        impl Message for Bar {
-            type Response = isize;
-        }
-
         impl Actor for Testor {}
 
         #[async_trait]
         impl Handler<Foo> for Testor {
+            type Response = usize;
             async fn handle(_: Arc<Mutex<Self>>, _: Box<Foo>) -> Result<usize, Error> {
                 println!("Handling Foo");
                 Ok(10)
@@ -290,6 +169,7 @@ mod tests {
 
         #[async_trait]
         impl Handler<Bar> for Testor {
+            type Response = isize;
             async fn handle(_: Arc<Mutex<Self>>, _: Box<Bar>) -> Result<isize, Error> {
                 for _ in 0..10_000 {
                     println!("Handling Bar");
@@ -301,7 +181,7 @@ mod tests {
         let mut res = 0;
         let mut res2 = 0;
 
-        let handle = Testor {}.start().await;
+        let handle = Testor {}.start();
         println!("HELLO WORLDS");
         for _ in 0..100 {
             res += handle.send_wait(Foo {}).unwrap().await.unwrap();
@@ -311,11 +191,9 @@ mod tests {
         handle.send(Foo {}).unwrap();
         handle.send_forget(Bar {});
 
-        let rec: Recipient<Foo> = handle.recipient();
-        res += rec.send_wait(Foo {}).unwrap().await.unwrap();
         handle.send_cmd(ActorCommand::Stop).unwrap();
 
-        assert_eq!(res, 1010);
+        assert_eq!(res, 1000);
         assert_eq!(res2, 1000);
     }
 
@@ -330,20 +208,13 @@ mod tests {
         #[derive(Debug)]
         struct Bar {}
 
-        impl Message for Foo {
-            type Response = usize;
-        }
-
-        impl Message for Bar {
-            type Response = isize;
-        }
-
         impl Actor for Testor {}
 
         static COUNT: AtomicUsize = AtomicUsize::new(0);
 
         #[async_trait]
         impl Handler<Foo> for Testor {
+            type Response = usize;
             async fn handle(_: Arc<Mutex<Testor>>, _: Box<Foo>) -> Result<usize, Error> {
                 println!("INCREMENTING COUNT FOO");
                 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
@@ -353,6 +224,7 @@ mod tests {
 
         #[async_trait]
         impl Handler<Bar> for Testor {
+            type Response = isize;
             async fn handle(_: Arc<Mutex<Testor>>, _: Box<Bar>) -> Result<isize, Error> {
                 println!("INCREMENTING COUNT BAR");
                 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
@@ -367,7 +239,7 @@ mod tests {
         let local_set = LocalSet::new();
 
         let task = async {
-            let handle = Testor {}.start().await;
+            let handle = Testor {}.start();
 
             handle.send_wait(Bar {}).unwrap().await.unwrap();
             handle.send(Foo {}).unwrap();

+ 29 - 28
src/message.rs

@@ -5,11 +5,6 @@ use async_trait::async_trait;
 use tokio::sync::oneshot;
 use tokio::sync::Mutex;
 
-/// Represents a message that can be sent to an actor. The response type is what the actor must return in its handler implementation.
-pub trait Message {
-    type Response;
-}
-
 /// Represents a type erased message that ultimately gets stored in an [Envelope]. We need this indirection so we can abstract away the concrete message
 /// type when creating an actor handle, otherwise we would only be able to send a single message type to the actor.
 #[async_trait]
@@ -17,27 +12,33 @@ pub trait ActorMessage<A: Actor> {
     async fn handle(self: Box<Self>, actor: Arc<Mutex<A>>);
 }
 
-/// Used by [ActorHandle][super::ActorHandle]s to pack [Message]s into [Envelope]s so we have a type erased message to send to the actor.
-pub trait Enveloper<A: Actor, M: Message> {
+/// Used by [ActorHandle][super::ActorHandle]s to pack messages into [Envelope]s so we have a type erased message to send to the actor.
+pub trait Enveloper<A, M>
+where
+    A: Handler<M>,
+{
     /// Wrap a message in an envelope with an optional response channel.
-    fn pack(message: M, tx: Option<oneshot::Sender<M::Response>>) -> Envelope<A>;
+    fn pack(message: M, tx: Option<oneshot::Sender<<A as Handler<M>>::Response>>) -> Envelope<A>;
 }
 
 /// A type erased wrapper for messages. This wrapper essentially enables us to send any message to the actor
 /// so long as it implements the necessary handler.
-pub struct Envelope<A: Actor> {
-    message: Box<dyn ActorMessage<A> + Send + Sync>,
+pub struct Envelope<A>
+where
+    A: Actor,
+{
+    message: Box<dyn ActorMessage<A> + Send>,
 }
 
 impl<A> Envelope<A>
 where
     A: Actor,
 {
-    pub fn new<M>(message: M, tx: Option<oneshot::Sender<M::Response>>) -> Self
+    pub fn new<M>(message: M, tx: Option<oneshot::Sender<A::Response>>) -> Self
     where
         A: Handler<M> + Send + 'static,
-        M: Message + Send + 'static + Sync,
-        M::Response: Send,
+        A::Response: Send,
+        M: Send + 'static,
     {
         Self {
             message: Box::new(EnvelopeInner {
@@ -48,13 +49,6 @@ where
     }
 }
 
-/// The inner parts of the [Envelope] containing the actual message as well as an optional
-/// response channel.
-struct EnvelopeInner<M: Message> {
-    message: Box<M>,
-    tx: Option<oneshot::Sender<M::Response>>,
-}
-
 #[async_trait]
 impl<A> ActorMessage<A> for Envelope<A>
 where
@@ -65,12 +59,19 @@ where
     }
 }
 
+/// The inner parts of the [Envelope] containing the actual message as well as an optional
+/// response channel.
+struct EnvelopeInner<M, R> {
+    message: Box<M>,
+    tx: Option<oneshot::Sender<R>>,
+}
+
 #[async_trait]
-impl<A, M> ActorMessage<A> for EnvelopeInner<M>
+impl<A, M> ActorMessage<A> for EnvelopeInner<M, <A as Handler<M>>::Response>
 where
-    M: Message + Send + Sync,
-    M::Response: Send,
-    A: Actor + Handler<M> + Send + 'static,
+    A: Handler<M> + Send + 'static,
+    A::Response: Send,
+    M: Send,
 {
     async fn handle(self: Box<Self>, actor: Arc<Mutex<A>>) {
         let result = A::handle(actor, self.message).await;
@@ -82,11 +83,11 @@ where
 
 impl<A, M> Enveloper<A, M> for A
 where
-    A: Actor + Handler<M> + Send + 'static,
-    M: Message + Send + 'static + Sync,
-    M::Response: Send,
+    A: Handler<M> + Send + 'static,
+    A::Response: Send,
+    M: Send + Sync + 'static,
 {
-    fn pack(message: M, tx: Option<oneshot::Sender<M::Response>>) -> Envelope<A> {
+    fn pack(message: M, tx: Option<oneshot::Sender<<A as Handler<M>>::Response>>) -> Envelope<A> {
         Envelope::new(message, tx)
     }
 }

+ 7 - 15
src/runtime.rs

@@ -2,7 +2,6 @@ use crate::{
     message::ActorMessage, Actor, ActorCommand, ActorHandle, Envelope, Error,
     DEFAULT_CHANNEL_CAPACITY,
 };
-use async_trait::async_trait;
 use flume::{r#async::RecvStream, Receiver};
 use futures::{Future, StreamExt};
 use std::{
@@ -15,7 +14,6 @@ use tokio::sync::Mutex;
 
 pub const QUEUE_CAPACITY: usize = 128;
 
-#[async_trait]
 pub trait Runtime<A>
 where
     A: Actor + Send + 'static,
@@ -30,26 +28,21 @@ where
 
     fn at_capacity(&self) -> bool;
 
-    async fn run(actor: Arc<Mutex<A>>) -> ActorHandle<A> {
+    fn run(actor: A) -> ActorHandle<A> {
         let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
         let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
         tokio::spawn(ActorRuntime::new(actor, command_rx, message_rx));
         ActorHandle::new(message_tx, command_tx)
     }
 
-    fn process_commands(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
+    fn process_commands(&mut self, cx: &mut Context<'_>) -> Result<Option<ActorCommand>, Error> {
         match self.command_stream().poll_next_unpin(cx) {
-            Poll::Ready(Some(command)) => match command {
-                ActorCommand::Stop => {
-                    println!("Actor stopping");
-                    Ok(()) // TODO drain the queue and all that graceful stuff
-                }
-            },
+            Poll::Ready(Some(command)) => Ok(Some(command)),
             Poll::Ready(None) => {
                 println!("Command channel closed, ungracefully stopping actor");
                 Err(Error::ActorChannelClosed)
             }
-            Poll::Pending => Ok(()),
+            Poll::Pending => Ok(None),
         }
     }
 
@@ -92,12 +85,11 @@ where
     }
 }
 
-#[async_trait]
 impl<A> Runtime<A> for ActorRuntime<A>
 where
     A: Actor + Send + 'static,
 {
-    async fn run(actor: Arc<Mutex<A>>) -> ActorHandle<A> {
+    fn run(actor: A) -> ActorHandle<A> {
         let (message_tx, message_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
         let (command_tx, command_rx) = flume::bounded(DEFAULT_CHANNEL_CAPACITY);
         tokio::spawn(ActorRuntime::new(actor, command_rx, message_rx));
@@ -148,13 +140,13 @@ where
     A: Actor + 'static + Send,
 {
     pub fn new(
-        actor: Arc<Mutex<A>>,
+        actor: A,
         command_rx: Receiver<ActorCommand>,
         message_rx: Receiver<Envelope<A>>,
     ) -> Self {
         println!("Building default runtime");
         Self {
-            actor,
+            actor: Arc::new(Mutex::new(actor)),
             command_stream: command_rx.into_stream(),
             message_stream: message_rx.into_stream(),
             process_queue: VecDeque::with_capacity(QUEUE_CAPACITY),

+ 120 - 142
src/ws.rs

@@ -1,164 +1,104 @@
 use crate::{
     message::Envelope,
     runtime::{ActorJob, Runtime, QUEUE_CAPACITY},
-    Actor, ActorCommand, ActorHandle, ActorStatus, Error, Handler, Hello,
+    Actor, ActorCommand, ActorHandle, Error, Handler,
 };
-use async_trait::async_trait;
 use flume::{r#async::RecvStream, Receiver};
-use futures::{
-    stream::{SplitSink, SplitStream},
-    Future, FutureExt, SinkExt, StreamExt,
-};
+use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
 use std::{
     collections::VecDeque,
+    fmt::Display,
+    marker::PhantomData,
     pin::Pin,
     sync::atomic::AtomicUsize,
     task::{Context, Poll},
-    time::Duration,
 };
 use std::{sync::Arc, task::ready};
 use tokio::sync::Mutex;
-use warp::ws::{Message, WebSocket};
 
 const WS_QUEUE_SIZE: usize = 128;
 
-pub struct WebsocketActor {
-    websocket: Option<WebSocket>,
-    hello: ActorHandle<Hello>,
-}
-
-impl WebsocketActor {
-    pub fn new(ws: WebSocket, handle: ActorHandle<Hello>) -> Self {
-        Self {
-            websocket: Some(ws),
-            hello: handle,
-        }
-    }
-}
-
-#[async_trait]
-impl Actor for WebsocketActor {
-    async fn start(self) -> ActorHandle<Self> {
-        WebsocketRuntime::run(Arc::new(Mutex::new(self))).await
-    }
-}
-
-type WsFuture = Pin<Box<dyn Future<Output = Result<Option<Message>, Error>> + Send>>;
-
-struct WebsocketJob {
-    message: Option<Box<Message>>,
-    future: Option<WsFuture>,
+/// Represents an actor that can get access to a websocket stream and sink.
+///
+/// A websocket actor receives messages via the stream and processes them with
+/// its [Handler] implementation. The handler implementation should always return an
+/// `Option<M>` where M is the type used when implementing this trait. A handler that returns
+/// `None` will not forward any response to the sink. If the handler returns `Some(M)` it will
+/// be forwarded to the sink.
+pub trait WsActor<M, Str, Sin>
+where
+    Str: Stream<Item = Result<M, Self::Error>>,
+    Sin: Sink<M>,
+{
+    /// The error type of the underlying websocket implementation.
+    type Error: Display;
+    fn websocket(&mut self) -> (Sin, Str);
 }
 
-impl WebsocketJob {
-    pub fn new(message: Message) -> Self {
-        Self {
-            message: Some(Box::new(message)),
-            future: None,
-        }
-    }
-
-    fn poll(
-        &mut self,
-        actor: Arc<Mutex<WebsocketActor>>,
-        cx: &mut std::task::Context<'_>,
-    ) -> Poll<Result<Option<Message>, warp::Error>> {
-        let message = self.message.take();
-
-        match message {
-            Some(message) => {
-                let fut = WebsocketActor::handle(actor, message);
-                self.future = Some(fut);
-                let result = ready!(self.future.as_mut().unwrap().as_mut().poll(cx));
-                match result {
-                    Ok(response) => Poll::Ready(Ok(response)),
-                    Err(e) => {
-                        println!("Shit's fucked son {e}");
-                        Poll::Ready(Ok(None))
-                    }
-                }
-            }
-            None => match self.future {
-                Some(ref mut fut) => match fut.as_mut().poll(cx) {
-                    Poll::Ready(result) => match result {
-                        Ok(response) => Poll::Ready(Ok(response)),
-                        Err(e) => {
-                            println!("Shit's fucked son {e}");
-                            Poll::Ready(Ok(None))
-                        }
-                    },
-                    Poll::Pending => {
-                        PENDING.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
-                        // println!("Websocket Future pending - COUNT {PENDING:?}");
-                        Poll::Pending
-                    }
-                },
-                None => panic!("Impossibru"),
-            },
-        }
-    }
-}
-
-static PROCESSED: AtomicUsize = AtomicUsize::new(0);
-static PENDING: AtomicUsize = AtomicUsize::new(0);
-
-pub struct WebsocketRuntime {
-    actor: Arc<Mutex<WebsocketActor>>,
-
-    status: ActorStatus,
+pub struct WebsocketRuntime<A, M, Str, Sin>
+where
+    A: Actor + WsActor<M, Str, Sin> + Handler<M> + 'static,
+    Str: Stream<Item = Result<M, A::Error>>,
+    Sin: Sink<M>,
+{
+    actor: Arc<Mutex<A>>,
 
     /// The receiving end of the websocket
-    ws_stream: SplitStream<WebSocket>,
+    ws_stream: Str,
 
     /// The sending end of the websocket
-    ws_sink: SplitSink<WebSocket, Message>,
+    ws_sink: Sin,
 
     /// Actor message receiver
-    message_stream: RecvStream<'static, Envelope<WebsocketActor>>,
+    message_stream: RecvStream<'static, Envelope<A>>,
 
     /// Actor command receiver
     command_stream: RecvStream<'static, ActorCommand>,
 
     /// Actor messages currently being processed
-    process_queue: VecDeque<ActorJob<WebsocketActor>>,
+    process_queue: VecDeque<ActorJob<A>>,
 
     /// Received, but not yet processed websocket messages
-    response_queue: VecDeque<WebsocketJob>,
+    response_queue: VecDeque<WebsocketJob<A, M>>,
 }
 
-impl WebsocketRuntime {
-    pub async fn new(
-        actor: Arc<Mutex<WebsocketActor>>,
+impl<A, M, Str, Sin> WebsocketRuntime<A, M, Str, Sin>
+where
+    Str: Stream<Item = Result<M, A::Error>>,
+    Sin: Sink<M>,
+    A: Actor + WsActor<M, Str, Sin> + Send + 'static + Handler<M>,
+{
+    pub fn new(
+        mut actor: A,
         command_rx: Receiver<ActorCommand>,
-        message_rx: Receiver<Envelope<WebsocketActor>>,
+        message_rx: Receiver<Envelope<A>>,
     ) -> Self {
-        let (ws_sink, ws_stream) = actor
-            .lock()
-            .await
-            .websocket
-            .take()
-            .expect("Websocket runtime already started")
-            .split();
+        let (ws_sink, ws_stream) = actor.websocket();
 
         Self {
-            actor,
+            actor: Arc::new(Mutex::new(actor)),
             ws_sink,
             ws_stream,
             message_stream: message_rx.into_stream(),
             command_stream: command_rx.into_stream(),
             response_queue: VecDeque::new(),
             process_queue: VecDeque::new(),
-            status: ActorStatus::Starting,
         }
     }
 }
 
-impl Future for WebsocketRuntime {
+impl<A, M, Str, Sin> Future for WebsocketRuntime<A, M, Str, Sin>
+where
+    Self: Runtime<A>,
+    Str: Stream<Item = Result<M, A::Error>> + Unpin,
+    Sin: Sink<M> + Unpin,
+    A: Actor + WsActor<M, Str, Sin> + Handler<M, Response = Option<M>> + Send + Unpin + 'static,
+{
     type Output = Result<(), Error>;
 
-    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
         let actor = self.actor();
-        let mut this = self.as_mut();
+        let this = self.get_mut();
 
         this.process_commands(cx)?;
 
@@ -190,6 +130,7 @@ impl Future for WebsocketRuntime {
                             let mut feed = Pin::new(feed);
                             let _ = feed.as_mut().poll(cx);
                         }
+
                         PROCESSED.fetch_add(1, std::sync::atomic::Ordering::Acquire);
                         this.response_queue.swap_remove_front(idx);
                     }
@@ -209,24 +150,29 @@ impl Future for WebsocketRuntime {
             this.response_queue.len(),
         );
 
-        let _ = this.ws_sink.flush().poll_unpin(cx)?;
+        let _ = this.ws_sink.flush().poll_unpin(cx);
 
         cx.waker().wake_by_ref();
         Poll::Pending
     }
 }
 
-#[async_trait]
-impl Runtime<WebsocketActor> for WebsocketRuntime {
-    async fn run(actor: Arc<Mutex<WebsocketActor>>) -> ActorHandle<WebsocketActor> {
+impl<A, M, Str, Sin> Runtime<A> for WebsocketRuntime<A, M, Str, Sin>
+where
+    Str: Stream<Item = Result<M, A::Error>> + Unpin + Send + 'static,
+    Sin: Sink<M> + Unpin + Send + 'static,
+    A: Actor + WsActor<M, Str, Sin> + Send + 'static + Handler<M, Response = Option<M>> + Unpin,
+    M: Send + 'static,
+{
+    fn run(actor: A) -> ActorHandle<A> {
         let (message_tx, message_rx) = flume::unbounded();
         let (command_tx, command_rx) = flume::unbounded();
-        tokio::spawn(WebsocketRuntime::new(actor, command_rx, message_rx).await);
+        tokio::spawn(WebsocketRuntime::new(actor, command_rx, message_rx));
         ActorHandle::new(message_tx, command_tx)
     }
 
     #[inline]
-    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<WebsocketActor>> {
+    fn processing_queue(&mut self) -> &mut VecDeque<ActorJob<A>> {
         &mut self.process_queue
     }
 
@@ -236,12 +182,12 @@ impl Runtime<WebsocketActor> for WebsocketRuntime {
     }
 
     #[inline]
-    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<WebsocketActor>> {
+    fn message_stream(&mut self) -> &mut RecvStream<'static, Envelope<A>> {
         &mut self.message_stream
     }
 
     #[inline]
-    fn actor(&self) -> Arc<Mutex<WebsocketActor>> {
+    fn actor(&self) -> Arc<Mutex<A>> {
         self.actor.clone()
     }
 
@@ -251,31 +197,63 @@ impl Runtime<WebsocketActor> for WebsocketRuntime {
     }
 }
 
-impl crate::Message for Message {
-    type Response = Option<Message>;
+struct WebsocketJob<A, M>
+where
+    A: Handler<M>,
+{
+    message: Option<Box<M>>,
+    future: Option<WsFuture<M, A>>,
+    __a: PhantomData<A>,
 }
 
-#[async_trait]
-impl Handler<Message> for WebsocketActor {
-    async fn handle(
-        this: Arc<Mutex<Self>>,
-        message: Box<Message>,
-    ) -> Result<<Message as crate::message::Message>::Response, crate::Error> {
-        //let mut act = this.lock().await;
-        if message.is_text() {
-            this.lock()
-                .await
-                .hello
-                .send(crate::Msg {
-                    content: message.to_str().unwrap().to_owned(),
-                })
-                .unwrap_or_else(|e| println!("{e}"));
-            //        println!("Actor retreived lock and sent message got response {res}");
-            tokio::time::sleep(Duration::from_micros(1)).await;
-            //act.wait().await;
-            Ok(Some(*message.clone()))
-        } else {
-            Ok(None)
+impl<A, M> WebsocketJob<A, M>
+where
+    A: Handler<M> + 'static,
+{
+    pub fn new(message: M) -> Self {
+        Self {
+            message: Some(Box::new(message)),
+            future: None,
+            __a: PhantomData,
+        }
+    }
+
+    fn poll(
+        &mut self,
+        actor: Arc<Mutex<A>>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<A::Response, Error>> {
+        let message = self.message.take();
+
+        match message {
+            Some(message) => {
+                let fut = A::handle(actor, message);
+                self.future = Some(fut);
+                let result = ready!(self.future.as_mut().unwrap().as_mut().poll(cx));
+                match result {
+                    Ok(response) => Poll::Ready(Ok(response)),
+                    Err(e) => {
+                        println!("Shit's fucked son {e}");
+                        Poll::Ready(Err(e))
+                    }
+                }
+            }
+            None => {
+                let Some(ref mut fut) = self.future else { panic!("Impossibru") };
+                let result = ready!(fut.as_mut().poll(cx));
+                match result {
+                    Ok(response) => Poll::Ready(Ok(response)),
+                    Err(e) => {
+                        println!("Shit's fucked son {e}");
+                        Poll::Ready(Err(e))
+                    }
+                }
+            }
         }
     }
 }
+
+type WsFuture<M, A> =
+    Pin<Box<dyn Future<Output = Result<<A as Handler<M>>::Response, Error>> + Send>>;
+
+static PROCESSED: AtomicUsize = AtomicUsize::new(0);

+ 77 - 4
tests/websocket.rs

@@ -1,20 +1,92 @@
-use drama::ws::WebsocketActor;
-use drama::{Actor, ActorHandle, Hello};
+use async_trait::async_trait;
+use drama::runtime::Runtime;
+use drama::ws::{WebsocketRuntime, WsActor};
+use drama::{Actor, ActorHandle, Error, Handler};
+use futures::stream::{SplitSink, SplitStream};
+use futures::StreamExt;
 use std::collections::HashMap;
 use std::sync::atomic::AtomicUsize;
 use std::sync::{Arc, RwLock};
+use tokio::sync::Mutex;
+use warp::ws::{Message, WebSocket};
 use warp::Filter;
 
 type Arbiter = Arc<RwLock<HashMap<usize, ActorHandle<WebsocketActor>>>>;
 
 static ID: AtomicUsize = AtomicUsize::new(0);
 
+struct WebsocketActor {
+    websocket: Option<WebSocket>,
+    hello: ActorHandle<Hello>,
+}
+
+impl WebsocketActor {
+    fn new(ws: WebSocket, handle: ActorHandle<Hello>) -> Self {
+        Self {
+            websocket: Some(ws),
+            hello: handle,
+        }
+    }
+}
+
+impl Actor for WebsocketActor {
+    fn start(self) -> ActorHandle<Self> {
+        WebsocketRuntime::run(self)
+    }
+}
+
+impl WsActor<Message, SplitStream<WebSocket>, SplitSink<WebSocket, Message>> for WebsocketActor {
+    type Error = warp::Error;
+    fn websocket(&mut self) -> (SplitSink<WebSocket, Message>, SplitStream<WebSocket>) {
+        self.websocket
+            .take()
+            .expect("Websocket already split")
+            .split()
+    }
+}
+
+#[async_trait]
+impl Handler<Message> for WebsocketActor {
+    type Response = Option<Message>;
+    async fn handle(
+        this: Arc<Mutex<Self>>,
+        message: Box<Message>,
+    ) -> Result<Self::Response, Error> {
+        this.lock()
+            .await
+            .hello
+            .send(crate::Msg {
+                _content: message.to_str().unwrap().to_owned(),
+            })
+            .unwrap_or_else(|e| println!("{e}"));
+
+        Ok(Some(*message.clone()))
+    }
+}
+
+struct Hello {}
+
+impl Actor for Hello {}
+
+struct Msg {
+    pub _content: String,
+}
+
+#[async_trait]
+impl Handler<Msg> for Hello {
+    type Response = usize;
+    async fn handle(_: Arc<Mutex<Self>>, _: Box<Msg>) -> Result<usize, Error> {
+        println!("Handling message Hello");
+        Ok(10)
+    }
+}
+
 #[tokio::main]
 async fn main() {
     let pool = Arc::new(RwLock::new(HashMap::new()));
     let pool = warp::any().map(move || pool.clone());
 
-    let hello = Hello {}.start().await;
+    let hello = Hello {}.start();
     let hello = warp::any().map(move || hello.clone());
     // GET /chat -> websocket upgrade
     let chat = warp::path("chat")
@@ -27,7 +99,7 @@ async fn main() {
                 // This will call our function if the handshake succeeds.
                 ws.on_upgrade(|socket| async move {
                     let actor = WebsocketActor::new(socket, hello);
-                    let handle = actor.start().await;
+                    let handle = actor.start();
                     let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                     println!("Adding actor {id}");
                     pool.write().unwrap().insert(id, handle);
@@ -73,6 +145,7 @@ static INDEX_HTML: &str = r#"<!DOCTYPE html>
         };
 
         ws.onmessage = function(msg) {
+            console.log(msg)
             num += 1;
             message(msg.data);
         };