diff options
| author | Gustav Sörnäs <gustav@sornas.net> | 2021-03-31 21:51:47 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-03-31 21:51:47 +0200 |
| commit | 3f6281020b72ba949147a282c18c60a2842ad3dc (patch) | |
| tree | 0ba20ba532d325bf072969013fe8cf5bde84f6ba /mumd/src/network/tcp.rs | |
| parent | 795e46c98616801c678bd0a403b08cb0fcd5ee43 (diff) | |
| parent | 46a3938b6d9d81649e38e6e793599a52991d803d (diff) | |
| download | mum-3f6281020b72ba949147a282c18c60a2842ad3dc.tar.gz | |
Merge pull request #42 from mum-rs/handle-panics
Diffstat (limited to 'mumd/src/network/tcp.rs')
| -rw-r--r-- | mumd/src/network/tcp.rs | 87 |
1 files changed, 46 insertions, 41 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 47b1c20..6402a89 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,8 +1,10 @@ +use crate::error::{ServerSendError, TcpError}; use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; use futures_util::{FutureExt, SinkExt, StreamExt}; +use futures_util::select; use futures_util::stream::{SplitSink, SplitStream, Stream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; @@ -19,7 +21,6 @@ use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; use super::{run_until, VoiceStreamType}; -use futures_util::future::join5; type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, @@ -84,7 +85,7 @@ pub async fn handle( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, -) { +) -> Result<(), TcpError> { loop { let connection_info = 'data: loop { while connection_info_receiver.changed().await.is_ok() { @@ -92,20 +93,20 @@ pub async fn handle( break 'data data; } } - return; + return Err(TcpError::NoConnectionInfoReceived); }; let (mut sink, stream) = connect( connection_info.socket_addr, connection_info.hostname, connection_info.accept_invalid_cert, ) - .await; + .await?; // Handshake (omitting `Version` message for brevity) let state_lock = state.lock().await; let username = state_lock.username().unwrap().to_string(); let password = state_lock.password().map(|x| x.to_string()); - authenticate(&mut sink, username, password).await; + authenticate(&mut sink, username, password).await?; let phase_watcher = state_lock.phase_receiver(); let input_receiver = state_lock.audio().input_receiver(); drop(state_lock); @@ -113,26 +114,30 @@ pub async fn handle( info!("Logging in..."); + let phase_watcher_inner = phase_watcher.clone(); + run_until( |phase| matches!(phase, StatePhase::Disconnected), - join5( - send_pings(packet_sender.clone(), 10), - listen( - Arc::clone(&state), - stream, - crypt_state_sender.clone(), - event_queue.clone(), - ), - send_voice( - packet_sender.clone(), - Arc::clone(&input_receiver), - phase_watcher.clone(), - ), - send_packets(sink, &mut packet_receiver), - register_events(&mut tcp_event_register_receiver, event_queue.clone()), - ).map(|_| ()), + async { + select! { + r = send_pings(packet_sender.clone(), 10).fuse() => r, + r = listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + event_queue.clone(), + ).fuse() => r, + r = send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher_inner, + ).fuse() => r, + r = send_packets(sink, &mut packet_receiver).fuse() => r, + _ = register_events(&mut tcp_event_register_receiver, event_queue.clone()).fuse() => Ok(()), + } + }, phase_watcher, - ).await; + ).await.unwrap_or(Ok(()))?; event_queue.resolve(TcpEventData::Disconnected).await; @@ -144,62 +149,62 @@ async fn connect( server_addr: SocketAddr, server_host: String, accept_invalid_cert: bool, -) -> (TcpSender, TcpReceiver) { - let stream = TcpStream::connect(&server_addr) - .await - .expect("failed to connect to server:"); +) -> Result<(TcpSender, TcpReceiver), TcpError> { + let stream = TcpStream::connect(&server_addr).await?; debug!("TCP connected"); let mut builder = native_tls::TlsConnector::builder(); builder.danger_accept_invalid_certs(accept_invalid_cert); let connector: TlsConnector = builder .build() - .expect("failed to create TLS connector") + .map_err(|e| TcpError::TlsConnectorBuilderError(e))? .into(); let tls_stream = connector .connect(&server_host, stream) .await - .expect("failed to connect TLS: {}"); + .map_err(|e| TcpError::TlsConnectError(e))?; debug!("TLS connected"); // Wrap the TLS stream with Mumble's client-side control-channel codec - ClientControlCodec::new().framed(tls_stream).split() + Ok(ClientControlCodec::new().framed(tls_stream).split()) } async fn authenticate( sink: &mut TcpSender, username: String, password: Option<String> -) { +) -> Result<(), TcpError> { let mut msg = msgs::Authenticate::new(); msg.set_username(username); if let Some(password) = password { msg.set_password(password); } msg.set_opus(true); - sink.send(msg.into()).await.unwrap(); + sink.send(msg.into()).await?; + Ok(()) } async fn send_pings( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, delay_seconds: u64, -) { +) -> Result<(), TcpError> { let mut interval = time::interval(Duration::from_secs(delay_seconds)); loop { interval.tick().await; trace!("Sending TCP ping"); let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); + packet_sender.send(msg.into())?; } } async fn send_packets( mut sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, -) { +) -> Result<(), TcpError> { loop { + // Safe since we always have at least one sender alive. let packet = packet_receiver.recv().await.unwrap(); - sink.send(packet).await.unwrap(); + sink.send(packet).await?; } } @@ -207,7 +212,7 @@ async fn send_voice( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, phase_watcher: watch::Receiver<StatePhase>, -) { +) -> Result<(), TcpError> { loop { let mut inner_phase_watcher = phase_watcher.clone(); loop { @@ -226,13 +231,12 @@ async fn send_voice( .await .next() .await - .unwrap() - .into()) - .unwrap(); + .expect("No audio stream") + .into())?; } }, inner_phase_watcher.clone(), - ).await; + ).await.unwrap_or(Ok::<(), ServerSendError>(()))?; } } @@ -241,7 +245,7 @@ async fn listen( mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, event_queue: TcpEventQueue, -) { +) -> Result<(), TcpError> { let mut crypt_state = None; let mut crypt_state_sender = Some(crypt_state_sender); @@ -367,6 +371,7 @@ async fn listen( } } } + Ok(()) } async fn register_events( |
