From ab0cdc240c65fdc6b764ed17f6611786d449acc3 Mon Sep 17 00:00:00 2001 From: Eskil Queseth Date: Wed, 14 Oct 2020 17:45:04 +0200 Subject: add support for reconnecting to server --- mumd/src/audio.rs | 10 +++-- mumd/src/command.rs | 2 +- mumd/src/main.rs | 19 ++++----- mumd/src/network/tcp.rs | 65 +++++++++++++++--------------- mumd/src/network/udp.rs | 105 ++++++++++++++++++------------------------------ 5 files changed, 91 insertions(+), 110 deletions(-) (limited to 'mumd') diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index e13845e..aa06a9d 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -293,9 +293,13 @@ fn input_callback( .unwrap(); opus_buf.truncate(result); let bytes = Bytes::copy_from_slice(&opus_buf); - input_sender - .try_send(VoicePacketPayload::Opus(bytes, false)) - .unwrap(); //TODO handle full buffer / disconnect + match input_sender + .try_send(VoicePacketPayload::Opus(bytes, false)) { //TODO handle full buffer / disconnect + Ok(_) => {}, + Err(_e) => { + //warn!("Error sending audio packet: {:?}", e); + } + } *buf = tail; } } diff --git a/mumd/src/command.rs b/mumd/src/command.rs index bfdb7dd..1104671 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Mutex}; use tokio::sync::mpsc; use log::*; -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum Command { ChannelJoin { channel_id: u32, diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 812e7a1..6d435fa 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -11,7 +11,6 @@ use argparse::ArgumentParser; use argparse::Store; use argparse::StoreTrue; use colored::*; -use tokio::sync::oneshot; use futures::join; use log::*; use mumble_protocol::control::ControlPacket; @@ -72,16 +71,12 @@ async fn main() { // Oneshot channel for setting UDP CryptState from control task // For simplicity we don't deal with re-syncing, real applications would have to. - let (crypt_state_sender, crypt_state_receiver) = oneshot::channel::(); + let (crypt_state_sender, crypt_state_receiver) = mpsc::channel::(1); // crypt state should always be consumed before sending a new one let (packet_sender, packet_receiver) = mpsc::unbounded_channel::>(); let (command_sender, command_receiver) = mpsc::unbounded_channel::(); let (command_response_sender, command_response_receiver) = mpsc::unbounded_channel::, ()>>(); let (connection_info_sender, connection_info_receiver) = watch::channel::>(None); - command_sender.send(Command::ChannelList).unwrap(); - command_sender.send(Command::ServerConnect{host: server_host, port: server_port, username: username.clone(), accept_invalid_cert}); - //command_sender.send(Command::ChannelJoin{channel_id: 1}).unwrap(); - command_sender.send(Command::ChannelList).unwrap(); let state = State::new(packet_sender, command_sender.clone(), connection_info_sender); let state = Arc::new(Mutex::new(state)); @@ -104,7 +99,8 @@ async fn main() { command_response_sender, ), send_commands( - command_sender + command_sender, + Command::ServerConnect{host: server_host, port: server_port, username: username.clone(), accept_invalid_cert} ), receive_command_responses( command_response_receiver, @@ -112,8 +108,13 @@ async fn main() { ); } -async fn send_commands(command_sender: mpsc::UnboundedSender) { - tokio::time::delay_for(Duration::from_secs(5)).await; +async fn send_commands(command_sender: mpsc::UnboundedSender, connect_command: Command) { + command_sender.send(connect_command.clone()); + tokio::time::delay_for(Duration::from_secs(2)).await; + command_sender.send(Command::ServerDisconnect); + tokio::time::delay_for(Duration::from_secs(2)).await; + command_sender.send(connect_command.clone()); + tokio::time::delay_for(Duration::from_secs(2)).await; command_sender.send(Command::ServerDisconnect); debug!("Finished sending commands"); diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 1e0feee..d45b49d 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,7 +2,6 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use tokio::sync::oneshot; use futures::{join, select, pin_mut, SinkExt, StreamExt, FutureExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; @@ -12,7 +11,7 @@ use std::convert::{Into, TryInto}; use std::net::{SocketAddr}; use std::sync::{Arc, Mutex}; use tokio::net::TcpStream; -use tokio::sync::{mpsc, watch}; +use tokio::sync::{mpsc, watch, oneshot}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; @@ -27,37 +26,39 @@ type TcpReceiver = pub async fn handle( state: Arc>, mut connection_info_receiver: watch::Receiver>, - crypt_state_sender: oneshot::Sender, - packet_receiver: mpsc::UnboundedReceiver>, + crypt_state_sender: mpsc::Sender, + mut packet_receiver: mpsc::UnboundedReceiver>, ) { - let connection_info = loop { - match connection_info_receiver.recv().await { - None => { return; } - Some(None) => {} - Some(Some(connection_info)) => { break connection_info; } - } - }; - let (mut sink, stream) = connect(connection_info.socket_addr, - connection_info.hostname, - connection_info.accept_invalid_cert) - .await; + loop { + let connection_info = loop { + match connection_info_receiver.recv().await { + None => { return; } + Some(None) => {} + Some(Some(connection_info)) => { break connection_info; } + } + }; + let (mut sink, stream) = connect(connection_info.socket_addr, + connection_info.hostname, + connection_info.accept_invalid_cert) + .await; - // Handshake (omitting `Version` message for brevity) - let state_lock = state.lock().unwrap(); - authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; - let phase_watcher = state_lock.phase_receiver(); - let packet_sender = state_lock.packet_sender(); - drop(state_lock); + // Handshake (omitting `Version` message for brevity) + let state_lock = state.lock().unwrap(); + authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; + let phase_watcher = state_lock.phase_receiver(); + let packet_sender = state_lock.packet_sender(); + drop(state_lock); - info!("Logging in..."); + info!("Logging in..."); - join!( - send_pings(packet_sender, 10, phase_watcher.clone()), - listen(state, stream, crypt_state_sender, phase_watcher.clone()), - send_packets(sink, packet_receiver, phase_watcher), - ); + join!( + send_pings(packet_sender, 10, phase_watcher.clone()), + listen(Arc::clone(&state), stream, crypt_state_sender.clone(), phase_watcher.clone()), + send_packets(sink, &mut packet_receiver, phase_watcher), + ); - debug!("Fully disconnected TCP stream"); + debug!("Fully disconnected TCP stream, waiting for new connection info"); + } } async fn connect( @@ -134,7 +135,7 @@ async fn send_pings( async fn send_packets( mut sink: TcpSender, - mut packet_receiver: mpsc::UnboundedReceiver>, + packet_receiver: &mut mpsc::UnboundedReceiver>, mut phase_watcher: watch::Receiver, ) { let (tx, rx) = oneshot::channel(); @@ -181,7 +182,7 @@ async fn send_packets( async fn listen( state: Arc>, mut stream: TcpReceiver, - crypt_state_sender: oneshot::Sender, + crypt_state_sender: mpsc::Sender, mut phase_watcher: watch::Receiver, ) { let mut crypt_state = None; @@ -238,12 +239,12 @@ async fn listen( } ControlPacket::ServerSync(msg) => { info!("Logged in"); - if let Some(sender) = crypt_state_sender.take() { + if let Some(mut sender) = crypt_state_sender.take() { let _ = sender.send( crypt_state .take() .expect("Server didn't send us any CryptSetup packet!"), - ); + ).await; } let mut state = state.lock().unwrap(); let server = state.server_mut(); diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index a757a2b..45e6e80 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -11,14 +11,48 @@ use mumble_protocol::Serverbound; use std::net::{Ipv6Addr, SocketAddr}; use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; -use tokio::sync::{watch, oneshot}; +use tokio::sync::{watch, oneshot, mpsc}; use tokio_util::udp::UdpFramed; type UdpSender = SplitSink, (VoicePacket, SocketAddr)>; type UdpReceiver = SplitStream>; +pub async fn handle( + state: Arc>, + mut connection_info_receiver: watch::Receiver>, + mut crypt_state: mpsc::Receiver, +) { + let mut receiver = state.lock().unwrap().audio_mut().take_receiver().unwrap(); + + loop { + let connection_info = loop { + match connection_info_receiver.recv().await { + None => { return; } + Some(None) => {} + Some(Some(connection_info)) => { break connection_info; } + } + }; + let (mut sink, source) = connect(&mut crypt_state).await; + + // Note: A normal application would also send periodic Ping packets, and its own audio + // via UDP. We instead trick the server into accepting us by sending it one + // dummy voice packet. + send_ping(&mut sink, connection_info.socket_addr).await; + + let sink = Arc::new(Mutex::new(sink)); + + let phase_watcher = state.lock().unwrap().phase_receiver(); + join!( + listen(Arc::clone(&state), source, phase_watcher.clone()), + send_voice(sink, connection_info.socket_addr, phase_watcher, &mut receiver), + ); + + debug!("Fully disconnected UDP stream, waiting for new connection info"); + } +} + pub async fn connect( - crypt_state: oneshot::Receiver, + crypt_state: &mut mpsc::Receiver, ) -> (UdpSender, UdpReceiver) { // Bind UDP socket let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) @@ -26,10 +60,10 @@ pub async fn connect( .expect("Failed to bind UDP socket"); // Wait for initial CryptState - let crypt_state = match crypt_state.await { - Ok(crypt_state) => crypt_state, + let crypt_state = match crypt_state.recv().await { + Some(crypt_state) => crypt_state, // disconnected before we received the CryptSetup packet, oh well - Err(_) => panic!("disconnect before crypt packet received"), //TODO exit gracefully + None => panic!("disconnect before crypt packet received"), //TODO exit gracefully }; debug!("UDP connected"); @@ -37,36 +71,6 @@ pub async fn connect( UdpFramed::new(udp_socket, crypt_state).split() } -pub async fn handle( - state: Arc>, - mut connection_info_receiver: watch::Receiver>, - crypt_state: oneshot::Receiver, -) { - let connection_info = loop { - match connection_info_receiver.recv().await { - None => { return; } - Some(None) => {} - Some(Some(connection_info)) => { break connection_info; } - } - }; - let (mut sink, source) = connect(crypt_state).await; - - // Note: A normal application would also send periodic Ping packets, and its own audio - // via UDP. We instead trick the server into accepting us by sending it one - // dummy voice packet. - send_ping(&mut sink, connection_info.socket_addr).await; - - let sink = Arc::new(Mutex::new(sink)); - - let phase_watcher = state.lock().unwrap().phase_receiver(); - join!( - listen(Arc::clone(&state), source, phase_watcher.clone()), - send_voice(state, sink, connection_info.socket_addr, phase_watcher), - ); - - debug!("Fully disconnected UPD stream"); -} - async fn listen( state: Arc>, mut source: UdpReceiver, @@ -129,33 +133,6 @@ async fn listen( join!(main_block, phase_transition_block); debug!("UDP listener process killed"); - - /*while let Some(packet) = source.next().await { - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { .. } => { - // Note: A normal application would handle these and only use UDP for voice - // once it has received one. - continue; - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - state.lock().unwrap().audio().decode_packet(session_id, payload); - } - } - }*/ } async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { @@ -175,13 +152,11 @@ async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { } async fn send_voice( - state: Arc>, sink: Arc>, server_addr: SocketAddr, mut phase_watcher: watch::Receiver, + receiver: &mut mpsc::Receiver, ) { - let mut receiver = state.lock().unwrap().audio_mut().take_receiver().unwrap(); - let (tx, rx) = oneshot::channel(); let phase_transition_block = async { while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} -- cgit v1.2.1