diff options
Diffstat (limited to 'mumd/src/network/tcp.rs')
| -rw-r--r-- | mumd/src/network/tcp.rs | 76 |
1 files changed, 46 insertions, 30 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 7606987..b513797 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,4 +1,4 @@ -use crate::error::{ServerSendError, TcpError}; +use crate::{error::{ServerSendError, TcpError}, notifications}; use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; @@ -30,17 +30,20 @@ type TcpReceiver = SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>; pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData)>; +pub(crate) type TcpEventSubscriber = Box<dyn FnMut(TcpEventData) -> bool>; //the bool indicates if it should be kept or not #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum TcpEvent { Connected, //fires when the client has connected to a server Disconnected, //fires when the client has disconnected from a server + TextMessage, //fires when a text message comes in } #[derive(Clone)] pub enum TcpEventData<'a> { Connected(Result<&'a msgs::ServerSync, mumlib::Error>), Disconnected, + TextMessage(&'a msgs::TextMessage), } impl<'a> From<&TcpEventData<'a>> for TcpEvent { @@ -48,33 +51,53 @@ impl<'a> From<&TcpEventData<'a>> for TcpEvent { match t { TcpEventData::Connected(_) => TcpEvent::Connected, TcpEventData::Disconnected => TcpEvent::Disconnected, + TcpEventData::TextMessage(_) => TcpEvent::TextMessage, } } } #[derive(Clone)] -struct TcpEventQueue { - handlers: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, +pub struct TcpEventQueue { + callbacks: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + subscribers: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventSubscriber>>>>, } impl TcpEventQueue { - fn new() -> Self { + /// Creates a new `TcpEventQueue`. + pub fn new() -> Self { Self { - handlers: Arc::new(Mutex::new(HashMap::new())), + callbacks: Arc::new(RwLock::new(HashMap::new())), + subscribers: Arc::new(RwLock::new(HashMap::new())), } } - async fn register(&self, at: TcpEvent, callback: TcpEventCallback) { - self.handlers.lock().await.entry(at).or_default().push(callback); + /// Registers a new callback to be triggered when an event is fired. + pub fn register_callback(&self, at: TcpEvent, callback: TcpEventCallback) { + self.callbacks.write().unwrap().entry(at).or_default().push(callback); } - async fn resolve<'a>(&self, data: TcpEventData<'a>) { - if let Some(vec) = self.handlers.lock().await.get_mut(&TcpEvent::from(&data)) { + /// Registers a new callback to be triggered when an event is fired. + pub fn register_subscriber(&self, at: TcpEvent, callback: TcpEventSubscriber) { + self.subscribers.write().unwrap().entry(at).or_default().push(callback); + } + + /// Fires all callbacks related to a specific TCP event and removes them from the event queue. + /// Also calls all event subscribers, but keeps them in the queue + pub fn resolve<'a>(&self, data: TcpEventData<'a>) { + if let Some(vec) = self.callbacks.write().unwrap().get_mut(&TcpEvent::from(&data)) { let old = std::mem::take(vec); for handler in old { handler(data.clone()); } } + if let Some(vec) = self.subscribers.write().unwrap().get_mut(&TcpEvent::from(&data)) { + let old = std::mem::take(vec); + for mut e in old { + if e(data.clone()) { + vec.push(e) + } + } + } } } @@ -84,7 +107,7 @@ pub async fn handle( crypt_state_sender: mpsc::Sender<ClientCryptState>, packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, - mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, + event_queue: TcpEventQueue, ) -> Result<(), TcpError> { loop { let connection_info = 'data: loop { @@ -114,7 +137,6 @@ pub async fn handle( (state_lock.phase_receiver(), state_lock.audio_input().receiver()) }; - let event_queue = TcpEventQueue::new(); info!("Logging in..."); @@ -137,13 +159,12 @@ pub async fn handle( phase_watcher_inner, ).fuse() => r, r = send_packets(sink, &mut packet_receiver).fuse() => r, - _ = register_events(&mut tcp_event_register_receiver, event_queue.clone()).fuse() => Ok(()), } }, phase_watcher, ).await.unwrap_or(Ok(()))?; - event_queue.resolve(TcpEventData::Disconnected).await; + event_queue.resolve(TcpEventData::Disconnected); debug!("Fully disconnected TCP stream, waiting for new connection info"); } @@ -270,11 +291,16 @@ async fn listen( }; match packet { ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); + let mut state = state.write().unwrap(); + let user = state.server() + .and_then(|server| server.users().get(&msg.get_actor())) + .map(|user| user.name()); + if let Some(user) = user { + notifications::send(format!("{}: {}", user, msg.get_message())); //TODO: probably want a config flag for this + } + state.register_message((msg.get_message().to_owned(), msg.get_actor())); + drop(state); + event_queue.resolve(TcpEventData::TextMessage(&*msg)); } ControlPacket::CryptSetup(msg) => { debug!("Crypt setup"); @@ -302,7 +328,7 @@ async fn listen( ) .await; } - event_queue.resolve(TcpEventData::Connected(Ok(&msg))).await; + event_queue.resolve(TcpEventData::Connected(Ok(&msg))); let mut state = state.write().unwrap(); let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); @@ -319,7 +345,7 @@ async fn listen( debug!("Login rejected: {:?}", msg); match msg.get_field_type() { msgs::Reject_RejectType::WrongServerPW => { - event_queue.resolve(TcpEventData::Connected(Err(mumlib::Error::InvalidServerPassword))).await; + event_queue.resolve(TcpEventData::Connected(Err(mumlib::Error::InvalidServerPassword))); } ty => { warn!("Unhandled reject type: {:?}", ty); @@ -377,13 +403,3 @@ async fn listen( } Ok(()) } - -async fn register_events( - tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, - event_queue: TcpEventQueue, -) { - loop { - let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); - event_queue.register(event, handler).await; - } -} |
