ws.rs 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. use crate::{
  2. message::{ActorMessage, Envelope},
  3. runtime::Runtime,
  4. Actor, ActorCommand, ActorHandle, ActorStatus, Error, Handler,
  5. };
  6. use async_trait::async_trait;
  7. use flume::Receiver;
  8. use futures::{
  9. stream::{SplitSink, SplitStream},
  10. Future, SinkExt, Stream, StreamExt,
  11. };
  12. use pin_project::pin_project;
  13. use std::{
  14. collections::VecDeque,
  15. pin::Pin,
  16. sync::atomic::AtomicUsize,
  17. task::{Context, Poll},
  18. };
  19. use warp::ws::WebSocket;
  20. pub struct WebsocketActor {
  21. websocket: Option<WebSocket>,
  22. }
  23. impl Actor for WebsocketActor {
  24. fn start(self) -> ActorHandle<Self>
  25. where
  26. Self: Sized + Send + 'static,
  27. {
  28. println!("Starting websocket actor");
  29. WebsocketRuntime::run(self)
  30. }
  31. }
  32. impl WebsocketActor {
  33. pub fn new(ws: WebSocket) -> Self {
  34. Self {
  35. websocket: Some(ws),
  36. }
  37. }
  38. }
  39. static PROCESSED: AtomicUsize = AtomicUsize::new(0);
  40. #[pin_project]
  41. pub struct WebsocketRuntime {
  42. actor: WebsocketActor,
  43. status: ActorStatus,
  44. // Pin these 2 as we are polling them directly so we know they never move
  45. /// The receiving end of the websocket
  46. #[pin]
  47. ws_stream: SplitStream<WebSocket>,
  48. /// The sending end of the websocket
  49. #[pin]
  50. ws_sink: SplitSink<WebSocket, warp::ws::Message>,
  51. /// Actor message receiver
  52. message_rx: Receiver<Envelope<WebsocketActor>>,
  53. /// Actor command receiver
  54. command_rx: Receiver<ActorCommand>,
  55. /// Received, but not yet processed messages
  56. message_queue: VecDeque<Envelope<WebsocketActor>>,
  57. /// Received, but not yet processed websocket messages
  58. request_queue: VecDeque<warp::ws::Message>,
  59. /// Processed websocket messages ready to be flushed in the sink
  60. response_queue: VecDeque<warp::ws::Message>,
  61. }
  62. impl WebsocketRuntime {
  63. pub fn new(
  64. mut actor: WebsocketActor,
  65. command_rx: Receiver<ActorCommand>,
  66. message_rx: Receiver<Envelope<WebsocketActor>>,
  67. ) -> Self {
  68. let (ws_sink, ws_stream) = actor
  69. .websocket
  70. .take()
  71. .expect("Websocket runtime already started")
  72. .split();
  73. Self {
  74. actor,
  75. ws_sink,
  76. ws_stream,
  77. message_rx,
  78. command_rx,
  79. message_queue: VecDeque::new(),
  80. request_queue: VecDeque::new(),
  81. response_queue: VecDeque::new(),
  82. status: ActorStatus::Starting,
  83. }
  84. }
  85. }
  86. impl Future for WebsocketRuntime {
  87. type Output = Result<(), Error>;
  88. fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  89. let mut this = self.project();
  90. loop {
  91. // Poll command receiver
  92. match Pin::new(&mut this.command_rx.recv_async()).poll(cx) {
  93. Poll::Ready(Ok(message)) => match message {
  94. ActorCommand::Stop => {
  95. println!("Actor stopping");
  96. break Poll::Ready(Ok(())); // TODO drain the queue and all that graceful stuff
  97. }
  98. },
  99. Poll::Ready(Err(_)) => {
  100. println!("Actor stopping"); // TODO drain the queue and all that graceful stuff
  101. break Poll::Ready(Err(Error::ActorChannelClosed));
  102. }
  103. Poll::Pending => {}
  104. };
  105. // Poll the websocket stream for any messages and store them to the queue
  106. while let Poll::Ready(Some(ws_message)) = this.ws_stream.as_mut().poll_next(cx) {
  107. match ws_message {
  108. Ok(message) => this.request_queue.push_back(message),
  109. Err(e) => {
  110. eprintln!("WS error occurred {e}")
  111. }
  112. }
  113. }
  114. // Respond to any queued and processed websocket messages
  115. let mut idx = 0;
  116. while idx < this.request_queue.len() {
  117. let ws_message = &this.request_queue[idx];
  118. match this.actor.handle(ws_message.to_owned()).as_mut().poll(cx) {
  119. Poll::Ready(result) => match result {
  120. Ok(response) => {
  121. if let Some(response) = response {
  122. match Pin::new(&mut this.ws_sink.feed(response)).poll(cx) {
  123. Poll::Ready(result) => {
  124. result?;
  125. this.request_queue.swap_remove_front(idx);
  126. PROCESSED
  127. .fetch_add(1, std::sync::atomic::Ordering::Acquire);
  128. }
  129. Poll::Pending => idx += 1,
  130. }
  131. }
  132. }
  133. Err(e) => return Poll::Ready(Err(e)),
  134. },
  135. Poll::Pending => idx += 1,
  136. }
  137. }
  138. println!(
  139. "PROCESSED {}",
  140. PROCESSED.load(std::sync::atomic::Ordering::Acquire)
  141. );
  142. let _ = Pin::new(&mut this.ws_sink.flush()).poll(cx);
  143. // Process all messages
  144. this.message_queue
  145. .retain_mut(|message| message.handle(this.actor).as_mut().poll(cx).is_pending());
  146. // Poll message receiver and continue to process if anything comes up
  147. while let Poll::Ready(Ok(message)) =
  148. Pin::new(&mut this.message_rx.recv_async()).poll(cx)
  149. {
  150. this.message_queue.push_back(message);
  151. }
  152. // Poll again and process new messages if any
  153. match Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
  154. Poll::Ready(Ok(message)) => {
  155. this.message_queue.push_back(message);
  156. continue;
  157. }
  158. Poll::Ready(Err(_)) => {
  159. println!("Message channel closed, ungracefully stopping actor");
  160. break Poll::Ready(Err(Error::ActorChannelClosed));
  161. }
  162. Poll::Pending => {
  163. if !this.message_queue.is_empty() {
  164. continue;
  165. }
  166. }
  167. };
  168. cx.waker().wake_by_ref();
  169. return Poll::Pending;
  170. }
  171. }
  172. }
  173. impl Runtime<WebsocketActor> for WebsocketRuntime {
  174. fn run(actor: WebsocketActor) -> ActorHandle<WebsocketActor> {
  175. let (message_tx, message_rx) = flume::unbounded();
  176. let (command_tx, command_rx) = flume::unbounded();
  177. tokio::spawn(WebsocketRuntime::new(actor, command_rx, message_rx));
  178. ActorHandle::new(message_tx, command_tx)
  179. }
  180. }
  181. impl crate::Message for warp::ws::Message {
  182. type Response = Option<warp::ws::Message>;
  183. }
  184. #[async_trait]
  185. impl Handler<warp::ws::Message> for WebsocketActor {
  186. async fn handle(
  187. &mut self,
  188. message: warp::ws::Message,
  189. ) -> Result<<warp::ws::Message as crate::message::Message>::Response, crate::Error> {
  190. // println!("Actor received message {message:?}");
  191. if message.is_text() {
  192. Ok(Some(message))
  193. } else {
  194. Ok(None)
  195. }
  196. }
  197. }