aboutsummaryrefslogtreecommitdiffstats
path: root/mumd
diff options
context:
space:
mode:
authorGustav Sörnäs <gustav@sornas.net>2021-03-30 15:24:46 +0200
committerGitHub <noreply@github.com>2021-03-30 15:24:46 +0200
commit795e46c98616801c678bd0a403b08cb0fcd5ee43 (patch)
tree040efd79def19e28422980ebfb6ce414ff349570 /mumd
parenta6d433e3ad95b9a21d5d473da4b1f65e78585bb2 (diff)
parentb52068eade50758673e29c79e7cb8be3f1b4151f (diff)
downloadmum-795e46c98616801c678bd0a403b08cb0fcd5ee43.tar.gz
Merge pull request #79 from mum-rs/tcp-event-queue
Diffstat (limited to 'mumd')
-rw-r--r--mumd/src/main.rs2
-rw-r--r--mumd/src/network/tcp.rs108
-rw-r--r--mumd/src/state.rs49
3 files changed, 102 insertions, 57 deletions
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index 26e8d49..276e2ce 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -34,7 +34,7 @@ async fn main() {
bincode::serialize_into((&mut command).writer(), &Command::Ping).unwrap();
if let Ok(()) = writer.send(command.freeze()).await {
if let Some(Ok(buf)) = reader.next().await {
- if let Ok(Ok::<Option<CommandResponse>, mumlib::error::Error>(Some(CommandResponse::Pong))) = bincode::deserialize(&buf) {
+ if let Ok(Ok::<Option<CommandResponse>, mumlib::Error>(Some(CommandResponse::Pong))) = bincode::deserialize(&buf) {
error!("Another instance of mumd is already running");
return;
}
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 8ce49cb..47b1c20 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -36,11 +36,47 @@ pub enum TcpEvent {
Disconnected, //fires when the client has disconnected from a server
}
+#[derive(Clone)]
pub enum TcpEventData<'a> {
- Connected(&'a msgs::ServerSync),
+ Connected(Result<&'a msgs::ServerSync, mumlib::Error>),
Disconnected,
}
+impl<'a> From<&TcpEventData<'a>> for TcpEvent {
+ fn from(t: &TcpEventData) -> Self {
+ match t {
+ TcpEventData::Connected(_) => TcpEvent::Connected,
+ TcpEventData::Disconnected => TcpEvent::Disconnected,
+ }
+ }
+}
+
+#[derive(Clone)]
+struct TcpEventQueue {
+ handlers: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+}
+
+impl TcpEventQueue {
+ fn new() -> Self {
+ Self {
+ handlers: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ async fn register(&self, at: TcpEvent, callback: TcpEventCallback) {
+ self.handlers.lock().await.entry(at).or_default().push(callback);
+ }
+
+ async fn resolve<'a>(&self, data: TcpEventData<'a>) {
+ if let Some(vec) = self.handlers.lock().await.get_mut(&TcpEvent::from(&data)) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(data.clone());
+ }
+ }
+ }
+}
+
pub async fn handle(
state: Arc<Mutex<State>>,
mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
@@ -73,7 +109,7 @@ pub async fn handle(
let phase_watcher = state_lock.phase_receiver();
let input_receiver = state_lock.audio().input_receiver();
drop(state_lock);
- let event_queue = Arc::new(Mutex::new(HashMap::new()));
+ let event_queue = TcpEventQueue::new();
info!("Logging in...");
@@ -85,7 +121,7 @@ pub async fn handle(
Arc::clone(&state),
stream,
crypt_state_sender.clone(),
- Arc::clone(&event_queue),
+ event_queue.clone(),
),
send_voice(
packet_sender.clone(),
@@ -93,17 +129,12 @@ pub async fn handle(
phase_watcher.clone(),
),
send_packets(sink, &mut packet_receiver),
- register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)),
+ register_events(&mut tcp_event_register_receiver, event_queue.clone()),
).map(|_| ()),
phase_watcher,
).await;
- if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Disconnected);
- }
- }
+ event_queue.resolve(TcpEventData::Disconnected).await;
debug!("Fully disconnected TCP stream, waiting for new connection info");
}
@@ -154,12 +185,12 @@ async fn send_pings(
delay_seconds: u64,
) {
let mut interval = time::interval(Duration::from_secs(delay_seconds));
- loop {
- interval.tick().await;
- trace!("Sending TCP ping");
- let msg = msgs::Ping::new();
- packet_sender.send(msg.into()).unwrap();
- }
+ loop {
+ interval.tick().await;
+ trace!("Sending TCP ping");
+ let msg = msgs::Ping::new();
+ packet_sender.send(msg.into()).unwrap();
+ }
}
async fn send_packets(
@@ -209,14 +240,27 @@ async fn listen(
state: Arc<Mutex<State>>,
mut stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
- event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ event_queue: TcpEventQueue,
) {
let mut crypt_state = None;
let mut crypt_state_sender = Some(crypt_state_sender);
loop {
- let packet = stream.next().await.unwrap();
- match packet.unwrap() {
+ let packet = match stream.next().await {
+ Some(Ok(packet)) => packet,
+ Some(Err(e)) => {
+ error!("TCP error: {:?}", e);
+ continue; //TODO Break here? Maybe look at the error and handle it
+ }
+ None => {
+ // We end up here if the login was rejected. We probably want
+ // to exit before that.
+ warn!("TCP stream gone");
+ state.lock().await.broadcast_phase(StatePhase::Disconnected);
+ break;
+ }
+ };
+ match packet {
ControlPacket::TextMessage(msg) => {
info!(
"Got message from user with session ID {}: {}",
@@ -250,12 +294,7 @@ async fn listen(
)
.await;
}
- if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Connected(&msg));
- }
- }
+ event_queue.resolve(TcpEventData::Connected(Ok(&msg))).await;
let mut state = state.lock().await;
let server = state.server_mut().unwrap();
server.parse_server_sync(*msg);
@@ -269,7 +308,15 @@ async fn listen(
state.initialized();
}
ControlPacket::Reject(msg) => {
- warn!("Login rejected: {:?}", msg);
+ debug!("Login rejected: {:?}", msg);
+ match msg.get_field_type() {
+ msgs::Reject_RejectType::WrongServerPW => {
+ event_queue.resolve(TcpEventData::Connected(Err(mumlib::Error::InvalidServerPassword))).await;
+ }
+ ty => {
+ warn!("Unhandled reject type: {:?}", ty);
+ }
+ }
}
ControlPacket::UserState(msg) => {
state.lock().await.parse_user_state(*msg);
@@ -324,15 +371,10 @@ async fn listen(
async fn register_events(
tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
- event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ event_queue: TcpEventQueue,
) {
loop {
let (event, handler) = tcp_event_register_receiver.recv().await.unwrap();
- event_data
- .lock()
- .await
- .entry(event)
- .or_default()
- .push(handler);
+ event_queue.register(event, handler).await;
}
}
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index b279dfd..20fe660 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -15,7 +15,8 @@ use mumble_protocol::ping::PongPacket;
use mumble_protocol::voice::Serverbound;
use mumlib::command::{Command, CommandResponse};
use mumlib::config::Config;
-use mumlib::error::{ChannelIdentifierError, Error};
+use mumlib::error::ChannelIdentifierError;
+use mumlib::Error;
use crate::state::user::UserDiff;
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::sync::{mpsc, watch};
@@ -88,7 +89,7 @@ impl State {
match command {
Command::ChannelJoin { channel_identifier } => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
let channels = self.server().unwrap().channels();
@@ -138,7 +139,7 @@ impl State {
}
Command::ChannelList => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
let list = channel::into_channel(
self.server.as_ref().unwrap().channels(),
@@ -152,7 +153,7 @@ impl State {
}
Command::DeafenSelf(toggle) => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
let server = self.server().unwrap();
@@ -210,7 +211,7 @@ impl State {
}
Command::MuteOther(string, toggle) => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
let id = self
@@ -222,7 +223,7 @@ impl State {
let (id, user) = match id {
Some(id) => (*id.0, id.1),
- None => return now!(Err(Error::InvalidUsernameError(string))),
+ None => return now!(Err(Error::InvalidUsername(string))),
};
let action = match toggle {
@@ -245,7 +246,7 @@ impl State {
}
Command::MuteSelf(toggle) => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
let server = self.server().unwrap();
@@ -313,7 +314,7 @@ impl State {
accept_invalid_cert,
} => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) {
- return now!(Err(Error::AlreadyConnectedError));
+ return now!(Err(Error::AlreadyConnected));
}
let mut server = Server::new();
*server.username_mut() = Some(username);
@@ -332,7 +333,7 @@ impl State {
Ok(Some(v)) => v,
_ => {
warn!("Error parsing server addr");
- return now!(Err(Error::InvalidServerAddrError(host, port)));
+ return now!(Err(Error::InvalidServerAddr(host, port)));
}
};
connection_info_sender
@@ -342,16 +343,18 @@ impl State {
accept_invalid_cert,
)))
.unwrap();
- at!(TcpEvent::Connected, |e| {
+ at!(TcpEvent::Connected, |res| {
//runs the closure when the client is connected
- if let TcpEventData::Connected(msg) = e {
- Ok(Some(CommandResponse::ServerConnect {
- welcome_message: if msg.has_welcome_text() {
- Some(msg.get_welcome_text().to_string())
- } else {
- None
- },
- }))
+ if let TcpEventData::Connected(res) = res {
+ res.map(|msg| {
+ Some(CommandResponse::ServerConnect {
+ welcome_message: if msg.has_welcome_text() {
+ Some(msg.get_welcome_text().to_string())
+ } else {
+ None
+ },
+ })
+ })
} else {
unreachable!("callback should be provided with a TcpEventData::Connected");
}
@@ -359,7 +362,7 @@ impl State {
}
Command::ServerDisconnect => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
self.server = None;
@@ -379,7 +382,7 @@ impl State {
.map(|mut e| e.next())
{
Ok(Some(v)) => Ok(v),
- _ => Err(mumlib::error::Error::InvalidServerAddrError(host, port)),
+ _ => Err(Error::InvalidServerAddr(host, port)),
}
}),
Box::new(move |pong| {
@@ -393,7 +396,7 @@ impl State {
),
Command::Status => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
let state = self.server.as_ref().unwrap().into();
now!(Ok(Some(CommandResponse::Status {
@@ -402,7 +405,7 @@ impl State {
}
Command::UserVolumeSet(string, volume) => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
- return now!(Err(Error::DisconnectedError));
+ return now!(Err(Error::Disconnected));
}
let user_id = match self
.server()
@@ -412,7 +415,7 @@ impl State {
.find(|e| e.1.name() == string)
.map(|e| *e.0)
{
- None => return now!(Err(Error::InvalidUsernameError(string))),
+ None => return now!(Err(Error::InvalidUsername(string))),
Some(v) => v,
};