From ab407d694e5a8ce6f831f8a84fc32dbdf6685aac Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 12:39:49 +0100 Subject: lift waiting for disconnection --- mumd/src/network/tcp.rs | 402 +++++++++++++++++++++--------------------------- mumd/src/network/udp.rs | 259 ++++++++++++++----------------- 2 files changed, 287 insertions(+), 374 deletions(-) (limited to 'mumd/src') diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 3e4cbf3..e639dd0 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,17 +2,15 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use futures::{join, SinkExt, Stream, StreamExt}; +use futures::{FutureExt, SinkExt, Stream, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::VoicePacket; use mumble_protocol::{Clientbound, Serverbound}; -use std::cell::RefCell; use std::collections::HashMap; use std::convert::{Into, TryInto}; use std::net::SocketAddr; -use std::rc::Rc; use std::sync::Arc; use tokio::net::TcpStream; use tokio::sync::{mpsc, watch, Mutex}; @@ -21,6 +19,7 @@ use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; use super::{run_until, VoiceStreamType}; +use futures_util::future::join5; type TcpSender = SplitSink< Framed, ControlCodec>, @@ -76,24 +75,34 @@ pub async fn handle( info!("Logging in..."); - //TODO force exit all futures on disconnection - join!( - send_pings(packet_sender.clone(), 10, phase_watcher.clone()), - listen( - Arc::clone(&state), - stream, - crypt_state_sender.clone(), - Arc::clone(&event_queue), - phase_watcher.clone(), - ), - send_voice( - packet_sender.clone(), - Arc::clone(&input_receiver), - phase_watcher.clone(), - ), - send_packets(sink, &mut packet_receiver, phase_watcher.clone()), - register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), - ); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join5( + send_pings(packet_sender.clone(), 10), + listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + Arc::clone(&event_queue), + ), + send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher.clone(), + ), + send_packets(sink, &mut packet_receiver), + register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)), + ).map(|_| ()), + || async {}, + phase_watcher, + ).await; + + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Disconnected); + } + } debug!("Fully disconnected TCP stream, waiting for new connection info"); } @@ -135,50 +144,24 @@ async fn authenticate(sink: &mut TcpSender, username: String) { async fn send_pings( packet_sender: mpsc::UnboundedSender>, delay_seconds: u64, - phase_watcher: watch::Receiver, ) { let mut interval = time::interval(Duration::from_secs(delay_seconds)); - - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - interval.tick().await; - trace!("Sending TCP ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } - }, - || async {}, - phase_watcher, - ) - .await; - - debug!("Ping sender process killed"); + loop { + interval.tick().await; + trace!("Sending TCP ping"); + let msg = msgs::Ping::new(); + packet_sender.send(msg.into()).unwrap(); + } } async fn send_packets( - sink: TcpSender, + mut sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver>, - phase_watcher: watch::Receiver, ) { - let sink = Rc::new(RefCell::new(sink)); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let packet = packet_receiver.recv().await.unwrap(); - sink.borrow_mut().send(packet).await.unwrap(); - } - }, - || async { - sink.borrow_mut().close().await.unwrap(); - }, - phase_watcher, - ) - .await; - - debug!("TCP packet sender killed"); + loop { + let packet = packet_receiver.recv().await.unwrap(); + sink.send(packet).await.unwrap(); + } } async fn send_voice( @@ -186,41 +169,33 @@ async fn send_voice( receiver: Arc> + Unpin)>>>, phase_watcher: watch::Receiver, ) { - let inner_phase_watcher = phase_watcher.clone(); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + let mut inner_phase_watcher = phase_watcher.clone(); + loop { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + async { loop { - inner_phase_watcher_2.changed().await.unwrap(); - if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { - break; - } + packet_sender.send( + receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); } - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), - async { - loop { - packet_sender.send( - receiver - .lock() - .await - .next() - .await - .unwrap() - .into()) - .unwrap(); - } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - } - }, - || async {}, - phase_watcher, - ).await; + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } } async fn listen( @@ -228,158 +203,129 @@ async fn listen( mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender, event_queue: Arc>>>, - phase_watcher: watch::Receiver, ) { - let crypt_state = Rc::new(RefCell::new(None)); - let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); + let mut crypt_state = None; + let mut crypt_state_sender = Some(crypt_state_sender); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let packet = stream.next().await.unwrap(); - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.borrow_mut().take() { - let _ = sender - .send( - crypt_state - .borrow_mut() - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; - } - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); - } - } - let mut state = state.lock().await; - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), - } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); - } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - state.lock().await.parse_user_state(*msg); - } - ControlPacket::UserRemove(msg) => { - state.lock().await.remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .await - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial + loop { + let packet = stream.next().await.unwrap(); + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + crypt_state = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.take() { + let _ = sender + .send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Connected(&msg)); } - ControlPacket::ChannelRemove(msg) => { + } + let mut state = state.lock().await; + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + state.lock().await.parse_user_state(*msg); + } + ControlPacket::UserRemove(msg) => { + state.lock().await.remove_client(*msg); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + ControlPacket::UDPTunnel(msg) => { + match *msg { + VoicePacket::Ping { .. } => {} + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { state .lock() .await - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - ControlPacket::UDPTunnel(msg) => { - match *msg { - VoicePacket::Ping { .. } => {} - VoicePacket::Audio { + .audio() + .decode_packet_payload( + VoiceStreamType::TCP, session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() - .await - .audio() - .decode_packet_payload( - VoiceStreamType::TCP, - session_id, - payload); - } - } - } - packet => { - debug!("Received unhandled ControlPacket {:#?}", packet); + payload); } } } - }, - || async { - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Disconnected); - } + packet => { + debug!("Received unhandled ControlPacket {:#?}", packet); } - }, - phase_watcher, - ) - .await; - - debug!("Killing TCP listener block"); + } + } } async fn register_events( tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, event_data: Arc>>>, - phase_watcher: watch::Receiver, ) { - let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap(); - event_data - .lock() - .await - .entry(event) - .or_default() - .push(handler); - } - }, - || async {}, - phase_watcher, - ) - .await; + loop { + let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); + event_data + .lock() + .await + .entry(event) + .or_default() + .push(handler); + } } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 4167c15..9dd6ed3 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,7 +1,7 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use futures::{join, SinkExt, StreamExt, Stream}; +use futures::{join, FutureExt, SinkExt, StreamExt, Stream}; use futures_util::stream::{SplitSink, SplitStream}; use log::*; use mumble_protocol::crypt::ClientCryptState; @@ -20,6 +20,7 @@ use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; use super::{run_until, VoiceStreamType}; +use futures_util::future::join4; pub type PingRequest = (u64, SocketAddr, Box); @@ -49,28 +50,32 @@ pub async fn handle( let phase_watcher = state.lock().await.phase_receiver(); let last_ping_recv = AtomicU64::new(0); - join!( - listen( - Arc::clone(&state), - Arc::clone(&source), - phase_watcher.clone(), - &last_ping_recv, - ), - send_voice( - Arc::clone(&sink), - connection_info.socket_addr, - phase_watcher.clone(), - Arc::clone(&receiver), - ), - send_pings( - Arc::clone(&state), - Arc::clone(&sink), - connection_info.socket_addr, - &last_ping_recv, - phase_watcher.clone(), - ), - new_crypt_state(&mut crypt_state_receiver, sink, source, phase_watcher), - ); + + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join4( + listen( + Arc::clone(&state), + Arc::clone(&source), + &last_ping_recv, + ), + send_voice( + Arc::clone(&sink), + connection_info.socket_addr, + phase_watcher.clone(), + Arc::clone(&receiver), + ), + send_pings( + Arc::clone(&state), + Arc::clone(&sink), + connection_info.socket_addr, + &last_ping_recv, + ), + new_crypt_state(&mut crypt_state_receiver, sink, source), + ).map(|_| ()), + || async {}, + phase_watcher, + ).await; debug!("Fully disconnected UDP stream, waiting for new connection info"); } @@ -100,76 +105,58 @@ async fn new_crypt_state( crypt_state: &mut mpsc::Receiver, sink: Arc>, source: Arc>, - phase_watcher: watch::Receiver, ) { - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - if let Some(crypt_state) = crypt_state.recv().await { - info!("Received new crypt state"); - let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) - .await - .expect("Failed to bind UDP socket"); - let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().await = new_sink; - *source.lock().await = new_source; - } - } - }, - || async {}, - phase_watcher, - ).await; + loop { + if let Some(crypt_state) = crypt_state.recv().await { + info!("Received new crypt state"); + let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) + .await + .expect("Failed to bind UDP socket"); + let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); + *sink.lock().await = new_sink; + *source.lock().await = new_source; + } + } } async fn listen( state: Arc>, source: Arc>, - phase_watcher: watch::Receiver, last_ping_recv: &AtomicU64, ) { - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let packet = source.lock().await.next().await.unwrap(); - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { timestamp } => { - state - .lock() //TODO clean up unnecessary lock by only updating phase if it should change - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); - last_ping_recv.store(timestamp, Ordering::Relaxed); - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() //TODO change so that we only have to lock audio and not the whole state - .await - .audio() - .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); - } - } + loop { + let packet = source.lock().await.next().await.unwrap(); + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; } - }, - || async {}, - phase_watcher - ).await; - - debug!("UDP listener process killed"); + }; + match packet { + VoicePacket::Ping { timestamp } => { + state + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); + last_ping_recv.store(timestamp, Ordering::Relaxed); + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() //TODO change so that we only have to lock audio and not the whole state + .await + .audio() + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); + } + } + } } async fn send_pings( @@ -177,44 +164,34 @@ async fn send_pings( sink: Arc>, server_addr: SocketAddr, last_ping_recv: &AtomicU64, - phase_watcher: watch::Receiver, ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - interval.tick().await; - let last_recv = last_ping_recv.load(Ordering::Relaxed); - if last_send.is_some() && last_send.unwrap() != last_recv { - debug!("Sending TCP voice"); - state - .lock() - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); - } - match sink - .lock() - .await - .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) - .await - { - Ok(_) => { - last_send = Some(last_recv + 1); - }, - Err(e) => { - debug!("Error sending UDP ping: {}", e); - } - } + loop { + interval.tick().await; + let last_recv = last_ping_recv.load(Ordering::Relaxed); + if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); + state + .lock() + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); + } + match sink + .lock() + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) + .await + { + Ok(_) => { + last_send = Some(last_recv + 1); + }, + Err(e) => { + debug!("Error sending UDP ping: {}", e); } - }, - || async {}, - phase_watcher, - ).await; - - debug!("UDP ping sender process killed"); + } + } } async fn send_voice( @@ -223,37 +200,27 @@ async fn send_voice( phase_watcher: watch::Receiver, receiver: Arc> + Unpin)>>>, ) { - let inner_phase_watcher = phase_watcher.clone(); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + let mut inner_phase_watcher = phase_watcher.clone(); + loop { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; loop { - inner_phase_watcher_2.changed().await.unwrap(); - if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { - break; - } + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); } - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - async { - let mut receiver = receiver.lock().await; - loop { - let sending = (receiver.next().await.unwrap(), server_addr); - sink.lock().await.send(sending).await.unwrap(); - } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - } - }, - || async {}, - phase_watcher, - ).await; - - debug!("UDP sender process killed"); + }, + || async {}, + phase_watcher.clone(), + ).await; + } } pub async fn handle_pings( -- cgit v1.2.1