From ab038b58b4440804cdfded56167ce72b599d87c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Tue, 5 Jan 2021 17:08:48 +0100 Subject: yikes --- mumd/src/audio.rs | 2 +- mumd/src/network.rs | 62 ++++++++++++++++++++++++++ mumd/src/network/tcp.rs | 113 +++++++++++++++++++++++------------------------- mumd/src/network/udp.rs | 73 +++++++++++-------------------- 4 files changed, 143 insertions(+), 107 deletions(-) (limited to 'mumd') diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 40cdcb2..4f9b73c 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -337,7 +337,7 @@ impl Audio { } } - pub fn take_receiver(&mut self) -> Arc> + Unpin>>> { + pub fn input_receiver(&self) -> Arc> + Unpin>>> { Arc::clone(&self.input_channel_receiver) } diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 4fb2e77..03bc436 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -3,6 +3,16 @@ pub mod udp; use std::net::SocketAddr; +use futures::Future; +use futures::FutureExt; +use futures::channel::oneshot; +use futures::join; +use futures::pin_mut; +use futures::select; +use tokio::sync::watch; + +use crate::state::StatePhase; + #[derive(Clone, Debug)] pub struct ConnectionInfo { socket_addr: SocketAddr, @@ -25,3 +35,55 @@ pub enum VoiceStreamType { TCP, UDP, } + +async fn run_until( + phase_checker: impl Fn(StatePhase) -> bool, + mut generator: impl FnMut() -> F, + mut handler: impl FnMut(T) -> G, + mut shutdown: impl FnMut() -> H, + mut phase_watcher: watch::Receiver, +) where + F: Future>, + G: Future, + H: Future, +{ + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + loop { + phase_watcher.changed().await.unwrap(); + if phase_checker(*phase_watcher.borrow()) { + break; + } + } + tx.send(true).unwrap(); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = generator().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + //warn!("Channel closed before disconnect command"); //TODO make me informative + break; + } + Some(Some(data)) => { + handler(data).await; + } + } + } + + shutdown().await; + }; + + join!(main_block, phase_transition_block); +} diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 47ea311..f767446 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,25 +1,28 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; +use futures::Stream; use log::*; -use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; +use futures::{join, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; +use mumble_protocol::voice::VoicePacket; use mumble_protocol::{Clientbound, Serverbound}; use std::cell::RefCell; use std::collections::HashMap; use std::convert::{Into, TryInto}; -use std::future::Future; use std::net::SocketAddr; use std::rc::Rc; use std::sync::{Arc, Mutex}; use tokio::net::TcpStream; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch}; use tokio::time::{self, Duration}; use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use super::{run_until, VoiceStreamType}; + type TcpSender = SplitSink< Framed, ControlCodec>, ControlPacket, @@ -68,11 +71,13 @@ pub async fn handle( let state_lock = state.lock().unwrap(); authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; let phase_watcher = state_lock.phase_receiver(); + let input_receiver = state_lock.audio().input_receiver(); drop(state_lock); let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); + //TODO force exit all futures on disconnection join!( send_pings(packet_sender.clone(), 10, phase_watcher.clone()), listen( @@ -82,6 +87,11 @@ pub async fn handle( Arc::clone(&event_queue), phase_watcher.clone(), ), + send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher.clone(), + ), send_packets(sink, &mut packet_receiver, phase_watcher.clone()), register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), ); @@ -133,7 +143,8 @@ async fn send_pings( )))); let packet_sender = Rc::new(RefCell::new(packet_sender)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { Some(interval.borrow_mut().tick().await) }, |_| async { trace!("Sending ping"); @@ -155,7 +166,8 @@ async fn send_packets( ) { let sink = Rc::new(RefCell::new(sink)); let packet_receiver = Rc::new(RefCell::new(packet_receiver)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { packet_receiver.borrow_mut().recv().await }, |packet| async { sink.borrow_mut().send(packet).await.unwrap(); @@ -170,6 +182,40 @@ async fn send_packets( debug!("TCP packet sender killed"); } +async fn send_voice( + packet_sender: mpsc::UnboundedSender>, + receiver: Arc> + Unpin)>>>, + phase_watcher: watch::Receiver, +) { + let inner_phase_watcher = phase_watcher.clone(); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + || async { + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + || async { + packet_sender.send(receiver + .lock() + .unwrap() + .next() + .await + .unwrap() + .into()) + .unwrap(); + Some(Some(())) + }, + |_| async {}, + || async {}, + inner_phase_watcher.clone(), + ).await; + Some(Some(())) + }, + |_| async {}, + || async {}, + phase_watcher, + ).await; +} + async fn listen( state: Arc>, stream: TcpReceiver, @@ -181,7 +227,8 @@ async fn listen( let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); let stream = Rc::new(RefCell::new(stream)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { stream.borrow_mut().next().await }, |packet| async { match packet.unwrap() { @@ -289,7 +336,8 @@ async fn register_events( phase_watcher: watch::Receiver, ) { let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); - run_until_disconnection( + run_until( + |phase| matches!(phase, StatePhase::Disconnected), || async { tcp_event_register_receiver.borrow_mut().recv().await }, |(event, handler)| async { event_data @@ -304,54 +352,3 @@ async fn register_events( ) .await; } - -async fn run_until_disconnection( - mut generator: impl FnMut() -> F, - mut handler: impl FnMut(T) -> G, - mut shutdown: impl FnMut() -> H, - mut phase_watcher: watch::Receiver, -) where - F: Future>, - G: Future, - H: Future, -{ - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; - } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = generator().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - //warn!("Channel closed before disconnect command"); //TODO make me informative - break; - } - Some(Some(data)) => { - handler(data).await; - } - } - } - - shutdown().await; - }; - - join!(main_block, phase_transition_block); -} diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 1bc012d..9435e94 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -19,7 +19,7 @@ use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; -use super::VoiceStreamType; +use super::{run_until, VoiceStreamType}; pub type PingRequest = (u64, SocketAddr, Box); @@ -31,7 +31,7 @@ pub async fn handle( mut connection_info_receiver: watch::Receiver>, mut crypt_state_receiver: mpsc::Receiver, ) { - let receiver = state.lock().unwrap().audio_mut().take_receiver(); + let receiver = state.lock().unwrap().audio().input_receiver(); loop { let connection_info = 'data: loop { @@ -49,7 +49,6 @@ pub async fn handle( let phase_watcher = state.lock().unwrap().phase_receiver(); let last_ping_recv = AtomicU64::new(0); - let mut audio_receiver_lock = receiver.lock().unwrap(); join!( listen( Arc::clone(&state), @@ -61,7 +60,7 @@ pub async fn handle( Arc::clone(&sink), connection_info.socket_addr, phase_watcher, - &mut *audio_receiver_lock, + Arc::clone(&receiver), ), send_pings( Arc::clone(&state), @@ -228,51 +227,29 @@ async fn send_pings( async fn send_voice( sink: Arc>, server_addr: SocketAddr, - mut phase_watcher: watch::Receiver, - receiver: &mut (dyn Stream> + Unpin), + phase_watcher: watch::Receiver, + receiver: Arc> + Unpin)>>>, ) { - pin_mut!(receiver); - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; - } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = receiver.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(reply)) => { - sink.lock() - .unwrap() - .send((reply, server_addr)) - .await - .unwrap(); - } - } - } - }; - - join!(main_block, phase_transition_block); + let inner_phase_watcher = phase_watcher.clone(); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + || async { + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + || async { + sink.lock().unwrap().send((receiver.lock().unwrap().next().await.unwrap(), server_addr)).await.unwrap(); + Some(Some(())) + }, + |_| async {}, + || async {}, + inner_phase_watcher.clone(), + ).await; + Some(Some(())) + }, + |_| async {}, + || async {}, + phase_watcher, + ).await; debug!("UDP sender process killed"); } -- cgit v1.2.1