diff options
Diffstat (limited to 'mumd/src/network/tcp.rs')
| -rw-r--r-- | mumd/src/network/tcp.rs | 78 |
1 files changed, 76 insertions, 2 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 6471771..ab49417 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -15,6 +15,7 @@ use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use std::collections::HashMap; type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, @@ -23,11 +24,25 @@ type TcpSender = SplitSink< type TcpReceiver = SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>; +pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum TcpEvent { + Connected, + Disconnected, +} + +pub enum TcpEventData<'a> { + Connected(&'a msgs::ServerSync), + Disconnected, +} + pub async fn handle( state: Arc<Mutex<State>>, mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, crypt_state_sender: mpsc::Sender<ClientCryptState>, mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, + mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, ) { loop { let connection_info = loop { @@ -54,6 +69,7 @@ pub async fn handle( let phase_watcher = state_lock.phase_receiver(); let packet_sender = state_lock.packet_sender(); drop(state_lock); + let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); @@ -63,9 +79,11 @@ pub async fn handle( Arc::clone(&state), stream, crypt_state_sender.clone(), - phase_watcher.clone() + Arc::clone(&event_queue), + phase_watcher.clone(), ), - send_packets(sink, &mut packet_receiver, phase_watcher), + send_packets(sink, &mut packet_receiver, phase_watcher.clone()), + register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), ); debug!("Fully disconnected TCP stream, waiting for new connection info"); @@ -200,6 +218,7 @@ async fn listen( state: Arc<Mutex<State>>, mut stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, + event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, mut phase_watcher: watch::Receiver<StatePhase>, ) { let mut crypt_state = None; @@ -267,6 +286,12 @@ async fn listen( ) .await; } + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Connected(&msg)); + } + } let mut state = state.lock().unwrap(); let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); @@ -320,6 +345,13 @@ async fn listen( } } + if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Disconnected); + } + } + //TODO? clean up stream }; @@ -327,3 +359,45 @@ async fn listen( debug!("Killing TCP listener block"); } + +async fn register_events( + tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, + event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + mut phase_watcher: watch::Receiver<StatePhase>, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = tcp_event_register_receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some((event, handler))) => { + event_data.lock().unwrap().entry(event).or_default().push(handler); + } + } + } + }; + + join!(main_block, phase_transition_block); +}
\ No newline at end of file |
