diff options
Diffstat (limited to 'mumd/src')
| -rw-r--r-- | mumd/src/audio.rs | 59 | ||||
| -rw-r--r-- | mumd/src/audio/output.rs | 6 | ||||
| -rw-r--r-- | mumd/src/client.rs | 4 | ||||
| -rw-r--r-- | mumd/src/command.rs | 6 | ||||
| -rw-r--r-- | mumd/src/network.rs | 50 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 406 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 271 | ||||
| -rw-r--r-- | mumd/src/state.rs | 34 |
8 files changed, 422 insertions, 414 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 680433c..598dde6 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -2,6 +2,7 @@ pub mod input; pub mod output; use crate::audio::output::SaturatingAdd; +use crate::network::VoiceStreamType; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::{SampleFormat, SampleRate, StreamConfig}; @@ -75,14 +76,14 @@ pub struct Audio { _output_stream: cpal::Stream, _input_stream: cpal::Stream, - input_channel_receiver: Arc<Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>>, + input_channel_receiver: Arc<tokio::sync::Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>>, input_volume_sender: watch::Sender<f32>, output_volume_sender: watch::Sender<f32>, user_volumes: Arc<Mutex<HashMap<u32, (f32, bool)>>>, - client_streams: Arc<Mutex<HashMap<u32, output::ClientStream>>>, + client_streams: Arc<Mutex<HashMap<(VoiceStreamType, u32), output::ClientStream>>>, sounds: HashMap<NotificationEvents, Vec<f32>>, play_sounds: Arc<Mutex<VecDeque<f32>>>, @@ -131,11 +132,11 @@ impl Audio { let err_fn = |err| error!("An error occurred on the output audio stream: {}", err); - let user_volumes = Arc::new(Mutex::new(HashMap::new())); + let user_volumes = Arc::new(std::sync::Mutex::new(HashMap::new())); let (output_volume_sender, output_volume_receiver) = watch::channel::<f32>(output_volume); - let play_sounds = Arc::new(Mutex::new(VecDeque::new())); + let play_sounds = Arc::new(std::sync::Mutex::new(VecDeque::new())); - let client_streams = Arc::new(Mutex::new(HashMap::new())); + let client_streams = Arc::new(std::sync::Mutex::new(HashMap::new())); let output_stream = match output_supported_sample_format { SampleFormat::F32 => output_device.build_output_stream( &output_config, @@ -226,7 +227,7 @@ impl Audio { _output_stream: output_stream, _input_stream: input_stream, input_volume_sender, - input_channel_receiver: Arc::new(Mutex::new(Box::new(opus_stream))), + input_channel_receiver: Arc::new(tokio::sync::Mutex::new(Box::new(opus_stream))), client_streams, sounds: HashMap::new(), output_volume_sender, @@ -291,8 +292,8 @@ impl Audio { .collect(); } - pub fn decode_packet(&self, session_id: u32, payload: VoicePacketPayload) { - match self.client_streams.lock().unwrap().entry(session_id) { + pub fn decode_packet_payload(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) { + match self.client_streams.lock().unwrap().entry((stream_type, session_id)) { Entry::Occupied(mut entry) => { entry .get_mut() @@ -305,34 +306,38 @@ impl Audio { } pub fn add_client(&self, session_id: u32) { - match self.client_streams.lock().unwrap().entry(session_id) { - Entry::Occupied(_) => { - warn!("Session id {} already exists", session_id); - } - Entry::Vacant(entry) => { - entry.insert(output::ClientStream::new( - self.output_config.sample_rate.0, - self.output_config.channels, - )); + for stream_type in [VoiceStreamType::TCP, VoiceStreamType::UDP].iter() { + match self.client_streams.lock().unwrap().entry((*stream_type, session_id)) { + Entry::Occupied(_) => { + warn!("Session id {} already exists", session_id); + } + Entry::Vacant(entry) => { + entry.insert(output::ClientStream::new( + self.output_config.sample_rate.0, + self.output_config.channels, + )); + } } } } pub fn remove_client(&self, session_id: u32) { - match self.client_streams.lock().unwrap().entry(session_id) { - Entry::Occupied(entry) => { - entry.remove(); - } - Entry::Vacant(_) => { - warn!( - "Tried to remove session id {} that doesn't exist", - session_id - ); + for stream_type in [VoiceStreamType::TCP, VoiceStreamType::UDP].iter() { + match self.client_streams.lock().unwrap().entry((*stream_type, session_id)) { + Entry::Occupied(entry) => { + entry.remove(); + } + Entry::Vacant(_) => { + warn!( + "Tried to remove session id {} that doesn't exist", + session_id + ); + } } } } - pub fn take_receiver(&mut self) -> Arc<Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>> { + pub fn input_receiver(&self) -> Arc<tokio::sync::Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>> { Arc::clone(&self.input_channel_receiver) } diff --git a/mumd/src/audio/output.rs b/mumd/src/audio/output.rs index 5e0cb8d..421d395 100644 --- a/mumd/src/audio/output.rs +++ b/mumd/src/audio/output.rs @@ -1,3 +1,5 @@ +use crate::network::VoiceStreamType; + use cpal::{OutputCallbackInfo, Sample}; use mumble_protocol::voice::VoicePacketPayload; use opus::Channels; @@ -73,7 +75,7 @@ impl SaturatingAdd for u16 { pub fn curry_callback<T: Sample + AddAssign + SaturatingAdd + std::fmt::Display>( effect_sound: Arc<Mutex<VecDeque<f32>>>, - user_bufs: Arc<Mutex<HashMap<u32, ClientStream>>>, + user_bufs: Arc<Mutex<HashMap<(VoiceStreamType, u32), ClientStream>>>, output_volume_receiver: watch::Receiver<f32>, user_volumes: Arc<Mutex<HashMap<u32, (f32, bool)>>>, ) -> impl FnMut(&mut [T], &OutputCallbackInfo) + Send + 'static { @@ -86,7 +88,7 @@ pub fn curry_callback<T: Sample + AddAssign + SaturatingAdd + std::fmt::Display> let mut effects_sound = effect_sound.lock().unwrap(); let mut user_bufs = user_bufs.lock().unwrap(); - for (id, client_stream) in &mut *user_bufs { + for ((_, id), client_stream) in &mut *user_bufs { let (user_volume, muted) = user_volumes .lock() .unwrap() diff --git a/mumd/src/client.rs b/mumd/src/client.rs index 3613061..222e2a7 100644 --- a/mumd/src/client.rs +++ b/mumd/src/client.rs @@ -6,8 +6,8 @@ use futures_util::join; use ipc_channel::ipc::IpcSender; use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState}; use mumlib::command::{Command, CommandResponse}; -use std::sync::{Arc, Mutex}; -use tokio::sync::{mpsc, watch}; +use std::sync::Arc; +use tokio::sync::{mpsc, watch, Mutex}; pub async fn handle( command_receiver: mpsc::UnboundedReceiver<( diff --git a/mumd/src/command.rs b/mumd/src/command.rs index e77b34b..653d1fa 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -9,8 +9,8 @@ use ipc_channel::ipc::IpcSender; use log::*; use mumble_protocol::{Serverbound, control::ControlPacket}; use mumlib::command::{Command, CommandResponse}; -use std::sync::{Arc, Mutex}; -use tokio::sync::{mpsc, oneshot, watch}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; pub async fn handle( state: Arc<Mutex<State>>, @@ -26,7 +26,7 @@ pub async fn handle( debug!("Begin listening for commands"); while let Some((command, response_sender)) = command_receiver.recv().await { debug!("Received command {:?}", command); - let mut state = state.lock().unwrap(); + let mut state = state.lock().await; let event = state.handle_command(command, &mut packet_sender, &mut connection_info_sender); drop(state); match event { diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 1a31ee2..6c67b3a 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -1,7 +1,17 @@ pub mod tcp; pub mod udp; +use futures::Future; +use futures::FutureExt; +use futures::channel::oneshot; +use futures::join; +use futures::pin_mut; +use futures::select; +use log::*; use std::net::SocketAddr; +use tokio::sync::watch; + +use crate::state::StatePhase; #[derive(Clone, Debug)] pub struct ConnectionInfo { @@ -19,3 +29,43 @@ impl ConnectionInfo { } } } + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub enum VoiceStreamType { + TCP, + UDP, +} + +async fn run_until<F>( + phase_checker: impl Fn(StatePhase) -> bool, + fut: F, + mut phase_watcher: watch::Receiver<StatePhase>, +) where + F: Future<Output = ()>, +{ + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + loop { + phase_watcher.changed().await.unwrap(); + if phase_checker(*phase_watcher.borrow()) { + break; + } + } + if tx.send(true).is_err() { + warn!("future resolved before it could be cancelled"); + } + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + let fut = fut.fuse(); + pin_mut!(fut); + select! { + _ = fut => (), + _ = rx => (), + }; + }; + + join!(main_block, phase_transition_block); +} diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 47ea311..3a32b9f 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,24 +2,25 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; +use futures::{FutureExt, SinkExt, Stream, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; +use mumble_protocol::voice::VoicePacket; use mumble_protocol::{Clientbound, Serverbound}; -use std::cell::RefCell; use std::collections::HashMap; use std::convert::{Into, TryInto}; -use std::future::Future; use std::net::SocketAddr; -use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use tokio::net::TcpStream; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::{self, Duration}; use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use super::{run_until, VoiceStreamType}; +use futures_util::future::join5; + type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, ControlPacket<Serverbound>, @@ -65,26 +66,42 @@ pub async fn handle( .await; // Handshake (omitting `Version` message for brevity) - let state_lock = state.lock().unwrap(); + let state_lock = state.lock().await; authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; let phase_watcher = state_lock.phase_receiver(); + let input_receiver = state_lock.audio().input_receiver(); drop(state_lock); let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); - join!( - send_pings(packet_sender.clone(), 10, phase_watcher.clone()), - listen( - Arc::clone(&state), - stream, - crypt_state_sender.clone(), - Arc::clone(&event_queue), - phase_watcher.clone(), - ), - send_packets(sink, &mut packet_receiver, phase_watcher.clone()), - register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), - ); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join5( + send_pings(packet_sender.clone(), 10), + listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + Arc::clone(&event_queue), + ), + send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher.clone(), + ), + send_packets(sink, &mut packet_receiver), + register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)), + ).map(|_| ()), + phase_watcher, + ).await; + + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Disconnected); + } + } debug!("Fully disconnected TCP stream, waiting for new connection info"); } @@ -126,232 +143,187 @@ async fn authenticate(sink: &mut TcpSender, username: String) { async fn send_pings( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, delay_seconds: u64, - phase_watcher: watch::Receiver<StatePhase>, ) { - let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs( - delay_seconds, - )))); - let packet_sender = Rc::new(RefCell::new(packet_sender)); - - run_until_disconnection( - || async { Some(interval.borrow_mut().tick().await) }, - |_| async { - trace!("Sending ping"); + let mut interval = time::interval(Duration::from_secs(delay_seconds)); + loop { + interval.tick().await; + trace!("Sending TCP ping"); let msg = msgs::Ping::new(); - packet_sender.borrow_mut().send(msg.into()).unwrap(); - }, - || async {}, - phase_watcher, - ) - .await; - - debug!("Ping sender process killed"); + packet_sender.send(msg.into()).unwrap(); + } } async fn send_packets( - sink: TcpSender, + mut sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, - phase_watcher: watch::Receiver<StatePhase>, ) { - let sink = Rc::new(RefCell::new(sink)); - let packet_receiver = Rc::new(RefCell::new(packet_receiver)); - run_until_disconnection( - || async { packet_receiver.borrow_mut().recv().await }, - |packet| async { - sink.borrow_mut().send(packet).await.unwrap(); - }, - || async { - sink.borrow_mut().close().await.unwrap(); - }, - phase_watcher, - ) - .await; + loop { + let packet = packet_receiver.recv().await.unwrap(); + sink.send(packet).await.unwrap(); + } +} - debug!("TCP packet sender killed"); +async fn send_voice( + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, + phase_watcher: watch::Receiver<StatePhase>, +) { + loop { + let mut inner_phase_watcher = phase_watcher.clone(); + loop { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + async { + loop { + packet_sender.send( + receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); + } + }, + inner_phase_watcher.clone(), + ).await; + } } async fn listen( state: Arc<Mutex<State>>, - stream: TcpReceiver, + mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, - phase_watcher: watch::Receiver<StatePhase>, ) { - let crypt_state = Rc::new(RefCell::new(None)); - let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); + let mut crypt_state = None; + let mut crypt_state_sender = Some(crypt_state_sender); - let stream = Rc::new(RefCell::new(stream)); - run_until_disconnection( - || async { stream.borrow_mut().next().await }, - |packet| async { - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); + loop { + let packet = stream.next().await.unwrap(); + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + crypt_state = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.take() { + let _ = sender + .send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.borrow_mut().take() { - let _ = sender - .send( - crypt_state - .borrow_mut() - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Connected(&msg)); } - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); - } - } - let mut state = state.lock().unwrap(); - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), - } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); - } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); } - ControlPacket::UserState(msg) => { - state.lock().unwrap().parse_user_state(*msg); + let mut state = state.lock().await; + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), } - ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - packet => { - debug!("Received unhandled ControlPacket {:#?}", packet); + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + state.lock().await.parse_user_state(*msg); + } + ControlPacket::UserRemove(msg) => { + state.lock().await.remove_client(*msg); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_remove(*msg); } - }, - || async { - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Disconnected); + ControlPacket::UDPTunnel(msg) => { + match *msg { + VoicePacket::Ping { .. } => {} + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() + .await + .audio() + .decode_packet_payload( + VoiceStreamType::TCP, + session_id, + payload); + } } } - }, - phase_watcher, - ) - .await; - - debug!("Killing TCP listener block"); + packet => { + debug!("Received unhandled ControlPacket {:#?}", packet); + } + } + } } async fn register_events( tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, - phase_watcher: watch::Receiver<StatePhase>, ) { - let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); - run_until_disconnection( - || async { tcp_event_register_receiver.borrow_mut().recv().await }, - |(event, handler)| async { - event_data - .lock() - .unwrap() - .entry(event) - .or_default() - .push(handler); - }, - || async {}, - phase_watcher, - ) - .await; -} - -async fn run_until_disconnection<T, F, G, H>( - mut generator: impl FnMut() -> F, - mut handler: impl FnMut(T) -> G, - mut shutdown: impl FnMut() -> H, - mut phase_watcher: watch::Receiver<StatePhase>, -) where - F: Future<Output = Option<T>>, - G: Future<Output = ()>, - H: Future<Output = ()>, -{ - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; - } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = generator().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - //warn!("Channel closed before disconnect command"); //TODO make me informative - break; - } - Some(Some(data)) => { - handler(data).await; - } - } - } - - shutdown().await; - }; - - join!(main_block, phase_transition_block); + loop { + let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); + event_data + .lock() + .await + .entry(event) + .or_default() + .push(handler); + } } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 0c00029..5f24b51 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,23 +1,27 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use bytes::Bytes; -use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt, Stream}; +use futures::{join, FutureExt, SinkExt, StreamExt, Stream}; use futures_util::stream::{SplitSink, SplitStream}; use log::*; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::ping::{PingPacket, PongPacket}; -use mumble_protocol::voice::{VoicePacket, VoicePacketPayload}; +use mumble_protocol::voice::VoicePacket; use mumble_protocol::Serverbound; use std::collections::HashMap; use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use tokio::net::UdpSocket; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch, Mutex}; +use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; +use super::{run_until, VoiceStreamType}; +use futures_util::future::join4; + pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(PongPacket)>); type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>; @@ -28,7 +32,7 @@ pub async fn handle( mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>, ) { - let receiver = state.lock().unwrap().audio_mut().take_receiver(); + let receiver = state.lock().await.audio().input_receiver(); loop { let connection_info = 'data: loop { @@ -39,28 +43,38 @@ pub async fn handle( } return; }; - let (mut sink, source) = connect(&mut crypt_state_receiver).await; - - // Note: A normal application would also send periodic Ping packets, and its own audio - // via UDP. We instead trick the server into accepting us by sending it one - // dummy voice packet. - send_ping(&mut sink, connection_info.socket_addr).await; + let (sink, source) = connect(&mut crypt_state_receiver).await; let sink = Arc::new(Mutex::new(sink)); let source = Arc::new(Mutex::new(source)); - let phase_watcher = state.lock().unwrap().phase_receiver(); - let mut audio_receiver_lock = receiver.lock().unwrap(); - join!( - listen(Arc::clone(&state), Arc::clone(&source), phase_watcher.clone()), - send_voice( - Arc::clone(&sink), - connection_info.socket_addr, - phase_watcher, - &mut *audio_receiver_lock - ), - new_crypt_state(&mut crypt_state_receiver, sink, source) - ); + let phase_watcher = state.lock().await.phase_receiver(); + let last_ping_recv = AtomicU64::new(0); + + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join4( + listen( + Arc::clone(&state), + Arc::clone(&source), + &last_ping_recv, + ), + send_voice( + Arc::clone(&sink), + connection_info.socket_addr, + phase_watcher.clone(), + Arc::clone(&receiver), + ), + send_pings( + Arc::clone(&state), + Arc::clone(&sink), + connection_info.socket_addr, + &last_ping_recv, + ), + new_crypt_state(&mut crypt_state_receiver, sink, source), + ).map(|_| ()), + phase_watcher, + ).await; debug!("Fully disconnected UDP stream, waiting for new connection info"); } @@ -98,8 +112,8 @@ async fn new_crypt_state( .await .expect("Failed to bind UDP socket"); let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().unwrap() = new_sink; - *source.lock().unwrap() = new_source; + *sink.lock().await = new_sink; + *source.lock().await = new_source; } } } @@ -107,143 +121,104 @@ async fn new_crypt_state( async fn listen( state: Arc<Mutex<State>>, source: Arc<Mutex<UdpReceiver>>, - mut phase_watcher: watch::Receiver<StatePhase>, + last_ping_recv: &AtomicU64, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; + loop { + let packet = source.lock().await.next().await.unwrap(); + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let mut source = source.lock().unwrap(); - let packet_recv = source.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(packet)) => { - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { .. } => { - // Note: A normal application would handle these and only use UDP for voice - // once it has received one. - continue; - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() - .unwrap() - .audio() - .decode_packet(session_id, payload); - } - } - } + }; + match packet { + VoicePacket::Ping { timestamp } => { + state + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); + last_ping_recv.store(timestamp, Ordering::Relaxed); + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() //TODO change so that we only have to lock audio and not the whole state + .await + .audio() + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); } } - }; - - join!(main_block, phase_transition_block); - - debug!("UDP listener process killed"); + } } -async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { - sink.send(( - VoicePacket::Audio { - _dst: std::marker::PhantomData, - target: 0, - session_id: (), - seq_num: 0, - payload: VoicePacketPayload::Opus(Bytes::from([0u8; 128].as_ref()), true), - position_info: None, - }, - server_addr, - )) - .await - .unwrap(); +async fn send_pings( + state: Arc<Mutex<State>>, + sink: Arc<Mutex<UdpSender>>, + server_addr: SocketAddr, + last_ping_recv: &AtomicU64, +) { + let mut last_send = None; + let mut interval = interval(Duration::from_millis(1000)); + + loop { + interval.tick().await; + let last_recv = last_ping_recv.load(Ordering::Relaxed); + if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); + state + .lock() + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); + } + match sink + .lock() + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) + .await + { + Ok(_) => { + last_send = Some(last_recv + 1); + }, + Err(e) => { + debug!("Error sending UDP ping: {}", e); + } + } + } } async fn send_voice( sink: Arc<Mutex<UdpSender>>, server_addr: SocketAddr, - mut phase_watcher: watch::Receiver<StatePhase>, - receiver: &mut (dyn Stream<Item = VoicePacket<Serverbound>> + Unpin), + phase_watcher: watch::Receiver<StatePhase>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, ) { - pin_mut!(receiver); - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { + loop { + let mut inner_phase_watcher = phase_watcher.clone(); loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { break; } } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = receiver.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; + loop { + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); } - Some(Some(reply)) => { - sink.lock() - .unwrap() - .send((reply, server_addr)) - .await - .unwrap(); - } - } - } - }; - - join!(main_block, phase_transition_block); - - debug!("UDP sender process killed"); + }, + phase_watcher.clone(), + ).await; + } } pub async fn handle_pings( @@ -260,7 +235,7 @@ pub async fn handle_pings( let packet = PingPacket { id }; let packet: [u8; 12] = packet.into(); udp_socket.send_to(&packet, &socket_addr).await.unwrap(); - pending.lock().unwrap().insert(id, handle); + pending.lock().await.insert(id, handle); } }; @@ -271,7 +246,7 @@ pub async fn handle_pings( let packet = PongPacket::try_from(buf.as_slice()).unwrap(); - if let Some(handler) = pending.lock().unwrap().remove(&packet.id) { + if let Some(handler) = pending.lock().await.remove(&packet.id) { handler(packet); } } diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 84247bc..2ed73b2 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -3,11 +3,11 @@ pub mod server; pub mod user; use crate::audio::{Audio, NotificationEvents}; -use crate::network::ConnectionInfo; +use crate::network::{ConnectionInfo, VoiceStreamType}; +use crate::network::tcp::{TcpEvent, TcpEventData}; use crate::notify; use crate::state::server::Server; -use crate::network::tcp::{TcpEvent, TcpEventData}; use log::*; use mumble_protocol::control::msgs; use mumble_protocol::control::ControlPacket; @@ -45,11 +45,11 @@ pub enum ExecutionContext { ), } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum StatePhase { Disconnected, Connecting, - Connected, + Connected(VoiceStreamType), } pub struct State { @@ -85,7 +85,7 @@ impl State { ) -> ExecutionContext { match command { Command::ChannelJoin { channel_identifier } => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -135,7 +135,7 @@ impl State { now!(Ok(None)) } Command::ChannelList => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } let list = channel::into_channel( @@ -149,7 +149,7 @@ impl State { now!(Ok(None)) } Command::DeafenSelf(toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -207,7 +207,7 @@ impl State { now!(Ok(None)) } Command::MuteOther(string, toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -242,7 +242,7 @@ impl State { return now!(Ok(None)); } Command::MuteSelf(toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -354,7 +354,7 @@ impl State { }) } Command::ServerDisconnect => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -388,7 +388,7 @@ impl State { }), ), Command::Status => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } let state = self.server.as_ref().unwrap().into(); @@ -397,7 +397,7 @@ impl State { }))) } Command::UserVolumeSet(string, volume) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } let user_id = match self @@ -448,7 +448,7 @@ impl State { self.audio_mut().add_client(session); // send notification only if we've passed the connecting phase - if *self.phase_receiver().borrow() == StatePhase::Connected { + if matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { let channel_id = msg.get_channel_id(); if channel_id @@ -578,11 +578,15 @@ impl State { } } - pub fn initialized(&self) { + pub fn broadcast_phase(&self, phase: StatePhase) { self.phase_watcher .0 - .send(StatePhase::Connected) + .send(phase) .unwrap(); + } + + pub fn initialized(&self) { + self.broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); self.audio.play_effect(NotificationEvents::ServerConnect); } |
