aboutsummaryrefslogtreecommitdiffstats
path: root/mumd
diff options
context:
space:
mode:
Diffstat (limited to 'mumd')
-rw-r--r--mumd/src/client.rs25
-rw-r--r--mumd/src/error.rs28
-rw-r--r--mumd/src/main.rs20
-rw-r--r--mumd/src/network.rs19
-rw-r--r--mumd/src/network/tcp.rs46
5 files changed, 90 insertions, 48 deletions
diff --git a/mumd/src/client.rs b/mumd/src/client.rs
index 7c1b0b7..c1a0152 100644
--- a/mumd/src/client.rs
+++ b/mumd/src/client.rs
@@ -1,11 +1,13 @@
use crate::command;
+use crate::error::ClientError;
use crate::network::{tcp, udp, ConnectionInfo};
use crate::state::State;
+use futures_util::{select, FutureExt};
use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState};
use mumlib::command::{Command, CommandResponse};
use std::sync::Arc;
-use tokio::{join, sync::{Mutex, mpsc, oneshot, watch}};
+use tokio::sync::{Mutex, mpsc, oneshot, watch};
pub async fn handle(
state: State,
@@ -13,7 +15,7 @@ pub async fn handle(
Command,
oneshot::Sender<mumlib::error::Result<Option<CommandResponse>>>,
)>,
-) {
+) -> Result<(), ClientError> {
let (connection_info_sender, connection_info_receiver) =
watch::channel::<Option<ConnectionInfo>>(None);
let (crypt_state_sender, crypt_state_receiver) =
@@ -27,29 +29,28 @@ pub async fn handle(
let state = Arc::new(Mutex::new(state));
- //TODO report error here
- let (_, _, _, _) = join!(
- tcp::handle(
+ select! {
+ r = tcp::handle(
Arc::clone(&state),
connection_info_receiver.clone(),
crypt_state_sender,
packet_sender.clone(),
packet_receiver,
response_receiver,
- ),
- udp::handle(
+ ).fuse() => r.map_err(|e| ClientError::TcpError(e)),
+ _ = udp::handle(
Arc::clone(&state),
connection_info_receiver.clone(),
crypt_state_receiver,
- ),
- command::handle(
+ ).fuse() => Ok(()),
+ _ = command::handle(
state,
command_receiver,
response_sender,
ping_request_sender,
packet_sender,
connection_info_sender,
- ),
- udp::handle_pings(ping_request_receiver),
- );
+ ).fuse() => Ok(()),
+ _ = udp::handle_pings(ping_request_receiver).fuse() => Ok(()),
+ }
}
diff --git a/mumd/src/error.rs b/mumd/src/error.rs
index e4a8fee..142e806 100644
--- a/mumd/src/error.rs
+++ b/mumd/src/error.rs
@@ -12,6 +12,21 @@ pub enum TcpError {
IOError(std::io::Error),
}
+impl fmt::Display for TcpError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ TcpError::NoConnectionInfoReceived
+ => write!(f, "No connection info received"),
+ TcpError::TlsConnectorBuilderError(e)
+ => write!(f, "Error building TLS connector: {}", e),
+ TcpError::TlsConnectError(e)
+ => write!(f, "TLS error when connecting: {}", e),
+ TcpError::SendError(e) => write!(f, "Couldn't send packet: {}", e),
+ TcpError::IOError(e) => write!(f, "IO error: {}", e),
+ }
+ }
+}
+
impl From<std::io::Error> for TcpError {
fn from(e: std::io::Error) -> Self {
TcpError::IOError(e)
@@ -37,6 +52,18 @@ impl From<std::io::Error> for UdpError {
}
}
+pub enum ClientError {
+ TcpError(TcpError),
+}
+
+impl fmt::Display for ClientError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ ClientError::TcpError(e) => write!(f, "TCP error: {}", e),
+ }
+ }
+}
+
pub enum AudioStream {
Input,
Output,
@@ -96,4 +123,3 @@ impl fmt::Display for StateError {
}
}
}
-
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index cd53d4a..d7bc2c0 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -8,11 +8,11 @@ mod state;
use crate::state::State;
-use futures_util::{SinkExt, StreamExt};
+use futures_util::{select, FutureExt, SinkExt, StreamExt};
use log::*;
use mumlib::command::{Command, CommandResponse};
use mumlib::setup_logger;
-use tokio::{join, net::{UnixListener, UnixStream}, sync::{mpsc, oneshot}};
+use tokio::{net::{UnixListener, UnixStream}, sync::{mpsc, oneshot}};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use bytes::{BufMut, BytesMut};
@@ -64,10 +64,18 @@ async fn main() {
}
};
- join!(
- client::handle(state, command_receiver),
- receive_commands(command_sender),
- );
+ let run = select! {
+ r = client::handle(state, command_receiver).fuse() => r,
+ _ = receive_commands(command_sender).fuse() => Ok(()),
+ };
+
+ match run {
+ Err(e) => {
+ error!("mumd: {}", e);
+ std::process::exit(1);
+ }
+ _ => {}
+ }
}
async fn receive_commands(
diff --git a/mumd/src/network.rs b/mumd/src/network.rs
index 38a97ce..4eca90d 100644
--- a/mumd/src/network.rs
+++ b/mumd/src/network.rs
@@ -4,7 +4,7 @@ pub mod udp;
use futures_util::FutureExt;
use log::*;
use std::{future::Future, net::SocketAddr};
-use tokio::{join, select, sync::{oneshot, watch}};
+use tokio::{select, sync::{oneshot, watch}};
use crate::state::StatePhase;
@@ -31,12 +31,12 @@ pub enum VoiceStreamType {
UDP,
}
-async fn run_until<F>(
+async fn run_until<F, R>(
phase_checker: impl Fn(StatePhase) -> bool,
fut: F,
mut phase_watcher: watch::Receiver<StatePhase>,
-) where
- F: Future<Output = ()>,
+) -> Option<R>
+ where F: Future<Output = R>,
{
let (tx, rx) = oneshot::channel();
let phase_transition_block = async {
@@ -55,10 +55,13 @@ async fn run_until<F>(
let rx = rx.fuse();
let fut = fut.fuse();
select! {
- _ = fut => (),
- _ = rx => (),
- };
+ r = fut => Some(r),
+ _ = rx => None,
+ }
};
- join!(main_block, phase_transition_block);
+ select! {
+ m = main_block => m,
+ _ = phase_transition_block => None,
+ }
}
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 6460cba..9b0b68e 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -4,6 +4,7 @@ use crate::state::{State, StatePhase};
use log::*;
use futures_util::{FutureExt, SinkExt, StreamExt};
+use futures_util::select;
use futures_util::stream::{SplitSink, SplitStream, Stream};
use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket};
use mumble_protocol::crypt::ClientCryptState;
@@ -20,7 +21,6 @@ use tokio_native_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
use super::{run_until, VoiceStreamType};
-use futures_util::future::join5;
type TcpSender = SplitSink<
Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
@@ -114,27 +114,30 @@ pub async fn handle(
info!("Logging in...");
+ let phase_watcher_inner = phase_watcher.clone();
+
run_until(
|phase| matches!(phase, StatePhase::Disconnected),
- //TODO take out the errors here and return them
- join5(
- send_pings(packet_sender.clone(), 10),
- listen(
- Arc::clone(&state),
- stream,
- crypt_state_sender.clone(),
- event_queue.clone(),
- ),
- send_voice(
- packet_sender.clone(),
- Arc::clone(&input_receiver),
- phase_watcher.clone(),
- ),
- send_packets(sink, &mut packet_receiver),
- register_events(&mut tcp_event_register_receiver, event_queue.clone()),
- ).map(|_| ()),
+ async {
+ select! {
+ r = send_pings(packet_sender.clone(), 10).fuse() => r,
+ r = listen(
+ Arc::clone(&state),
+ stream,
+ crypt_state_sender.clone(),
+ event_queue.clone(),
+ ).fuse() => r,
+ r = send_voice(
+ packet_sender.clone(),
+ Arc::clone(&input_receiver),
+ phase_watcher_inner,
+ ).fuse() => r,
+ r = send_packets(sink, &mut packet_receiver).fuse() => r,
+ _ = register_events(&mut tcp_event_register_receiver, event_queue.clone()).fuse() => Ok(()),
+ }
+ },
phase_watcher,
- ).await;
+ ).await.unwrap_or(Ok(()))?;
event_queue.resolve(TcpEventData::Disconnected).await;
@@ -209,7 +212,7 @@ async fn send_voice(
packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
phase_watcher: watch::Receiver<StatePhase>,
-) {
+) -> Result<(), TcpError> {
loop {
let mut inner_phase_watcher = phase_watcher.clone();
loop {
@@ -243,7 +246,7 @@ async fn listen(
mut stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
event_queue: TcpEventQueue,
-) {
+) -> Result<(), TcpError> {
let mut crypt_state = None;
let mut crypt_state_sender = Some(crypt_state_sender);
@@ -369,6 +372,7 @@ async fn listen(
}
}
}
+ Ok(())
}
async fn register_events(