diff options
Diffstat (limited to 'mumd/src/network/udp.rs')
| -rw-r--r-- | mumd/src/network/udp.rs | 259 |
1 files changed, 113 insertions, 146 deletions
diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 4167c15..9dd6ed3 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,7 +1,7 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use futures::{join, SinkExt, StreamExt, Stream}; +use futures::{join, FutureExt, SinkExt, StreamExt, Stream}; use futures_util::stream::{SplitSink, SplitStream}; use log::*; use mumble_protocol::crypt::ClientCryptState; @@ -20,6 +20,7 @@ use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; use super::{run_until, VoiceStreamType}; +use futures_util::future::join4; pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(PongPacket)>); @@ -49,28 +50,32 @@ pub async fn handle( let phase_watcher = state.lock().await.phase_receiver(); let last_ping_recv = AtomicU64::new(0); - join!( - listen( - Arc::clone(&state), - Arc::clone(&source), - phase_watcher.clone(), - &last_ping_recv, - ), - send_voice( - Arc::clone(&sink), - connection_info.socket_addr, - phase_watcher.clone(), - Arc::clone(&receiver), - ), - send_pings( - Arc::clone(&state), - Arc::clone(&sink), - connection_info.socket_addr, - &last_ping_recv, - phase_watcher.clone(), - ), - new_crypt_state(&mut crypt_state_receiver, sink, source, phase_watcher), - ); + + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join4( + listen( + Arc::clone(&state), + Arc::clone(&source), + &last_ping_recv, + ), + send_voice( + Arc::clone(&sink), + connection_info.socket_addr, + phase_watcher.clone(), + Arc::clone(&receiver), + ), + send_pings( + Arc::clone(&state), + Arc::clone(&sink), + connection_info.socket_addr, + &last_ping_recv, + ), + new_crypt_state(&mut crypt_state_receiver, sink, source), + ).map(|_| ()), + || async {}, + phase_watcher, + ).await; debug!("Fully disconnected UDP stream, waiting for new connection info"); } @@ -100,76 +105,58 @@ async fn new_crypt_state( crypt_state: &mut mpsc::Receiver<ClientCryptState>, sink: Arc<Mutex<UdpSender>>, source: Arc<Mutex<UdpReceiver>>, - phase_watcher: watch::Receiver<StatePhase>, ) { - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - if let Some(crypt_state) = crypt_state.recv().await { - info!("Received new crypt state"); - let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) - .await - .expect("Failed to bind UDP socket"); - let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().await = new_sink; - *source.lock().await = new_source; - } - } - }, - || async {}, - phase_watcher, - ).await; + loop { + if let Some(crypt_state) = crypt_state.recv().await { + info!("Received new crypt state"); + let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) + .await + .expect("Failed to bind UDP socket"); + let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); + *sink.lock().await = new_sink; + *source.lock().await = new_source; + } + } } async fn listen( state: Arc<Mutex<State>>, source: Arc<Mutex<UdpReceiver>>, - phase_watcher: watch::Receiver<StatePhase>, last_ping_recv: &AtomicU64, ) { - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let packet = source.lock().await.next().await.unwrap(); - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { timestamp } => { - state - .lock() //TODO clean up unnecessary lock by only updating phase if it should change - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); - last_ping_recv.store(timestamp, Ordering::Relaxed); - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() //TODO change so that we only have to lock audio and not the whole state - .await - .audio() - .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); - } - } + loop { + let packet = source.lock().await.next().await.unwrap(); + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; } - }, - || async {}, - phase_watcher - ).await; - - debug!("UDP listener process killed"); + }; + match packet { + VoicePacket::Ping { timestamp } => { + state + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); + last_ping_recv.store(timestamp, Ordering::Relaxed); + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() //TODO change so that we only have to lock audio and not the whole state + .await + .audio() + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); + } + } + } } async fn send_pings( @@ -177,44 +164,34 @@ async fn send_pings( sink: Arc<Mutex<UdpSender>>, server_addr: SocketAddr, last_ping_recv: &AtomicU64, - phase_watcher: watch::Receiver<StatePhase>, ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - interval.tick().await; - let last_recv = last_ping_recv.load(Ordering::Relaxed); - if last_send.is_some() && last_send.unwrap() != last_recv { - debug!("Sending TCP voice"); - state - .lock() - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); - } - match sink - .lock() - .await - .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) - .await - { - Ok(_) => { - last_send = Some(last_recv + 1); - }, - Err(e) => { - debug!("Error sending UDP ping: {}", e); - } - } + loop { + interval.tick().await; + let last_recv = last_ping_recv.load(Ordering::Relaxed); + if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); + state + .lock() + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); + } + match sink + .lock() + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) + .await + { + Ok(_) => { + last_send = Some(last_recv + 1); + }, + Err(e) => { + debug!("Error sending UDP ping: {}", e); } - }, - || async {}, - phase_watcher, - ).await; - - debug!("UDP ping sender process killed"); + } + } } async fn send_voice( @@ -223,37 +200,27 @@ async fn send_voice( phase_watcher: watch::Receiver<StatePhase>, receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, ) { - let inner_phase_watcher = phase_watcher.clone(); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + let mut inner_phase_watcher = phase_watcher.clone(); + loop { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; loop { - inner_phase_watcher_2.changed().await.unwrap(); - if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { - break; - } + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); } - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - async { - let mut receiver = receiver.lock().await; - loop { - let sending = (receiver.next().await.unwrap(), server_addr); - sink.lock().await.send(sending).await.unwrap(); - } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - } - }, - || async {}, - phase_watcher, - ).await; - - debug!("UDP sender process killed"); + }, + || async {}, + phase_watcher.clone(), + ).await; + } } pub async fn handle_pings( |
