From 1d331f0707eaa4a056aa6261410fb1edb63097b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Tue, 30 Mar 2021 16:15:53 +0200 Subject: report tcp errors all the way --- mumd/src/client.rs | 25 +++++++++++++------------ mumd/src/error.rs | 28 +++++++++++++++++++++++++++- mumd/src/main.rs | 20 ++++++++++++++------ mumd/src/network.rs | 19 +++++++++++-------- mumd/src/network/tcp.rs | 46 +++++++++++++++++++++++++--------------------- 5 files changed, 90 insertions(+), 48 deletions(-) diff --git a/mumd/src/client.rs b/mumd/src/client.rs index 7c1b0b7..c1a0152 100644 --- a/mumd/src/client.rs +++ b/mumd/src/client.rs @@ -1,11 +1,13 @@ use crate::command; +use crate::error::ClientError; use crate::network::{tcp, udp, ConnectionInfo}; use crate::state::State; +use futures_util::{select, FutureExt}; use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState}; use mumlib::command::{Command, CommandResponse}; use std::sync::Arc; -use tokio::{join, sync::{Mutex, mpsc, oneshot, watch}}; +use tokio::sync::{Mutex, mpsc, oneshot, watch}; pub async fn handle( state: State, @@ -13,7 +15,7 @@ pub async fn handle( Command, oneshot::Sender>>, )>, -) { +) -> Result<(), ClientError> { let (connection_info_sender, connection_info_receiver) = watch::channel::>(None); let (crypt_state_sender, crypt_state_receiver) = @@ -27,29 +29,28 @@ pub async fn handle( let state = Arc::new(Mutex::new(state)); - //TODO report error here - let (_, _, _, _) = join!( - tcp::handle( + select! { + r = tcp::handle( Arc::clone(&state), connection_info_receiver.clone(), crypt_state_sender, packet_sender.clone(), packet_receiver, response_receiver, - ), - udp::handle( + ).fuse() => r.map_err(|e| ClientError::TcpError(e)), + _ = udp::handle( Arc::clone(&state), connection_info_receiver.clone(), crypt_state_receiver, - ), - command::handle( + ).fuse() => Ok(()), + _ = command::handle( state, command_receiver, response_sender, ping_request_sender, packet_sender, connection_info_sender, - ), - udp::handle_pings(ping_request_receiver), - ); + ).fuse() => Ok(()), + _ = udp::handle_pings(ping_request_receiver).fuse() => Ok(()), + } } diff --git a/mumd/src/error.rs b/mumd/src/error.rs index e4a8fee..142e806 100644 --- a/mumd/src/error.rs +++ b/mumd/src/error.rs @@ -12,6 +12,21 @@ pub enum TcpError { IOError(std::io::Error), } +impl fmt::Display for TcpError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TcpError::NoConnectionInfoReceived + => write!(f, "No connection info received"), + TcpError::TlsConnectorBuilderError(e) + => write!(f, "Error building TLS connector: {}", e), + TcpError::TlsConnectError(e) + => write!(f, "TLS error when connecting: {}", e), + TcpError::SendError(e) => write!(f, "Couldn't send packet: {}", e), + TcpError::IOError(e) => write!(f, "IO error: {}", e), + } + } +} + impl From for TcpError { fn from(e: std::io::Error) -> Self { TcpError::IOError(e) @@ -37,6 +52,18 @@ impl From for UdpError { } } +pub enum ClientError { + TcpError(TcpError), +} + +impl fmt::Display for ClientError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ClientError::TcpError(e) => write!(f, "TCP error: {}", e), + } + } +} + pub enum AudioStream { Input, Output, @@ -96,4 +123,3 @@ impl fmt::Display for StateError { } } } - diff --git a/mumd/src/main.rs b/mumd/src/main.rs index cd53d4a..d7bc2c0 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -8,11 +8,11 @@ mod state; use crate::state::State; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{select, FutureExt, SinkExt, StreamExt}; use log::*; use mumlib::command::{Command, CommandResponse}; use mumlib::setup_logger; -use tokio::{join, net::{UnixListener, UnixStream}, sync::{mpsc, oneshot}}; +use tokio::{net::{UnixListener, UnixStream}, sync::{mpsc, oneshot}}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; use bytes::{BufMut, BytesMut}; @@ -64,10 +64,18 @@ async fn main() { } }; - join!( - client::handle(state, command_receiver), - receive_commands(command_sender), - ); + let run = select! { + r = client::handle(state, command_receiver).fuse() => r, + _ = receive_commands(command_sender).fuse() => Ok(()), + }; + + match run { + Err(e) => { + error!("mumd: {}", e); + std::process::exit(1); + } + _ => {} + } } async fn receive_commands( diff --git a/mumd/src/network.rs b/mumd/src/network.rs index 38a97ce..4eca90d 100644 --- a/mumd/src/network.rs +++ b/mumd/src/network.rs @@ -4,7 +4,7 @@ pub mod udp; use futures_util::FutureExt; use log::*; use std::{future::Future, net::SocketAddr}; -use tokio::{join, select, sync::{oneshot, watch}}; +use tokio::{select, sync::{oneshot, watch}}; use crate::state::StatePhase; @@ -31,12 +31,12 @@ pub enum VoiceStreamType { UDP, } -async fn run_until( +async fn run_until( phase_checker: impl Fn(StatePhase) -> bool, fut: F, mut phase_watcher: watch::Receiver, -) where - F: Future, +) -> Option + where F: Future, { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { @@ -55,10 +55,13 @@ async fn run_until( let rx = rx.fuse(); let fut = fut.fuse(); select! { - _ = fut => (), - _ = rx => (), - }; + r = fut => Some(r), + _ = rx => None, + } }; - join!(main_block, phase_transition_block); + select! { + m = main_block => m, + _ = phase_transition_block => None, + } } diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 6460cba..9b0b68e 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -4,6 +4,7 @@ use crate::state::{State, StatePhase}; use log::*; use futures_util::{FutureExt, SinkExt, StreamExt}; +use futures_util::select; use futures_util::stream::{SplitSink, SplitStream, Stream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; @@ -20,7 +21,6 @@ use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; use super::{run_until, VoiceStreamType}; -use futures_util::future::join5; type TcpSender = SplitSink< Framed, ControlCodec>, @@ -114,27 +114,30 @@ pub async fn handle( info!("Logging in..."); + let phase_watcher_inner = phase_watcher.clone(); + run_until( |phase| matches!(phase, StatePhase::Disconnected), - //TODO take out the errors here and return them - join5( - send_pings(packet_sender.clone(), 10), - listen( - Arc::clone(&state), - stream, - crypt_state_sender.clone(), - event_queue.clone(), - ), - send_voice( - packet_sender.clone(), - Arc::clone(&input_receiver), - phase_watcher.clone(), - ), - send_packets(sink, &mut packet_receiver), - register_events(&mut tcp_event_register_receiver, event_queue.clone()), - ).map(|_| ()), + async { + select! { + r = send_pings(packet_sender.clone(), 10).fuse() => r, + r = listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + event_queue.clone(), + ).fuse() => r, + r = send_voice( + packet_sender.clone(), + Arc::clone(&input_receiver), + phase_watcher_inner, + ).fuse() => r, + r = send_packets(sink, &mut packet_receiver).fuse() => r, + _ = register_events(&mut tcp_event_register_receiver, event_queue.clone()).fuse() => Ok(()), + } + }, phase_watcher, - ).await; + ).await.unwrap_or(Ok(()))?; event_queue.resolve(TcpEventData::Disconnected).await; @@ -209,7 +212,7 @@ async fn send_voice( packet_sender: mpsc::UnboundedSender>, receiver: Arc> + Unpin)>>>, phase_watcher: watch::Receiver, -) { +) -> Result<(), TcpError> { loop { let mut inner_phase_watcher = phase_watcher.clone(); loop { @@ -243,7 +246,7 @@ async fn listen( mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender, event_queue: TcpEventQueue, -) { +) -> Result<(), TcpError> { let mut crypt_state = None; let mut crypt_state_sender = Some(crypt_state_sender); @@ -369,6 +372,7 @@ async fn listen( } } } + Ok(()) } async fn register_events( -- cgit v1.2.1