ws.rs 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. use std::{
  2. collections::VecDeque,
  3. pin::Pin,
  4. task::{Context, Poll},
  5. };
  6. use crate::{
  7. message::{Envelope, PackedMessage},
  8. runtime::Runtime,
  9. Actor, ActorCommand, ActorHandle, Error, Handler,
  10. };
  11. use flume::Receiver;
  12. use futures::{
  13. stream::{SplitSink, SplitStream},
  14. Future, SinkExt, StreamExt,
  15. };
  16. use pin_project::pin_project;
  17. use warp::ws::WebSocket;
  18. pub struct WebsocketActor {
  19. websocket: Option<WebSocket>,
  20. }
  21. impl Actor for WebsocketActor {
  22. fn start(self) -> ActorHandle<Self>
  23. where
  24. Self: Sized + Send + 'static,
  25. {
  26. println!("Starting websocket actor");
  27. WebsocketRuntime::run(self)
  28. }
  29. }
  30. impl WebsocketActor {
  31. pub fn new(ws: WebSocket) -> Self {
  32. Self {
  33. websocket: Some(ws),
  34. }
  35. }
  36. }
  37. #[pin_project]
  38. pub struct WebsocketRuntime {
  39. actor: WebsocketActor,
  40. ws_stream: SplitStream<WebSocket>,
  41. ws_sink: SplitSink<WebSocket, warp::ws::Message>,
  42. message_rx: Receiver<Envelope<WebsocketActor>>,
  43. command_rx: Receiver<ActorCommand>,
  44. message_queue: VecDeque<Envelope<WebsocketActor>>,
  45. ws_queue: VecDeque<warp::ws::Message>,
  46. }
  47. impl WebsocketRuntime {
  48. pub fn new(
  49. mut actor: WebsocketActor,
  50. command_rx: Receiver<ActorCommand>,
  51. message_rx: Receiver<Envelope<WebsocketActor>>,
  52. ) -> Self {
  53. let (ws_sink, ws_stream) = actor
  54. .websocket
  55. .take()
  56. .expect("Websocket runtime already started")
  57. .split();
  58. Self {
  59. actor,
  60. ws_sink,
  61. ws_stream,
  62. message_rx,
  63. command_rx,
  64. message_queue: VecDeque::new(),
  65. ws_queue: VecDeque::new(),
  66. }
  67. }
  68. }
  69. impl Future for WebsocketRuntime {
  70. type Output = Result<(), Error>;
  71. fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  72. let this = self.project();
  73. loop {
  74. // Poll command receiver
  75. match Pin::new(&mut this.command_rx.recv_async()).poll(cx) {
  76. Poll::Ready(Ok(message)) => match message {
  77. ActorCommand::Stop => {
  78. println!("Actor stopping");
  79. break Poll::Ready(Ok(())); // TODO drain the queue and all that graceful stuff
  80. }
  81. },
  82. Poll::Ready(Err(_)) => {
  83. println!("Command stream dropped, ungracefully stopping actor");
  84. break Poll::Ready(Err(Error::ActorChannelClosed));
  85. }
  86. Poll::Pending => {}
  87. };
  88. // Poll the websocket stream for any messages and store them to the queue
  89. while let Poll::Ready(Some(ws_message)) = Pin::new(&mut this.ws_stream.next()).poll(cx)
  90. {
  91. match ws_message {
  92. Ok(message) => this.ws_queue.push_back(message),
  93. Err(e) => {
  94. eprintln!("WS error occurred {e}")
  95. }
  96. }
  97. }
  98. // Respond to any queued websocket messages
  99. while let Some(ws_message) = this.ws_queue.pop_front() {
  100. if let Some(res) = this.actor.handle(ws_message)? {
  101. match Pin::new(&mut this.ws_sink.send(res)).poll(cx) {
  102. Poll::Ready(result) => result?,
  103. Poll::Pending => todo!(),
  104. }
  105. }
  106. }
  107. // Process all messages
  108. while let Some(mut message) = this.message_queue.pop_front() {
  109. message.handle(this.actor)
  110. }
  111. // Poll message receiver and continue to process if anything comes up
  112. while let Poll::Ready(Ok(message)) =
  113. Pin::new(&mut this.message_rx.recv_async()).poll(cx)
  114. {
  115. this.message_queue.push_back(message);
  116. }
  117. // Poll again and process new messages if any
  118. match Pin::new(&mut this.message_rx.recv_async()).poll(cx) {
  119. Poll::Ready(Ok(message)) => {
  120. this.message_queue.push_back(message);
  121. continue;
  122. }
  123. Poll::Ready(Err(_)) => {
  124. println!("Message channel closed, ungracefully stopping actor");
  125. break Poll::Ready(Err(Error::ActorChannelClosed));
  126. }
  127. Poll::Pending => {
  128. if !this.message_queue.is_empty() {
  129. continue;
  130. }
  131. }
  132. };
  133. cx.waker().wake_by_ref();
  134. return Poll::Pending;
  135. }
  136. }
  137. }
  138. impl Runtime<WebsocketActor> for WebsocketRuntime {
  139. fn run(actor: WebsocketActor) -> ActorHandle<WebsocketActor> {
  140. let (tx, rx) = flume::unbounded();
  141. let (cmd_tx, cmd_rx) = flume::unbounded();
  142. let rt = WebsocketRuntime::new(actor, cmd_rx, rx);
  143. tokio::spawn(rt);
  144. ActorHandle {
  145. message_tx: tx,
  146. command_tx: cmd_tx,
  147. }
  148. }
  149. }
  150. impl crate::Message for warp::ws::Message {
  151. type Response = Option<warp::ws::Message>;
  152. }
  153. impl Handler<warp::ws::Message> for WebsocketActor {
  154. fn handle(
  155. &mut self,
  156. message: warp::ws::Message,
  157. ) -> Result<<warp::ws::Message as crate::message::Message>::Response, crate::Error> {
  158. println!("Actor received message {message:?}");
  159. if message.is_text() {
  160. Ok(Some(message))
  161. } else {
  162. Ok(None)
  163. }
  164. }
  165. }