diff options
Diffstat (limited to 'mumd/src/state.rs')
| -rw-r--r-- | mumd/src/state.rs | 120 |
1 files changed, 71 insertions, 49 deletions
diff --git a/mumd/src/state.rs b/mumd/src/state.rs index d355ef5..f9ed077 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -15,6 +15,7 @@ use mumlib::config::Config; use mumlib::error::{ChannelIdentifierError, Error}; use std::net::ToSocketAddrs; use tokio::sync::{mpsc, watch}; +use crate::network::tcp::{TcpEvent, TcpEventData}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum StatePhase { @@ -56,11 +57,11 @@ impl State { pub async fn handle_command( &mut self, command: Command, - ) -> (bool, mumlib::error::Result<Option<CommandResponse>>) { + ) -> (Option<TcpEvent>, Box<dyn FnOnce(Option<&TcpEventData>) -> mumlib::error::Result<Option<CommandResponse>>>) { match command { Command::ChannelJoin { channel_identifier } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } let channels = self.server() @@ -78,33 +79,34 @@ impl State { .filter(|e| e.1.ends_with(&channel_identifier.to_lowercase())) .collect::<Vec<_>>(); match soft_matches.len() { - 0 => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + 0 => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))), 1 => *soft_matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))), } }, 1 => *matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))), + _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous)))), }; let mut msg = msgs::UserState::new(); msg.set_session(self.server.as_ref().unwrap().session_id().unwrap()); msg.set_channel_id(id); self.packet_sender.send(msg.into()).unwrap(); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::ChannelList => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } + let list = channel::into_channel( + self.server.as_ref().unwrap().channels(), + self.server.as_ref().unwrap().users(), + ); ( - false, - Ok(Some(CommandResponse::ChannelList { - channels: channel::into_channel( - self.server.as_ref().unwrap().channels(), - self.server.as_ref().unwrap().users(), - ), - })), + None, + Box::new(move |_| Ok(Some(CommandResponse::ChannelList { + channels: list, + }))), ) } Command::ServerConnect { @@ -114,7 +116,7 @@ impl State { accept_invalid_cert, } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { - return (false, Err(Error::AlreadyConnectedError)); + return (None, Box::new(|_| Err(Error::AlreadyConnectedError))); } let mut server = Server::new(); *server.username_mut() = Some(username); @@ -132,7 +134,7 @@ impl State { Ok(Some(v)) => v, _ => { warn!("Error parsing server addr"); - return (false, Err(Error::InvalidServerAddrError(host, port))); + return (None, Box::new(move |_| Err(Error::InvalidServerAddrError(host, port)))); } }; self.connection_info_sender @@ -142,22 +144,35 @@ impl State { accept_invalid_cert, ))) .unwrap(); - (true, Ok(None)) + (Some(TcpEvent::Connected), Box::new(|e| { //runs the closure when the client is connected + if let Some(TcpEventData::Connected(msg)) = e { + Ok(Some(CommandResponse::ServerConnect { + welcome_message: if msg.has_welcome_text() { + Some(msg.get_welcome_text().to_string()) + } else { + None + } + })) + } else { + unreachable!("callback should be provided with a TcpEventData::Connected"); + } + })) } Command::Status => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } + let state = self.server.as_ref().unwrap().into(); ( - false, - Ok(Some(CommandResponse::Status { - server_state: self.server.as_ref().unwrap().into(), //guaranteed not to panic because if we are connected, server is guaranteed to be Some - })), + None, + Box::new(move |_| Ok(Some(CommandResponse::Status { + server_state: state, //guaranteed not to panic because if we are connected, server is guaranteed to be Some + }))), ) } Command::ServerDisconnect => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } self.server = None; @@ -167,46 +182,54 @@ impl State { .0 .broadcast(StatePhase::Disconnected) .unwrap(); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::InputVolumeSet(volume) => { self.audio.set_input_volume(volume); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::ConfigReload => { self.reload_config(); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } } } - pub fn parse_initial_user_state(&mut self, msg: msgs::UserState) { + pub fn parse_user_state(&mut self, msg: msgs::UserState) -> Option<mumlib::state::UserDiff> { if !msg.has_session() { warn!("Can't parse user state without session"); - return; + return None; } - if !msg.has_name() { - warn!("Missing name in initial user state"); - } else if msg.get_name() == self.server.as_ref().unwrap().username().unwrap() { - match self.server.as_ref().unwrap().session_id() { - None => { - debug!("Found our session id: {}", msg.get_session()); - *self.server_mut().unwrap().session_id_mut() = Some(msg.get_session()); - } - Some(session) => { - if session != msg.get_session() { - error!( - "Got two different session IDs ({} and {}) for ourselves", - session, - msg.get_session() - ); - } else { - debug!("Got our session ID twice"); - } - } + let sess = msg.get_session(); + // check if this is initial state + if !self.server().unwrap().users().contains_key(&sess) { + if !msg.has_name() { + warn!("Missing name in initial user state"); + } else if msg.get_name() == self.server().unwrap().username().unwrap() { + // this is us + *self.server_mut().unwrap().session_id_mut() = Some(sess); + } else { + // this is someone else + self.audio_mut().add_client(sess); } + self.server_mut().unwrap().users_mut().insert(sess, user::User::new(msg)); + None + } else { + let user = self.server_mut().unwrap().users_mut().get_mut(&sess).unwrap(); + let diff = mumlib::state::UserDiff::from(msg); + user.apply_user_diff(&diff); + Some(diff) } - self.server.as_mut().unwrap().parse_user_state(msg); + } + + pub fn remove_client(&mut self, msg: msgs::UserRemove) { + if !msg.has_session() { + warn!("Tried to remove user state without session"); + return; + } + self.audio().remove_client(msg.get_session()); + self.server_mut().unwrap().users_mut().remove(&msg.get_session()); + info!("User {} disconnected", msg.get_session()); } pub fn reload_config(&mut self) { @@ -252,4 +275,3 @@ impl State { self.server.as_ref().map(|e| e.username()).flatten() } } - |
