diff options
Diffstat (limited to 'mumd/src/network/udp.rs')
| -rw-r--r-- | mumd/src/network/udp.rs | 271 |
1 files changed, 123 insertions, 148 deletions
diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 0c00029..5f24b51 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,23 +1,27 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use bytes::Bytes; -use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt, Stream}; +use futures::{join, FutureExt, SinkExt, StreamExt, Stream}; use futures_util::stream::{SplitSink, SplitStream}; use log::*; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::ping::{PingPacket, PongPacket}; -use mumble_protocol::voice::{VoicePacket, VoicePacketPayload}; +use mumble_protocol::voice::VoicePacket; use mumble_protocol::Serverbound; use std::collections::HashMap; use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use tokio::net::UdpSocket; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch, Mutex}; +use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; +use super::{run_until, VoiceStreamType}; +use futures_util::future::join4; + pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(PongPacket)>); type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>; @@ -28,7 +32,7 @@ pub async fn handle( mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>, ) { - let receiver = state.lock().unwrap().audio_mut().take_receiver(); + let receiver = state.lock().await.audio().input_receiver(); loop { let connection_info = 'data: loop { @@ -39,28 +43,38 @@ pub async fn handle( } return; }; - let (mut sink, source) = connect(&mut crypt_state_receiver).await; - - // Note: A normal application would also send periodic Ping packets, and its own audio - // via UDP. We instead trick the server into accepting us by sending it one - // dummy voice packet. - send_ping(&mut sink, connection_info.socket_addr).await; + let (sink, source) = connect(&mut crypt_state_receiver).await; let sink = Arc::new(Mutex::new(sink)); let source = Arc::new(Mutex::new(source)); - let phase_watcher = state.lock().unwrap().phase_receiver(); - let mut audio_receiver_lock = receiver.lock().unwrap(); - join!( - listen(Arc::clone(&state), Arc::clone(&source), phase_watcher.clone()), - send_voice( - Arc::clone(&sink), - connection_info.socket_addr, - phase_watcher, - &mut *audio_receiver_lock - ), - new_crypt_state(&mut crypt_state_receiver, sink, source) - ); + let phase_watcher = state.lock().await.phase_receiver(); + let last_ping_recv = AtomicU64::new(0); + + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join4( + listen( + Arc::clone(&state), + Arc::clone(&source), + &last_ping_recv, + ), + send_voice( + Arc::clone(&sink), + connection_info.socket_addr, + phase_watcher.clone(), + Arc::clone(&receiver), + ), + send_pings( + Arc::clone(&state), + Arc::clone(&sink), + connection_info.socket_addr, + &last_ping_recv, + ), + new_crypt_state(&mut crypt_state_receiver, sink, source), + ).map(|_| ()), + phase_watcher, + ).await; debug!("Fully disconnected UDP stream, waiting for new connection info"); } @@ -98,8 +112,8 @@ async fn new_crypt_state( .await .expect("Failed to bind UDP socket"); let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().unwrap() = new_sink; - *source.lock().unwrap() = new_source; + *sink.lock().await = new_sink; + *source.lock().await = new_source; } } } @@ -107,143 +121,104 @@ async fn new_crypt_state( async fn listen( state: Arc<Mutex<State>>, source: Arc<Mutex<UdpReceiver>>, - mut phase_watcher: watch::Receiver<StatePhase>, + last_ping_recv: &AtomicU64, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; + loop { + let packet = source.lock().await.next().await.unwrap(); + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let mut source = source.lock().unwrap(); - let packet_recv = source.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(packet)) => { - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { .. } => { - // Note: A normal application would handle these and only use UDP for voice - // once it has received one. - continue; - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() - .unwrap() - .audio() - .decode_packet(session_id, payload); - } - } - } + }; + match packet { + VoicePacket::Ping { timestamp } => { + state + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); + last_ping_recv.store(timestamp, Ordering::Relaxed); + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() //TODO change so that we only have to lock audio and not the whole state + .await + .audio() + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); } } - }; - - join!(main_block, phase_transition_block); - - debug!("UDP listener process killed"); + } } -async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { - sink.send(( - VoicePacket::Audio { - _dst: std::marker::PhantomData, - target: 0, - session_id: (), - seq_num: 0, - payload: VoicePacketPayload::Opus(Bytes::from([0u8; 128].as_ref()), true), - position_info: None, - }, - server_addr, - )) - .await - .unwrap(); +async fn send_pings( + state: Arc<Mutex<State>>, + sink: Arc<Mutex<UdpSender>>, + server_addr: SocketAddr, + last_ping_recv: &AtomicU64, +) { + let mut last_send = None; + let mut interval = interval(Duration::from_millis(1000)); + + loop { + interval.tick().await; + let last_recv = last_ping_recv.load(Ordering::Relaxed); + if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); + state + .lock() + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); + } + match sink + .lock() + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) + .await + { + Ok(_) => { + last_send = Some(last_recv + 1); + }, + Err(e) => { + debug!("Error sending UDP ping: {}", e); + } + } + } } async fn send_voice( sink: Arc<Mutex<UdpSender>>, server_addr: SocketAddr, - mut phase_watcher: watch::Receiver<StatePhase>, - receiver: &mut (dyn Stream<Item = VoicePacket<Serverbound>> + Unpin), + phase_watcher: watch::Receiver<StatePhase>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, ) { - pin_mut!(receiver); - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { + loop { + let mut inner_phase_watcher = phase_watcher.clone(); loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { break; } } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = receiver.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; + loop { + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); } - Some(Some(reply)) => { - sink.lock() - .unwrap() - .send((reply, server_addr)) - .await - .unwrap(); - } - } - } - }; - - join!(main_block, phase_transition_block); - - debug!("UDP sender process killed"); + }, + phase_watcher.clone(), + ).await; + } } pub async fn handle_pings( @@ -260,7 +235,7 @@ pub async fn handle_pings( let packet = PingPacket { id }; let packet: [u8; 12] = packet.into(); udp_socket.send_to(&packet, &socket_addr).await.unwrap(); - pending.lock().unwrap().insert(id, handle); + pending.lock().await.insert(id, handle); } }; @@ -271,7 +246,7 @@ pub async fn handle_pings( let packet = PongPacket::try_from(buf.as_slice()).unwrap(); - if let Some(handler) = pending.lock().unwrap().remove(&packet.id) { + if let Some(handler) = pending.lock().await.remove(&packet.id) { handler(packet); } } |
