diff options
| author | Gustav Sörnäs <gustav@sornas.net> | 2021-03-30 15:24:46 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-03-30 15:24:46 +0200 |
| commit | 795e46c98616801c678bd0a403b08cb0fcd5ee43 (patch) | |
| tree | 040efd79def19e28422980ebfb6ce414ff349570 /mumd/src/network/tcp.rs | |
| parent | a6d433e3ad95b9a21d5d473da4b1f65e78585bb2 (diff) | |
| parent | b52068eade50758673e29c79e7cb8be3f1b4151f (diff) | |
| download | mum-795e46c98616801c678bd0a403b08cb0fcd5ee43.tar.gz | |
Merge pull request #79 from mum-rs/tcp-event-queue
Diffstat (limited to 'mumd/src/network/tcp.rs')
| -rw-r--r-- | mumd/src/network/tcp.rs | 108 |
1 files changed, 75 insertions, 33 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 8ce49cb..47b1c20 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -36,11 +36,47 @@ pub enum TcpEvent { Disconnected, //fires when the client has disconnected from a server } +#[derive(Clone)] pub enum TcpEventData<'a> { - Connected(&'a msgs::ServerSync), + Connected(Result<&'a msgs::ServerSync, mumlib::Error>), Disconnected, } +impl<'a> From<&TcpEventData<'a>> for TcpEvent { + fn from(t: &TcpEventData) -> Self { + match t { + TcpEventData::Connected(_) => TcpEvent::Connected, + TcpEventData::Disconnected => TcpEvent::Disconnected, + } + } +} + +#[derive(Clone)] +struct TcpEventQueue { + handlers: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, +} + +impl TcpEventQueue { + fn new() -> Self { + Self { + handlers: Arc::new(Mutex::new(HashMap::new())), + } + } + + async fn register(&self, at: TcpEvent, callback: TcpEventCallback) { + self.handlers.lock().await.entry(at).or_default().push(callback); + } + + async fn resolve<'a>(&self, data: TcpEventData<'a>) { + if let Some(vec) = self.handlers.lock().await.get_mut(&TcpEvent::from(&data)) { + let old = std::mem::take(vec); + for handler in old { + handler(data.clone()); + } + } + } +} + pub async fn handle( state: Arc<Mutex<State>>, mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, @@ -73,7 +109,7 @@ pub async fn handle( let phase_watcher = state_lock.phase_receiver(); let input_receiver = state_lock.audio().input_receiver(); drop(state_lock); - let event_queue = Arc::new(Mutex::new(HashMap::new())); + let event_queue = TcpEventQueue::new(); info!("Logging in..."); @@ -85,7 +121,7 @@ pub async fn handle( Arc::clone(&state), stream, crypt_state_sender.clone(), - Arc::clone(&event_queue), + event_queue.clone(), ), send_voice( packet_sender.clone(), @@ -93,17 +129,12 @@ pub async fn handle( phase_watcher.clone(), ), send_packets(sink, &mut packet_receiver), - register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)), + register_events(&mut tcp_event_register_receiver, event_queue.clone()), ).map(|_| ()), phase_watcher, ).await; - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Disconnected); - } - } + event_queue.resolve(TcpEventData::Disconnected).await; debug!("Fully disconnected TCP stream, waiting for new connection info"); } @@ -154,12 +185,12 @@ async fn send_pings( delay_seconds: u64, ) { let mut interval = time::interval(Duration::from_secs(delay_seconds)); - loop { - interval.tick().await; - trace!("Sending TCP ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } + loop { + interval.tick().await; + trace!("Sending TCP ping"); + let msg = msgs::Ping::new(); + packet_sender.send(msg.into()).unwrap(); + } } async fn send_packets( @@ -209,14 +240,27 @@ async fn listen( state: Arc<Mutex<State>>, mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, - event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + event_queue: TcpEventQueue, ) { let mut crypt_state = None; let mut crypt_state_sender = Some(crypt_state_sender); loop { - let packet = stream.next().await.unwrap(); - match packet.unwrap() { + let packet = match stream.next().await { + Some(Ok(packet)) => packet, + Some(Err(e)) => { + error!("TCP error: {:?}", e); + continue; //TODO Break here? Maybe look at the error and handle it + } + None => { + // We end up here if the login was rejected. We probably want + // to exit before that. + warn!("TCP stream gone"); + state.lock().await.broadcast_phase(StatePhase::Disconnected); + break; + } + }; + match packet { ControlPacket::TextMessage(msg) => { info!( "Got message from user with session ID {}: {}", @@ -250,12 +294,7 @@ async fn listen( ) .await; } - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); - } - } + event_queue.resolve(TcpEventData::Connected(Ok(&msg))).await; let mut state = state.lock().await; let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); @@ -269,7 +308,15 @@ async fn listen( state.initialized(); } ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); + debug!("Login rejected: {:?}", msg); + match msg.get_field_type() { + msgs::Reject_RejectType::WrongServerPW => { + event_queue.resolve(TcpEventData::Connected(Err(mumlib::Error::InvalidServerPassword))).await; + } + ty => { + warn!("Unhandled reject type: {:?}", ty); + } + } } ControlPacket::UserState(msg) => { state.lock().await.parse_user_state(*msg); @@ -324,15 +371,10 @@ async fn listen( async fn register_events( tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, - event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + event_queue: TcpEventQueue, ) { loop { let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); - event_data - .lock() - .await - .entry(event) - .or_default() - .push(handler); + event_queue.register(event, handler).await; } } |
