diff options
| author | Eskil Q <eskilq@kth.se> | 2021-01-06 18:31:49 +0100 |
|---|---|---|
| committer | Eskil Q <eskilq@kth.se> | 2021-01-06 18:31:49 +0100 |
| commit | 02e6f2b84d72294b29a1698c1b73fbb5697815da (patch) | |
| tree | af85a0277c89ef7983f79ff795acf1bd94eee848 | |
| parent | b15e010a6bebc7b7c6b8afb1b51f2673d0695e06 (diff) | |
| download | mum-02e6f2b84d72294b29a1698c1b73fbb5697815da.tar.gz | |
clean up network::run_until
| -rw-r--r-- | mumd/src/network.rs | 41 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 225 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 8 |
3 files changed, 131 insertions, 143 deletions
diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 03bc436..75b983e 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -10,6 +10,7 @@ use futures::join; use futures::pin_mut; use futures::select; use tokio::sync::watch; +use log::*; use crate::state::StatePhase; @@ -36,16 +37,14 @@ pub enum VoiceStreamType { UDP, } -async fn run_until<T, F, G, H>( +async fn run_until<F, G>( phase_checker: impl Fn(StatePhase) -> bool, - mut generator: impl FnMut() -> F, - mut handler: impl FnMut(T) -> G, - mut shutdown: impl FnMut() -> H, + fut: F, + mut shutdown: impl FnMut() -> G, mut phase_watcher: watch::Receiver<StatePhase>, ) where - F: Future<Output = Option<T>>, + F: Future<Output = ()>, G: Future<Output = ()>, - H: Future<Output = ()>, { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { @@ -55,32 +54,20 @@ async fn run_until<T, F, G, H>( break; } } - tx.send(true).unwrap(); + if tx.send(true).is_err() { + warn!("future resolved before it could be cancelled"); + } }; let main_block = async { let rx = rx.fuse(); pin_mut!(rx); - loop { - let packet_recv = generator().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - //warn!("Channel closed before disconnect command"); //TODO make me informative - break; - } - Some(Some(data)) => { - handler(data).await; - } - } - } + let fut = fut.fuse(); + pin_mut!(fut); + select! { + _ = fut => (), + _ = rx => (), + }; shutdown().await; }; diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 717b195..982e747 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -145,11 +145,13 @@ async fn send_pings( run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { Some(interval.borrow_mut().tick().await) }, - |_| async { - trace!("Sending ping"); - let msg = msgs::Ping::new(); - packet_sender.borrow_mut().send(msg.into()).unwrap(); + async { + loop { + interval.borrow_mut().tick().await; + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.borrow_mut().send(msg.into()).unwrap(); + } }, || async {}, phase_watcher, @@ -168,9 +170,11 @@ async fn send_packets( let packet_receiver = Rc::new(RefCell::new(packet_receiver)); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { packet_receiver.borrow_mut().recv().await }, - |packet| async { - sink.borrow_mut().send(packet).await.unwrap(); + async { + loop { + let packet = packet_receiver.borrow_mut().recv().await.unwrap(); + sink.borrow_mut().send(packet).await.unwrap(); + } }, || async { sink.borrow_mut().close().await.unwrap(); @@ -190,28 +194,26 @@ async fn send_voice( let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { + async { run_until( |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), - || async { - packet_sender.send(receiver - .lock() - .await - .next() - .await - .unwrap() - .into()) - .unwrap(); - Some(Some(())) + async { + loop { + packet_sender.send(receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); + } }, - |_| async {}, || async {}, inner_phase_watcher.clone(), ).await; debug!("Stopped sending TCP voice"); - Some(Some(())) }, - |_| async {}, || async {}, phase_watcher, ).await; @@ -219,7 +221,7 @@ async fn send_voice( async fn listen( state: Arc<Mutex<State>>, - stream: TcpReceiver, + mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, phase_watcher: watch::Receiver<StatePhase>, @@ -227,92 +229,93 @@ async fn listen( let crypt_state = Rc::new(RefCell::new(None)); let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); - let stream = Rc::new(RefCell::new(stream)); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { stream.borrow_mut().next().await }, - |packet| async { - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.borrow_mut().take() { - let _ = sender - .send( - crypt_state - .borrow_mut() - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; + async { + loop { + let packet = stream.next().await.unwrap(); + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); } - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.borrow_mut().take() { + let _ = sender + .send( + crypt_state + .borrow_mut() + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Connected(&msg)); + } + } + let mut state = state.lock().unwrap(); + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); } - let mut state = state.lock().unwrap(); - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); + ControlPacket::UserState(msg) => { + state.lock().unwrap().parse_user_state(*msg); + } + ControlPacket::UserRemove(msg) => { + state.lock().unwrap().remove_client(*msg); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + packet => { + debug!("Received unhandled ControlPacket {:#?}", packet); } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - state.lock().unwrap().parse_user_state(*msg); - } - ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - packet => { - debug!("Received unhandled ControlPacket {:#?}", packet); } } }, @@ -339,14 +342,16 @@ async fn register_events( let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { tcp_event_register_receiver.borrow_mut().recv().await }, - |(event, handler)| async { - event_data - .lock() - .unwrap() - .entry(event) - .or_default() - .push(handler); + async { + loop { + let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap(); + event_data + .lock() + .unwrap() + .entry(event) + .or_default() + .push(handler); + } }, || async {}, phase_watcher, diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 5e725cd..d35a255 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -233,22 +233,18 @@ async fn send_voice( let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { + async { run_until( |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - || async { + async { debug!("Sending UDP audio"); sink.lock().unwrap().send((receiver.lock().await.next().await.unwrap(), server_addr)).await.unwrap(); debug!("Sent UDP audio"); - Some(Some(())) }, - |_| async {}, || async {}, inner_phase_watcher.clone(), ).await; - Some(Some(())) }, - |_| async {}, || async {}, phase_watcher, ).await; |
