aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src/network/tcp.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mumd/src/network/tcp.rs')
-rw-r--r--mumd/src/network/tcp.rs391
1 files changed, 213 insertions, 178 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 88d2b59..c2cb234 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -15,6 +15,10 @@ use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::{self, Duration};
use tokio_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
+use std::collections::HashMap;
+use std::future::Future;
+use std::rc::Rc;
+use std::cell::RefCell;
type TcpSender = SplitSink<
Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
@@ -23,11 +27,25 @@ type TcpSender = SplitSink<
type TcpReceiver =
SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>;
+pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>;
+
+#[derive(Debug, Clone, Hash, Eq, PartialEq)]
+pub enum TcpEvent {
+ Connected, //fires when the client has connected to a server
+ Disconnected, //fires when the client has disconnected from a server
+}
+
+pub enum TcpEventData<'a> {
+ Connected(&'a msgs::ServerSync),
+ Disconnected,
+}
+
pub async fn handle(
state: Arc<Mutex<State>>,
mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
+ mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
) {
loop {
let connection_info = loop {
@@ -54,6 +72,7 @@ pub async fn handle(
let phase_watcher = state_lock.phase_receiver();
let packet_sender = state_lock.packet_sender();
drop(state_lock);
+ let event_queue = Arc::new(Mutex::new(HashMap::new()));
info!("Logging in...");
@@ -63,9 +82,11 @@ pub async fn handle(
Arc::clone(&state),
stream,
crypt_state_sender.clone(),
- phase_watcher.clone()
+ Arc::clone(&event_queue),
+ phase_watcher.clone(),
),
- send_packets(sink, &mut packet_receiver, phase_watcher),
+ send_packets(sink, &mut packet_receiver, phase_watcher.clone()),
+ register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher),
);
debug!("Fully disconnected TCP stream, waiting for new connection info");
@@ -108,103 +129,207 @@ async fn authenticate(sink: &mut TcpSender, username: String) {
async fn send_pings(
packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
delay_seconds: u64,
- mut phase_watcher: watch::Receiver<StatePhase>,
+ phase_watcher: watch::Receiver<StatePhase>,
) {
- let (tx, rx) = oneshot::channel();
- let phase_transition_block = async {
- while !matches!(
- phase_watcher.recv().await.unwrap(),
- StatePhase::Disconnected
- ) {}
- tx.send(true).unwrap();
- };
+ let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs(delay_seconds))));
+ let packet_sender = Rc::new(RefCell::new(packet_sender));
- let mut interval = time::interval(Duration::from_secs(delay_seconds));
- let main_block = async {
- let rx = rx.fuse();
- pin_mut!(rx);
- loop {
- let interval_waiter = interval.tick().fuse();
- pin_mut!(interval_waiter);
- let exitor = select! {
- data = interval_waiter => Some(data),
- _ = rx => None
- };
-
- match exitor {
- Some(_) => {
- trace!("Sending ping");
- let msg = msgs::Ping::new();
- packet_sender.send(msg.into()).unwrap();
- }
- None => break,
- }
- }
- };
-
- join!(main_block, phase_transition_block);
+ run_until_disconnection(
+ || async {
+ Some(interval.borrow_mut().tick().await)
+ },
+ |_| async {
+ trace!("Sending ping");
+ let msg = msgs::Ping::new();
+ packet_sender.borrow_mut().send(msg.into()).unwrap();
+ },
+ || async {},
+ phase_watcher,
+ ).await;
debug!("Ping sender process killed");
}
async fn send_packets(
- mut sink: TcpSender,
+ sink: TcpSender,
packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
- mut phase_watcher: watch::Receiver<StatePhase>,
+ phase_watcher: watch::Receiver<StatePhase>,
) {
- let (tx, rx) = oneshot::channel();
- let phase_transition_block = async {
- while !matches!(
- phase_watcher.recv().await.unwrap(),
- StatePhase::Disconnected
- ) {}
- tx.send(true).unwrap();
- };
-
- let main_block = async {
- let rx = rx.fuse();
- pin_mut!(rx);
- loop {
- let packet_recv = packet_receiver.recv().fuse();
- pin_mut!(packet_recv);
- let exitor = select! {
- data = packet_recv => Some(data),
- _ = rx => None
- };
- match exitor {
- None => {
- break;
- }
- Some(None) => {
- warn!("Channel closed before disconnect command");
- break;
- }
- Some(Some(packet)) => {
- sink.send(packet).await.unwrap();
- }
- }
- }
+ let sink = Rc::new(RefCell::new(sink));
+ let packet_receiver = Rc::new(RefCell::new(packet_receiver));
+ run_until_disconnection(
+ || async {
+ packet_receiver.borrow_mut().recv().await
+ },
+ |packet| async {
+ sink.borrow_mut().send(packet).await.unwrap();
+ },
+ || async {
+ //clears queue of remaining packets
+ while packet_receiver.borrow_mut().try_recv().is_ok() {}
- //clears queue of remaining packets
- while packet_receiver.try_recv().is_ok() {}
-
- sink.close().await.unwrap();
- };
-
- join!(main_block, phase_transition_block);
+ sink.borrow_mut().close().await.unwrap();
+ },
+ phase_watcher,
+ ).await;
debug!("TCP packet sender killed");
}
async fn listen(
state: Arc<Mutex<State>>,
- mut stream: TcpReceiver,
+ stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
- mut phase_watcher: watch::Receiver<StatePhase>,
+ event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ phase_watcher: watch::Receiver<StatePhase>,
) {
- let mut crypt_state = None;
- let mut crypt_state_sender = Some(crypt_state_sender);
+ let crypt_state = Rc::new(RefCell::new(None));
+ let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender)));
+ let stream = Rc::new(RefCell::new(stream));
+ run_until_disconnection(
+ || async {
+ stream.borrow_mut().next().await
+ },
+ |packet| async {
+ match packet.unwrap() {
+ ControlPacket::TextMessage(msg) => {
+ info!(
+ "Got message from user with session ID {}: {}",
+ msg.get_actor(),
+ msg.get_message()
+ );
+ }
+ ControlPacket::CryptSetup(msg) => {
+ debug!("Crypt setup");
+ // Wait until we're fully connected before initiating UDP voice
+ *crypt_state.borrow_mut() = Some(ClientCryptState::new_from(
+ msg.get_key()
+ .try_into()
+ .expect("Server sent private key with incorrect size"),
+ msg.get_client_nonce()
+ .try_into()
+ .expect("Server sent client_nonce with incorrect size"),
+ msg.get_server_nonce()
+ .try_into()
+ .expect("Server sent server_nonce with incorrect size"),
+ ));
+ }
+ ControlPacket::ServerSync(msg) => {
+ info!("Logged in");
+ if let Some(mut sender) = crypt_state_sender.borrow_mut().take() {
+ let _ = sender
+ .send(
+ crypt_state
+ .borrow_mut()
+ .take()
+ .expect("Server didn't send us any CryptSetup packet!"),
+ )
+ .await;
+ }
+ if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(&TcpEventData::Connected(&msg));
+ }
+ }
+ let mut state = state.lock().unwrap();
+ let server = state.server_mut().unwrap();
+ server.parse_server_sync(*msg);
+ match &server.welcome_text {
+ Some(s) => info!("Welcome: {}", s),
+ None => info!("No welcome received"),
+ }
+ for channel in server.channels().values() {
+ info!("Found channel {}", channel.name());
+ }
+ state.initialized();
+ }
+ ControlPacket::Reject(msg) => {
+ warn!("Login rejected: {:?}", msg);
+ }
+ ControlPacket::UserState(msg) => {
+ let mut state = state.lock().unwrap();
+ let session = msg.get_session();
+ if *state.phase_receiver().borrow() == StatePhase::Connecting {
+ state.audio_mut().add_client(msg.get_session());
+ state.parse_user_state(*msg);
+ } else {
+ state.parse_user_state(*msg);
+ }
+ let server = state.server_mut().unwrap();
+ let user = server.users().get(&session).unwrap();
+ info!("User {} connected to {}", user.name(), user.channel());
+ }
+ ControlPacket::UserRemove(msg) => {
+ info!("User {} left", msg.get_session());
+ state
+ .lock()
+ .unwrap()
+ .audio_mut()
+ .remove_client(msg.get_session());
+ }
+ ControlPacket::ChannelState(msg) => {
+ debug!("Channel state received");
+ state
+ .lock()
+ .unwrap()
+ .server_mut()
+ .unwrap()
+ .parse_channel_state(*msg); //TODO parse initial if initial
+ }
+ ControlPacket::ChannelRemove(msg) => {
+ state
+ .lock()
+ .unwrap()
+ .server_mut()
+ .unwrap()
+ .parse_channel_remove(*msg);
+ }
+ _ => {}
+ }
+ },
+ || async {
+ if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(&TcpEventData::Disconnected);
+ }
+ }
+ },
+ phase_watcher,
+ ).await;
+
+ debug!("Killing TCP listener block");
+}
+
+async fn register_events(
+ tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
+ event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ phase_watcher: watch::Receiver<StatePhase>,
+) {
+ let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver));
+ run_until_disconnection(
+ || async {
+ tcp_event_register_receiver.borrow_mut().recv().await
+ },
+ |(event, handler)| async { event_data.lock().unwrap().entry(event).or_default().push(handler); },
+ || async {},
+ phase_watcher,
+ ).await;
+}
+
+async fn run_until_disconnection<T, F, G, H>(
+ mut generator: impl FnMut() -> F,
+ mut handler: impl FnMut(T) -> G,
+ mut shutdown: impl FnMut() -> H,
+ mut phase_watcher: watch::Receiver<StatePhase>,
+)
+ where
+ F: Future<Output = Option<T>>,
+ G: Future<Output = ()>,
+ H: Future<Output = ()>,
+{
let (tx, rx) = oneshot::channel();
let phase_transition_block = async {
while !matches!(
@@ -214,11 +339,11 @@ async fn listen(
tx.send(true).unwrap();
};
- let listener_block = async {
+ let main_block = async {
let rx = rx.fuse();
pin_mut!(rx);
loop {
- let packet_recv = stream.next().fuse();
+ let packet_recv = generator().fuse();
pin_mut!(packet_recv);
let exitor = select! {
data = packet_recv => Some(data),
@@ -229,107 +354,17 @@ async fn listen(
break;
}
Some(None) => {
- warn!("Channel closed before disconnect command");
+ //warn!("Channel closed before disconnect command"); //TODO make me informative
break;
}
- Some(Some(packet)) => {
- match packet.unwrap() {
- ControlPacket::TextMessage(msg) => {
- info!(
- "Got message from user with session ID {}: {}",
- msg.get_actor(),
- msg.get_message()
- );
- }
- ControlPacket::CryptSetup(msg) => {
- debug!("Crypt setup");
- // Wait until we're fully connected before initiating UDP voice
- crypt_state = Some(ClientCryptState::new_from(
- msg.get_key()
- .try_into()
- .expect("Server sent private key with incorrect size"),
- msg.get_client_nonce()
- .try_into()
- .expect("Server sent client_nonce with incorrect size"),
- msg.get_server_nonce()
- .try_into()
- .expect("Server sent server_nonce with incorrect size"),
- ));
- }
- ControlPacket::ServerSync(msg) => {
- info!("Logged in");
- if let Some(mut sender) = crypt_state_sender.take() {
- let _ = sender
- .send(
- crypt_state
- .take()
- .expect("Server didn't send us any CryptSetup packet!"),
- )
- .await;
- }
- let mut state = state.lock().unwrap();
- let server = state.server_mut().unwrap();
- server.parse_server_sync(*msg);
- match &server.welcome_text {
- Some(s) => info!("Welcome: {}", s),
- None => info!("No welcome received"),
- }
- for channel in server.channels().values() {
- info!("Found channel {}", channel.name());
- }
- state.initialized();
- }
- ControlPacket::Reject(msg) => {
- warn!("Login rejected: {:?}", msg);
- }
- ControlPacket::UserState(msg) => {
- let mut state = state.lock().unwrap();
- let session = msg.get_session();
- if *state.phase_receiver().borrow() == StatePhase::Connecting {
- state.audio_mut().add_client(msg.get_session());
- state.parse_initial_user_state(*msg);
- } else {
- state.server_mut().unwrap().parse_user_state(*msg);
- }
- let server = state.server_mut().unwrap();
- let user = server.users().get(&session).unwrap();
- info!("User {} connected to {}", user.name(), user.channel());
- }
- ControlPacket::UserRemove(msg) => {
- info!("User {} left", msg.get_session());
- state
- .lock()
- .unwrap()
- .audio_mut()
- .remove_client(msg.get_session());
- }
- ControlPacket::ChannelState(msg) => {
- debug!("Channel state received");
- state
- .lock()
- .unwrap()
- .server_mut()
- .unwrap()
- .parse_channel_state(*msg); //TODO parse initial if initial
- }
- ControlPacket::ChannelRemove(msg) => {
- state
- .lock()
- .unwrap()
- .server_mut()
- .unwrap()
- .parse_channel_remove(*msg);
- }
- _ => {}
- }
+ Some(Some(data)) => {
+ handler(data).await;
}
}
}
- //TODO? clean up stream
+ shutdown().await;
};
- join!(phase_transition_block, listener_block);
-
- debug!("Killing TCP listener block");
-}
+ join!(main_block, phase_transition_block);
+} \ No newline at end of file