From 46a53f38cde86439a2ca8b6d24887f842530f679 Mon Sep 17 00:00:00 2001 From: Eskil Queseth Date: Wed, 21 Oct 2020 01:53:38 +0200 Subject: add tcp event system --- mumd/src/command.rs | 38 +++++++++++++----------- mumd/src/main.rs | 4 ++- mumd/src/network/tcp.rs | 78 +++++++++++++++++++++++++++++++++++++++++++++++-- mumd/src/state.rs | 33 +++++++++++---------- 4 files changed, 117 insertions(+), 36 deletions(-) (limited to 'mumd/src') diff --git a/mumd/src/command.rs b/mumd/src/command.rs index a035a26..5285a9d 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -1,10 +1,11 @@ -use crate::state::{State, StatePhase}; +use crate::state::State; use ipc_channel::ipc::IpcSender; use log::*; use mumlib::command::{Command, CommandResponse}; use std::sync::{Arc, Mutex}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; +use crate::network::tcp::{TcpEvent, TcpEventCallback}; pub async fn handle( state: Arc>, @@ -12,23 +13,26 @@ pub async fn handle( Command, IpcSender>>, )>, + tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>, ) { debug!("Begin listening for commands"); - while let Some(command) = command_receiver.recv().await { - debug!("Received command {:?}", command.0); - let mut state = state.lock().unwrap(); - let (wait_for_connected, command_response) = state.handle_command(command.0).await; - if wait_for_connected { - let mut watcher = state.phase_receiver(); - drop(state); - while !matches!(watcher.recv().await.unwrap(), StatePhase::Connected) {} + while let Some((command, response_sender)) = command_receiver.recv().await { + debug!("Received command {:?}", command); + let mut statee = state.lock().unwrap(); + let (event_data, command_response) = statee.handle_command(command).await; + drop(statee); + if let Some((event, callback)) = event_data { + let (tx, rx) = oneshot::channel(); + tcp_event_register_sender.send((event, Box::new(move |e| { + println!("något hände"); + callback(e); + response_sender.send(command_response).unwrap(); + tx.send(()); + }))); + + rx.await; + } else { + response_sender.send(command_response).unwrap(); } - command.1.send(command_response).unwrap(); } - //TODO err if not connected - //while let Some(command) = command_receiver.recv().await { - // debug!("Parsing command {:?}", command); - //} - - //debug!("Finished handling commands"); } diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 75726f8..e88eede 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -33,6 +33,7 @@ async fn main() { )>(); let (connection_info_sender, connection_info_receiver) = watch::channel::>(None); + let (response_sender, response_receiver) = mpsc::unbounded_channel(); let state = State::new(packet_sender, connection_info_sender); let state = Arc::new(Mutex::new(state)); @@ -43,13 +44,14 @@ async fn main() { connection_info_receiver.clone(), crypt_state_sender, packet_receiver, + response_receiver, ), network::udp::handle( Arc::clone(&state), connection_info_receiver.clone(), crypt_state_receiver, ), - command::handle(state, command_receiver,), + command::handle(state, command_receiver, response_sender), spawn_blocking(move || { // IpcSender is blocking receive_oneshot_commands(command_sender); diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 6471771..ab49417 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -15,6 +15,7 @@ use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use std::collections::HashMap; type TcpSender = SplitSink< Framed, ControlCodec>, @@ -23,11 +24,25 @@ type TcpSender = SplitSink< type TcpReceiver = SplitStream, ControlCodec>>; +pub(crate) type TcpEventCallback = Box; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum TcpEvent { + Connected, + Disconnected, +} + +pub enum TcpEventData<'a> { + Connected(&'a msgs::ServerSync), + Disconnected, +} + pub async fn handle( state: Arc>, mut connection_info_receiver: watch::Receiver>, crypt_state_sender: mpsc::Sender, mut packet_receiver: mpsc::UnboundedReceiver>, + mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, ) { loop { let connection_info = loop { @@ -54,6 +69,7 @@ pub async fn handle( let phase_watcher = state_lock.phase_receiver(); let packet_sender = state_lock.packet_sender(); drop(state_lock); + let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); @@ -63,9 +79,11 @@ pub async fn handle( Arc::clone(&state), stream, crypt_state_sender.clone(), - phase_watcher.clone() + Arc::clone(&event_queue), + phase_watcher.clone(), ), - send_packets(sink, &mut packet_receiver, phase_watcher), + send_packets(sink, &mut packet_receiver, phase_watcher.clone()), + register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), ); debug!("Fully disconnected TCP stream, waiting for new connection info"); @@ -200,6 +218,7 @@ async fn listen( state: Arc>, mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender, + event_data: Arc>>>, mut phase_watcher: watch::Receiver, ) { let mut crypt_state = None; @@ -267,6 +286,12 @@ async fn listen( ) .await; } + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Connected(&msg)); + } + } let mut state = state.lock().unwrap(); let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); @@ -320,6 +345,13 @@ async fn listen( } } + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Disconnected); + } + } + //TODO? clean up stream }; @@ -327,3 +359,45 @@ async fn listen( debug!("Killing TCP listener block"); } + +async fn register_events( + tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, + event_data: Arc>>>, + mut phase_watcher: watch::Receiver, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = tcp_event_register_receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some((event, handler))) => { + event_data.lock().unwrap().entry(event).or_default().push(handler); + } + } + } + }; + + join!(main_block, phase_transition_block); +} \ No newline at end of file diff --git a/mumd/src/state.rs b/mumd/src/state.rs index e9db616..c247b08 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -14,6 +14,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::net::ToSocketAddrs; use tokio::sync::{mpsc, watch}; +use crate::network::tcp::{TcpEventCallback, TcpEvent}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum StatePhase { @@ -55,11 +56,11 @@ impl State { pub async fn handle_command( &mut self, command: Command, - ) -> (bool, mumlib::error::Result>) { + ) -> (Option<(TcpEvent, TcpEventCallback)>, mumlib::error::Result>) { match command { Command::ChannelJoin { channel_identifier } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } let channels = self.server() @@ -77,27 +78,27 @@ impl State { .filter(|e| e.1.ends_with(&channel_identifier.to_lowercase())) .collect::>(); match soft_matches.len() { - 0 => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + 0 => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), 1 => *soft_matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), } }, 1 => *matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))), + _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))), }; let mut msg = msgs::UserState::new(); msg.set_session(self.server.as_ref().unwrap().session_id.unwrap()); msg.set_channel_id(id); self.packet_sender.send(msg.into()).unwrap(); - (false, Ok(None)) + (None, Ok(None)) } Command::ChannelList => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } ( - false, + None, Ok(Some(CommandResponse::ChannelList { channels: into_channel( self.server.as_ref().unwrap().channels(), @@ -113,7 +114,7 @@ impl State { accept_invalid_cert, } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { - return (false, Err(Error::AlreadyConnectedError)); + return (None, Err(Error::AlreadyConnectedError)); } let mut server = Server::new(); server.username = Some(username); @@ -131,7 +132,7 @@ impl State { Ok(Some(v)) => v, _ => { warn!("Error parsing server addr"); - return (false, Err(Error::InvalidServerAddrError(host, port))); + return (None, Err(Error::InvalidServerAddrError(host, port))); } }; self.connection_info_sender @@ -141,14 +142,14 @@ impl State { accept_invalid_cert, ))) .unwrap(); - (true, Ok(None)) + (Some((TcpEvent::Connected, Box::new(|_| {}))), Ok(None)) } Command::Status => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } ( - false, + None, Ok(Some(CommandResponse::Status { server_state: self.server.as_ref().unwrap().into(), //guaranteed not to panic because if we are connected, server is guaranteed to be Some })), @@ -156,7 +157,7 @@ impl State { } Command::ServerDisconnect => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Err(Error::DisconnectedError)); } self.server = None; @@ -166,11 +167,11 @@ impl State { .0 .broadcast(StatePhase::Disconnected) .unwrap(); - (false, Ok(None)) + (None, Ok(None)) } Command::InputVolumeSet(volume) => { self.audio.set_input_volume(volume); - (false, Ok(None)) + (None, Ok(None)) } Command::ConfigReload => { self.reload_config(); -- cgit v1.2.1 From d215385473f5380a1166101596e135ec6ede5501 Mon Sep 17 00:00:00 2001 From: Eskil Queseth Date: Wed, 21 Oct 2020 02:24:16 +0200 Subject: add printing of welcome message to server connect --- mumd/src/command.rs | 22 +++++++++--------- mumd/src/state.rs | 64 ++++++++++++++++++++++++++++++++--------------------- 2 files changed, 50 insertions(+), 36 deletions(-) (limited to 'mumd/src') diff --git a/mumd/src/command.rs b/mumd/src/command.rs index 5285a9d..075bfaf 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -18,21 +18,21 @@ pub async fn handle( debug!("Begin listening for commands"); while let Some((command, response_sender)) = command_receiver.recv().await { debug!("Received command {:?}", command); - let mut statee = state.lock().unwrap(); - let (event_data, command_response) = statee.handle_command(command).await; - drop(statee); - if let Some((event, callback)) = event_data { + let mut state = state.lock().unwrap(); + let (event, generator) = state.handle_command(command).await; + drop(state); + if let Some(event) = event { let (tx, rx) = oneshot::channel(); - tcp_event_register_sender.send((event, Box::new(move |e| { - println!("något hände"); - callback(e); - response_sender.send(command_response).unwrap(); - tx.send(()); + //TODO handle this error + let _ = tcp_event_register_sender.send((event, Box::new(move |e| { + let response = generator(Some(e)); + response_sender.send(response).unwrap(); + tx.send(()).unwrap(); }))); - rx.await; + rx.await.unwrap(); } else { - response_sender.send(command_response).unwrap(); + response_sender.send(generator(None)).unwrap(); } } } diff --git a/mumd/src/state.rs b/mumd/src/state.rs index c247b08..0822de0 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -14,7 +14,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::net::ToSocketAddrs; use tokio::sync::{mpsc, watch}; -use crate::network::tcp::{TcpEventCallback, TcpEvent}; +use crate::network::tcp::{TcpEvent, TcpEventData}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum StatePhase { @@ -56,11 +56,11 @@ impl State { pub async fn handle_command( &mut self, command: Command, - ) -> (Option<(TcpEvent, TcpEventCallback)>, mumlib::error::Result>) { + ) -> (Option, Box) -> mumlib::error::Result>>) { match command { Command::ChannelJoin { channel_identifier } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (None, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } let channels = self.server() @@ -78,33 +78,34 @@ impl State { .filter(|e| e.1.ends_with(&channel_identifier.to_lowercase())) .collect::>(); match soft_matches.len() { - 0 => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + 0 => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))), 1 => *soft_matches.get(0).unwrap().0, - _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))), } }, 1 => *matches.get(0).unwrap().0, - _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))), + _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous)))), }; let mut msg = msgs::UserState::new(); msg.set_session(self.server.as_ref().unwrap().session_id.unwrap()); msg.set_channel_id(id); self.packet_sender.send(msg.into()).unwrap(); - (None, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::ChannelList => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (None, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } + let list = into_channel( + self.server.as_ref().unwrap().channels(), + self.server.as_ref().unwrap().users(), + ); ( None, - Ok(Some(CommandResponse::ChannelList { - channels: into_channel( - self.server.as_ref().unwrap().channels(), - self.server.as_ref().unwrap().users(), - ), - })), + Box::new(move |_| Ok(Some(CommandResponse::ChannelList { + channels: list, + }))), ) } Command::ServerConnect { @@ -114,7 +115,7 @@ impl State { accept_invalid_cert, } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { - return (None, Err(Error::AlreadyConnectedError)); + return (None, Box::new(|_| Err(Error::AlreadyConnectedError))); } let mut server = Server::new(); server.username = Some(username); @@ -132,7 +133,7 @@ impl State { Ok(Some(v)) => v, _ => { warn!("Error parsing server addr"); - return (None, Err(Error::InvalidServerAddrError(host, port))); + return (None, Box::new(move |_| Err(Error::InvalidServerAddrError(host, port)))); } }; self.connection_info_sender @@ -142,22 +143,35 @@ impl State { accept_invalid_cert, ))) .unwrap(); - (Some((TcpEvent::Connected, Box::new(|_| {}))), Ok(None)) + (Some(TcpEvent::Connected), Box::new(|e| { + if let Some(TcpEventData::Connected(msg)) = e { + Ok(Some(CommandResponse::ServerConnect { + welcome_message: if msg.has_welcome_text() { + Some(msg.get_welcome_text().to_string()) + } else { + None + } + })) + } else { + unreachable!("callback should be provided with a TcpEventData::Connected"); + } + })) } Command::Status => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (None, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } + let state = self.server.as_ref().unwrap().into(); ( None, - Ok(Some(CommandResponse::Status { - server_state: self.server.as_ref().unwrap().into(), //guaranteed not to panic because if we are connected, server is guaranteed to be Some - })), + Box::new(move |_| Ok(Some(CommandResponse::Status { + server_state: state, //guaranteed not to panic because if we are connected, server is guaranteed to be Some + }))), ) } Command::ServerDisconnect => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (None, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } self.server = None; @@ -167,15 +181,15 @@ impl State { .0 .broadcast(StatePhase::Disconnected) .unwrap(); - (None, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::InputVolumeSet(volume) => { self.audio.set_input_volume(volume); - (None, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::ConfigReload => { self.reload_config(); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } } } -- cgit v1.2.1 From c0f24a185ebc270cbaa121cc8ce69ac8ba3f4d30 Mon Sep 17 00:00:00 2001 From: Eskil Queseth Date: Wed, 21 Oct 2020 03:52:58 +0200 Subject: reduce code duplication --- mumd/src/network/tcp.rs | 389 ++++++++++++++++++++++-------------------------- 1 file changed, 178 insertions(+), 211 deletions(-) (limited to 'mumd/src') diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index ab49417..7ac0474 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -16,6 +16,9 @@ use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; use std::collections::HashMap; +use std::future::Future; +use std::rc::Rc; +use std::cell::RefCell; type TcpSender = SplitSink< Framed, ControlCodec>, @@ -126,236 +129,176 @@ async fn authenticate(sink: &mut TcpSender, username: String) { async fn send_pings( packet_sender: mpsc::UnboundedSender>, delay_seconds: u64, - mut phase_watcher: watch::Receiver, + phase_watcher: watch::Receiver, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - while !matches!( - phase_watcher.recv().await.unwrap(), - StatePhase::Disconnected - ) {} - tx.send(true).unwrap(); - }; - - let mut interval = time::interval(Duration::from_secs(delay_seconds)); - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let interval_waiter = interval.tick().fuse(); - pin_mut!(interval_waiter); - let exitor = select! { - data = interval_waiter => Some(data), - _ = rx => None - }; - - match exitor { - Some(_) => { - trace!("Sending ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } - None => break, - } - } - }; - - join!(main_block, phase_transition_block); + let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs(delay_seconds)))); + let packet_sender = Rc::new(RefCell::new(packet_sender)); + + run_until_disconnection( + || async { + Some(interval.borrow_mut().tick().await) + }, + |_| async { + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.borrow_mut().send(msg.into()).unwrap(); + }, + || async {}, + phase_watcher, + ).await; debug!("Ping sender process killed"); } async fn send_packets( - mut sink: TcpSender, + sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver>, - mut phase_watcher: watch::Receiver, + phase_watcher: watch::Receiver, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - while !matches!( - phase_watcher.recv().await.unwrap(), - StatePhase::Disconnected - ) {} - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = packet_receiver.recv().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(packet)) => { - sink.send(packet).await.unwrap(); - } - } - } - - //clears queue of remaining packets - while packet_receiver.try_recv().is_ok() {} - - sink.close().await.unwrap(); - }; - - join!(main_block, phase_transition_block); + let sink = Rc::new(RefCell::new(sink)); + let packet_receiver = Rc::new(RefCell::new(packet_receiver)); + run_until_disconnection( + || async { + packet_receiver.borrow_mut().recv().await + }, + |packet| async { + sink.borrow_mut().send(packet).await.unwrap(); + }, + || async { + //clears queue of remaining packets + while packet_receiver.borrow_mut().try_recv().is_ok() {} + + sink.borrow_mut().close().await.unwrap(); + }, + phase_watcher, + ).await; debug!("TCP packet sender killed"); } async fn listen( state: Arc>, - mut stream: TcpReceiver, + stream: TcpReceiver, crypt_state_sender: mpsc::Sender, event_data: Arc>>>, - mut phase_watcher: watch::Receiver, + phase_watcher: watch::Receiver, ) { - let mut crypt_state = None; - let mut crypt_state_sender = Some(crypt_state_sender); - - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - while !matches!( - phase_watcher.recv().await.unwrap(), - StatePhase::Disconnected - ) {} - tx.send(true).unwrap(); - }; - - let listener_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = stream.next().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; + let crypt_state = Rc::new(RefCell::new(None)); + let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); + + let stream = Rc::new(RefCell::new(stream)); + run_until_disconnection( + || async { + stream.borrow_mut().next().await + }, + |packet| async { + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); } - Some(Some(packet)) => { - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - crypt_state = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(mut sender) = crypt_state_sender.take() { - let _ = sender - .send( - crypt_state - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; - } - if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Connected) { - let old = std::mem::take(vec); - for handler in old { - handler(&TcpEventData::Connected(&msg)); - } - } - let mut state = state.lock().unwrap(); - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), - } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); - } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - let mut state = state.lock().unwrap(); - let session = msg.get_session(); - - let user_state_diff = state.parse_user_state(*msg); - //TODO do something with user state diff - debug!("user state diff: {:#?}", &user_state_diff); - - let server = state.server_mut().unwrap(); - let user = server.users().get(&session).unwrap(); - info!("User {} connected to {}", user.name(), user.channel()); - } - ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_remove(*msg); + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(mut sender) = crypt_state_sender.borrow_mut().take() { + let _ = sender + .send( + crypt_state + .borrow_mut() + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Connected(&msg)); } - _ => {} } + let mut state = state.lock().unwrap(); + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + let mut state = state.lock().unwrap(); + let session = msg.get_session(); + if *state.phase_receiver().borrow() == StatePhase::Connecting { + state.audio_mut().add_client(msg.get_session()); + state.parse_user_state(*msg); + } else { + state.parse_user_state(*msg); + } + let server = state.server_mut().unwrap(); + let user = server.users().get(&session).unwrap(); + info!("User {} connected to {}", user.name(), user.channel()); + } + ControlPacket::UserRemove(msg) => { + info!("User {} left", msg.get_session()); + state + .lock() + .unwrap() + .audio_mut() + .remove_client(msg.get_session()); } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + _ => {} } - } - - if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Disconnected) { - let old = std::mem::take(vec); - for handler in old { - handler(&TcpEventData::Disconnected); + }, + || async { + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Disconnected); + } } - } - - //TODO? clean up stream - }; - - join!(phase_transition_block, listener_block); + }, + phase_watcher, + ).await; debug!("Killing TCP listener block"); } @@ -363,8 +306,30 @@ async fn listen( async fn register_events( tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, event_data: Arc>>>, - mut phase_watcher: watch::Receiver, + phase_watcher: watch::Receiver, ) { + let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); + run_until_disconnection( + || async { + tcp_event_register_receiver.borrow_mut().recv().await + }, + |(event, handler)| async { event_data.lock().unwrap().entry(event).or_default().push(handler); }, + || async {}, + phase_watcher, + ).await; +} + +async fn run_until_disconnection( + mut generator: impl FnMut() -> F, + mut handler: impl FnMut(T) -> G, + mut shutdown: impl FnMut() -> H, + mut phase_watcher: watch::Receiver, +) + where + F: Future>, + G: Future, + H: Future, +{ let (tx, rx) = oneshot::channel(); let phase_transition_block = async { while !matches!( @@ -378,7 +343,7 @@ async fn register_events( let rx = rx.fuse(); pin_mut!(rx); loop { - let packet_recv = tcp_event_register_receiver.recv().fuse(); + let packet_recv = generator().fuse(); pin_mut!(packet_recv); let exitor = select! { data = packet_recv => Some(data), @@ -389,14 +354,16 @@ async fn register_events( break; } Some(None) => { - warn!("Channel closed before disconnect command"); + //warn!("Channel closed before disconnect command"); //TODO make me informative break; } - Some(Some((event, handler))) => { - event_data.lock().unwrap().entry(event).or_default().push(handler); + Some(Some(data)) => { + handler(data).await; } } } + + shutdown().await; }; join!(main_block, phase_transition_block); -- cgit v1.2.1 From 9d865becb19e7ce870b23c4d96d9127baff44d56 Mon Sep 17 00:00:00 2001 From: Eskil Queseth Date: Wed, 21 Oct 2020 04:05:22 +0200 Subject: minor changes --- mumd/src/network/tcp.rs | 10 +++++----- mumd/src/state.rs | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'mumd/src') diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 7ac0474..c2cb234 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -31,8 +31,8 @@ pub(crate) type TcpEventCallback = Box; #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum TcpEvent { - Connected, - Disconnected, + Connected, //fires when the client has connected to a server + Disconnected, //fires when the client has disconnected from a server } pub enum TcpEventData<'a> { @@ -180,7 +180,7 @@ async fn listen( state: Arc>, stream: TcpReceiver, crypt_state_sender: mpsc::Sender, - event_data: Arc>>>, + event_queue: Arc>>>, phase_watcher: watch::Receiver, ) { let crypt_state = Rc::new(RefCell::new(None)); @@ -227,7 +227,7 @@ async fn listen( ) .await; } - if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Connected) { + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { let old = std::mem::take(vec); for handler in old { handler(&TcpEventData::Connected(&msg)); @@ -290,7 +290,7 @@ async fn listen( } }, || async { - if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { let old = std::mem::take(vec); for handler in old { handler(&TcpEventData::Disconnected); diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 0822de0..1a02068 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -143,7 +143,7 @@ impl State { accept_invalid_cert, ))) .unwrap(); - (Some(TcpEvent::Connected), Box::new(|e| { + (Some(TcpEvent::Connected), Box::new(|e| { //runs the closure when the client is connected if let Some(TcpEventData::Connected(msg)) = e { Ok(Some(CommandResponse::ServerConnect { welcome_message: if msg.has_welcome_text() { -- cgit v1.2.1