diff options
| author | Gustav Sörnäs <gustav@sornas.net> | 2021-06-12 06:37:06 +0200 |
|---|---|---|
| committer | Gustav Sörnäs <gustav@sornas.net> | 2021-06-12 06:37:06 +0200 |
| commit | cf3f8c185cede889faccd3d55655a494ccd6f707 (patch) | |
| tree | 149bb196e2a16cb8d297d03fc16f56f03c84dcfc /mumd/src | |
| parent | dcd70175a98c83a3334d7980e5196bc866e04efb (diff) | |
| parent | b7701a6f61b525c116e29981f122a58552751f78 (diff) | |
| download | mum-cf3f8c185cede889faccd3d55655a494ccd6f707.tar.gz | |
Merge remote-tracking branch 'origin/invalid-cert'
Diffstat (limited to 'mumd/src')
| -rw-r--r-- | mumd/src/command.rs | 34 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 58 | ||||
| -rw-r--r-- | mumd/src/state.rs | 47 |
3 files changed, 91 insertions, 48 deletions
diff --git a/mumd/src/command.rs b/mumd/src/command.rs index 410751a..2069178 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -4,10 +4,7 @@ use crate::state::{ExecutionContext, State}; use log::*; use mumble_protocol::{control::ControlPacket, Serverbound}; use mumlib::command::{Command, CommandResponse}; -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, RwLock, -}; +use std::{rc::Rc, sync::{Arc, RwLock, atomic::{AtomicBool, AtomicU64, Ordering}}}; use tokio::sync::{mpsc, watch}; pub async fn handle( @@ -32,16 +29,25 @@ pub async fn handle( &mut connection_info_sender, ); match event { - ExecutionContext::TcpEventCallback(event, generator) => { - tcp_event_queue.register_callback( - event, - Box::new(move |e| { - let response = generator(e); - for response in response { - response_sender.send(response).unwrap(); - } - }), - ); + ExecutionContext::TcpEventCallback(callbacks) => { + // A shared bool ensures that only one of the supplied callbacks is run. + let should_handle = Rc::new(AtomicBool::new(true)); + for (event, generator) in callbacks { + let should_handle = Rc::clone(&should_handle); + let response_sender = response_sender.clone(); + tcp_event_queue.register_callback( + event, + Box::new(move |e| { + // If should_handle == true no other callback has been run yet. + if should_handle.swap(false, Ordering::Relaxed) { + let response = generator(e); + for response in response { + response_sender.send(response).unwrap(); + } + } + }), + ); + } } ExecutionContext::TcpEventSubscriber(event, mut handler) => tcp_event_queue .register_subscriber( diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 4a753bf..f620a32 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,14 +1,12 @@ +use crate::error::{ServerSendError, TcpError}; use crate::network::ConnectionInfo; +use crate::notifications; use crate::state::{State, StatePhase}; -use crate::{ - error::{ServerSendError, TcpError}, - notifications, -}; -use log::*; use futures_util::select; use futures_util::stream::{SplitSink, SplitStream, Stream}; use futures_util::{FutureExt, SinkExt, StreamExt}; +use log::*; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::VoicePacket; @@ -36,17 +34,33 @@ type TcpReceiver = pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData)>; pub(crate) type TcpEventSubscriber = Box<dyn FnMut(TcpEventData) -> bool>; //the bool indicates if it should be kept or not -#[derive(Debug, Clone, Hash, Eq, PartialEq)] +/// Why the TCP was disconnected. +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub enum DisconnectedReason { + InvalidTls, + User, + TcpError, +} + +/// Something a callback can register to. Data is sent via a respective [TcpEventData]. +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] pub enum TcpEvent { Connected, //fires when the client has connected to a server - Disconnected, //fires when the client has disconnected from a server + Disconnected(DisconnectedReason), //fires when the client has disconnected from a server TextMessage, //fires when a text message comes in } +/// When a [TcpEvent] occurs, this contains the data for the event. +/// +/// The events are picked up by a [crate::state::ExecutionContext]. +/// +/// Having two different types might feel a bit confusing. Essentially, a +/// callback _registers_ to a [TcpEvent] but _takes_ a [TcpEventData] as +/// parameter. #[derive(Clone)] pub enum TcpEventData<'a> { Connected(Result<&'a msgs::ServerSync, mumlib::Error>), - Disconnected, + Disconnected(DisconnectedReason), TextMessage(&'a msgs::TextMessage), } @@ -54,7 +68,7 @@ impl<'a> From<&TcpEventData<'a>> for TcpEvent { fn from(t: &TcpEventData) -> Self { match t { TcpEventData::Connected(_) => TcpEvent::Connected, - TcpEventData::Disconnected => TcpEvent::Disconnected, + TcpEventData::Disconnected(reason) => TcpEvent::Disconnected(*reason), TcpEventData::TextMessage(_) => TcpEvent::TextMessage, } } @@ -142,12 +156,25 @@ pub async fn handle( } return Err(TcpError::NoConnectionInfoReceived); }; - let (mut sink, stream) = connect( + let connect_result = connect( connection_info.socket_addr, connection_info.hostname, connection_info.accept_invalid_cert, ) - .await?; + .await; + + let (mut sink, stream) = match connect_result { + Ok(ok) => ok, + Err(TcpError::TlsConnectError(_)) => { + warn!("Invalid TLS"); + state.read().unwrap().broadcast_phase(StatePhase::Disconnected); + event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::InvalidTls)); + continue; + } + Err(e) => { + return Err(e); + } + }; // Handshake (omitting `Version` message for brevity) let (username, password) = { @@ -170,7 +197,7 @@ pub async fn handle( let phase_watcher_inner = phase_watcher.clone(); - run_until( + let result = run_until( |phase| matches!(phase, StatePhase::Disconnected), async { select! { @@ -192,9 +219,12 @@ pub async fn handle( phase_watcher, ) .await - .unwrap_or(Ok(()))?; + .unwrap_or(Ok(())); - event_queue.resolve(TcpEventData::Disconnected); + match result { + Ok(()) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::User)), + Err(_) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::TcpError)), + } debug!("Fully disconnected TCP stream, waiting for new connection info"); } diff --git a/mumd/src/state.rs b/mumd/src/state.rs index d12b5b6..d2d77b1 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -2,7 +2,7 @@ pub mod channel; pub mod server; pub mod user; -use crate::audio::{AudioInput, AudioOutput, NotificationEvents}; +use crate::{audio::{AudioInput, AudioOutput, NotificationEvents}, network::tcp::DisconnectedReason}; use crate::error::StateError; use crate::network::tcp::{TcpEvent, TcpEventData}; use crate::network::{ConnectionInfo, VoiceStreamType}; @@ -27,8 +27,10 @@ use std::{ use tokio::sync::{mpsc, watch}; macro_rules! at { - ($event:expr, $generator:expr) => { - ExecutionContext::TcpEventCallback($event, Box::new($generator)) + ( $( $event:expr => $generator:expr ),+ $(,)? ) => { + ExecutionContext::TcpEventCallback(vec![ + $( ($event, Box::new($generator)), )* + ]) }; } @@ -42,7 +44,7 @@ type Responses = Box<dyn Iterator<Item = mumlib::error::Result<Option<CommandRes //TODO give me a better name pub enum ExecutionContext { - TcpEventCallback(TcpEvent, Box<dyn FnOnce(TcpEventData) -> Responses>), + TcpEventCallback(Vec<(TcpEvent, Box<dyn FnOnce(TcpEventData) -> Responses>)>), TcpEventSubscriber( TcpEvent, Box< @@ -606,23 +608,28 @@ pub fn handle_command( ))) .unwrap(); let state = Arc::clone(&og_state); - at!(TcpEvent::Connected, move |res| { - //runs the closure when the client is connected - if let TcpEventData::Connected(res) = res { - Box::new(iter::once(res.map(|msg| { - Some(CommandResponse::ServerConnect { - welcome_message: if msg.has_welcome_text() { - Some(msg.get_welcome_text().to_string()) - } else { - None - }, - server_state: state.read().unwrap().server.as_ref().unwrap().into(), - }) - }))) - } else { - unreachable!("callback should be provided with a TcpEventData::Connected"); + at!( + TcpEvent::Connected => move |res| { + //runs the closure when the client is connected + if let TcpEventData::Connected(res) = res { + Box::new(iter::once(res.map(|msg| { + Some(CommandResponse::ServerConnect { + welcome_message: if msg.has_welcome_text() { + Some(msg.get_welcome_text().to_string()) + } else { + None + }, + server_state: state.read().unwrap().server.as_ref().unwrap().into(), + }) + }))) + } else { + unreachable!("callback should be provided with a TcpEventData::Connected"); + } + }, + TcpEvent::Disconnected(DisconnectedReason::InvalidTls) => |_| { + Box::new(iter::once(Err(Error::ServerCertReject))) } - }) + ) } Command::ServerDisconnect => { if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { |
