aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src
diff options
context:
space:
mode:
Diffstat (limited to 'mumd/src')
-rw-r--r--mumd/src/command.rs34
-rw-r--r--mumd/src/network/tcp.rs58
-rw-r--r--mumd/src/state.rs47
3 files changed, 91 insertions, 48 deletions
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index 410751a..2069178 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -4,10 +4,7 @@ use crate::state::{ExecutionContext, State};
use log::*;
use mumble_protocol::{control::ControlPacket, Serverbound};
use mumlib::command::{Command, CommandResponse};
-use std::sync::{
- atomic::{AtomicU64, Ordering},
- Arc, RwLock,
-};
+use std::{rc::Rc, sync::{Arc, RwLock, atomic::{AtomicBool, AtomicU64, Ordering}}};
use tokio::sync::{mpsc, watch};
pub async fn handle(
@@ -32,16 +29,25 @@ pub async fn handle(
&mut connection_info_sender,
);
match event {
- ExecutionContext::TcpEventCallback(event, generator) => {
- tcp_event_queue.register_callback(
- event,
- Box::new(move |e| {
- let response = generator(e);
- for response in response {
- response_sender.send(response).unwrap();
- }
- }),
- );
+ ExecutionContext::TcpEventCallback(callbacks) => {
+ // A shared bool ensures that only one of the supplied callbacks is run.
+ let should_handle = Rc::new(AtomicBool::new(true));
+ for (event, generator) in callbacks {
+ let should_handle = Rc::clone(&should_handle);
+ let response_sender = response_sender.clone();
+ tcp_event_queue.register_callback(
+ event,
+ Box::new(move |e| {
+ // If should_handle == true no other callback has been run yet.
+ if should_handle.swap(false, Ordering::Relaxed) {
+ let response = generator(e);
+ for response in response {
+ response_sender.send(response).unwrap();
+ }
+ }
+ }),
+ );
+ }
}
ExecutionContext::TcpEventSubscriber(event, mut handler) => tcp_event_queue
.register_subscriber(
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 4a753bf..f620a32 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -1,14 +1,12 @@
+use crate::error::{ServerSendError, TcpError};
use crate::network::ConnectionInfo;
+use crate::notifications;
use crate::state::{State, StatePhase};
-use crate::{
- error::{ServerSendError, TcpError},
- notifications,
-};
-use log::*;
use futures_util::select;
use futures_util::stream::{SplitSink, SplitStream, Stream};
use futures_util::{FutureExt, SinkExt, StreamExt};
+use log::*;
use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket};
use mumble_protocol::crypt::ClientCryptState;
use mumble_protocol::voice::VoicePacket;
@@ -36,17 +34,33 @@ type TcpReceiver =
pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData)>;
pub(crate) type TcpEventSubscriber = Box<dyn FnMut(TcpEventData) -> bool>; //the bool indicates if it should be kept or not
-#[derive(Debug, Clone, Hash, Eq, PartialEq)]
+/// Why the TCP was disconnected.
+#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
+pub enum DisconnectedReason {
+ InvalidTls,
+ User,
+ TcpError,
+}
+
+/// Something a callback can register to. Data is sent via a respective [TcpEventData].
+#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub enum TcpEvent {
Connected, //fires when the client has connected to a server
- Disconnected, //fires when the client has disconnected from a server
+ Disconnected(DisconnectedReason), //fires when the client has disconnected from a server
TextMessage, //fires when a text message comes in
}
+/// When a [TcpEvent] occurs, this contains the data for the event.
+///
+/// The events are picked up by a [crate::state::ExecutionContext].
+///
+/// Having two different types might feel a bit confusing. Essentially, a
+/// callback _registers_ to a [TcpEvent] but _takes_ a [TcpEventData] as
+/// parameter.
#[derive(Clone)]
pub enum TcpEventData<'a> {
Connected(Result<&'a msgs::ServerSync, mumlib::Error>),
- Disconnected,
+ Disconnected(DisconnectedReason),
TextMessage(&'a msgs::TextMessage),
}
@@ -54,7 +68,7 @@ impl<'a> From<&TcpEventData<'a>> for TcpEvent {
fn from(t: &TcpEventData) -> Self {
match t {
TcpEventData::Connected(_) => TcpEvent::Connected,
- TcpEventData::Disconnected => TcpEvent::Disconnected,
+ TcpEventData::Disconnected(reason) => TcpEvent::Disconnected(*reason),
TcpEventData::TextMessage(_) => TcpEvent::TextMessage,
}
}
@@ -142,12 +156,25 @@ pub async fn handle(
}
return Err(TcpError::NoConnectionInfoReceived);
};
- let (mut sink, stream) = connect(
+ let connect_result = connect(
connection_info.socket_addr,
connection_info.hostname,
connection_info.accept_invalid_cert,
)
- .await?;
+ .await;
+
+ let (mut sink, stream) = match connect_result {
+ Ok(ok) => ok,
+ Err(TcpError::TlsConnectError(_)) => {
+ warn!("Invalid TLS");
+ state.read().unwrap().broadcast_phase(StatePhase::Disconnected);
+ event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::InvalidTls));
+ continue;
+ }
+ Err(e) => {
+ return Err(e);
+ }
+ };
// Handshake (omitting `Version` message for brevity)
let (username, password) = {
@@ -170,7 +197,7 @@ pub async fn handle(
let phase_watcher_inner = phase_watcher.clone();
- run_until(
+ let result = run_until(
|phase| matches!(phase, StatePhase::Disconnected),
async {
select! {
@@ -192,9 +219,12 @@ pub async fn handle(
phase_watcher,
)
.await
- .unwrap_or(Ok(()))?;
+ .unwrap_or(Ok(()));
- event_queue.resolve(TcpEventData::Disconnected);
+ match result {
+ Ok(()) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::User)),
+ Err(_) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::TcpError)),
+ }
debug!("Fully disconnected TCP stream, waiting for new connection info");
}
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index d12b5b6..d2d77b1 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -2,7 +2,7 @@ pub mod channel;
pub mod server;
pub mod user;
-use crate::audio::{AudioInput, AudioOutput, NotificationEvents};
+use crate::{audio::{AudioInput, AudioOutput, NotificationEvents}, network::tcp::DisconnectedReason};
use crate::error::StateError;
use crate::network::tcp::{TcpEvent, TcpEventData};
use crate::network::{ConnectionInfo, VoiceStreamType};
@@ -27,8 +27,10 @@ use std::{
use tokio::sync::{mpsc, watch};
macro_rules! at {
- ($event:expr, $generator:expr) => {
- ExecutionContext::TcpEventCallback($event, Box::new($generator))
+ ( $( $event:expr => $generator:expr ),+ $(,)? ) => {
+ ExecutionContext::TcpEventCallback(vec![
+ $( ($event, Box::new($generator)), )*
+ ])
};
}
@@ -42,7 +44,7 @@ type Responses = Box<dyn Iterator<Item = mumlib::error::Result<Option<CommandRes
//TODO give me a better name
pub enum ExecutionContext {
- TcpEventCallback(TcpEvent, Box<dyn FnOnce(TcpEventData) -> Responses>),
+ TcpEventCallback(Vec<(TcpEvent, Box<dyn FnOnce(TcpEventData) -> Responses>)>),
TcpEventSubscriber(
TcpEvent,
Box<
@@ -606,23 +608,28 @@ pub fn handle_command(
)))
.unwrap();
let state = Arc::clone(&og_state);
- at!(TcpEvent::Connected, move |res| {
- //runs the closure when the client is connected
- if let TcpEventData::Connected(res) = res {
- Box::new(iter::once(res.map(|msg| {
- Some(CommandResponse::ServerConnect {
- welcome_message: if msg.has_welcome_text() {
- Some(msg.get_welcome_text().to_string())
- } else {
- None
- },
- server_state: state.read().unwrap().server.as_ref().unwrap().into(),
- })
- })))
- } else {
- unreachable!("callback should be provided with a TcpEventData::Connected");
+ at!(
+ TcpEvent::Connected => move |res| {
+ //runs the closure when the client is connected
+ if let TcpEventData::Connected(res) = res {
+ Box::new(iter::once(res.map(|msg| {
+ Some(CommandResponse::ServerConnect {
+ welcome_message: if msg.has_welcome_text() {
+ Some(msg.get_welcome_text().to_string())
+ } else {
+ None
+ },
+ server_state: state.read().unwrap().server.as_ref().unwrap().into(),
+ })
+ })))
+ } else {
+ unreachable!("callback should be provided with a TcpEventData::Connected");
+ }
+ },
+ TcpEvent::Disconnected(DisconnectedReason::InvalidTls) => |_| {
+ Box::new(iter::once(Err(Error::ServerCertReject)))
}
- })
+ )
}
Command::ServerDisconnect => {
if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) {