From 46a53f38cde86439a2ca8b6d24887f842530f679 Mon Sep 17 00:00:00 2001 From: Eskil Queseth Date: Wed, 21 Oct 2020 01:53:38 +0200 Subject: add tcp event system --- mumd/src/command.rs | 38 +++++++++++++----------- mumd/src/main.rs | 4 ++- mumd/src/network/tcp.rs | 78 +++++++++++++++++++++++++++++++++++++++++++++++-- mumd/src/state.rs | 33 +++++++++++---------- 4 files changed, 117 insertions(+), 36 deletions(-) diff --git a/mumd/src/command.rs b/mumd/src/command.rs index a035a26..5285a9d 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -1,10 +1,11 @@ -use crate::state::{State, StatePhase}; +use crate::state::State; use ipc_channel::ipc::IpcSender; use log::*; use mumlib::command::{Command, CommandResponse}; use std::sync::{Arc, Mutex}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; +use crate::network::tcp::{TcpEvent, TcpEventCallback}; pub async fn handle( state: Arc>, @@ -12,23 +13,26 @@ pub async fn handle( Command, IpcSender>>, )>, + tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>, ) { debug!("Begin listening for commands"); - while let Some(command) = command_receiver.recv().await { - debug!("Received command {:?}", command.0); - let mut state = state.lock().unwrap(); - let (wait_for_connected, command_response) = state.handle_command(command.0).await; - if wait_for_connected { - let mut watcher = state.phase_receiver(); - drop(state); - while !matches!(watcher.recv().await.unwrap(), StatePhase::Connected) {} + while let Some((command, response_sender)) = command_receiver.recv().await { + debug!("Received command {:?}", command); + let mut statee = state.lock().unwrap(); + let (event_data, command_response) = statee.handle_command(command).await; + drop(statee); + if let Some((event, callback)) = event_data { + let (tx, rx) = oneshot::channel(); + tcp_event_register_sender.send((event, Box::new(move |e| { + println!("något hände"); + callback(e); + response_sender.send(command_response).unwrap(); + tx.send(()); + }))); + + rx.await; + } else { + response_sender.send(command_response).unwrap(); } - command.1.send(command_response).unwrap(); } - //TODO err if not connected - //while let Some(command) = command_receiver.recv().await { - // debug!("Parsing command {:?}", command); - //} - - //debug!("Finished handling commands"); } diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 75726f8..e88eede 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -33,6 +33,7 @@ async fn main() { )>(); let (connection_info_sender, connection_info_receiver) = watch::channel::>(None); + let (response_sender, response_receiver) = mpsc::unbounded_channel(); let state = State::new(packet_sender, connection_info_sender); let state = Arc::new(Mutex::new(state)); @@ -43,13 +44,14 @@ async fn main() { connection_info_receiver.clone(), crypt_state_sender, packet_receiver, + response_receiver, ), network::udp::handle( Arc::clone(&state), connection_info_receiver.clone(), crypt_state_receiver, ), - command::handle(state, command_receiver,), + command::handle(state, command_receiver, response_sender), spawn_blocking(move || { // IpcSender is blocking receive_oneshot_commands(command_sender); diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 6471771..ab49417 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -15,6 +15,7 @@ use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use std::collections::HashMap; type TcpSender = SplitSink< Framed, ControlCodec>, @@ -23,11 +24,25 @@ type TcpSender = SplitSink< type TcpReceiver = SplitStream, ControlCodec>>; +pub(crate) type TcpEventCallback = Box; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum TcpEvent { + Connected, + Disconnected, +} + +pub enum TcpEventData<'a> { + Connected(&'a msgs::ServerSync), + Disconnected, +} + pub async fn handle( state: Arc>, mut connection_info_receiver: watch::Receiver>, crypt_state_sender: mpsc::Sender, mut packet_receiver: mpsc::UnboundedReceiver>, + mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, ) { loop { let connection_info = loop { @@ -54,6 +69,7 @@ pub async fn handle( let phase_watcher = state_lock.phase_receiver(); let packet_sender = state_lock.packet_sender(); drop(state_lock); + let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); @@ -63,9 +79,11 @@ pub async fn handle( Arc::clone(&state), stream, crypt_state_sender.clone(), - phase_watcher.clone() + Arc::clone(&event_queue), + phase_watcher.clone(), ), - send_packets(sink, &mut packet_receiver, phase_watcher), + send_packets(sink, &mut packet_receiver, phase_watcher.clone()), + register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), ); debug!("Fully disconnected TCP stream, waiting for new connection info"); @@ -200,6 +218,7 @@ async fn listen( state: Arc>, mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender, + event_data: Arc>>>, mut phase_watcher: watch::Receiver, ) { let mut crypt_state = None; @@ -267,6 +286,12 @@ async fn listen( ) .await; } + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Connected(&msg)); + } + } let mut state = state.lock().unwrap(); let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); @@ -320,6 +345,13 @@ async fn listen( } } + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Disconnected); + } + } + //TODO? clean up stream }; @@ -327,3 +359,45 @@ async fn listen( debug!("Killing TCP listener block"); } + +async fn register_events( + tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, + event_data: Arc>>>, + mut phase_watcher: watch::Receiver, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = tcp_event_register_receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some((event, handler))) => { + event_data.lock().unwrap().entry(event).or_default().push(handler); + } + } + } + }; + + join!(main_block, phase_transition_block); +} \ No newline at end of file diff --git a/mumd/src/state.rs b/mumd/src/state.rs index e9db616..c247b08 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -14,6 +14,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::net::ToSocketAddrs; use tokio::sync::{mpsc, watch}; +use crate::network::tcp::{TcpEventCallback, TcpEvent}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum StatePhase { @@ -55,11 +56,11 @@ impl State { pub async fn handle_command( &mut self, command: Command, - ) -> (bool, mumlib::error::Result>) { + ) -> (Option<(TcpEvent, TcpEventCallback)>, mumlib::error::Result>) { match command { Command::ChannelJoin { channel_identifier } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } let channels = self.server() @@ -77,27 +78,27 @@ impl State { .filter(|e| e.1.ends_with(&channel_identifier.to_lowercase())) .collect::>(); match soft_matches.len() { - 0 => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + 0 => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), 1 => *soft_matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), } }, 1 => *matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))), + _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))), }; let mut msg = msgs::UserState::new(); msg.set_session(self.server.as_ref().unwrap().session_id.unwrap()); msg.set_channel_id(id); self.packet_sender.send(msg.into()).unwrap(); - (false, Ok(None)) + (None, Ok(None)) } Command::ChannelList => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } ( - false, + None, Ok(Some(CommandResponse::ChannelList { channels: into_channel( self.server.as_ref().unwrap().channels(), @@ -113,7 +114,7 @@ impl State { accept_invalid_cert, } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { - return (false, Err(Error::AlreadyConnectedError)); + return (None, Err(Error::AlreadyConnectedError)); } let mut server = Server::new(); server.username = Some(username); @@ -131,7 +132,7 @@ impl State { Ok(Some(v)) => v, _ => { warn!("Error parsing server addr"); - return (false, Err(Error::InvalidServerAddrError(host, port))); + return (None, Err(Error::InvalidServerAddrError(host, port))); } }; self.connection_info_sender @@ -141,14 +142,14 @@ impl State { accept_invalid_cert, ))) .unwrap(); - (true, Ok(None)) + (Some((TcpEvent::Connected, Box::new(|_| {}))), Ok(None)) } Command::Status => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } ( - false, + None, Ok(Some(CommandResponse::Status { server_state: self.server.as_ref().unwrap().into(), //guaranteed not to panic because if we are connected, server is guaranteed to be Some })), @@ -156,7 +157,7 @@ impl State { } Command::ServerDisconnect => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } self.server = None; @@ -166,11 +167,11 @@ impl State { .0 .broadcast(StatePhase::Disconnected) .unwrap(); - (false, Ok(None)) + (None, Ok(None)) } Command::InputVolumeSet(volume) => { self.audio.set_input_volume(volume); - (false, Ok(None)) + (None, Ok(None)) } Command::ConfigReload => { self.reload_config(); -- cgit v1.2.1