From c5af1b237027031be310951c36f23f0a0bc760b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Mon, 29 Mar 2021 18:10:41 +0200 Subject: tcp event queue --- mumd/src/network/tcp.rs | 76 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 8ce49cb..cd178f8 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -36,9 +36,52 @@ pub enum TcpEvent { Disconnected, //fires when the client has disconnected from a server } +#[derive(Clone)] pub enum TcpEventData<'a> { Connected(&'a msgs::ServerSync), - Disconnected, + _Disconnected, +} + +impl<'a> From<&TcpEventData<'a>> for TcpEvent { + fn from(t: &TcpEventData) -> Self { + match t { + TcpEventData::Connected(_) => TcpEvent::Connected, + TcpEventData::_Disconnected => TcpEvent::Disconnected, + } + } +} + +struct TcpEventQueue { + handlers: Arc>>>, +} + +impl TcpEventQueue { + fn new() -> Self { + Self { + handlers: Arc::new(Mutex::new(HashMap::new())), + } + } + + async fn register(&mut self, at: TcpEvent, callback: TcpEventCallback) { + self.handlers.lock().await.entry(at).or_default().push(callback); + } + + async fn send<'a>(&mut self, data: TcpEventData<'a>) { + if let Some(vec) = self.handlers.lock().await.get_mut(&TcpEvent::from(&data)) { + let old = std::mem::take(vec); + for handler in old { + handler(data.clone()); + } + } + } +} + +impl Clone for TcpEventQueue { + fn clone(&self) -> Self { + Self { + handlers: Arc::clone(&self.handlers), + } + } } pub async fn handle( @@ -73,7 +116,7 @@ pub async fn handle( let phase_watcher = state_lock.phase_receiver(); let input_receiver = state_lock.audio().input_receiver(); drop(state_lock); - let event_queue = Arc::new(Mutex::new(HashMap::new())); + let event_queue = TcpEventQueue::new(); info!("Logging in..."); @@ -85,7 +128,7 @@ pub async fn handle( Arc::clone(&state), stream, crypt_state_sender.clone(), - Arc::clone(&event_queue), + event_queue.clone(), ), send_voice( packet_sender.clone(), @@ -93,18 +136,11 @@ pub async fn handle( phase_watcher.clone(), ), send_packets(sink, &mut packet_receiver), - register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)), + register_events(&mut tcp_event_register_receiver, event_queue.clone()), ).map(|_| ()), phase_watcher, ).await; - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Disconnected); - } - } - debug!("Fully disconnected TCP stream, waiting for new connection info"); } } @@ -209,7 +245,7 @@ async fn listen( state: Arc>, mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender, - event_queue: Arc>>>, + mut event_queue: TcpEventQueue, ) { let mut crypt_state = None; let mut crypt_state_sender = Some(crypt_state_sender); @@ -250,12 +286,7 @@ async fn listen( ) .await; } - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); - } - } + event_queue.send(TcpEventData::Connected(&msg)).await; let mut state = state.lock().await; let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); @@ -324,15 +355,10 @@ async fn listen( async fn register_events( tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, - event_data: Arc>>>, + mut event_queue: TcpEventQueue, ) { loop { let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); - event_data - .lock() - .await - .entry(event) - .or_default() - .push(handler); + event_queue.register(event, handler).await; } } -- cgit v1.2.1