diff options
| author | Kapten Z∅∅m <55669224+default-username-852@users.noreply.github.com> | 2021-01-07 22:22:24 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-01-07 22:22:24 +0100 |
| commit | 154a2930b3bbe5ede06648c3a10b8fa4904277f4 (patch) | |
| tree | 18ee2f8b569991d1d0e6b6248539f70da63a62d7 /mumd/src/network/tcp.rs | |
| parent | ba4aa72654f2d57d59f6e25151315213bec21192 (diff) | |
| parent | 62d3e3d6bf3842a1aad28874a69992b0b880137e (diff) | |
| download | mum-154a2930b3bbe5ede06648c3a10b8fa4904277f4.tar.gz | |
Merge pull request #58 from mum-rs/tcp-voice-tunnel-2
TCP voice tunnel
Diffstat (limited to 'mumd/src/network/tcp.rs')
| -rw-r--r-- | mumd/src/network/tcp.rs | 406 |
1 files changed, 189 insertions, 217 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 47ea311..3a32b9f 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,24 +2,25 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; +use futures::{FutureExt, SinkExt, Stream, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; +use mumble_protocol::voice::VoicePacket; use mumble_protocol::{Clientbound, Serverbound}; -use std::cell::RefCell; use std::collections::HashMap; use std::convert::{Into, TryInto}; -use std::future::Future; use std::net::SocketAddr; -use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use tokio::net::TcpStream; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::{self, Duration}; use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use super::{run_until, VoiceStreamType}; +use futures_util::future::join5; + type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, ControlPacket<Serverbound>, @@ -65,26 +66,42 @@ pub async fn handle( .await; // Handshake (omitting `Version` message for brevity) - let state_lock = state.lock().unwrap(); + let state_lock = state.lock().await; authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; let phase_watcher = state_lock.phase_receiver(); + let input_receiver = state_lock.audio().input_receiver(); drop(state_lock); let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); - join!( - send_pings(packet_sender.clone(), 10, phase_watcher.clone()), - listen( - Arc::clone(&state), - stream, - crypt_state_sender.clone(), - Arc::clone(&event_queue), - phase_watcher.clone(), - ), - send_packets(sink, &mut packet_receiver, phase_watcher.clone()), - register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), - ); + run_until( + |phase| matches!(phase, StatePhase::Disconnected), + join5( + send_pings(packet_sender.clone(), 10), + listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + Arc::clone(&event_queue), + ), + send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher.clone(), + ), + send_packets(sink, &mut packet_receiver), + register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)), + ).map(|_| ()), + phase_watcher, + ).await; + + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Disconnected); + } + } debug!("Fully disconnected TCP stream, waiting for new connection info"); } @@ -126,232 +143,187 @@ async fn authenticate(sink: &mut TcpSender, username: String) { async fn send_pings( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, delay_seconds: u64, - phase_watcher: watch::Receiver<StatePhase>, ) { - let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs( - delay_seconds, - )))); - let packet_sender = Rc::new(RefCell::new(packet_sender)); - - run_until_disconnection( - || async { Some(interval.borrow_mut().tick().await) }, - |_| async { - trace!("Sending ping"); + let mut interval = time::interval(Duration::from_secs(delay_seconds)); + loop { + interval.tick().await; + trace!("Sending TCP ping"); let msg = msgs::Ping::new(); - packet_sender.borrow_mut().send(msg.into()).unwrap(); - }, - || async {}, - phase_watcher, - ) - .await; - - debug!("Ping sender process killed"); + packet_sender.send(msg.into()).unwrap(); + } } async fn send_packets( - sink: TcpSender, + mut sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, - phase_watcher: watch::Receiver<StatePhase>, ) { - let sink = Rc::new(RefCell::new(sink)); - let packet_receiver = Rc::new(RefCell::new(packet_receiver)); - run_until_disconnection( - || async { packet_receiver.borrow_mut().recv().await }, - |packet| async { - sink.borrow_mut().send(packet).await.unwrap(); - }, - || async { - sink.borrow_mut().close().await.unwrap(); - }, - phase_watcher, - ) - .await; + loop { + let packet = packet_receiver.recv().await.unwrap(); + sink.send(packet).await.unwrap(); + } +} - debug!("TCP packet sender killed"); +async fn send_voice( + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, + phase_watcher: watch::Receiver<StatePhase>, +) { + loop { + let mut inner_phase_watcher = phase_watcher.clone(); + loop { + inner_phase_watcher.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + async { + loop { + packet_sender.send( + receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); + } + }, + inner_phase_watcher.clone(), + ).await; + } } async fn listen( state: Arc<Mutex<State>>, - stream: TcpReceiver, + mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, - phase_watcher: watch::Receiver<StatePhase>, ) { - let crypt_state = Rc::new(RefCell::new(None)); - let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); + let mut crypt_state = None; + let mut crypt_state_sender = Some(crypt_state_sender); - let stream = Rc::new(RefCell::new(stream)); - run_until_disconnection( - || async { stream.borrow_mut().next().await }, - |packet| async { - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); + loop { + let packet = stream.next().await.unwrap(); + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + crypt_state = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.take() { + let _ = sender + .send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.borrow_mut().take() { - let _ = sender - .send( - crypt_state - .borrow_mut() - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(TcpEventData::Connected(&msg)); } - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Connected(&msg)); - } - } - let mut state = state.lock().unwrap(); - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), - } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); - } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); } - ControlPacket::UserState(msg) => { - state.lock().unwrap().parse_user_state(*msg); + let mut state = state.lock().await; + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), } - ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - packet => { - debug!("Received unhandled ControlPacket {:#?}", packet); + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + state.lock().await.parse_user_state(*msg); + } + ControlPacket::UserRemove(msg) => { + state.lock().await.remove_client(*msg); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .await + .server_mut() + .unwrap() + .parse_channel_remove(*msg); } - }, - || async { - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { - let old = std::mem::take(vec); - for handler in old { - handler(TcpEventData::Disconnected); + ControlPacket::UDPTunnel(msg) => { + match *msg { + VoicePacket::Ping { .. } => {} + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() + .await + .audio() + .decode_packet_payload( + VoiceStreamType::TCP, + session_id, + payload); + } } } - }, - phase_watcher, - ) - .await; - - debug!("Killing TCP listener block"); + packet => { + debug!("Received unhandled ControlPacket {:#?}", packet); + } + } + } } async fn register_events( tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, - phase_watcher: watch::Receiver<StatePhase>, ) { - let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); - run_until_disconnection( - || async { tcp_event_register_receiver.borrow_mut().recv().await }, - |(event, handler)| async { - event_data - .lock() - .unwrap() - .entry(event) - .or_default() - .push(handler); - }, - || async {}, - phase_watcher, - ) - .await; -} - -async fn run_until_disconnection<T, F, G, H>( - mut generator: impl FnMut() -> F, - mut handler: impl FnMut(T) -> G, - mut shutdown: impl FnMut() -> H, - mut phase_watcher: watch::Receiver<StatePhase>, -) where - F: Future<Output = Option<T>>, - G: Future<Output = ()>, - H: Future<Output = ()>, -{ - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - loop { - phase_watcher.changed().await.unwrap(); - if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) { - break; - } - } - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = generator().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - //warn!("Channel closed before disconnect command"); //TODO make me informative - break; - } - Some(Some(data)) => { - handler(data).await; - } - } - } - - shutdown().await; - }; - - join!(main_block, phase_transition_block); + loop { + let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); + event_data + .lock() + .await + .entry(event) + .or_default() + .push(handler); + } } |
