diff options
Diffstat (limited to 'mumd/src/network/tcp.rs')
| -rw-r--r-- | mumd/src/network/tcp.rs | 385 |
1 files changed, 213 insertions, 172 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 6471771..c2cb234 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -15,6 +15,10 @@ use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use std::collections::HashMap; +use std::future::Future; +use std::rc::Rc; +use std::cell::RefCell; type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, @@ -23,11 +27,25 @@ type TcpSender = SplitSink< type TcpReceiver = SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>; +pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum TcpEvent { + Connected, //fires when the client has connected to a server + Disconnected, //fires when the client has disconnected from a server +} + +pub enum TcpEventData<'a> { + Connected(&'a msgs::ServerSync), + Disconnected, +} + pub async fn handle( state: Arc<Mutex<State>>, mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, crypt_state_sender: mpsc::Sender<ClientCryptState>, mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, + mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, ) { loop { let connection_info = loop { @@ -54,6 +72,7 @@ pub async fn handle( let phase_watcher = state_lock.phase_receiver(); let packet_sender = state_lock.packet_sender(); drop(state_lock); + let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); @@ -63,9 +82,11 @@ pub async fn handle( Arc::clone(&state), stream, crypt_state_sender.clone(), - phase_watcher.clone() + Arc::clone(&event_queue), + phase_watcher.clone(), ), - send_packets(sink, &mut packet_receiver, phase_watcher), + send_packets(sink, &mut packet_receiver, phase_watcher.clone()), + register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), ); debug!("Fully disconnected TCP stream, waiting for new connection info"); @@ -108,103 +129,207 @@ async fn authenticate(sink: &mut TcpSender, username: String) { async fn send_pings( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, delay_seconds: u64, - mut phase_watcher: watch::Receiver<StatePhase>, + phase_watcher: watch::Receiver<StatePhase>, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - while !matches!( - phase_watcher.recv().await.unwrap(), - StatePhase::Disconnected - ) {} - tx.send(true).unwrap(); - }; - - let mut interval = time::interval(Duration::from_secs(delay_seconds)); - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let interval_waiter = interval.tick().fuse(); - pin_mut!(interval_waiter); - let exitor = select! { - data = interval_waiter => Some(data), - _ = rx => None - }; - - match exitor { - Some(_) => { - trace!("Sending ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } - None => break, - } - } - }; + let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs(delay_seconds)))); + let packet_sender = Rc::new(RefCell::new(packet_sender)); - join!(main_block, phase_transition_block); + run_until_disconnection( + || async { + Some(interval.borrow_mut().tick().await) + }, + |_| async { + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.borrow_mut().send(msg.into()).unwrap(); + }, + || async {}, + phase_watcher, + ).await; debug!("Ping sender process killed"); } async fn send_packets( - mut sink: TcpSender, + sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, - mut phase_watcher: watch::Receiver<StatePhase>, + phase_watcher: watch::Receiver<StatePhase>, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - while !matches!( - phase_watcher.recv().await.unwrap(), - StatePhase::Disconnected - ) {} - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = packet_receiver.recv().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(packet)) => { - sink.send(packet).await.unwrap(); - } - } - } + let sink = Rc::new(RefCell::new(sink)); + let packet_receiver = Rc::new(RefCell::new(packet_receiver)); + run_until_disconnection( + || async { + packet_receiver.borrow_mut().recv().await + }, + |packet| async { + sink.borrow_mut().send(packet).await.unwrap(); + }, + || async { + //clears queue of remaining packets + while packet_receiver.borrow_mut().try_recv().is_ok() {} - //clears queue of remaining packets - while packet_receiver.try_recv().is_ok() {} - - sink.close().await.unwrap(); - }; - - join!(main_block, phase_transition_block); + sink.borrow_mut().close().await.unwrap(); + }, + phase_watcher, + ).await; debug!("TCP packet sender killed"); } async fn listen( state: Arc<Mutex<State>>, - mut stream: TcpReceiver, + stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, - mut phase_watcher: watch::Receiver<StatePhase>, + event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + phase_watcher: watch::Receiver<StatePhase>, +) { + let crypt_state = Rc::new(RefCell::new(None)); + let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); + + let stream = Rc::new(RefCell::new(stream)); + run_until_disconnection( + || async { + stream.borrow_mut().next().await + }, + |packet| async { + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(mut sender) = crypt_state_sender.borrow_mut().take() { + let _ = sender + .send( + crypt_state + .borrow_mut() + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Connected(&msg)); + } + } + let mut state = state.lock().unwrap(); + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + let mut state = state.lock().unwrap(); + let session = msg.get_session(); + if *state.phase_receiver().borrow() == StatePhase::Connecting { + state.audio_mut().add_client(msg.get_session()); + state.parse_user_state(*msg); + } else { + state.parse_user_state(*msg); + } + let server = state.server_mut().unwrap(); + let user = server.users().get(&session).unwrap(); + info!("User {} connected to {}", user.name(), user.channel()); + } + ControlPacket::UserRemove(msg) => { + info!("User {} left", msg.get_session()); + state + .lock() + .unwrap() + .audio_mut() + .remove_client(msg.get_session()); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + _ => {} + } + }, + || async { + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Disconnected); + } + } + }, + phase_watcher, + ).await; + + debug!("Killing TCP listener block"); +} + +async fn register_events( + tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, + event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + phase_watcher: watch::Receiver<StatePhase>, ) { - let mut crypt_state = None; - let mut crypt_state_sender = Some(crypt_state_sender); + let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); + run_until_disconnection( + || async { + tcp_event_register_receiver.borrow_mut().recv().await + }, + |(event, handler)| async { event_data.lock().unwrap().entry(event).or_default().push(handler); }, + || async {}, + phase_watcher, + ).await; +} +async fn run_until_disconnection<T, F, G, H>( + mut generator: impl FnMut() -> F, + mut handler: impl FnMut(T) -> G, + mut shutdown: impl FnMut() -> H, + mut phase_watcher: watch::Receiver<StatePhase>, +) + where + F: Future<Output = Option<T>>, + G: Future<Output = ()>, + H: Future<Output = ()>, +{ let (tx, rx) = oneshot::channel(); let phase_transition_block = async { while !matches!( @@ -214,11 +339,11 @@ async fn listen( tx.send(true).unwrap(); }; - let listener_block = async { + let main_block = async { let rx = rx.fuse(); pin_mut!(rx); loop { - let packet_recv = stream.next().fuse(); + let packet_recv = generator().fuse(); pin_mut!(packet_recv); let exitor = select! { data = packet_recv => Some(data), @@ -229,101 +354,17 @@ async fn listen( break; } Some(None) => { - warn!("Channel closed before disconnect command"); + //warn!("Channel closed before disconnect command"); //TODO make me informative break; } - Some(Some(packet)) => { - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - crypt_state = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(mut sender) = crypt_state_sender.take() { - let _ = sender - .send( - crypt_state - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; - } - let mut state = state.lock().unwrap(); - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), - } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); - } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - let mut state = state.lock().unwrap(); - let session = msg.get_session(); - - let user_state_diff = state.parse_user_state(*msg); - //TODO do something with user state diff - debug!("user state diff: {:#?}", &user_state_diff); - - let server = state.server_mut().unwrap(); - let user = server.users().get(&session).unwrap(); - info!("User {} connected to {}", user.name(), user.channel()); - } - ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - _ => {} - } + Some(Some(data)) => { + handler(data).await; } } } - //TODO? clean up stream + shutdown().await; }; - join!(phase_transition_block, listener_block); - - debug!("Killing TCP listener block"); -} + join!(main_block, phase_transition_block); +}
\ No newline at end of file |
