diff options
Diffstat (limited to 'mumd/src')
| -rw-r--r-- | mumd/src/client.rs | 13 | ||||
| -rw-r--r-- | mumd/src/command.rs | 48 | ||||
| -rw-r--r-- | mumd/src/main.rs | 30 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 76 | ||||
| -rw-r--r-- | mumd/src/state.rs | 781 | ||||
| -rw-r--r-- | mumd/src/state/channel.rs | 2 | ||||
| -rw-r--r-- | mumd/src/state/server.rs | 39 |
7 files changed, 557 insertions, 432 deletions
diff --git a/mumd/src/client.rs b/mumd/src/client.rs index 9c2c2a0..ba9cad4 100644 --- a/mumd/src/client.rs +++ b/mumd/src/client.rs @@ -1,4 +1,4 @@ -use crate::command; +use crate::{command, network::tcp::TcpEventQueue}; use crate::error::ClientError; use crate::network::{tcp, udp, ConnectionInfo}; use crate::state::State; @@ -7,13 +7,13 @@ use futures_util::{select, FutureExt}; use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState}; use mumlib::command::{Command, CommandResponse}; use std::sync::{Arc, RwLock}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch}; pub async fn handle( state: State, command_receiver: mpsc::UnboundedReceiver<( Command, - oneshot::Sender<mumlib::error::Result<Option<CommandResponse>>>, + mpsc::UnboundedSender<mumlib::error::Result<Option<CommandResponse>>>, )>, ) -> Result<(), ClientError> { let (connection_info_sender, connection_info_receiver) = @@ -24,8 +24,7 @@ pub async fn handle( mpsc::unbounded_channel::<ControlPacket<Serverbound>>(); let (ping_request_sender, ping_request_receiver) = mpsc::unbounded_channel(); - let (response_sender, response_receiver) = - mpsc::unbounded_channel(); + let event_queue = TcpEventQueue::new(); let state = Arc::new(RwLock::new(state)); @@ -36,7 +35,7 @@ pub async fn handle( crypt_state_sender, packet_sender.clone(), packet_receiver, - response_receiver, + event_queue.clone(), ).fuse() => r.map_err(|e| ClientError::TcpError(e)), _ = udp::handle( Arc::clone(&state), @@ -46,7 +45,7 @@ pub async fn handle( _ = command::handle( state, command_receiver, - response_sender, + event_queue, ping_request_sender, packet_sender, connection_info_sender, diff --git a/mumd/src/command.rs b/mumd/src/command.rs index 1337dce..5255afa 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -1,51 +1,53 @@ -use crate::network::{ - ConnectionInfo, - tcp::{TcpEvent, TcpEventCallback}, - udp::PingRequest -}; +use crate::network::{ConnectionInfo, tcp::TcpEventQueue, udp::PingRequest}; use crate::state::{ExecutionContext, State}; use log::*; use mumble_protocol::{Serverbound, control::ControlPacket}; use mumlib::command::{Command, CommandResponse}; use std::sync::{atomic::{AtomicU64, Ordering}, Arc, RwLock}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch}; pub async fn handle( state: Arc<RwLock<State>>, mut command_receiver: mpsc::UnboundedReceiver<( Command, - oneshot::Sender<mumlib::error::Result<Option<CommandResponse>>>, + mpsc::UnboundedSender<mumlib::error::Result<Option<CommandResponse>>>, )>, - tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>, + tcp_event_queue: TcpEventQueue, ping_request_sender: mpsc::UnboundedSender<PingRequest>, mut packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, mut connection_info_sender: watch::Sender<Option<ConnectionInfo>>, ) { debug!("Begin listening for commands"); let ping_count = AtomicU64::new(0); - while let Some((command, response_sender)) = command_receiver.recv().await { + while let Some((command, mut response_sender)) = command_receiver.recv().await { debug!("Received command {:?}", command); - let mut state = state.write().unwrap(); - let event = state.handle_command(command, &mut packet_sender, &mut connection_info_sender); - drop(state); + let event = crate::state::handle_command(Arc::clone(&state), command, &mut packet_sender, &mut connection_info_sender); match event { - ExecutionContext::TcpEvent(event, generator) => { - let (tx, rx) = oneshot::channel(); - //TODO handle this error - let _ = tcp_event_register_sender.send(( - event, + ExecutionContext::TcpEventCallback(event, generator) => { + tcp_event_queue.register_callback( + event, Box::new(move |e| { let response = generator(e); - response_sender.send(response).unwrap(); - tx.send(()).unwrap(); + for response in response { + response_sender.send(response).unwrap(); + } }), - )); - - rx.await.unwrap(); + ); + } + ExecutionContext::TcpEventSubscriber(event, mut handler) => { + tcp_event_queue.register_subscriber( + event, + Box::new(move |event| { + handler(event, &mut response_sender) + }), + ) } ExecutionContext::Now(generator) => { - response_sender.send(generator()).unwrap(); + for response in generator() { + response_sender.send(response).unwrap(); + } + drop(response_sender); } ExecutionContext::Ping(generator, converter) => { let ret = generator(); diff --git a/mumd/src/main.rs b/mumd/src/main.rs index f298070..0c175c2 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -8,13 +8,14 @@ mod state; use crate::state::State; +use bytes::{BufMut, BytesMut}; use futures_util::{select, FutureExt, SinkExt, StreamExt}; use log::*; use mumlib::command::{Command, CommandResponse}; use mumlib::setup_logger; -use tokio::{net::{UnixListener, UnixStream}, sync::{mpsc, oneshot}}; +use std::io::ErrorKind; +use tokio::{net::{UnixListener, UnixStream}, sync::mpsc}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; -use bytes::{BufMut, BytesMut}; #[tokio::main] async fn main() { @@ -81,7 +82,7 @@ async fn main() { async fn receive_commands( command_sender: mpsc::UnboundedSender<( Command, - oneshot::Sender<mumlib::error::Result<Option<CommandResponse>>>, + mpsc::UnboundedSender<mumlib::error::Result<Option<CommandResponse>>>, )>, ) { let socket = UnixListener::bind(mumlib::SOCKET_PATH).unwrap(); @@ -105,21 +106,22 @@ async fn receive_commands( Err(_) => continue, }; - let (tx, rx) = oneshot::channel(); + let (tx, mut rx) = mpsc::unbounded_channel(); sender.send((command, tx)).unwrap(); - let response = match rx.await { - Ok(r) => r, - Err(_) => { - error!("Internal command response sender dropped"); - Ok(None) - } - }; - let mut serialized = BytesMut::new(); - bincode::serialize_into((&mut serialized).writer(), &response).unwrap(); + while let Some(response) = rx.recv().await { + let mut serialized = BytesMut::new(); + bincode::serialize_into((&mut serialized).writer(), &response).unwrap(); - let _ = writer.send(serialized.freeze()).await; + if let Err(e) = writer.send(serialized.freeze()).await { + if e.kind() != ErrorKind::BrokenPipe { //if the client closed the connection, ignore logging the error + //we just assume that they just don't want any more packets + error!("Error sending response: {:?}", e); + } + break; + } + } } }); } diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 7606987..b513797 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,4 +1,4 @@ -use crate::error::{ServerSendError, TcpError}; +use crate::{error::{ServerSendError, TcpError}, notifications}; use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; @@ -30,17 +30,20 @@ type TcpReceiver = SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>; pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData)>; +pub(crate) type TcpEventSubscriber = Box<dyn FnMut(TcpEventData) -> bool>; //the bool indicates if it should be kept or not #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum TcpEvent { Connected, //fires when the client has connected to a server Disconnected, //fires when the client has disconnected from a server + TextMessage, //fires when a text message comes in } #[derive(Clone)] pub enum TcpEventData<'a> { Connected(Result<&'a msgs::ServerSync, mumlib::Error>), Disconnected, + TextMessage(&'a msgs::TextMessage), } impl<'a> From<&TcpEventData<'a>> for TcpEvent { @@ -48,33 +51,53 @@ impl<'a> From<&TcpEventData<'a>> for TcpEvent { match t { TcpEventData::Connected(_) => TcpEvent::Connected, TcpEventData::Disconnected => TcpEvent::Disconnected, + TcpEventData::TextMessage(_) => TcpEvent::TextMessage, } } } #[derive(Clone)] -struct TcpEventQueue { - handlers: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, +pub struct TcpEventQueue { + callbacks: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + subscribers: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventSubscriber>>>>, } impl TcpEventQueue { - fn new() -> Self { + /// Creates a new `TcpEventQueue`. + pub fn new() -> Self { Self { - handlers: Arc::new(Mutex::new(HashMap::new())), + callbacks: Arc::new(RwLock::new(HashMap::new())), + subscribers: Arc::new(RwLock::new(HashMap::new())), } } - async fn register(&self, at: TcpEvent, callback: TcpEventCallback) { - self.handlers.lock().await.entry(at).or_default().push(callback); + /// Registers a new callback to be triggered when an event is fired. + pub fn register_callback(&self, at: TcpEvent, callback: TcpEventCallback) { + self.callbacks.write().unwrap().entry(at).or_default().push(callback); } - async fn resolve<'a>(&self, data: TcpEventData<'a>) { - if let Some(vec) = self.handlers.lock().await.get_mut(&TcpEvent::from(&data)) { + /// Registers a new callback to be triggered when an event is fired. + pub fn register_subscriber(&self, at: TcpEvent, callback: TcpEventSubscriber) { + self.subscribers.write().unwrap().entry(at).or_default().push(callback); + } + + /// Fires all callbacks related to a specific TCP event and removes them from the event queue. + /// Also calls all event subscribers, but keeps them in the queue + pub fn resolve<'a>(&self, data: TcpEventData<'a>) { + if let Some(vec) = self.callbacks.write().unwrap().get_mut(&TcpEvent::from(&data)) { let old = std::mem::take(vec); for handler in old { handler(data.clone()); } } + if let Some(vec) = self.subscribers.write().unwrap().get_mut(&TcpEvent::from(&data)) { + let old = std::mem::take(vec); + for mut e in old { + if e(data.clone()) { + vec.push(e) + } + } + } } } @@ -84,7 +107,7 @@ pub async fn handle( crypt_state_sender: mpsc::Sender<ClientCryptState>, packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, - mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, + event_queue: TcpEventQueue, ) -> Result<(), TcpError> { loop { let connection_info = 'data: loop { @@ -114,7 +137,6 @@ pub async fn handle( (state_lock.phase_receiver(), state_lock.audio_input().receiver()) }; - let event_queue = TcpEventQueue::new(); info!("Logging in..."); @@ -137,13 +159,12 @@ pub async fn handle( phase_watcher_inner, ).fuse() => r, r = send_packets(sink, &mut packet_receiver).fuse() => r, - _ = register_events(&mut tcp_event_register_receiver, event_queue.clone()).fuse() => Ok(()), } }, phase_watcher, ).await.unwrap_or(Ok(()))?; - event_queue.resolve(TcpEventData::Disconnected).await; + event_queue.resolve(TcpEventData::Disconnected); debug!("Fully disconnected TCP stream, waiting for new connection info"); } @@ -270,11 +291,16 @@ async fn listen( }; match packet { ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); + let mut state = state.write().unwrap(); + let user = state.server() + .and_then(|server| server.users().get(&msg.get_actor())) + .map(|user| user.name()); + if let Some(user) = user { + notifications::send(format!("{}: {}", user, msg.get_message())); //TODO: probably want a config flag for this + } + state.register_message((msg.get_message().to_owned(), msg.get_actor())); + drop(state); + event_queue.resolve(TcpEventData::TextMessage(&*msg)); } ControlPacket::CryptSetup(msg) => { debug!("Crypt setup"); @@ -302,7 +328,7 @@ async fn listen( ) .await; } - event_queue.resolve(TcpEventData::Connected(Ok(&msg))).await; + event_queue.resolve(TcpEventData::Connected(Ok(&msg))); let mut state = state.write().unwrap(); let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); @@ -319,7 +345,7 @@ async fn listen( debug!("Login rejected: {:?}", msg); match msg.get_field_type() { msgs::Reject_RejectType::WrongServerPW => { - event_queue.resolve(TcpEventData::Connected(Err(mumlib::Error::InvalidServerPassword))).await; + event_queue.resolve(TcpEventData::Connected(Err(mumlib::Error::InvalidServerPassword))); } ty => { warn!("Unhandled reject type: {:?}", ty); @@ -377,13 +403,3 @@ async fn listen( } Ok(()) } - -async fn register_events( - tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, - event_queue: TcpEventQueue, -) { - loop { - let (event, handler) = tcp_event_register_receiver.recv().await.unwrap(); - event_queue.register(event, handler).await; - } -} diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 46df421..a553e18 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -14,33 +14,38 @@ use mumble_protocol::control::msgs; use mumble_protocol::control::ControlPacket; use mumble_protocol::ping::PongPacket; use mumble_protocol::voice::Serverbound; -use mumlib::command::{Command, CommandResponse}; +use mumlib::command::{Command, CommandResponse, MessageTarget}; use mumlib::config::Config; -use mumlib::error::ChannelIdentifierError; use mumlib::Error; use crate::state::user::UserDiff; -use std::net::{SocketAddr, ToSocketAddrs}; +use std::{iter, net::{SocketAddr, ToSocketAddrs}, sync::{Arc, RwLock}}; use tokio::sync::{mpsc, watch}; macro_rules! at { ($event:expr, $generator:expr) => { - ExecutionContext::TcpEvent($event, Box::new($generator)) + ExecutionContext::TcpEventCallback($event, Box::new($generator)) }; } macro_rules! now { ($data:expr) => { - ExecutionContext::Now(Box::new(move || $data)) + ExecutionContext::Now(Box::new(move || Box::new(iter::once($data)))) }; } +type Responses = Box<dyn Iterator<Item = mumlib::error::Result<Option<CommandResponse>>>>; + //TODO give me a better name pub enum ExecutionContext { - TcpEvent( + TcpEventCallback( + TcpEvent, + Box<dyn FnOnce(TcpEventData) -> Responses>, + ), + TcpEventSubscriber( TcpEvent, - Box<dyn FnOnce(TcpEventData) -> mumlib::error::Result<Option<CommandResponse>>>, + Box<dyn FnMut(TcpEventData, &mut mpsc::UnboundedSender<mumlib::error::Result<Option<CommandResponse>>>) -> bool>, ), - Now(Box<dyn FnOnce() -> mumlib::error::Result<Option<CommandResponse>>>), + Now(Box<dyn FnOnce() -> Responses>), Ping( Box<dyn FnOnce() -> mumlib::error::Result<SocketAddr>>, Box<dyn FnOnce(Option<PongPacket>) -> mumlib::error::Result<Option<CommandResponse>> + Send>, @@ -59,6 +64,7 @@ pub struct State { server: Option<Server>, audio_input: AudioInput, audio_output: AudioOutput, + message_buffer: Vec<(String, u32)>, phase_watcher: (watch::Sender<StatePhase>, watch::Receiver<StatePhase>), } @@ -79,355 +85,13 @@ impl State { server: None, audio_input, audio_output, + message_buffer: Vec::new(), phase_watcher, }; state.reload_config(); Ok(state) } - pub fn handle_command( - &mut self, - command: Command, - packet_sender: &mut mpsc::UnboundedSender<ControlPacket<Serverbound>>, - connection_info_sender: &mut watch::Sender<Option<ConnectionInfo>>, - ) -> ExecutionContext { - match command { - Command::ChannelJoin { channel_identifier } => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - - let channels = self.server().unwrap().channels(); - - let matches = channels - .iter() - .map(|e| (e.0, e.1.path(channels))) - .filter(|e| e.1.ends_with(&channel_identifier)) - .collect::<Vec<_>>(); - let id = match matches.len() { - 0 => { - let soft_matches = channels - .iter() - .map(|e| (e.0, e.1.path(channels).to_lowercase())) - .filter(|e| e.1.ends_with(&channel_identifier.to_lowercase())) - .collect::<Vec<_>>(); - match soft_matches.len() { - 0 => { - return now!(Err(Error::ChannelIdentifierError( - channel_identifier, - ChannelIdentifierError::Invalid - ))) - } - 1 => *soft_matches.get(0).unwrap().0, - _ => { - return now!(Err(Error::ChannelIdentifierError( - channel_identifier, - ChannelIdentifierError::Invalid - ))) - } - } - } - 1 => *matches.get(0).unwrap().0, - _ => { - return now!(Err(Error::ChannelIdentifierError( - channel_identifier, - ChannelIdentifierError::Ambiguous - ))) - } - }; - - let mut msg = msgs::UserState::new(); - msg.set_session(self.server.as_ref().unwrap().session_id().unwrap()); - msg.set_channel_id(id); - packet_sender.send(msg.into()).unwrap(); - now!(Ok(None)) - } - Command::ChannelList => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - let list = channel::into_channel( - self.server.as_ref().unwrap().channels(), - self.server.as_ref().unwrap().users(), - ); - now!(Ok(Some(CommandResponse::ChannelList { channels: list }))) - } - Command::ConfigReload => { - self.reload_config(); - now!(Ok(None)) - } - Command::DeafenSelf(toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - - let server = self.server().unwrap(); - let action = match (toggle, server.muted(), server.deafened()) { - (Some(false), false, false) => None, - (Some(false), false, true) => Some((false, false)), - (Some(false), true, false) => None, - (Some(false), true, true) => Some((true, false)), - (Some(true), false, false) => Some((false, true)), - (Some(true), false, true) => None, - (Some(true), true, false) => Some((true, true)), - (Some(true), true, true) => None, - (None, false, false) => Some((false, true)), - (None, false, true) => Some((false, false)), - (None, true, false) => Some((true, true)), - (None, true, true) => Some((true, false)), - }; - - let mut new_deaf = None; - if let Some((mute, deafen)) = action { - if server.deafened() != deafen { - self.audio_output.play_effect(if deafen { - NotificationEvents::Deafen - } else { - NotificationEvents::Undeafen - }); - } else if server.muted() != mute { - self.audio_output.play_effect(if mute { - NotificationEvents::Mute - } else { - NotificationEvents::Unmute - }); - } - let mut msg = msgs::UserState::new(); - if server.muted() != mute { - msg.set_self_mute(mute); - } else if !mute && !deafen && server.deafened() { - msg.set_self_mute(false); - } - if server.deafened() != deafen { - msg.set_self_deaf(deafen); - new_deaf = Some(deafen); - } - let server = self.server_mut().unwrap(); - server.set_muted(mute); - server.set_deafened(deafen); - packet_sender.send(msg.into()).unwrap(); - } - - now!(Ok(new_deaf.map(|b| CommandResponse::DeafenStatus { is_deafened: b }))) - } - Command::InputVolumeSet(volume) => { - self.audio_input.set_volume(volume); - now!(Ok(None)) - } - Command::MuteOther(string, toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - - let id = self - .server_mut() - .unwrap() - .users_mut() - .iter_mut() - .find(|(_, user)| user.name() == string); - - let (id, user) = match id { - Some(id) => (*id.0, id.1), - None => return now!(Err(Error::InvalidUsername(string))), - }; - - let action = match toggle { - Some(state) => { - if user.suppressed() != state { - Some(state) - } else { - None - } - } - None => Some(!user.suppressed()), - }; - - if let Some(action) = action { - user.set_suppressed(action); - self.audio_output.set_mute(id, action); - } - - return now!(Ok(None)); - } - Command::MuteSelf(toggle) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - - let server = self.server().unwrap(); - let action = match (toggle, server.muted(), server.deafened()) { - (Some(false), false, false) => None, - (Some(false), false, true) => Some((false, false)), - (Some(false), true, false) => Some((false, false)), - (Some(false), true, true) => Some((false, false)), - (Some(true), false, false) => Some((true, false)), - (Some(true), false, true) => None, - (Some(true), true, false) => None, - (Some(true), true, true) => None, - (None, false, false) => Some((true, false)), - (None, false, true) => Some((false, false)), - (None, true, false) => Some((false, false)), - (None, true, true) => Some((false, false)), - }; - - let mut new_mute = None; - if let Some((mute, deafen)) = action { - if server.deafened() != deafen { - self.audio_output.play_effect(if deafen { - NotificationEvents::Deafen - } else { - NotificationEvents::Undeafen - }); - } else if server.muted() != mute { - self.audio_output.play_effect(if mute { - NotificationEvents::Mute - } else { - NotificationEvents::Unmute - }); - } - let mut msg = msgs::UserState::new(); - if server.muted() != mute { - msg.set_self_mute(mute); - new_mute = Some(mute) - } else if !mute && !deafen && server.deafened() { - msg.set_self_mute(false); - new_mute = Some(false) - } - if server.deafened() != deafen { - msg.set_self_deaf(deafen); - } - let server = self.server_mut().unwrap(); - server.set_muted(mute); - server.set_deafened(deafen); - packet_sender.send(msg.into()).unwrap(); - } - - now!(Ok(new_mute.map(|b| CommandResponse::MuteStatus { is_muted: b }))) - } - Command::OutputVolumeSet(volume) => { - self.audio_output.set_volume(volume); - now!(Ok(None)) - } - Command::Ping => { - now!(Ok(Some(CommandResponse::Pong))) - } - Command::ServerConnect { - host, - port, - username, - password, - accept_invalid_cert, - } => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { - return now!(Err(Error::AlreadyConnected)); - } - let mut server = Server::new(); - *server.username_mut() = Some(username); - *server.password_mut() = password; - *server.host_mut() = Some(format!("{}:{}", host, port)); - self.server = Some(server); - self.phase_watcher - .0 - .send(StatePhase::Connecting) - .unwrap(); - - let socket_addr = match (host.as_ref(), port) - .to_socket_addrs() - .map(|mut e| e.next()) - { - Ok(Some(v)) => v, - _ => { - warn!("Error parsing server addr"); - return now!(Err(Error::InvalidServerAddr(host, port))); - } - }; - connection_info_sender - .send(Some(ConnectionInfo::new( - socket_addr, - host, - accept_invalid_cert, - ))) - .unwrap(); - at!(TcpEvent::Connected, |res| { - //runs the closure when the client is connected - if let TcpEventData::Connected(res) = res { - res.map(|msg| { - Some(CommandResponse::ServerConnect { - welcome_message: if msg.has_welcome_text() { - Some(msg.get_welcome_text().to_string()) - } else { - None - }, - }) - }) - } else { - unreachable!("callback should be provided with a TcpEventData::Connected"); - } - }) - } - Command::ServerDisconnect => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - - self.server = None; - - self.phase_watcher - .0 - .send(StatePhase::Disconnected) - .unwrap(); - self.audio_output.play_effect(NotificationEvents::ServerDisconnect); - now!(Ok(None)) - } - Command::ServerStatus { host, port } => ExecutionContext::Ping( - Box::new(move || { - match (host.as_str(), port) - .to_socket_addrs() - .map(|mut e| e.next()) - { - Ok(Some(v)) => Ok(v), - _ => Err(Error::InvalidServerAddr(host, port)), - } - }), - Box::new(move |pong| { - Ok(pong.map(|pong| CommandResponse::ServerStatus { - version: pong.version, - users: pong.users, - max_users: pong.max_users, - bandwidth: pong.bandwidth, - })) - }), - ), - Command::Status => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - let state = self.server.as_ref().unwrap().into(); - now!(Ok(Some(CommandResponse::Status { - server_state: state, //guaranteed not to panic because if we are connected, server is guaranteed to be Some - }))) - } - Command::UserVolumeSet(string, volume) => { - if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) { - return now!(Err(Error::Disconnected)); - } - let user_id = match self - .server() - .unwrap() - .users() - .iter() - .find(|e| e.1.name() == string) - .map(|e| *e.0) - { - None => return now!(Err(Error::InvalidUsername(string))), - Some(v) => v, - }; - - self.audio_output.set_user_volume(user_id, volume); - now!(Ok(None)) - } - } - } pub fn parse_user_state(&mut self, msg: msgs::UserState) { if !msg.has_session() { @@ -590,6 +254,10 @@ impl State { self.audio_output.load_sound_effects(sound_effects); } } + + pub fn register_message(&mut self, msg: (String, u32)) { + self.message_buffer.push(msg); + } pub fn broadcast_phase(&self, phase: StatePhase) { self.phase_watcher @@ -609,12 +277,6 @@ impl State { pub fn audio_output(&self) -> &AudioOutput { &self.audio_output } - pub fn audio_input_mut(&mut self) -> &mut AudioInput { - &mut self.audio_input - } - pub fn audio_output_mut(&mut self) -> &mut AudioOutput { - &mut self.audio_output - } pub fn phase_receiver(&self) -> watch::Receiver<StatePhase> { self.phase_watcher.1.clone() } @@ -640,4 +302,409 @@ impl State { .1 .channel() } + + /// Gets the username of a user with id `user` connected to the same server that we are connected to. + /// If we are connected to the server but the user with the id doesn't exist, the string "Unknown user {id}" + /// is returned instead. If we aren't connected to a server, None is returned instead. + fn get_user_name(&self, user: u32) -> Option<String> { + self.server() + .map(|e| e.users() + .get(&user).map(|e| e.name().to_string()) + .unwrap_or(format!("Unknown user {}", user))) + } +} + +pub fn handle_command( + og_state: Arc<RwLock<State>>, + command: Command, + packet_sender: &mut mpsc::UnboundedSender<ControlPacket<Serverbound>>, + connection_info_sender: &mut watch::Sender<Option<ConnectionInfo>>, +) -> ExecutionContext { + let mut state = og_state.write().unwrap(); + match command { + Command::ChannelJoin { channel_identifier } => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + + let id = match state.server().unwrap().channel_name(&channel_identifier) { + Ok((id, _)) => id, + Err(e) => return now!(Err(Error::ChannelIdentifierError(channel_identifier, e))), + }; + + let mut msg = msgs::UserState::new(); + msg.set_session(state.server.as_ref().unwrap().session_id().unwrap()); + msg.set_channel_id(id); + packet_sender.send(msg.into()).unwrap(); + now!(Ok(None)) + } + Command::ChannelList => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + let list = channel::into_channel( + state.server.as_ref().unwrap().channels(), + state.server.as_ref().unwrap().users(), + ); + now!(Ok(Some(CommandResponse::ChannelList { channels: list }))) + } + Command::ConfigReload => { + state.reload_config(); + now!(Ok(None)) + } + Command::DeafenSelf(toggle) => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + + let server = state.server().unwrap(); + let action = match (toggle, server.muted(), server.deafened()) { + (Some(false), false, false) => None, + (Some(false), false, true) => Some((false, false)), + (Some(false), true, false) => None, + (Some(false), true, true) => Some((true, false)), + (Some(true), false, false) => Some((false, true)), + (Some(true), false, true) => None, + (Some(true), true, false) => Some((true, true)), + (Some(true), true, true) => None, + (None, false, false) => Some((false, true)), + (None, false, true) => Some((false, false)), + (None, true, false) => Some((true, true)), + (None, true, true) => Some((true, false)), + }; + + let mut new_deaf = None; + if let Some((mute, deafen)) = action { + if server.deafened() != deafen { + state.audio_output.play_effect(if deafen { + NotificationEvents::Deafen + } else { + NotificationEvents::Undeafen + }); + } else if server.muted() != mute { + state.audio_output.play_effect(if mute { + NotificationEvents::Mute + } else { + NotificationEvents::Unmute + }); + } + let mut msg = msgs::UserState::new(); + if server.muted() != mute { + msg.set_self_mute(mute); + } else if !mute && !deafen && server.deafened() { + msg.set_self_mute(false); + } + if server.deafened() != deafen { + msg.set_self_deaf(deafen); + new_deaf = Some(deafen); + } + let server = state.server_mut().unwrap(); + server.set_muted(mute); + server.set_deafened(deafen); + packet_sender.send(msg.into()).unwrap(); + } + + now!(Ok(new_deaf.map(|b| CommandResponse::DeafenStatus { is_deafened: b }))) + } + Command::InputVolumeSet(volume) => { + state.audio_input.set_volume(volume); + now!(Ok(None)) + } + Command::MuteOther(string, toggle) => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + + let id = state + .server_mut() + .unwrap() + .users_mut() + .iter_mut() + .find(|(_, user)| user.name() == string); + + let (id, user) = match id { + Some(id) => (*id.0, id.1), + None => return now!(Err(Error::InvalidUsername(string))), + }; + + let action = match toggle { + Some(state) => { + if user.suppressed() != state { + Some(state) + } else { + None + } + } + None => Some(!user.suppressed()), + }; + + if let Some(action) = action { + user.set_suppressed(action); + state.audio_output.set_mute(id, action); + } + + return now!(Ok(None)); + } + Command::MuteSelf(toggle) => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + + let server = state.server().unwrap(); + let action = match (toggle, server.muted(), server.deafened()) { + (Some(false), false, false) => None, + (Some(false), false, true) => Some((false, false)), + (Some(false), true, false) => Some((false, false)), + (Some(false), true, true) => Some((false, false)), + (Some(true), false, false) => Some((true, false)), + (Some(true), false, true) => None, + (Some(true), true, false) => None, + (Some(true), true, true) => None, + (None, false, false) => Some((true, false)), + (None, false, true) => Some((false, false)), + (None, true, false) => Some((false, false)), + (None, true, true) => Some((false, false)), + }; + + let mut new_mute = None; + if let Some((mute, deafen)) = action { + if server.deafened() != deafen { + state.audio_output.play_effect(if deafen { + NotificationEvents::Deafen + } else { + NotificationEvents::Undeafen + }); + } else if server.muted() != mute { + state.audio_output.play_effect(if mute { + NotificationEvents::Mute + } else { + NotificationEvents::Unmute + }); + } + let mut msg = msgs::UserState::new(); + if server.muted() != mute { + msg.set_self_mute(mute); + new_mute = Some(mute) + } else if !mute && !deafen && server.deafened() { + msg.set_self_mute(false); + new_mute = Some(false) + } + if server.deafened() != deafen { + msg.set_self_deaf(deafen); + } + let server = state.server_mut().unwrap(); + server.set_muted(mute); + server.set_deafened(deafen); + packet_sender.send(msg.into()).unwrap(); + } + + now!(Ok(new_mute.map(|b| CommandResponse::MuteStatus { is_muted: b }))) + } + Command::OutputVolumeSet(volume) => { + state.audio_output.set_volume(volume); + now!(Ok(None)) + } + Command::Ping => { + now!(Ok(Some(CommandResponse::Pong))) + } + Command::ServerConnect { + host, + port, + username, + password, + accept_invalid_cert, + } => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Disconnected) { + return now!(Err(Error::AlreadyConnected)); + } + let mut server = Server::new(); + *server.username_mut() = Some(username); + *server.password_mut() = password; + *server.host_mut() = Some(format!("{}:{}", host, port)); + state.server = Some(server); + state.phase_watcher + .0 + .send(StatePhase::Connecting) + .unwrap(); + + let socket_addr = match (host.as_ref(), port) + .to_socket_addrs() + .map(|mut e| e.next()) + { + Ok(Some(v)) => v, + _ => { + warn!("Error parsing server addr"); + return now!(Err(Error::InvalidServerAddr(host, port))); + } + }; + connection_info_sender + .send(Some(ConnectionInfo::new( + socket_addr, + host, + accept_invalid_cert, + ))) + .unwrap(); + at!(TcpEvent::Connected, |res| { + //runs the closure when the client is connected + if let TcpEventData::Connected(res) = res { + Box::new(iter::once(res.map(|msg| { + Some(CommandResponse::ServerConnect { + welcome_message: if msg.has_welcome_text() { + Some(msg.get_welcome_text().to_string()) + } else { + None + }, + }) + }))) + } else { + unreachable!("callback should be provided with a TcpEventData::Connected"); + } + }) + } + Command::ServerDisconnect => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + + state.server = None; + + state.phase_watcher + .0 + .send(StatePhase::Disconnected) + .unwrap(); + state.audio_output.play_effect(NotificationEvents::ServerDisconnect); + now!(Ok(None)) + } + Command::ServerStatus { host, port } => ExecutionContext::Ping( + Box::new(move || { + match (host.as_str(), port) + .to_socket_addrs() + .map(|mut e| e.next()) + { + Ok(Some(v)) => Ok(v), + _ => Err(Error::InvalidServerAddr(host, port)), + } + }), + Box::new(move |pong| { + Ok(pong.map(|pong| (CommandResponse::ServerStatus { + version: pong.version, + users: pong.users, + max_users: pong.max_users, + bandwidth: pong.bandwidth, + }))) + }), + ), + Command::Status => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + let state = state.server.as_ref().unwrap().into(); + now!(Ok(Some(CommandResponse::Status { + server_state: state, //guaranteed not to panic because if we are connected, server is guaranteed to be Some + }))) + } + Command::UserVolumeSet(string, volume) => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + let user_id = match state + .server() + .unwrap() + .users() + .iter() + .find(|e| e.1.name() == string) + .map(|e| *e.0) + { + None => return now!(Err(Error::InvalidUsername(string))), + Some(v) => v, + }; + + state.audio_output.set_user_volume(user_id, volume); + now!(Ok(None)) + } + Command::PastMessages { block } => { + //does it make sense to wait for messages while not connected? + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + if block { + let ref_state = Arc::clone(&og_state); + ExecutionContext::TcpEventSubscriber( + TcpEvent::TextMessage, + Box::new(move |data, sender| { + if let TcpEventData::TextMessage(a) = data { + let message = ( + a.get_message().to_owned(), + ref_state.read().unwrap().get_user_name(a.get_actor()).unwrap() + ); + sender.send(Ok(Some(CommandResponse::PastMessage { message }))).is_ok() + } else { + unreachable!("Should only receive a TextMessage data when listening to TextMessage events"); + } + }), + ) + } else { + let messages = std::mem::take(&mut state.message_buffer); + let messages: Vec<_> = messages.into_iter() + .map(|(msg, user)| (msg, state.get_user_name(user).unwrap())) + .map(|e| Ok(Some(CommandResponse::PastMessage { message: e }))) + .collect(); + + ExecutionContext::Now(Box::new(move || { + Box::new(messages.into_iter()) + })) + } + } + Command::SendMessage { message, targets } => { + if !matches!(*state.phase_receiver().borrow(), StatePhase::Connected(_)) { + return now!(Err(Error::Disconnected)); + } + + let mut msg = msgs::TextMessage::new(); + + msg.set_message(message); + + for target in targets { + match target { + MessageTarget::Channel { recursive, name } => { + let channel_id = state + .server() + .unwrap() + .channel_name(&name); + + let channel_id = match channel_id { + Ok(id) => id, + Err(e) => return now!(Err(Error::ChannelIdentifierError(name, e))), + }.0; + + if recursive { + msg.mut_tree_id() + } else { + msg.mut_channel_id() + }.push(channel_id); + } + MessageTarget::User { name } => { + let id = state + .server() + .unwrap() + .users() + .iter() + .find(|(_, user)| user.name() == &name) + .map(|(e, _)| *e); + + let id = match id { + Some(id) => id, + None => return now!(Err(Error::InvalidUsername(name))), + }; + + msg.mut_session().push(id); + } + } + } + + packet_sender.send(msg.into()).unwrap(); + + now!(Ok(None)) + } + } } diff --git a/mumd/src/state/channel.rs b/mumd/src/state/channel.rs index 5b6d669..f58ed15 100644 --- a/mumd/src/state/channel.rs +++ b/mumd/src/state/channel.rs @@ -4,7 +4,7 @@ use mumble_protocol::control::msgs; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct Channel { description: Option<String>, links: Vec<u32>, diff --git a/mumd/src/state/server.rs b/mumd/src/state/server.rs index c9f8a69..78a10b9 100644 --- a/mumd/src/state/server.rs +++ b/mumd/src/state/server.rs @@ -3,6 +3,7 @@ use crate::state::user::User; use log::*; use mumble_protocol::control::msgs; +use mumlib::error::ChannelIdentifierError; use serde::{Deserialize, Serialize}; use std::collections::hash_map::Entry; use std::collections::HashMap; @@ -88,6 +89,44 @@ impl Server { &self.channels } + /// Takes a channel name and returns either a tuple with the channel id and a reference to the + /// channel struct if the channel name unambiguosly refers to a channel, or an error describing + /// if the channel identifier was ambigous or invalid. + /// note that doctests currently aren't run in binary crates yet (see #50784) + /// ``` + /// use crate::state::channel::Channel; + /// let mut server = Server::new(); + /// let channel = Channel { + /// name: "Foobar".to_owned(), + /// ..Default::default(), + /// }; + /// server.channels.insert(0, channel.clone); + /// assert_eq!(server.channel_name("Foobar"), Ok((0, &channel))); + /// ``` + pub fn channel_name(&self, channel_name: &str) -> Result<(u32, &Channel), ChannelIdentifierError> { + let matches = self.channels + .iter() + .map(|e| ((*e.0, e.1), e.1.path(&self.channels))) + .filter(|e| e.1.ends_with(channel_name)) + .collect::<Vec<_>>(); + Ok(match matches.len() { + 0 => { + let soft_matches = self.channels + .iter() + .map(|e| ((*e.0, e.1), e.1.path(&self.channels).to_lowercase())) + .filter(|e| e.1.ends_with(&channel_name.to_lowercase())) + .collect::<Vec<_>>(); + match soft_matches.len() { + 0 => return Err(ChannelIdentifierError::Invalid), + 1 => soft_matches.get(0).unwrap().0, + _ => return Err(ChannelIdentifierError::Ambiguous), + } + } + 1 => matches.get(0).unwrap().0, + _ => return Err(ChannelIdentifierError::Ambiguous), + }) + } + pub fn host_mut(&mut self) -> &mut Option<String> { &mut self.host } |
