Browse Source

init commit

Josip Benko-Đaković 1 year ago
commit
af6d036ff6
8 changed files with 686 additions and 0 deletions
  1. 2 0
      .gitignore
  2. 22 0
      Cargo.toml
  3. 14 0
      src/debug.rs
  4. 128 0
      src/lib.rs
  5. 133 0
      src/message.rs
  6. 112 0
      src/runtime.rs
  7. 131 0
      src/ws.rs
  8. 144 0
      tests/websocket.rs

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+/target
+/Cargo.lock

+ 22 - 0
Cargo.toml

@@ -0,0 +1,22 @@
+[package]
+edition = "2021"
+name = "actors"
+version = "0.1.0"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[[bin]]
+name = "test-ws"
+path = "tests/websocket.rs"
+
+[dependencies]
+futures = "0.3.28"
+pin-project = "1.1.0"
+thiserror = "1.0.40"
+tokio = { version = "1.28.2", features = [
+  "macros",
+  "rt-multi-thread",
+  "sync",
+  "time",
+] }
+warp = "0.3.5"

+ 14 - 0
src/debug.rs

@@ -0,0 +1,14 @@
+//! So we don't polute the main lib with debug impls
+
+use crate::{Actor, Envelope};
+
+impl<A> std::fmt::Debug for Envelope<A>
+where
+    A: Actor,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("Envelope")
+            .field("message", &"{{..}}")
+            .finish()
+    }
+}

+ 128 - 0
src/lib.rs

@@ -0,0 +1,128 @@
+use crate::runtime::{DefaultActorRuntime, Runtime};
+use message::{Envelope, Message, MessagePacker, MessageRequest};
+use std::fmt::Debug;
+use tokio::sync::{mpsc::Sender, oneshot};
+
+pub mod debug;
+pub mod message;
+pub mod runtime;
+pub mod ws;
+
+pub trait Actor {
+    fn start(self) -> ActorHandle<Self>
+    where
+        Self: Sized + Send + 'static,
+    {
+        println!("Starting actor");
+        DefaultActorRuntime::start(self)
+    }
+}
+
+pub struct ActorHandle<A>
+where
+    A: Actor,
+{
+    message_tx: Sender<Envelope<A>>,
+    command_tx: Sender<ActorCommand>,
+}
+
+impl<A> ActorHandle<A>
+where
+    A: Actor,
+{
+    pub async fn send_sync<M>(&self, message: M) -> Result<M::Response, Error>
+    where
+        M: Message + Send + 'static,
+        A: Handler<M> + MessagePacker<A, M>,
+    {
+        let (tx, rx) = oneshot::channel();
+        let packed = A::pack(message, Some(tx));
+        self.message_tx.send(packed).await.unwrap(); // TODO
+        MessageRequest { response_rx: rx }.await
+    }
+
+    pub async fn send<M>(&self, message: M) -> Result<(), Error>
+    where
+        M: Message + Send + 'static,
+        A: Handler<M> + MessagePacker<A, M> + 'static,
+    {
+        let packed = A::pack(message, None);
+        self.message_tx.send(packed).await.unwrap(); // TODO
+        Ok(())
+    }
+
+    pub async fn send_cmd(&self, cmd: ActorCommand) -> Result<(), Error> {
+        self.command_tx.send(cmd).await.unwrap(); // TODO
+        Ok(())
+    }
+}
+
+pub trait Handler<M>: Actor
+where
+    M: Message,
+{
+    fn handle(&mut self, message: M) -> Result<M::Response, Error>;
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+    #[error("Actor channel closed")]
+    ActorChannelClosed,
+    #[error("Channel closed: {0}")]
+    ChannelClosed(#[from] oneshot::error::TryRecvError),
+}
+
+#[derive(Debug)]
+pub enum ActorCommand {
+    Stop,
+}
+
+#[cfg(test)]
+mod tests {
+
+    use super::*;
+
+    #[tokio::test]
+    async fn it_works() {
+        #[derive(Debug)]
+        struct Testor {}
+
+        #[derive(Debug)]
+        struct Foo {}
+
+        #[derive(Debug)]
+        struct Bar {}
+
+        impl Message for Foo {
+            type Response = usize;
+        }
+
+        impl Message for Bar {
+            type Response = isize;
+        }
+
+        impl Actor for Testor {}
+
+        impl Handler<Foo> for Testor {
+            fn handle(&mut self, _: Foo) -> Result<usize, Error> {
+                Ok(10)
+            }
+        }
+
+        impl Handler<Bar> for Testor {
+            fn handle(&mut self, _: Bar) -> Result<isize, Error> {
+                Ok(10)
+            }
+        }
+
+        let handle = Testor {}.start();
+
+        let res = handle.send_sync(Foo {}).await.unwrap();
+        let res2 = handle.send_sync(Bar {}).await.unwrap();
+
+        handle.send_cmd(ActorCommand::Stop).await.unwrap();
+
+        assert_eq!(res, 10);
+        assert_eq!(res2, 10);
+    }
+}

+ 133 - 0
src/message.rs

@@ -0,0 +1,133 @@
+use crate::{runtime::DefaultActorRuntime, Actor, Error, Handler};
+use tokio::sync::oneshot;
+
+/// Represents a message that can be sent to an actor. The response type is what the actor must return in its handler implementation.
+pub trait Message {
+    type Response;
+}
+
+/// Represents a type erased message that ultimately gets stored in an [Envelope]. We need this indirection so we can abstract away the concrete message
+/// type when creating an actor handle - otherwise we would only be able to send a single message type to the actor.
+pub trait PackedMessage<A: Actor> {
+    fn handle(&mut self, actor: &mut A);
+}
+
+/// Used by [ActorHandle][super::ActorHandle]s to pack [Message]s into [Envelope]s so we have a type erased message to send to the actor.
+pub trait MessagePacker<A: Actor, M: Message + Send + 'static> {
+    fn pack(message: M, tx: Option<oneshot::Sender<M::Response>>) -> Envelope<A>;
+}
+
+/// A type erased wrapper for messages. This wrapper essentially enables us to send any message to the actor
+/// so long as it implements the necessary handler.
+pub struct Envelope<A: Actor> {
+    message: Box<dyn PackedMessage<A> + Send>,
+}
+
+impl<A> Envelope<A>
+where
+    A: Actor,
+{
+    pub fn new<M>(message: M, tx: Option<oneshot::Sender<M::Response>>) -> Self
+    where
+        A: Handler<M>,
+        M: Message + Send + 'static,
+        M::Response: Send,
+    {
+        Self {
+            message: Box::new(EnvelopeInner {
+                message: Some(message),
+                tx,
+            }),
+        }
+    }
+}
+
+/// The inner parts of the [Envelope] containing the actual message as well as an optional
+/// response channel.
+struct EnvelopeInner<M: Message + Send> {
+    message: Option<M>,
+    tx: Option<oneshot::Sender<M::Response>>,
+}
+
+impl<A> PackedMessage<A> for Envelope<A>
+where
+    A: Actor,
+{
+    fn handle(&mut self, actor: &mut A) {
+        self.message.handle(actor)
+    }
+}
+
+impl<A, M> PackedMessage<A> for EnvelopeInner<M>
+where
+    M: Message + Send + 'static,
+    M::Response: Send,
+    A: Actor + Handler<M>,
+{
+    fn handle(&mut self, actor: &mut A) {
+        if let Some(message) = self.message.take() {
+            match actor.handle(message) {
+                Ok(result) => {
+                    if let Some(res_tx) = self.tx.take() {
+                        // TODO
+                        let _ = res_tx.send(result);
+                    }
+                }
+                Err(_) => todo!(),
+            };
+        }
+    }
+}
+
+impl<A, M> MessagePacker<A, M> for DefaultActorRuntime<A>
+where
+    A: Actor + Handler<M>,
+    M: Message + Send + 'static,
+    M::Response: Send,
+{
+    fn pack(message: M, tx: Option<oneshot::Sender<<M as Message>::Response>>) -> Envelope<A> {
+        A::pack(message, tx)
+    }
+}
+
+impl<A, M> MessagePacker<A, M> for A
+where
+    A: Actor + Handler<M>,
+    M: Message + Send + 'static,
+    M::Response: Send,
+{
+    fn pack(message: M, tx: Option<oneshot::Sender<M::Response>>) -> Envelope<A> {
+        Envelope::new(message, tx)
+    }
+}
+
+pub struct MessageRequest<R> {
+    pub response_rx: oneshot::Receiver<R>,
+}
+
+impl<R> std::future::Future for MessageRequest<R> {
+    type Output = Result<R, Error>;
+
+    fn poll(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Self::Output> {
+        println!("Awaiting response");
+        match self.as_mut().response_rx.try_recv() {
+            Ok(msg) => {
+                println!("Future ready");
+                std::task::Poll::Ready(Ok(msg))
+            }
+            Err(e) => {
+                println!("Future pending {e}");
+                match e {
+                    oneshot::error::TryRecvError::Empty => {
+                        cx.waker().wake_by_ref();
+                        std::task::Poll::Pending
+                    }
+                    oneshot::error::TryRecvError::Closed => std::task::Poll::Ready(Err(e.into())),
+                }
+            }
+        }
+    }
+}

+ 112 - 0
src/runtime.rs

@@ -0,0 +1,112 @@
+use crate::{message::PackedMessage, Actor, ActorCommand, ActorHandle, Envelope, Error};
+use futures::Future;
+use pin_project::pin_project;
+use std::{collections::VecDeque, task::Poll};
+use tokio::sync::mpsc::Receiver;
+
+pub trait Runtime<A> {
+    fn start(actor: A) -> ActorHandle<A>
+    where
+        A: Actor + Send + 'static,
+    {
+        let (tx, rx) = tokio::sync::mpsc::channel(100);
+        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::channel(100);
+        let rt = DefaultActorRuntime::new(actor, cmd_rx, rx);
+        tokio::spawn(rt);
+        ActorHandle {
+            message_tx: tx,
+            command_tx: cmd_tx,
+        }
+    }
+}
+
+#[pin_project]
+pub struct DefaultActorRuntime<A>
+where
+    A: Actor,
+{
+    actor: A,
+    command_rx: Receiver<ActorCommand>,
+    message_rx: Receiver<Envelope<A>>,
+    message_queue: VecDeque<Envelope<A>>,
+}
+
+impl<A> Runtime<A> for DefaultActorRuntime<A> where A: Actor {}
+
+impl<A> Future for DefaultActorRuntime<A>
+where
+    A: Actor,
+{
+    type Output = Result<(), Error>;
+
+    fn poll(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Self::Output> {
+        loop {
+            // Poll command receiver
+            match self.command_rx.poll_recv(cx) {
+                std::task::Poll::Ready(Some(message)) => match message {
+                    ActorCommand::Stop => {
+                        println!("Actor stopping");
+                        break Poll::Ready(Ok(())); // TODO drain the queue and all that graceful stuff
+                    }
+                },
+                std::task::Poll::Ready(None) => {
+                    println!("Command channel closed, ungracefully stopping actor");
+                    break Poll::Ready(Err(Error::ActorChannelClosed));
+                }
+                std::task::Poll::Pending => {}
+            };
+
+            // Process all messages
+            while let Some(mut message) = self.message_queue.pop_front() {
+                message.handle(&mut self.actor)
+            }
+
+            // Poll message receiver and continue to process if anything comes up
+            let mut new_messages = false;
+            while let Poll::Ready(Some(message)) = self.message_rx.poll_recv(cx) {
+                self.message_queue.push_back(message);
+                new_messages = true;
+            }
+
+            match self.message_rx.poll_recv(cx) {
+                std::task::Poll::Ready(Some(message)) => {
+                    self.message_queue.push_back(message);
+                    continue;
+                }
+                std::task::Poll::Ready(None) => {
+                    println!("Message channel closed, ungracefully stopping actor");
+                    break Poll::Ready(Err(Error::ActorChannelClosed));
+                }
+                std::task::Poll::Pending => {
+                    if new_messages {
+                        continue;
+                    }
+                }
+            };
+
+            return std::task::Poll::Pending;
+        }
+    }
+}
+
+impl<A> DefaultActorRuntime<A>
+where
+    A: Actor,
+{
+    pub fn new(
+        actor: A,
+        command_rx: Receiver<ActorCommand>,
+        message_rx: Receiver<Envelope<A>>,
+    ) -> Self {
+        println!("Building default runtime");
+        Self {
+            actor,
+            command_rx,
+            message_rx,
+            message_queue: VecDeque::new(),
+        }
+    }
+}

+ 131 - 0
src/ws.rs

@@ -0,0 +1,131 @@
+use crate::{
+    message::{Envelope, PackedMessage},
+    runtime::Runtime,
+    Actor, ActorCommand, ActorHandle, Handler,
+};
+use futures::{SinkExt, StreamExt, TryFutureExt};
+use tokio::{select, sync::mpsc::Receiver};
+use warp::ws::WebSocket;
+
+pub struct WebsocketActor {
+    websocket: Option<WebSocket>,
+}
+
+impl Actor for WebsocketActor {
+    fn start(self) -> ActorHandle<Self>
+    where
+        Self: Sized + Send + 'static,
+    {
+        println!("Starting websocket actor");
+        WebsocketRuntime::start(self)
+    }
+}
+
+impl WebsocketActor {
+    pub fn new(ws: WebSocket) -> Self {
+        Self {
+            websocket: Some(ws),
+        }
+    }
+}
+
+pub struct WebsocketRuntime {
+    actor: WebsocketActor,
+    message_rx: Receiver<Envelope<WebsocketActor>>,
+    command_rx: Receiver<ActorCommand>,
+}
+
+impl WebsocketRuntime {
+    pub fn new(
+        actor: WebsocketActor,
+        command_rx: Receiver<ActorCommand>,
+        message_rx: Receiver<Envelope<WebsocketActor>>,
+    ) -> Self {
+        Self {
+            actor,
+            message_rx,
+            command_rx,
+        }
+    }
+}
+
+impl WebsocketRuntime {
+    pub async fn run(mut self) {
+        let (mut ws_sender, mut ws_receiver) = self
+            .actor
+            .websocket
+            .take()
+            .expect("Websocket runtime already started")
+            .split();
+
+        loop {
+            select! {
+                // Handle any pending commands
+                Some(msg) = self.command_rx.recv() => {
+                    match msg {
+                        ActorCommand::Stop => {
+                            println!("Actor stopping");
+                            break;
+                        }
+                    }
+                }
+                // Handle any in-process messages
+                Some(mut message) = self.message_rx.recv() => {
+                    println!("Processing Message");
+                    message.handle(&mut self.actor)
+                }
+                // Handle any messages from the websocket
+                Some(message) = ws_receiver.next() => {
+                    match message {
+                        Ok(message) => {
+                            if let Some(res) = self.actor.handle(message).unwrap() {// TODO
+                                ws_sender.send(res)
+                                    .unwrap_or_else(|e| {
+                                        eprintln!("websocket send error: {}", e);
+                                    })
+                                    .await;
+                            }
+                        },
+                        Err(e) => {
+                            eprintln!("WS error occurred {e}")
+                        },
+                    }
+                }
+                else => {
+                    println!("No messages")
+                }
+            }
+        }
+    }
+}
+
+impl Runtime<WebsocketActor> for WebsocketRuntime {
+    fn start(actor: WebsocketActor) -> ActorHandle<WebsocketActor> {
+        let (tx, rx) = tokio::sync::mpsc::channel(100);
+        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::channel(100);
+        let rt = WebsocketRuntime::new(actor, cmd_rx, rx);
+        tokio::spawn(rt.run());
+        ActorHandle {
+            message_tx: tx,
+            command_tx: cmd_tx,
+        }
+    }
+}
+
+impl crate::Message for warp::ws::Message {
+    type Response = Option<warp::ws::Message>;
+}
+
+impl Handler<warp::ws::Message> for WebsocketActor {
+    fn handle(
+        &mut self,
+        message: warp::ws::Message,
+    ) -> Result<<warp::ws::Message as crate::message::Message>::Response, crate::Error> {
+        println!("Actor received message {message:?}");
+        if message.is_text() {
+            Ok(Some(message))
+        } else {
+            Ok(None)
+        }
+    }
+}

+ 144 - 0
tests/websocket.rs

@@ -0,0 +1,144 @@
+use actors::ws::WebsocketActor;
+use actors::Actor;
+use warp::Filter;
+
+#[tokio::main]
+async fn main() {
+    // GET /chat -> websocket upgrade
+    let chat = warp::path("chat")
+        // The `ws()` filter will prepare Websocket handshake...
+        .and(warp::ws())
+        .map(|ws: warp::ws::Ws| {
+            // This will call our function if the handshake succeeds.
+            ws.on_upgrade(move |socket| {
+                let actor = WebsocketActor::new(socket);
+                actor.start();
+                futures::future::ready(())
+            })
+        });
+
+    // GET / -> index html
+    let index = warp::path::end().map(|| warp::reply::html(INDEX_HTML));
+
+    let routes = index.or(chat);
+
+    warp::serve(routes).run(([127, 0, 0, 1], 3030)).await;
+}
+/*
+async fn user_connected(ws: WebSocket, users: Users) {
+    // Use a counter to assign a new unique ID for this user.
+    let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);
+
+    eprintln!("new chat user: {}", my_id);
+
+    // Split the socket into a sender and receive of messages.
+    let (mut user_ws_tx, mut user_ws_rx) = ws.split();
+
+    // Use an unbounded channel to handle buffering and flushing of messages
+    // to the websocket...
+    let (tx, rx) = mpsc::unbounded_channel();
+    let mut rx = UnboundedReceiverStream::new(rx);
+
+    tokio::task::spawn(async move {
+        while let Some(message) = rx.next().await {
+            user_ws_tx
+                .send(message)
+                .unwrap_or_else(|e| {
+                    eprintln!("websocket send error: {}", e);
+                })
+                .await;
+        }
+    });
+
+    // Save the sender in our list of connected users.
+    users.write().await.insert(my_id, tx);
+
+    // Return a `Future` that is basically a state machine managing
+    // this specific user's connection.
+
+    // Every time the user sends a message, broadcast it to
+    // all other users...
+    while let Some(result) = user_ws_rx.next().await {
+        let msg = match result {
+            Ok(msg) => {
+                println!("MESSAGE {msg:?}");
+                msg
+            }
+            Err(e) => {
+                eprintln!("websocket error(uid={}): {}", my_id, e);
+                break;
+            }
+        };
+        // Skip any non-Text messages...
+        let msg = if let Ok(s) = msg.to_str() {
+            s
+        } else {
+            return;
+        };
+
+        let new_msg = format!("<User#{}>: {}", my_id, msg);
+
+        // New message from this user, send it to everyone else (except same uid)...
+        for (&uid, tx) in users.read().await.iter() {
+            if my_id != uid {
+                if let Err(_disconnected) = tx.send(Message::text(new_msg.clone())) {
+                    // The tx is disconnected, our `user_disconnected` code
+                    // should be happening in another task, nothing more to
+                    // do here.
+                }
+            }
+        }
+    }
+
+    // user_ws_rx stream will keep processing as long as the user stays
+    // connected. Once they disconnect, then...
+    user_disconnected(my_id, &users).await;
+}
+ */
+static INDEX_HTML: &str = r#"<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <title>Warp Chat</title>
+    </head>
+    <body>
+        <h1>Warp chat</h1>
+        <div id="chat">
+            <p><em>Connecting...</em></p>
+        </div>
+        <input type="text" id="text" />
+        <button type="button" id="send">Send</button>
+        <script type="text/javascript">
+        const chat = document.getElementById('chat');
+        const text = document.getElementById('text');
+        const uri = 'ws://' + location.host + '/chat';
+        const ws = new WebSocket(uri);
+
+        function message(data) {
+            const line = document.createElement('p');
+            line.innerText = data;
+            chat.appendChild(line);
+        }
+
+        ws.onopen = function() {
+            chat.innerHTML = '<p><em>Connected!</em></p>';
+        };
+
+        ws.onmessage = function(msg) {
+            message(msg.data);
+        };
+
+        ws.onclose = function() {
+            chat.getElementsByTagName('em')[0].innerText = 'Disconnected!';
+        };
+
+        send.onclick = function() {
+            const msg = text.value;
+            ws.send(msg);
+            text.value = '';
+
+            message('<You>: ' + msg);
+        };
+        </script>
+    </body>
+</html>
+"#;