aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src/network
diff options
context:
space:
mode:
authorGustav Sörnäs <gustav@sornas.net>2021-03-30 15:24:46 +0200
committerGitHub <noreply@github.com>2021-03-30 15:24:46 +0200
commit795e46c98616801c678bd0a403b08cb0fcd5ee43 (patch)
tree040efd79def19e28422980ebfb6ce414ff349570 /mumd/src/network
parenta6d433e3ad95b9a21d5d473da4b1f65e78585bb2 (diff)
parentb52068eade50758673e29c79e7cb8be3f1b4151f (diff)
downloadmum-795e46c98616801c678bd0a403b08cb0fcd5ee43.tar.gz
Merge pull request #79 from mum-rs/tcp-event-queue
Diffstat (limited to 'mumd/src/network')
-rw-r--r--mumd/src/network/tcp.rs108
1 files changed, 75 insertions, 33 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 8ce49cb..47b1c20 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -36,11 +36,47 @@ pub enum TcpEvent {
Disconnected, //fires when the client has disconnected from a server
}
+#[derive(Clone)]
pub enum TcpEventData<'a> {
- Connected(&'a msgs::ServerSync),
+ Connected(Result<&'a msgs::ServerSync, mumlib::Error>),
Disconnected,
}
+impl<'a> From<&TcpEventData<'a>> for TcpEvent {
+ fn from(t: &TcpEventData) -> Self {
+ match t {
+ TcpEventData::Connected(_) => TcpEvent::Connected,
+ TcpEventData::Disconnected => TcpEvent::Disconnected,
+ }
+ }
+}
+
+#[derive(Clone)]
+struct TcpEventQueue {
+ handlers: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+}
+
+impl TcpEventQueue {
+ fn new() -> Self {
+ Self {
+ handlers: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ async fn register(&self, at: TcpEvent, callback: TcpEventCallback) {
+ self.handlers.lock().await.entry(at).or_default().push(callback);
+ }
+
+ async fn resolve<'a>(&self, data: TcpEventData<'a>) {
+ if let Some(vec) = self.handlers.lock().await.get_mut(&TcpEvent::from(&data)) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(data.clone());
+ }
+ }
+ }
+}
+
pub async fn handle(
state: Arc<Mutex<State>>,
mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
@@ -73,7 +109,7 @@ pub async fn handle(
let phase_watcher = state_lock.phase_receiver();
let input_receiver = state_lock.audio().input_receiver();
drop(state_lock);
- let event_queue = Arc::new(Mutex::new(HashMap::new()));
+ let event_queue = TcpEventQueue::new();
info!("Logging in...");
@@ -85,7 +121,7 @@ pub async fn handle(
Arc::clone(&state),
stream,
crypt_state_sender.clone(),
- Arc::clone(&event_queue),
+ event_queue.clone(),
),
send_voice(
packet_sender.clone(),
@@ -93,17 +129,12 @@ pub async fn handle(
phase_watcher.clone(),
),
send_packets(sink, &mut packet_receiver),
- register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)),
+ register_events(&mut tcp_event_register_receiver, event_queue.clone()),
).map(|_| ()),
phase_watcher,
).await;
- if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Disconnected);
- }
- }
+ event_queue.resolve(TcpEventData::Disconnected).await;
debug!("Fully disconnected TCP stream, waiting for new connection info");
}
@@ -154,12 +185,12 @@ async fn send_pings(
delay_seconds: u64,
) {
let mut interval = time::interval(Duration::from_secs(delay_seconds));
- loop {
- interval.tick().await;
- trace!("Sending TCP ping");
- let msg = msgs::Ping::new();
- packet_sender.send(msg.into()).unwrap();
- }
+ loop {
+ interval.tick().await;
+ trace!("Sending TCP ping");
+ let msg = msgs::Ping::new();
+ packet_sender.send(msg.into()).unwrap();
+ }
}
async fn send_packets(
@@ -209,14 +240,27 @@ async fn listen(
state: Arc<Mutex<State>>,
mut stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
- event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ event_queue: TcpEventQueue,
) {
let mut crypt_state = None;
let mut crypt_state_sender = Some(crypt_state_sender);
loop {
- let packet = stream.next().await.unwrap();
- match packet.unwrap() {
+ let packet = match stream.next().await {
+ Some(Ok(packet)) => packet,
+ Some(Err(e)) => {
+ error!("TCP error: {:?}", e);
+ continue; //TODO Break here? Maybe look at the error and handle it
+ }
+ None => {
+ // We end up here if the login was rejected. We probably want
+ // to exit before that.
+ warn!("TCP stream gone");
+ state.lock().await.broadcast_phase(StatePhase::Disconnected);
+ break;
+ }
+ };
+ match packet {
ControlPacket::TextMessage(msg) => {
info!(
"Got message from user with session ID {}: {}",
@@ -250,12 +294,7 @@ async fn listen(
)
.await;
}
- if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Connected(&msg));
- }
- }
+ event_queue.resolve(TcpEventData::Connected(Ok(&msg))).await;
let mut state = state.lock().await;
let server = state.server_mut().unwrap();
server.parse_server_sync(*msg);
@@ -269,7 +308,15 @@ async fn listen(
state.initialized();
}
ControlPacket::Reject(msg) => {
- warn!("Login rejected: {:?}", msg);
+ debug!("Login rejected: {:?}", msg);
+ match msg.get_field_type() {
+ msgs::Reject_RejectType::WrongServerPW => {
+ event_queue.resolve(TcpEventData::Connected(Err(mumlib::Error::InvalidServerPassword))).await;
+ }
+ ty => {
+ warn!("Unhandled reject type: {:?}", ty);
+ }
+ }
}
ControlPacket::UserState(msg) => {
state.lock().await.parse_user_state(*msg);
@@ -324,15 +371,10 @@ async fn listen(
async fn register_events(
tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
- event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ event_queue: TcpEventQueue,
) {
loop {
let (event, handler) = tcp_event_register_receiver.recv().await.unwrap();
- event_data
- .lock()
- .await
- .entry(event)
- .or_default()
- .push(handler);
+ event_queue.register(event, handler).await;
}
}