aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src
diff options
context:
space:
mode:
Diffstat (limited to 'mumd/src')
-rw-r--r--mumd/src/command.rs38
-rw-r--r--mumd/src/main.rs4
-rw-r--r--mumd/src/network/tcp.rs78
-rw-r--r--mumd/src/state.rs33
4 files changed, 117 insertions, 36 deletions
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index a035a26..5285a9d 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -1,10 +1,11 @@
-use crate::state::{State, StatePhase};
+use crate::state::State;
use ipc_channel::ipc::IpcSender;
use log::*;
use mumlib::command::{Command, CommandResponse};
use std::sync::{Arc, Mutex};
-use tokio::sync::mpsc;
+use tokio::sync::{mpsc, oneshot};
+use crate::network::tcp::{TcpEvent, TcpEventCallback};
pub async fn handle(
state: Arc<Mutex<State>>,
@@ -12,23 +13,26 @@ pub async fn handle(
Command,
IpcSender<mumlib::error::Result<Option<CommandResponse>>>,
)>,
+ tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>,
) {
debug!("Begin listening for commands");
- while let Some(command) = command_receiver.recv().await {
- debug!("Received command {:?}", command.0);
- let mut state = state.lock().unwrap();
- let (wait_for_connected, command_response) = state.handle_command(command.0).await;
- if wait_for_connected {
- let mut watcher = state.phase_receiver();
- drop(state);
- while !matches!(watcher.recv().await.unwrap(), StatePhase::Connected) {}
+ while let Some((command, response_sender)) = command_receiver.recv().await {
+ debug!("Received command {:?}", command);
+ let mut statee = state.lock().unwrap();
+ let (event_data, command_response) = statee.handle_command(command).await;
+ drop(statee);
+ if let Some((event, callback)) = event_data {
+ let (tx, rx) = oneshot::channel();
+ tcp_event_register_sender.send((event, Box::new(move |e| {
+ println!("något hände");
+ callback(e);
+ response_sender.send(command_response).unwrap();
+ tx.send(());
+ })));
+
+ rx.await;
+ } else {
+ response_sender.send(command_response).unwrap();
}
- command.1.send(command_response).unwrap();
}
- //TODO err if not connected
- //while let Some(command) = command_receiver.recv().await {
- // debug!("Parsing command {:?}", command);
- //}
-
- //debug!("Finished handling commands");
}
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index 75726f8..e88eede 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -33,6 +33,7 @@ async fn main() {
)>();
let (connection_info_sender, connection_info_receiver) =
watch::channel::<Option<ConnectionInfo>>(None);
+ let (response_sender, response_receiver) = mpsc::unbounded_channel();
let state = State::new(packet_sender, connection_info_sender);
let state = Arc::new(Mutex::new(state));
@@ -43,13 +44,14 @@ async fn main() {
connection_info_receiver.clone(),
crypt_state_sender,
packet_receiver,
+ response_receiver,
),
network::udp::handle(
Arc::clone(&state),
connection_info_receiver.clone(),
crypt_state_receiver,
),
- command::handle(state, command_receiver,),
+ command::handle(state, command_receiver, response_sender),
spawn_blocking(move || {
// IpcSender is blocking
receive_oneshot_commands(command_sender);
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 6471771..ab49417 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -15,6 +15,7 @@ use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::{self, Duration};
use tokio_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
+use std::collections::HashMap;
type TcpSender = SplitSink<
Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
@@ -23,11 +24,25 @@ type TcpSender = SplitSink<
type TcpReceiver =
SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>;
+pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>;
+
+#[derive(Debug, Clone, Hash, Eq, PartialEq)]
+pub enum TcpEvent {
+ Connected,
+ Disconnected,
+}
+
+pub enum TcpEventData<'a> {
+ Connected(&'a msgs::ServerSync),
+ Disconnected,
+}
+
pub async fn handle(
state: Arc<Mutex<State>>,
mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
+ mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
) {
loop {
let connection_info = loop {
@@ -54,6 +69,7 @@ pub async fn handle(
let phase_watcher = state_lock.phase_receiver();
let packet_sender = state_lock.packet_sender();
drop(state_lock);
+ let event_queue = Arc::new(Mutex::new(HashMap::new()));
info!("Logging in...");
@@ -63,9 +79,11 @@ pub async fn handle(
Arc::clone(&state),
stream,
crypt_state_sender.clone(),
- phase_watcher.clone()
+ Arc::clone(&event_queue),
+ phase_watcher.clone(),
),
- send_packets(sink, &mut packet_receiver, phase_watcher),
+ send_packets(sink, &mut packet_receiver, phase_watcher.clone()),
+ register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher),
);
debug!("Fully disconnected TCP stream, waiting for new connection info");
@@ -200,6 +218,7 @@ async fn listen(
state: Arc<Mutex<State>>,
mut stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
+ event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
mut phase_watcher: watch::Receiver<StatePhase>,
) {
let mut crypt_state = None;
@@ -267,6 +286,12 @@ async fn listen(
)
.await;
}
+ if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Connected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(&TcpEventData::Connected(&msg));
+ }
+ }
let mut state = state.lock().unwrap();
let server = state.server_mut().unwrap();
server.parse_server_sync(*msg);
@@ -320,6 +345,13 @@ async fn listen(
}
}
+ if let Some(vec) = event_data.lock().unwrap().get_mut(&TcpEvent::Disconnected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(&TcpEventData::Disconnected);
+ }
+ }
+
//TODO? clean up stream
};
@@ -327,3 +359,45 @@ async fn listen(
debug!("Killing TCP listener block");
}
+
+async fn register_events(
+ tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
+ event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ mut phase_watcher: watch::Receiver<StatePhase>,
+) {
+ let (tx, rx) = oneshot::channel();
+ let phase_transition_block = async {
+ while !matches!(
+ phase_watcher.recv().await.unwrap(),
+ StatePhase::Disconnected
+ ) {}
+ tx.send(true).unwrap();
+ };
+
+ let main_block = async {
+ let rx = rx.fuse();
+ pin_mut!(rx);
+ loop {
+ let packet_recv = tcp_event_register_receiver.recv().fuse();
+ pin_mut!(packet_recv);
+ let exitor = select! {
+ data = packet_recv => Some(data),
+ _ = rx => None
+ };
+ match exitor {
+ None => {
+ break;
+ }
+ Some(None) => {
+ warn!("Channel closed before disconnect command");
+ break;
+ }
+ Some(Some((event, handler))) => {
+ event_data.lock().unwrap().entry(event).or_default().push(handler);
+ }
+ }
+ }
+ };
+
+ join!(main_block, phase_transition_block);
+} \ No newline at end of file
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index e9db616..c247b08 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -14,6 +14,7 @@ use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use tokio::sync::{mpsc, watch};
+use crate::network::tcp::{TcpEventCallback, TcpEvent};
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum StatePhase {
@@ -55,11 +56,11 @@ impl State {
pub async fn handle_command(
&mut self,
command: Command,
- ) -> (bool, mumlib::error::Result<Option<CommandResponse>>) {
+ ) -> (Option<(TcpEvent, TcpEventCallback)>, mumlib::error::Result<Option<CommandResponse>>) {
match command {
Command::ChannelJoin { channel_identifier } => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Err(Error::DisconnectedError));
}
let channels = self.server()
@@ -77,27 +78,27 @@ impl State {
.filter(|e| e.1.ends_with(&channel_identifier.to_lowercase()))
.collect::<Vec<_>>();
match soft_matches.len() {
- 0 => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))),
+ 0 => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))),
1 => *soft_matches.get(0).unwrap().0,
- _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))),
+ _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))),
}
},
1 => *matches.get(0).unwrap().0,
- _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))),
+ _ => return (None, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))),
};
let mut msg = msgs::UserState::new();
msg.set_session(self.server.as_ref().unwrap().session_id.unwrap());
msg.set_channel_id(id);
self.packet_sender.send(msg.into()).unwrap();
- (false, Ok(None))
+ (None, Ok(None))
}
Command::ChannelList => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Err(Error::DisconnectedError));
}
(
- false,
+ None,
Ok(Some(CommandResponse::ChannelList {
channels: into_channel(
self.server.as_ref().unwrap().channels(),
@@ -113,7 +114,7 @@ impl State {
accept_invalid_cert,
} => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) {
- return (false, Err(Error::AlreadyConnectedError));
+ return (None, Err(Error::AlreadyConnectedError));
}
let mut server = Server::new();
server.username = Some(username);
@@ -131,7 +132,7 @@ impl State {
Ok(Some(v)) => v,
_ => {
warn!("Error parsing server addr");
- return (false, Err(Error::InvalidServerAddrError(host, port)));
+ return (None, Err(Error::InvalidServerAddrError(host, port)));
}
};
self.connection_info_sender
@@ -141,14 +142,14 @@ impl State {
accept_invalid_cert,
)))
.unwrap();
- (true, Ok(None))
+ (Some((TcpEvent::Connected, Box::new(|_| {}))), Ok(None))
}
Command::Status => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Err(Error::DisconnectedError));
}
(
- false,
+ None,
Ok(Some(CommandResponse::Status {
server_state: self.server.as_ref().unwrap().into(), //guaranteed not to panic because if we are connected, server is guaranteed to be Some
})),
@@ -156,7 +157,7 @@ impl State {
}
Command::ServerDisconnect => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Err(Error::DisconnectedError));
}
self.server = None;
@@ -166,11 +167,11 @@ impl State {
.0
.broadcast(StatePhase::Disconnected)
.unwrap();
- (false, Ok(None))
+ (None, Ok(None))
}
Command::InputVolumeSet(volume) => {
self.audio.set_input_volume(volume);
- (false, Ok(None))
+ (None, Ok(None))
}
Command::ConfigReload => {
self.reload_config();