From 55644de7b35421997198c9dec4a8bba5dfb8dd8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Tue, 5 Jan 2021 12:47:04 +0100 Subject: add voice stream type --- mumd/src/audio.rs | 47 ++++++++++++++++++++++++++--------------------- mumd/src/audio/output.rs | 6 ++++-- mumd/src/network.rs | 6 ++++++ mumd/src/network/udp.rs | 4 +++- mumd/src/state.rs | 28 ++++++++++++++-------------- 5 files changed, 53 insertions(+), 38 deletions(-) diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 680433c..40cdcb2 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -2,6 +2,7 @@ pub mod input; pub mod output; use crate::audio::output::SaturatingAdd; +use crate::network::VoiceStreamType; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::{SampleFormat, SampleRate, StreamConfig}; @@ -82,7 +83,7 @@ pub struct Audio { user_volumes: Arc>>, - client_streams: Arc>>, + client_streams: Arc>>, sounds: HashMap>, play_sounds: Arc>>, @@ -291,8 +292,8 @@ impl Audio { .collect(); } - pub fn decode_packet(&self, session_id: u32, payload: VoicePacketPayload) { - match self.client_streams.lock().unwrap().entry(session_id) { + pub fn decode_packet(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) { + match self.client_streams.lock().unwrap().entry((stream_type, session_id)) { Entry::Occupied(mut entry) => { entry .get_mut() @@ -305,29 +306,33 @@ impl Audio { } pub fn add_client(&self, session_id: u32) { - match self.client_streams.lock().unwrap().entry(session_id) { - Entry::Occupied(_) => { - warn!("Session id {} already exists", session_id); - } - Entry::Vacant(entry) => { - entry.insert(output::ClientStream::new( - self.output_config.sample_rate.0, - self.output_config.channels, - )); + for stream_type in [VoiceStreamType::TCP, VoiceStreamType::UDP].iter() { + match self.client_streams.lock().unwrap().entry((*stream_type, session_id)) { + Entry::Occupied(_) => { + warn!("Session id {} already exists", session_id); + } + Entry::Vacant(entry) => { + entry.insert(output::ClientStream::new( + self.output_config.sample_rate.0, + self.output_config.channels, + )); + } } } } pub fn remove_client(&self, session_id: u32) { - match self.client_streams.lock().unwrap().entry(session_id) { - Entry::Occupied(entry) => { - entry.remove(); - } - Entry::Vacant(_) => { - warn!( - "Tried to remove session id {} that doesn't exist", - session_id - ); + for stream_type in [VoiceStreamType::TCP, VoiceStreamType::UDP].iter() { + match self.client_streams.lock().unwrap().entry((*stream_type, session_id)) { + Entry::Occupied(entry) => { + entry.remove(); + } + Entry::Vacant(_) => { + warn!( + "Tried to remove session id {} that doesn't exist", + session_id + ); + } } } } diff --git a/mumd/src/audio/output.rs b/mumd/src/audio/output.rs index 5e0cb8d..421d395 100644 --- a/mumd/src/audio/output.rs +++ b/mumd/src/audio/output.rs @@ -1,3 +1,5 @@ +use crate::network::VoiceStreamType; + use cpal::{OutputCallbackInfo, Sample}; use mumble_protocol::voice::VoicePacketPayload; use opus::Channels; @@ -73,7 +75,7 @@ impl SaturatingAdd for u16 { pub fn curry_callback( effect_sound: Arc>>, - user_bufs: Arc>>, + user_bufs: Arc>>, output_volume_receiver: watch::Receiver, user_volumes: Arc>>, ) -> impl FnMut(&mut [T], &OutputCallbackInfo) + Send + 'static { @@ -86,7 +88,7 @@ pub fn curry_callback let mut effects_sound = effect_sound.lock().unwrap(); let mut user_bufs = user_bufs.lock().unwrap(); - for (id, client_stream) in &mut *user_bufs { + for ((_, id), client_stream) in &mut *user_bufs { let (user_volume, muted) = user_volumes .lock() .unwrap() diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 1a31ee2..4fb2e77 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -19,3 +19,9 @@ impl ConnectionInfo { } } } + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub enum VoiceStreamType { + TCP, + UDP, +} diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 0c00029..1465e8c 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -18,6 +18,8 @@ use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, watch}; use tokio_util::udp::UdpFramed; +use super::VoiceStreamType; + pub type PingRequest = (u64, SocketAddr, Box); type UdpSender = SplitSink, (VoicePacket, SocketAddr)>; @@ -165,7 +167,7 @@ async fn listen( .lock() .unwrap() .audio() - .decode_packet(session_id, payload); + .decode_packet(VoiceStreamType::UDP, session_id, payload); } } } diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 84247bc..4e8a886 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -3,11 +3,11 @@ pub mod server; pub mod user; use crate::audio::{Audio, NotificationEvents}; -use crate::network::ConnectionInfo; +use crate::network::{ConnectionInfo, VoiceStreamType}; +use crate::network::tcp::{TcpEvent, TcpEventData}; use crate::notify; use crate::state::server::Server; -use crate::network::tcp::{TcpEvent, TcpEventData}; use log::*; use mumble_protocol::control::msgs; use mumble_protocol::control::ControlPacket; @@ -45,11 +45,11 @@ pub enum ExecutionContext { ), } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum StatePhase { Disconnected, Connecting, - Connected, + Connected(VoiceStreamType), } pub struct State { @@ -85,7 +85,7 @@ impl State { ) -> ExecutionContext { match command { Command::ChannelJoin { channel_identifier } => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -135,7 +135,7 @@ impl State { now!(Ok(None)) } Command::ChannelList => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } let list = channel::into_channel( @@ -149,7 +149,7 @@ impl State { now!(Ok(None)) } Command::DeafenSelf(toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -207,7 +207,7 @@ impl State { now!(Ok(None)) } Command::MuteOther(string, toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -242,7 +242,7 @@ impl State { return now!(Ok(None)); } Command::MuteSelf(toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -354,7 +354,7 @@ impl State { }) } Command::ServerDisconnect => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } @@ -388,7 +388,7 @@ impl State { }), ), Command::Status => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } let state = self.server.as_ref().unwrap().into(); @@ -397,7 +397,7 @@ impl State { }))) } Command::UserVolumeSet(string, volume) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { return now!(Err(Error::DisconnectedError)); } let user_id = match self @@ -448,7 +448,7 @@ impl State { self.audio_mut().add_client(session); // send notification only if we've passed the connecting phase - if *self.phase_receiver().borrow() == StatePhase::Connected { + if matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { let channel_id = msg.get_channel_id(); if channel_id @@ -581,7 +581,7 @@ impl State { pub fn initialized(&self) { self.phase_watcher .0 - .send(StatePhase::Connected) + .send(StatePhase::Connected(VoiceStreamType::UDP)) .unwrap(); self.audio.play_effect(NotificationEvents::ServerConnect); } -- cgit v1.2.1 From 00969263678bf0626de8229fd21b1d5d183b62e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Tue, 5 Jan 2021 12:58:57 +0100 Subject: send actual udp pings regularly --- mumd/src/network/udp.rs | 54 +++++++++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 1465e8c..cfbabe1 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,13 +1,12 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use bytes::Bytes; use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt, Stream}; use futures_util::stream::{SplitSink, SplitStream}; use log::*; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::ping::{PingPacket, PongPacket}; -use mumble_protocol::voice::{VoicePacket, VoicePacketPayload}; +use mumble_protocol::voice::VoicePacket; use mumble_protocol::Serverbound; use std::collections::HashMap; use std::convert::TryFrom; @@ -16,6 +15,7 @@ use std::rc::Rc; use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, watch}; +use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; use super::VoiceStreamType; @@ -41,12 +41,7 @@ pub async fn handle( } return; }; - let (mut sink, source) = connect(&mut crypt_state_receiver).await; - - // Note: A normal application would also send periodic Ping packets, and its own audio - // via UDP. We instead trick the server into accepting us by sending it one - // dummy voice packet. - send_ping(&mut sink, connection_info.socket_addr).await; + let (sink, source) = connect(&mut crypt_state_receiver).await; let sink = Arc::new(Mutex::new(sink)); let source = Arc::new(Mutex::new(source)); @@ -54,14 +49,22 @@ pub async fn handle( let phase_watcher = state.lock().unwrap().phase_receiver(); let mut audio_receiver_lock = receiver.lock().unwrap(); join!( - listen(Arc::clone(&state), Arc::clone(&source), phase_watcher.clone()), + listen( + Arc::clone(&state), + Arc::clone(&source), + phase_watcher.clone() + ), send_voice( Arc::clone(&sink), connection_info.socket_addr, phase_watcher, &mut *audio_receiver_lock ), - new_crypt_state(&mut crypt_state_receiver, sink, source) + send_pings( + Arc::clone(&sink), + connection_info.socket_addr, + ), + new_crypt_state(&mut crypt_state_receiver, sink, source), ); debug!("Fully disconnected UDP stream, waiting for new connection info"); @@ -180,20 +183,23 @@ async fn listen( debug!("UDP listener process killed"); } -async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { - sink.send(( - VoicePacket::Audio { - _dst: std::marker::PhantomData, - target: 0, - session_id: (), - seq_num: 0, - payload: VoicePacketPayload::Opus(Bytes::from([0u8; 128].as_ref()), true), - position_info: None, - }, - server_addr, - )) - .await - .unwrap(); +async fn send_pings(sink: Arc>, server_addr: SocketAddr) { + let mut interval = interval(Duration::from_millis(1000)); + + loop { + match sink + .lock() + .unwrap() + .send((VoicePacket::Ping { timestamp: 0 }, server_addr)) + .await + { + Ok(_) => { /* TODO */ }, + Err(e) => { + debug!("Error sending UDP ping: {}", e); + } + } + interval.tick().await; + } } async fn send_voice( -- cgit v1.2.1 From 6c59a37fbfce72a92581b362048b509dcb67dae1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Tue, 5 Jan 2021 13:09:23 +0100 Subject: compare udp ping responses to sent values --- mumd/src/network/udp.rs | 39 +++++++++++++++++++++++++++++++-------- mumd/src/state.rs | 8 ++++++-- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index cfbabe1..1bc012d 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -12,6 +12,7 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::rc::Rc; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, watch}; @@ -47,22 +48,26 @@ pub async fn handle( let source = Arc::new(Mutex::new(source)); let phase_watcher = state.lock().unwrap().phase_receiver(); + let last_ping_recv = AtomicU64::new(0); let mut audio_receiver_lock = receiver.lock().unwrap(); join!( listen( Arc::clone(&state), Arc::clone(&source), - phase_watcher.clone() + phase_watcher.clone(), + &last_ping_recv, ), send_voice( Arc::clone(&sink), connection_info.socket_addr, phase_watcher, - &mut *audio_receiver_lock + &mut *audio_receiver_lock, ), send_pings( + Arc::clone(&state), Arc::clone(&sink), connection_info.socket_addr, + &last_ping_recv, ), new_crypt_state(&mut crypt_state_receiver, sink, source), ); @@ -113,6 +118,7 @@ async fn listen( state: Arc>, source: Arc>, mut phase_watcher: watch::Receiver, + last_ping_recv: &AtomicU64, ) { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { @@ -154,10 +160,12 @@ async fn listen( } }; match packet { - VoicePacket::Ping { .. } => { - // Note: A normal application would handle these and only use UDP for voice - // once it has received one. - continue; + VoicePacket::Ping { timestamp } => { + state + .lock() + .unwrap() + .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); + last_ping_recv.store(timestamp, Ordering::Relaxed); } VoicePacket::Audio { session_id, @@ -183,17 +191,32 @@ async fn listen( debug!("UDP listener process killed"); } -async fn send_pings(sink: Arc>, server_addr: SocketAddr) { +async fn send_pings( + state: Arc>, + sink: Arc>, + server_addr: SocketAddr, + last_ping_recv: &AtomicU64, +) { + let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); loop { + let last_recv = last_ping_recv.load(Ordering::Relaxed); + if last_send.is_some() && last_send.unwrap() != last_recv { + state + .lock() + .unwrap() + .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); + } match sink .lock() .unwrap() .send((VoicePacket::Ping { timestamp: 0 }, server_addr)) .await { - Ok(_) => { /* TODO */ }, + Ok(_) => { + last_send = Some(last_recv + 1); + }, Err(e) => { debug!("Error sending UDP ping: {}", e); } diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 4e8a886..2ed73b2 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -578,11 +578,15 @@ impl State { } } - pub fn initialized(&self) { + pub fn broadcast_phase(&self, phase: StatePhase) { self.phase_watcher .0 - .send(StatePhase::Connected(VoiceStreamType::UDP)) + .send(phase) .unwrap(); + } + + pub fn initialized(&self) { + self.broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); self.audio.play_effect(NotificationEvents::ServerConnect); } -- cgit v1.2.1 From ab038b58b4440804cdfded56167ce72b599d87c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Tue, 5 Jan 2021 17:08:48 +0100 Subject: yikes --- mumd/src/audio.rs | 2 +- mumd/src/network.rs | 62 ++++++++++++++++++++++++++ mumd/src/network/tcp.rs | 113 +++++++++++++++++++++++------------------------- mumd/src/network/udp.rs | 73 +++++++++++-------------------- 4 files changed, 143 insertions(+), 107 deletions(-) diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 40cdcb2..4f9b73c 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -337,7 +337,7 @@ impl Audio { } } - pub fn take_receiver(&mut self) -> Arc> + Unpin>>> { + pub fn input_receiver(&self) -> Arc> + Unpin>>> { Arc::clone(&self.input_channel_receiver) } diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 4fb2e77..03bc436 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -3,6 +3,16 @@ pub mod udp; use std::net::SocketAddr; +use futures::Future; +use futures::FutureExt; +use futures::channel::oneshot; +use futures::join; +use futures::pin_mut; +use futures::select; +use tokio::sync::watch; + +use crate::state::StatePhase; + #[derive(Clone, Debug)] pub struct ConnectionInfo { socket_addr: SocketAddr, @@ -25,3 +35,55 @@ pub enum VoiceStreamType { TCP, UDP, } + +async fn run_until( + phase_checker: impl Fn(StatePhase) -> bool, + mut generator: impl FnMut() -> F, + mut handler: impl FnMut(T) -> G, + mut shutdown: impl FnMut() -> H, + mut phase_watcher: watch::Receiver, +) where + F: Future>, + G: Future, + H: Future, +{ + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + loop { + phase_watcher.changed().await.unwrap(); + if phase_checker(*phase_watcher.borrow()) { + break; + } + } + tx.send(true).unwrap(); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = generator().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + //warn!("Channel closed before disconnect command"); //TODO make me informative + break; + } + Some(Some(data)) => { + handler(data).await; + } + } + } + + shutdown().await; + }; + + join!(main_block, phase_transition_block); +} diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 47ea311..f767446 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,25 +1,28 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; +use futures::Stream; use log::*; -use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; +use futures::{join, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; +use mumble_protocol::voice::VoicePacket; use mumble_protocol::{Clientbound, Serverbound}; use std::cell::RefCell; use std::collections::HashMap; use std::convert::{Into, TryInto}; -use std::future::Future; use std::net::SocketAddr; use std::rc::Rc; use std::sync::{Arc, Mutex}; use tokio::net::TcpStream; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch}; use tokio::time::{self, Duration}; use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use super::{run_until, VoiceStreamType}; + type TcpSender = SplitSink< Framed, ControlCodec>, ControlPacket, @@ -68,11 +71,13 @@ pub async fn handle( let state_lock = state.lock().unwrap(); authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; let phase_watcher = state_lock.phase_receiver(); + let input_receiver = state_lock.audio().input_receiver(); drop(state_lock); let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); + //TODO force exit all futures on disconnection join!( send_pings(packet_sender.clone(), 10, phase_watcher.clone()), listen( @@ -82,6 +87,11 @@ pub async fn handle( Arc::clone(&event_queue), phase_watcher.clone(), ), + send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher.clone(), + ), send_packets(sink, &mut packet_receiver, phase_watcher.clone()), register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), ); @@ -133,7 +143,8 @@ async fn send_pings( )))); let packet_sender = Rc::new(RefCell::new(packet_sender)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { Some(interval.borrow_mut().tick().await) }, |_| async { trace!("Sending ping"); @@ -155,7 +166,8 @@ async fn send_packets( ) { let sink = Rc::new(RefCell::new(sink)); let packet_receiver = Rc::new(RefCell::new(packet_receiver)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { packet_receiver.borrow_mut().recv().await }, |packet| async { sink.borrow_mut().send(packet).await.unwrap(); @@ -170,6 +182,40 @@ async fn send_packets( debug!("TCP packet sender killed"); } +async fn send_voice( + packet_sender: mpsc::UnboundedSender>, + receiver: Arc> + Unpin)>>>, + phase_watcher: watch::Receiver, +) { + let inner_phase_watcher = phase_watcher.clone(); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + || async { + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + || async { + packet_sender.send(receiver + .lock() + .unwrap() + .next() + .await + .unwrap() + .into()) + .unwrap(); + Some(Some(())) + }, + |_| async {}, + || async {}, + inner_phase_watcher.clone(), + ).await; + Some(Some(())) + }, + |_| async {}, + || async {}, + phase_watcher, + ).await; +} + async fn listen( state: Arc>, stream: TcpReceiver, @@ -181,7 +227,8 @@ async fn listen( let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); let stream = Rc::new(RefCell::new(stream)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { stream.borrow_mut().next().await }, |packet| async { match packet.unwrap() { @@ -289,7 +336,8 @@ async fn register_events( phase_watcher: watch::Receiver, ) { let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { tcp_event_register_receiver.borrow_mut().recv().await }, |(event, handler)| async { event_data @@ -304,54 +352,3 @@ async fn register_events( ) .await; } - -async fn run_until_disconnection( - mut generator: impl FnMut() -> F, - mut handler: impl FnMut(T) -> G, - mut shutdown: impl FnMut() -> H, - mut phase_watcher: watch::Receiver, -) where - F: Future>, - G: Future, - H: Future, -{ - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; - } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = generator().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - //warn!("Channel closed before disconnect command"); //TODO make me informative - break; - } - Some(Some(data)) => { - handler(data).await; - } - } - } - - shutdown().await; - }; - - join!(main_block, phase_transition_block); -} diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 1bc012d..9435e94 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -19,7 +19,7 @@ use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; -use super::VoiceStreamType; +use super::{run_until, VoiceStreamType}; pub type PingRequest = (u64, SocketAddr, Box); @@ -31,7 +31,7 @@ pub async fn handle( mut connection_info_receiver: watch::Receiver>, mut crypt_state_receiver: mpsc::Receiver, ) { - let receiver = state.lock().unwrap().audio_mut().take_receiver(); + let receiver = state.lock().unwrap().audio().input_receiver(); loop { let connection_info = 'data: loop { @@ -49,7 +49,6 @@ pub async fn handle( let phase_watcher = state.lock().unwrap().phase_receiver(); let last_ping_recv = AtomicU64::new(0); - let mut audio_receiver_lock = receiver.lock().unwrap(); join!( listen( Arc::clone(&state), @@ -61,7 +60,7 @@ pub async fn handle( Arc::clone(&sink), connection_info.socket_addr, phase_watcher, - &mut *audio_receiver_lock, + Arc::clone(&receiver), ), send_pings( Arc::clone(&state), @@ -228,51 +227,29 @@ async fn send_pings( async fn send_voice( sink: Arc>, server_addr: SocketAddr, - mut phase_watcher: watch::Receiver, - receiver: &mut (dyn Stream> + Unpin), + phase_watcher: watch::Receiver, + receiver: Arc> + Unpin)>>>, ) { - pin_mut!(receiver); - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; - } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = receiver.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(reply)) => { - sink.lock() - .unwrap() - .send((reply, server_addr)) - .await - .unwrap(); - } - } - } - }; - - join!(main_block, phase_transition_block); + let inner_phase_watcher = phase_watcher.clone(); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + || async { + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + || async { + sink.lock().unwrap().send((receiver.lock().unwrap().next().await.unwrap(), server_addr)).await.unwrap(); + Some(Some(())) + }, + |_| async {}, + || async {}, + inner_phase_watcher.clone(), + ).await; + Some(Some(())) + }, + |_| async {}, + || async {}, + phase_watcher, + ).await; debug!("UDP sender process killed"); } -- cgit v1.2.1 From b15e010a6bebc7b7c6b8afb1b51f2673d0695e06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Tue, 5 Jan 2021 20:02:32 +0100 Subject: tokio mutex --- mumd/src/audio.rs | 6 +++--- mumd/src/network/tcp.rs | 5 +++-- mumd/src/network/udp.rs | 6 ++++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 4f9b73c..3f03e61 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -76,7 +76,7 @@ pub struct Audio { _output_stream: cpal::Stream, _input_stream: cpal::Stream, - input_channel_receiver: Arc> + Unpin>>>, + input_channel_receiver: Arc> + Unpin>>>, input_volume_sender: watch::Sender, output_volume_sender: watch::Sender, @@ -227,7 +227,7 @@ impl Audio { _output_stream: output_stream, _input_stream: input_stream, input_volume_sender, - input_channel_receiver: Arc::new(Mutex::new(Box::new(opus_stream))), + input_channel_receiver: Arc::new(tokio::sync::Mutex::new(Box::new(opus_stream))), client_streams, sounds: HashMap::new(), output_volume_sender, @@ -337,7 +337,7 @@ impl Audio { } } - pub fn input_receiver(&self) -> Arc> + Unpin>>> { + pub fn input_receiver(&self) -> Arc> + Unpin>>> { Arc::clone(&self.input_channel_receiver) } diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index f767446..717b195 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -184,7 +184,7 @@ async fn send_packets( async fn send_voice( packet_sender: mpsc::UnboundedSender>, - receiver: Arc> + Unpin)>>>, + receiver: Arc> + Unpin)>>>, phase_watcher: watch::Receiver, ) { let inner_phase_watcher = phase_watcher.clone(); @@ -196,7 +196,7 @@ async fn send_voice( || async { packet_sender.send(receiver .lock() - .unwrap() + .await .next() .await .unwrap() @@ -208,6 +208,7 @@ async fn send_voice( || async {}, inner_phase_watcher.clone(), ).await; + debug!("Stopped sending TCP voice"); Some(Some(())) }, |_| async {}, diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 9435e94..5e725cd 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -228,7 +228,7 @@ async fn send_voice( sink: Arc>, server_addr: SocketAddr, phase_watcher: watch::Receiver, - receiver: Arc> + Unpin)>>>, + receiver: Arc> + Unpin)>>>, ) { let inner_phase_watcher = phase_watcher.clone(); run_until( @@ -237,7 +237,9 @@ async fn send_voice( run_until( |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), || async { - sink.lock().unwrap().send((receiver.lock().unwrap().next().await.unwrap(), server_addr)).await.unwrap(); + debug!("Sending UDP audio"); + sink.lock().unwrap().send((receiver.lock().await.next().await.unwrap(), server_addr)).await.unwrap(); + debug!("Sent UDP audio"); Some(Some(())) }, |_| async {}, -- cgit v1.2.1 From 02e6f2b84d72294b29a1698c1b73fbb5697815da Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Wed, 6 Jan 2021 18:31:49 +0100 Subject: clean up network::run_until --- mumd/src/network.rs | 41 +++------ mumd/src/network/tcp.rs | 225 +++++++++++++++++++++++++----------------------- mumd/src/network/udp.rs | 8 +- 3 files changed, 131 insertions(+), 143 deletions(-) diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 03bc436..75b983e 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -10,6 +10,7 @@ use futures::join; use futures::pin_mut; use futures::select; use tokio::sync::watch; +use log::*; use crate::state::StatePhase; @@ -36,16 +37,14 @@ pub enum VoiceStreamType { UDP, } -async fn run_until( +async fn run_until( phase_checker: impl Fn(StatePhase) -> bool, - mut generator: impl FnMut() -> F, - mut handler: impl FnMut(T) -> G, - mut shutdown: impl FnMut() -> H, + fut: F, + mut shutdown: impl FnMut() -> G, mut phase_watcher: watch::Receiver, ) where - F: Future>, + F: Future, G: Future, - H: Future, { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { @@ -55,32 +54,20 @@ async fn run_until( break; } } - tx.send(true).unwrap(); + if tx.send(true).is_err() { + warn!("future resolved before it could be cancelled"); + } }; let main_block = async { let rx = rx.fuse(); pin_mut!(rx); - loop { - let packet_recv = generator().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - //warn!("Channel closed before disconnect command"); //TODO make me informative - break; - } - Some(Some(data)) => { - handler(data).await; - } - } - } + let fut = fut.fuse(); + pin_mut!(fut); + select! { + _ = fut => (), + _ = rx => (), + }; shutdown().await; }; diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 717b195..982e747 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -145,11 +145,13 @@ async fn send_pings( run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { Some(interval.borrow_mut().tick().await) }, - |_| async { - trace!("Sending ping"); - let msg = msgs::Ping::new(); - packet_sender.borrow_mut().send(msg.into()).unwrap(); + async { + loop { + interval.borrow_mut().tick().await; + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.borrow_mut().send(msg.into()).unwrap(); + } }, || async {}, phase_watcher, @@ -168,9 +170,11 @@ async fn send_packets( let packet_receiver = Rc::new(RefCell::new(packet_receiver)); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { packet_receiver.borrow_mut().recv().await }, - |packet| async { - sink.borrow_mut().send(packet).await.unwrap(); + async { + loop { + let packet = packet_receiver.borrow_mut().recv().await.unwrap(); + sink.borrow_mut().send(packet).await.unwrap(); + } }, || async { sink.borrow_mut().close().await.unwrap(); @@ -190,28 +194,26 @@ async fn send_voice( let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { + async { run_until( |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), - || async { - packet_sender.send(receiver - .lock() - .await - .next() - .await - .unwrap() - .into()) - .unwrap(); - Some(Some(())) + async { + loop { + packet_sender.send(receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); + } }, - |_| async {}, || async {}, inner_phase_watcher.clone(), ).await; debug!("Stopped sending TCP voice"); - Some(Some(())) }, - |_| async {}, || async {}, phase_watcher, ).await; @@ -219,7 +221,7 @@ async fn send_voice( async fn listen( state: Arc>, - stream: TcpReceiver, + mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender, event_queue: Arc>>>, phase_watcher: watch::Receiver, @@ -227,92 +229,93 @@ async fn listen( let crypt_state = Rc::new(RefCell::new(None)); let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); - let stream = Rc::new(RefCell::new(stream)); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { stream.borrow_mut().next().await }, - |packet| async { - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.borrow_mut().take() { - let _ = sender - .send( - crypt_state - .borrow_mut() - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; + async { + loop { + let packet = stream.next().await.unwrap(); + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); } - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.borrow_mut().take() { + let _ = sender + .send( + crypt_state + .borrow_mut() + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Connected(&msg)); + } + } + let mut state = state.lock().unwrap(); + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); } - let mut state = state.lock().unwrap(); - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); + ControlPacket::UserState(msg) => { + state.lock().unwrap().parse_user_state(*msg); + } + ControlPacket::UserRemove(msg) => { + state.lock().unwrap().remove_client(*msg); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + packet => { + debug!("Received unhandled ControlPacket {:#?}", packet); } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - state.lock().unwrap().parse_user_state(*msg); - } - ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - packet => { - debug!("Received unhandled ControlPacket {:#?}", packet); } } }, @@ -339,14 +342,16 @@ async fn register_events( let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { tcp_event_register_receiver.borrow_mut().recv().await }, - |(event, handler)| async { - event_data - .lock() - .unwrap() - .entry(event) - .or_default() - .push(handler); + async { + loop { + let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap(); + event_data + .lock() + .unwrap() + .entry(event) + .or_default() + .push(handler); + } }, || async {}, phase_watcher, diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 5e725cd..d35a255 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -233,22 +233,18 @@ async fn send_voice( let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), - || async { + async { run_until( |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - || async { + async { debug!("Sending UDP audio"); sink.lock().unwrap().send((receiver.lock().await.next().await.unwrap(), server_addr)).await.unwrap(); debug!("Sent UDP audio"); - Some(Some(())) }, - |_| async {}, || async {}, inner_phase_watcher.clone(), ).await; - Some(Some(())) }, - |_| async {}, || async {}, phase_watcher, ).await; -- cgit v1.2.1 From 92d5b21bf0f910f219c473002f83ba93ddcbee6d Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Wed, 6 Jan 2021 23:50:09 +0100 Subject: fix deadlock --- mumd/src/audio.rs | 10 ++--- mumd/src/client.rs | 4 +- mumd/src/command.rs | 8 ++-- mumd/src/main.rs | 2 +- mumd/src/network/tcp.rs | 100 ++++++++++++++++++++++++++++++------------------ mumd/src/network/udp.rs | 69 ++++++++++++++++++++------------- 6 files changed, 118 insertions(+), 75 deletions(-) diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 3f03e61..bdc8377 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -31,7 +31,7 @@ use std::{ }; use strum::IntoEnumIterator; use strum_macros::EnumIter; -use tokio::sync::watch; +use tokio::sync::{watch}; const SAMPLE_RATE: u32 = 48000; @@ -132,11 +132,11 @@ impl Audio { let err_fn = |err| error!("An error occurred on the output audio stream: {}", err); - let user_volumes = Arc::new(Mutex::new(HashMap::new())); + let user_volumes = Arc::new(std::sync::Mutex::new(HashMap::new())); let (output_volume_sender, output_volume_receiver) = watch::channel::(output_volume); - let play_sounds = Arc::new(Mutex::new(VecDeque::new())); + let play_sounds = Arc::new(std::sync::Mutex::new(VecDeque::new())); - let client_streams = Arc::new(Mutex::new(HashMap::new())); + let client_streams = Arc::new(std::sync::Mutex::new(HashMap::new())); let output_stream = match output_supported_sample_format { SampleFormat::F32 => output_device.build_output_stream( &output_config, @@ -292,7 +292,7 @@ impl Audio { .collect(); } - pub fn decode_packet(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) { + pub fn decode_packet_payload(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) { match self.client_streams.lock().unwrap().entry((stream_type, session_id)) { Entry::Occupied(mut entry) => { entry diff --git a/mumd/src/client.rs b/mumd/src/client.rs index 3613061..222e2a7 100644 --- a/mumd/src/client.rs +++ b/mumd/src/client.rs @@ -6,8 +6,8 @@ use futures_util::join; use ipc_channel::ipc::IpcSender; use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState}; use mumlib::command::{Command, CommandResponse}; -use std::sync::{Arc, Mutex}; -use tokio::sync::{mpsc, watch}; +use std::sync::Arc; +use tokio::sync::{mpsc, watch, Mutex}; pub async fn handle( command_receiver: mpsc::UnboundedReceiver<( diff --git a/mumd/src/command.rs b/mumd/src/command.rs index e77b34b..b099ae1 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -9,8 +9,8 @@ use ipc_channel::ipc::IpcSender; use log::*; use mumble_protocol::{Serverbound, control::ControlPacket}; use mumlib::command::{Command, CommandResponse}; -use std::sync::{Arc, Mutex}; -use tokio::sync::{mpsc, oneshot, watch}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; pub async fn handle( state: Arc>, @@ -26,9 +26,11 @@ pub async fn handle( debug!("Begin listening for commands"); while let Some((command, response_sender)) = command_receiver.recv().await { debug!("Received command {:?}", command); - let mut state = state.lock().unwrap(); + debug!("locking state"); + let mut state = state.lock().await; let event = state.handle_command(command, &mut packet_sender, &mut connection_info_sender); drop(state); + debug!("unlocking state"); match event { ExecutionContext::TcpEvent(event, generator) => { let (tx, rx) = oneshot::channel(); diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 67481f9..a8cb230 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -14,7 +14,7 @@ use std::fs; use tokio::sync::mpsc; use tokio::task::spawn_blocking; -#[tokio::main] +#[tokio::main(worker_threads = 4)] async fn main() { setup_logger(std::io::stderr(), true); notify::init(); diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 982e747..6f18473 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -14,9 +14,9 @@ use std::collections::HashMap; use std::convert::{Into, TryInto}; use std::net::SocketAddr; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use tokio::net::TcpStream; -use tokio::sync::{mpsc, watch}; +use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::{self, Duration}; use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; @@ -68,7 +68,7 @@ pub async fn handle( .await; // Handshake (omitting `Version` message for brevity) - let state_lock = state.lock().unwrap(); + let state_lock = state.lock().await; authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; let phase_watcher = state_lock.phase_receiver(); let input_receiver = state_lock.audio().input_receiver(); @@ -138,19 +138,16 @@ async fn send_pings( delay_seconds: u64, phase_watcher: watch::Receiver, ) { - let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs( - delay_seconds, - )))); - let packet_sender = Rc::new(RefCell::new(packet_sender)); + let mut interval = time::interval(Duration::from_secs(delay_seconds)); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { loop { - interval.borrow_mut().tick().await; + interval.tick().await; trace!("Sending ping"); let msg = msgs::Ping::new(); - packet_sender.borrow_mut().send(msg.into()).unwrap(); + packet_sender.send(msg.into()).unwrap(); } }, || async {}, @@ -167,12 +164,11 @@ async fn send_packets( phase_watcher: watch::Receiver, ) { let sink = Rc::new(RefCell::new(sink)); - let packet_receiver = Rc::new(RefCell::new(packet_receiver)); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { loop { - let packet = packet_receiver.borrow_mut().recv().await.unwrap(); + let packet = packet_receiver.recv().await.unwrap(); sink.borrow_mut().send(packet).await.unwrap(); } }, @@ -188,31 +184,40 @@ async fn send_packets( async fn send_voice( packet_sender: mpsc::UnboundedSender>, - receiver: Arc> + Unpin)>>>, + receiver: Arc> + Unpin)>>>, phase_watcher: watch::Receiver, ) { let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), - async { - loop { - packet_sender.send(receiver - .lock() - .await - .next() - .await - .unwrap() - .into()) - .unwrap(); + loop { + let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + inner_phase_watcher_2.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { + break; } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - debug!("Stopped sending TCP voice"); + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + async { + loop { + packet_sender.send( + receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); + } + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } }, || async {}, phase_watcher, @@ -269,13 +274,13 @@ async fn listen( ) .await; } - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { let old = std::mem::take(vec); for handler in old { handler(TcpEventData::Connected(&msg)); } } - let mut state = state.lock().unwrap(); + let mut state = state.lock().await; let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); match &server.welcome_text { @@ -291,16 +296,16 @@ async fn listen( warn!("Login rejected: {:?}", msg); } ControlPacket::UserState(msg) => { - state.lock().unwrap().parse_user_state(*msg); + state.lock().await.parse_user_state(*msg); } ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); + state.lock().await.remove_client(*msg); } ControlPacket::ChannelState(msg) => { debug!("Channel state received"); state .lock() - .unwrap() + .await .server_mut() .unwrap() .parse_channel_state(*msg); //TODO parse initial if initial @@ -308,11 +313,32 @@ async fn listen( ControlPacket::ChannelRemove(msg) => { state .lock() - .unwrap() + .await .server_mut() .unwrap() .parse_channel_remove(*msg); } + ControlPacket::UDPTunnel(msg) => { + match *msg { + VoicePacket::Ping { .. } => {} + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() + .await + .audio() + .decode_packet_payload( + VoiceStreamType::TCP, + session_id, + payload); + } + } + } packet => { debug!("Received unhandled ControlPacket {:#?}", packet); } @@ -320,7 +346,7 @@ async fn listen( } }, || async { - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { let old = std::mem::take(vec); for handler in old { handler(TcpEventData::Disconnected); @@ -347,7 +373,7 @@ async fn register_events( let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap(); event_data .lock() - .unwrap() + .await .entry(event) .or_default() .push(handler); diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index d35a255..25ec8d5 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -13,9 +13,9 @@ use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::rc::Rc; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc}; use tokio::net::UdpSocket; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; @@ -31,7 +31,7 @@ pub async fn handle( mut connection_info_receiver: watch::Receiver>, mut crypt_state_receiver: mpsc::Receiver, ) { - let receiver = state.lock().unwrap().audio().input_receiver(); + let receiver = state.lock().await.audio().input_receiver(); loop { let connection_info = 'data: loop { @@ -47,7 +47,7 @@ pub async fn handle( let sink = Arc::new(Mutex::new(sink)); let source = Arc::new(Mutex::new(source)); - let phase_watcher = state.lock().unwrap().phase_receiver(); + let phase_watcher = state.lock().await.phase_receiver(); let last_ping_recv = AtomicU64::new(0); join!( listen( @@ -107,8 +107,8 @@ async fn new_crypt_state( .await .expect("Failed to bind UDP socket"); let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().unwrap() = new_sink; - *source.lock().unwrap() = new_source; + *sink.lock().await = new_sink; + *source.lock().await = new_source; } } } @@ -134,13 +134,14 @@ async fn listen( let rx = rx.fuse(); pin_mut!(rx); loop { - let mut source = source.lock().unwrap(); + let mut source = source.lock().await; let packet_recv = source.next().fuse(); pin_mut!(packet_recv); let exitor = select! { data = packet_recv => Some(data), _ = rx => None }; + drop(source); match exitor { None => { break; @@ -160,9 +161,10 @@ async fn listen( }; match packet { VoicePacket::Ping { timestamp } => { + // debug!("Sending UDP voice"); state - .lock() - .unwrap() + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); last_ping_recv.store(timestamp, Ordering::Relaxed); } @@ -175,9 +177,9 @@ async fn listen( } => { state .lock() - .unwrap() + .await .audio() - .decode_packet(VoiceStreamType::UDP, session_id, payload); + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); } } } @@ -198,19 +200,21 @@ async fn send_pings( ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); + interval.tick().await; //this is so we get rid of the first instant resolve loop { let last_recv = last_ping_recv.load(Ordering::Relaxed); if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); state .lock() - .unwrap() + .await .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); } match sink .lock() - .unwrap() - .send((VoicePacket::Ping { timestamp: 0 }, server_addr)) + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) .await { Ok(_) => { @@ -228,22 +232,33 @@ async fn send_voice( sink: Arc>, server_addr: SocketAddr, phase_watcher: watch::Receiver, - receiver: Arc> + Unpin)>>>, + receiver: Arc> + Unpin)>>>, ) { let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - async { - debug!("Sending UDP audio"); - sink.lock().unwrap().send((receiver.lock().await.next().await.unwrap(), server_addr)).await.unwrap(); - debug!("Sent UDP audio"); - }, - || async {}, - inner_phase_watcher.clone(), - ).await; + loop { + let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + inner_phase_watcher_2.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; + loop { + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); + } + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } }, || async {}, phase_watcher, @@ -266,7 +281,7 @@ pub async fn handle_pings( let packet = PingPacket { id }; let packet: [u8; 12] = packet.into(); udp_socket.send_to(&packet, &socket_addr).await.unwrap(); - pending.lock().unwrap().insert(id, handle); + pending.lock().await.insert(id, handle); } }; @@ -277,7 +292,7 @@ pub async fn handle_pings( let packet = PongPacket::try_from(buf.as_slice()).unwrap(); - if let Some(handler) = pending.lock().unwrap().remove(&packet.id) { + if let Some(handler) = pending.lock().await.remove(&packet.id) { handler(packet); } } -- cgit v1.2.1 From ce5bce8681220b2460cacc931937e2772545fe52 Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Wed, 6 Jan 2021 23:56:58 +0100 Subject: clean up UDP packet listening --- mumd/src/network/udp.rs | 110 ++++++++++++++++++------------------------------ 1 file changed, 41 insertions(+), 69 deletions(-) diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 25ec8d5..441d08b 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,7 +1,7 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt, Stream}; +use futures::{join, SinkExt, StreamExt, Stream}; use futures_util::stream::{SplitSink, SplitStream}; use log::*; use mumble_protocol::crypt::ClientCryptState; @@ -15,7 +15,7 @@ use std::rc::Rc; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc}; use tokio::net::UdpSocket; -use tokio::sync::{mpsc, oneshot, watch, Mutex}; +use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; @@ -116,78 +116,50 @@ async fn new_crypt_state( async fn listen( state: Arc>, source: Arc>, - mut phase_watcher: watch::Receiver, + phase_watcher: watch::Receiver, last_ping_recv: &AtomicU64, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; - } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let mut source = source.lock().await; - let packet_recv = source.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - drop(source); - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(packet)) => { - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { timestamp } => { - // debug!("Sending UDP voice"); - state - .lock() //TODO clean up unnecessary lock by only updating phase if it should change - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); - last_ping_recv.store(timestamp, Ordering::Relaxed); - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() - .await - .audio() - .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); - } + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + async { + loop { + let packet = source.lock().await.next().await.unwrap(); + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; + } + }; + match packet { + VoicePacket::Ping { timestamp } => { + // debug!("Sending UDP voice"); + state + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); + last_ping_recv.store(timestamp, Ordering::Relaxed); + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() + .await + .audio() + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); } } } - } - }; - - join!(main_block, phase_transition_block); + }, + || async {}, + phase_watcher + ).await; debug!("UDP listener process killed"); } -- cgit v1.2.1 From fe6e5eb67405cd929e65d8ff01ba52dbac9565ad Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 00:02:30 +0100 Subject: remove unnecessary debug print --- mumd/src/command.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/mumd/src/command.rs b/mumd/src/command.rs index b099ae1..653d1fa 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -26,11 +26,9 @@ pub async fn handle( debug!("Begin listening for commands"); while let Some((command, response_sender)) = command_receiver.recv().await { debug!("Received command {:?}", command); - debug!("locking state"); let mut state = state.lock().await; let event = state.handle_command(command, &mut packet_sender, &mut connection_info_sender); drop(state); - debug!("unlocking state"); match event { ExecutionContext::TcpEvent(event, generator) => { let (tx, rx) = oneshot::channel(); -- cgit v1.2.1 From 11e026e7a8edda97274c984a31fd51f3f480e623 Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 00:03:10 +0100 Subject: restore tokio configuration --- mumd/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mumd/src/main.rs b/mumd/src/main.rs index a8cb230..67481f9 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -14,7 +14,7 @@ use std::fs; use tokio::sync::mpsc; use tokio::task::spawn_blocking; -#[tokio::main(worker_threads = 4)] +#[tokio::main] async fn main() { setup_logger(std::io::stderr(), true); notify::init(); -- cgit v1.2.1 From 215689a4fb9fd6fa8dbfd263fcbf7b59b04942e5 Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 11:58:09 +0100 Subject: update changelog --- CHANGELOG | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG b/CHANGELOG index 5d6d64b..468d9a6 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -19,6 +19,7 @@ Added ~~~~~ * Added a noise gate +* Added tunneling audio through TCP if UDP connection goes down // Changed // ~~~~~~~ -- cgit v1.2.1 From 8b042801d090e1a17ca72ddb559d92ccbbb41091 Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 12:02:43 +0100 Subject: update according to feedback --- mumd/src/audio.rs | 2 +- mumd/src/network.rs | 5 ++--- mumd/src/network/tcp.rs | 5 ++--- mumd/src/network/udp.rs | 8 +++----- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index bdc8377..598dde6 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -31,7 +31,7 @@ use std::{ }; use strum::IntoEnumIterator; use strum_macros::EnumIter; -use tokio::sync::{watch}; +use tokio::sync::watch; const SAMPLE_RATE: u32 = 48000; diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 75b983e..9463ad7 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -1,16 +1,15 @@ pub mod tcp; pub mod udp; -use std::net::SocketAddr; - use futures::Future; use futures::FutureExt; use futures::channel::oneshot; use futures::join; use futures::pin_mut; use futures::select; -use tokio::sync::watch; use log::*; +use std::net::SocketAddr; +use tokio::sync::watch; use crate::state::StatePhase; diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 6f18473..3e4cbf3 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,9 +1,8 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use futures::Stream; use log::*; -use futures::{join, SinkExt, StreamExt}; +use futures::{join, SinkExt, Stream, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; @@ -145,7 +144,7 @@ async fn send_pings( async { loop { interval.tick().await; - trace!("Sending ping"); + trace!("Sending TCP ping"); let msg = msgs::Ping::new(); packet_sender.send(msg.into()).unwrap(); } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 441d08b..ac67bcb 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -13,7 +13,7 @@ use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::rc::Rc; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc}; +use std::sync::Arc; use tokio::net::UdpSocket; use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::{interval, Duration}; @@ -134,7 +134,6 @@ async fn listen( }; match packet { VoicePacket::Ping { timestamp } => { - // debug!("Sending UDP voice"); state .lock() //TODO clean up unnecessary lock by only updating phase if it should change .await @@ -149,7 +148,7 @@ async fn listen( .. } => { state - .lock() + .lock() //TODO change so that we only have to lock audio and not the whole state .await .audio() .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); @@ -172,9 +171,9 @@ async fn send_pings( ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); - interval.tick().await; //this is so we get rid of the first instant resolve loop { + interval.tick().await; let last_recv = last_ping_recv.load(Ordering::Relaxed); if last_send.is_some() && last_send.unwrap() != last_recv { debug!("Sending TCP voice"); @@ -196,7 +195,6 @@ async fn send_pings( debug!("Error sending UDP ping: {}", e); } } - interval.tick().await; } } -- cgit v1.2.1 From f6a8a126e67ff1a89dcbdb35033e1f324add50dc Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 12:10:32 +0100 Subject: fix UDP shutting down properly --- mumd/src/network/udp.rs | 91 ++++++++++++++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index ac67bcb..4167c15 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -59,7 +59,7 @@ pub async fn handle( send_voice( Arc::clone(&sink), connection_info.socket_addr, - phase_watcher, + phase_watcher.clone(), Arc::clone(&receiver), ), send_pings( @@ -67,8 +67,9 @@ pub async fn handle( Arc::clone(&sink), connection_info.socket_addr, &last_ping_recv, + phase_watcher.clone(), ), - new_crypt_state(&mut crypt_state_receiver, sink, source), + new_crypt_state(&mut crypt_state_receiver, sink, source, phase_watcher), ); debug!("Fully disconnected UDP stream, waiting for new connection info"); @@ -99,18 +100,26 @@ async fn new_crypt_state( crypt_state: &mut mpsc::Receiver, sink: Arc>, source: Arc>, + phase_watcher: watch::Receiver, ) { - loop { - if let Some(crypt_state) = crypt_state.recv().await { - info!("Received new crypt state"); - let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) - .await - .expect("Failed to bind UDP socket"); - let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().await = new_sink; - *source.lock().await = new_source; - } - } + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + async { + loop { + if let Some(crypt_state) = crypt_state.recv().await { + info!("Received new crypt state"); + let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) + .await + .expect("Failed to bind UDP socket"); + let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); + *sink.lock().await = new_sink; + *source.lock().await = new_source; + } + } + }, + || async {}, + phase_watcher, + ).await; } async fn listen( @@ -168,34 +177,44 @@ async fn send_pings( sink: Arc>, server_addr: SocketAddr, last_ping_recv: &AtomicU64, + phase_watcher: watch::Receiver, ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); - loop { - interval.tick().await; - let last_recv = last_ping_recv.load(Ordering::Relaxed); - if last_send.is_some() && last_send.unwrap() != last_recv { - debug!("Sending TCP voice"); - state - .lock() - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); - } - match sink - .lock() - .await - .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) - .await - { - Ok(_) => { - last_send = Some(last_recv + 1); - }, - Err(e) => { - debug!("Error sending UDP ping: {}", e); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + async { + loop { + interval.tick().await; + let last_recv = last_ping_recv.load(Ordering::Relaxed); + if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); + state + .lock() + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); + } + match sink + .lock() + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) + .await + { + Ok(_) => { + last_send = Some(last_recv + 1); + }, + Err(e) => { + debug!("Error sending UDP ping: {}", e); + } + } } - } - } + }, + || async {}, + phase_watcher, + ).await; + + debug!("UDP ping sender process killed"); } async fn send_voice( -- cgit v1.2.1 From ab407d694e5a8ce6f831f8a84fc32dbdf6685aac Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 12:39:49 +0100 Subject: lift waiting for disconnection --- mumd/src/network/tcp.rs | 402 +++++++++++++++++++++--------------------------- mumd/src/network/udp.rs | 259 ++++++++++++++----------------- 2 files changed, 287 insertions(+), 374 deletions(-) diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 3e4cbf3..e639dd0 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,17 +2,15 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use futures::{join, SinkExt, Stream, StreamExt}; +use futures::{FutureExt, SinkExt, Stream, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::VoicePacket; use mumble_protocol::{Clientbound, Serverbound}; -use std::cell::RefCell; use std::collections::HashMap; use std::convert::{Into, TryInto}; use std::net::SocketAddr; -use std::rc::Rc; use std::sync::Arc; use tokio::net::TcpStream; use tokio::sync::{mpsc, watch, Mutex}; @@ -21,6 +19,7 @@ use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; use super::{run_until, VoiceStreamType}; +use futures_util::future::join5; type TcpSender = SplitSink< Framed, ControlCodec>, @@ -76,24 +75,34 @@ pub async fn handle( info!("Logging in..."); - //TODO force exit all futures on disconnection - join!( - send_pings(packet_sender.clone(), 10, phase_watcher.clone()), - listen( - Arc::clone(&state), - stream, - crypt_state_sender.clone(), - Arc::clone(&event_queue), - phase_watcher.clone(), - ), - send_voice( - packet_sender.clone(), - Arc::clone(&input_receiver), - phase_watcher.clone(), - ), - send_packets(sink, &mut packet_receiver, phase_watcher.clone()), - register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), - ); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join5( + send_pings(packet_sender.clone(), 10), + listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + Arc::clone(&event_queue), + ), + send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher.clone(), + ), + send_packets(sink, &mut packet_receiver), + register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)), + ).map(|_| ()), + || async {}, + phase_watcher, + ).await; + + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Disconnected); + } + } debug!("Fully disconnected TCP stream, waiting for new connection info"); } @@ -135,50 +144,24 @@ async fn authenticate(sink: &mut TcpSender, username: String) { async fn send_pings( packet_sender: mpsc::UnboundedSender>, delay_seconds: u64, - phase_watcher: watch::Receiver, ) { let mut interval = time::interval(Duration::from_secs(delay_seconds)); - - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - interval.tick().await; - trace!("Sending TCP ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } - }, - || async {}, - phase_watcher, - ) - .await; - - debug!("Ping sender process killed"); + loop { + interval.tick().await; + trace!("Sending TCP ping"); + let msg = msgs::Ping::new(); + packet_sender.send(msg.into()).unwrap(); + } } async fn send_packets( - sink: TcpSender, + mut sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver>, - phase_watcher: watch::Receiver, ) { - let sink = Rc::new(RefCell::new(sink)); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let packet = packet_receiver.recv().await.unwrap(); - sink.borrow_mut().send(packet).await.unwrap(); - } - }, - || async { - sink.borrow_mut().close().await.unwrap(); - }, - phase_watcher, - ) - .await; - - debug!("TCP packet sender killed"); + loop { + let packet = packet_receiver.recv().await.unwrap(); + sink.send(packet).await.unwrap(); + } } async fn send_voice( @@ -186,41 +169,33 @@ async fn send_voice( receiver: Arc> + Unpin)>>>, phase_watcher: watch::Receiver, ) { - let inner_phase_watcher = phase_watcher.clone(); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + let mut inner_phase_watcher = phase_watcher.clone(); + loop { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + async { loop { - inner_phase_watcher_2.changed().await.unwrap(); - if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { - break; - } + packet_sender.send( + receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); } - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), - async { - loop { - packet_sender.send( - receiver - .lock() - .await - .next() - .await - .unwrap() - .into()) - .unwrap(); - } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - } - }, - || async {}, - phase_watcher, - ).await; + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } } async fn listen( @@ -228,158 +203,129 @@ async fn listen( mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender, event_queue: Arc>>>, - phase_watcher: watch::Receiver, ) { - let crypt_state = Rc::new(RefCell::new(None)); - let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); + let mut crypt_state = None; + let mut crypt_state_sender = Some(crypt_state_sender); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let packet = stream.next().await.unwrap(); - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.borrow_mut().take() { - let _ = sender - .send( - crypt_state - .borrow_mut() - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; - } - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); - } - } - let mut state = state.lock().await; - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), - } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); - } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - state.lock().await.parse_user_state(*msg); - } - ControlPacket::UserRemove(msg) => { - state.lock().await.remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .await - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial + loop { + let packet = stream.next().await.unwrap(); + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + crypt_state = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.take() { + let _ = sender + .send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Connected(&msg)); } - ControlPacket::ChannelRemove(msg) => { + } + let mut state = state.lock().await; + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + state.lock().await.parse_user_state(*msg); + } + ControlPacket::UserRemove(msg) => { + state.lock().await.remove_client(*msg); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + ControlPacket::UDPTunnel(msg) => { + match *msg { + VoicePacket::Ping { .. } => {} + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { state .lock() .await - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - ControlPacket::UDPTunnel(msg) => { - match *msg { - VoicePacket::Ping { .. } => {} - VoicePacket::Audio { + .audio() + .decode_packet_payload( + VoiceStreamType::TCP, session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() - .await - .audio() - .decode_packet_payload( - VoiceStreamType::TCP, - session_id, - payload); - } - } - } - packet => { - debug!("Received unhandled ControlPacket {:#?}", packet); + payload); } } } - }, - || async { - if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Disconnected); - } + packet => { + debug!("Received unhandled ControlPacket {:#?}", packet); } - }, - phase_watcher, - ) - .await; - - debug!("Killing TCP listener block"); + } + } } async fn register_events( tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, event_data: Arc>>>, - phase_watcher: watch::Receiver, ) { - let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap(); - event_data - .lock() - .await - .entry(event) - .or_default() - .push(handler); - } - }, - || async {}, - phase_watcher, - ) - .await; + loop { + let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); + event_data + .lock() + .await + .entry(event) + .or_default() + .push(handler); + } } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 4167c15..9dd6ed3 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,7 +1,7 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; -use futures::{join, SinkExt, StreamExt, Stream}; +use futures::{join, FutureExt, SinkExt, StreamExt, Stream}; use futures_util::stream::{SplitSink, SplitStream}; use log::*; use mumble_protocol::crypt::ClientCryptState; @@ -20,6 +20,7 @@ use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; use super::{run_until, VoiceStreamType}; +use futures_util::future::join4; pub type PingRequest = (u64, SocketAddr, Box); @@ -49,28 +50,32 @@ pub async fn handle( let phase_watcher = state.lock().await.phase_receiver(); let last_ping_recv = AtomicU64::new(0); - join!( - listen( - Arc::clone(&state), - Arc::clone(&source), - phase_watcher.clone(), - &last_ping_recv, - ), - send_voice( - Arc::clone(&sink), - connection_info.socket_addr, - phase_watcher.clone(), - Arc::clone(&receiver), - ), - send_pings( - Arc::clone(&state), - Arc::clone(&sink), - connection_info.socket_addr, - &last_ping_recv, - phase_watcher.clone(), - ), - new_crypt_state(&mut crypt_state_receiver, sink, source, phase_watcher), - ); + + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join4( + listen( + Arc::clone(&state), + Arc::clone(&source), + &last_ping_recv, + ), + send_voice( + Arc::clone(&sink), + connection_info.socket_addr, + phase_watcher.clone(), + Arc::clone(&receiver), + ), + send_pings( + Arc::clone(&state), + Arc::clone(&sink), + connection_info.socket_addr, + &last_ping_recv, + ), + new_crypt_state(&mut crypt_state_receiver, sink, source), + ).map(|_| ()), + || async {}, + phase_watcher, + ).await; debug!("Fully disconnected UDP stream, waiting for new connection info"); } @@ -100,76 +105,58 @@ async fn new_crypt_state( crypt_state: &mut mpsc::Receiver, sink: Arc>, source: Arc>, - phase_watcher: watch::Receiver, ) { - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - if let Some(crypt_state) = crypt_state.recv().await { - info!("Received new crypt state"); - let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) - .await - .expect("Failed to bind UDP socket"); - let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().await = new_sink; - *source.lock().await = new_source; - } - } - }, - || async {}, - phase_watcher, - ).await; + loop { + if let Some(crypt_state) = crypt_state.recv().await { + info!("Received new crypt state"); + let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) + .await + .expect("Failed to bind UDP socket"); + let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); + *sink.lock().await = new_sink; + *source.lock().await = new_source; + } + } } async fn listen( state: Arc>, source: Arc>, - phase_watcher: watch::Receiver, last_ping_recv: &AtomicU64, ) { - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let packet = source.lock().await.next().await.unwrap(); - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { timestamp } => { - state - .lock() //TODO clean up unnecessary lock by only updating phase if it should change - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); - last_ping_recv.store(timestamp, Ordering::Relaxed); - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state - .lock() //TODO change so that we only have to lock audio and not the whole state - .await - .audio() - .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); - } - } + loop { + let packet = source.lock().await.next().await.unwrap(); + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; } - }, - || async {}, - phase_watcher - ).await; - - debug!("UDP listener process killed"); + }; + match packet { + VoicePacket::Ping { timestamp } => { + state + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); + last_ping_recv.store(timestamp, Ordering::Relaxed); + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() //TODO change so that we only have to lock audio and not the whole state + .await + .audio() + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); + } + } + } } async fn send_pings( @@ -177,44 +164,34 @@ async fn send_pings( sink: Arc>, server_addr: SocketAddr, last_ping_recv: &AtomicU64, - phase_watcher: watch::Receiver, ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - interval.tick().await; - let last_recv = last_ping_recv.load(Ordering::Relaxed); - if last_send.is_some() && last_send.unwrap() != last_recv { - debug!("Sending TCP voice"); - state - .lock() - .await - .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); - } - match sink - .lock() - .await - .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) - .await - { - Ok(_) => { - last_send = Some(last_recv + 1); - }, - Err(e) => { - debug!("Error sending UDP ping: {}", e); - } - } + loop { + interval.tick().await; + let last_recv = last_ping_recv.load(Ordering::Relaxed); + if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); + state + .lock() + .await + .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); + } + match sink + .lock() + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) + .await + { + Ok(_) => { + last_send = Some(last_recv + 1); + }, + Err(e) => { + debug!("Error sending UDP ping: {}", e); } - }, - || async {}, - phase_watcher, - ).await; - - debug!("UDP ping sender process killed"); + } + } } async fn send_voice( @@ -223,37 +200,27 @@ async fn send_voice( phase_watcher: watch::Receiver, receiver: Arc> + Unpin)>>>, ) { - let inner_phase_watcher = phase_watcher.clone(); - run_until( - |phase| matches!(phase, StatePhase::Disconnected), - async { - loop { - let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + let mut inner_phase_watcher = phase_watcher.clone(); + loop { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; loop { - inner_phase_watcher_2.changed().await.unwrap(); - if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { - break; - } + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); } - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - async { - let mut receiver = receiver.lock().await; - loop { - let sending = (receiver.next().await.unwrap(), server_addr); - sink.lock().await.send(sending).await.unwrap(); - } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - } - }, - || async {}, - phase_watcher, - ).await; - - debug!("UDP sender process killed"); + }, + || async {}, + phase_watcher.clone(), + ).await; + } } pub async fn handle_pings( -- cgit v1.2.1 From 62d3e3d6bf3842a1aad28874a69992b0b880137e Mon Sep 17 00:00:00 2001 From: Eskil Q Date: Thu, 7 Jan 2021 12:41:43 +0100 Subject: remove shutdown function on run_until it wasn't used and there are other ways of accomplishing the same thing --- mumd/src/network.rs | 6 +----- mumd/src/network/tcp.rs | 2 -- mumd/src/network/udp.rs | 2 -- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 9463ad7..6c67b3a 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -36,14 +36,12 @@ pub enum VoiceStreamType { UDP, } -async fn run_until( +async fn run_until( phase_checker: impl Fn(StatePhase) -> bool, fut: F, - mut shutdown: impl FnMut() -> G, mut phase_watcher: watch::Receiver, ) where F: Future, - G: Future, { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { @@ -67,8 +65,6 @@ async fn run_until( _ = fut => (), _ = rx => (), }; - - shutdown().await; }; join!(main_block, phase_transition_block); diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index e639dd0..3a32b9f 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -93,7 +93,6 @@ pub async fn handle( send_packets(sink, &mut packet_receiver), register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)), ).map(|_| ()), - || async {}, phase_watcher, ).await; @@ -192,7 +191,6 @@ async fn send_voice( .unwrap(); } }, - || async {}, inner_phase_watcher.clone(), ).await; } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 9dd6ed3..5f24b51 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -73,7 +73,6 @@ pub async fn handle( ), new_crypt_state(&mut crypt_state_receiver, sink, source), ).map(|_| ()), - || async {}, phase_watcher, ).await; @@ -217,7 +216,6 @@ async fn send_voice( sink.lock().await.send(sending).await.unwrap(); } }, - || async {}, phase_watcher.clone(), ).await; } -- cgit v1.2.1