aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src
diff options
context:
space:
mode:
Diffstat (limited to 'mumd/src')
-rw-r--r--mumd/src/command.rs53
-rw-r--r--mumd/src/main.rs13
-rw-r--r--mumd/src/network/tcp.rs6
-rw-r--r--mumd/src/network/udp.rs45
-rw-r--r--mumd/src/state.rs42
5 files changed, 127 insertions, 32 deletions
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index d4b25d0..ff53dc7 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -1,4 +1,4 @@
-use crate::state::State;
+use crate::state::{State, ExecutionContext};
use crate::network::tcp::{TcpEvent, TcpEventCallback};
use ipc_channel::ipc::IpcSender;
@@ -6,6 +6,8 @@ use log::*;
use mumlib::command::{Command, CommandResponse};
use std::sync::{Arc, Mutex};
use tokio::sync::{mpsc, oneshot};
+use mumble_protocol::ping::PongPacket;
+use std::net::SocketAddr;
pub async fn handle(
state: Arc<Mutex<State>>,
@@ -14,28 +16,47 @@ pub async fn handle(
IpcSender<mumlib::error::Result<Option<CommandResponse>>>,
)>,
tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>,
+ ping_request_sender: mpsc::UnboundedSender<(u64, SocketAddr, Box<dyn FnOnce(PongPacket)>)>,
) {
debug!("Begin listening for commands");
while let Some((command, response_sender)) = command_receiver.recv().await {
debug!("Received command {:?}", command);
let mut state = state.lock().unwrap();
- let (event, generator) = state.handle_command(command);
+ let event = state.handle_command(command);
drop(state);
- if let Some(event) = event {
- let (tx, rx) = oneshot::channel();
- //TODO handle this error
- let _ = tcp_event_register_sender.send((
- event,
- Box::new(move |e| {
- let response = generator(Some(e));
- response_sender.send(response).unwrap();
- tx.send(()).unwrap();
- }),
- ));
+ match event {
+ ExecutionContext::TcpEvent(event, generator) => {
+ let (tx, rx) = oneshot::channel();
+ //TODO handle this error
+ let _ = tcp_event_register_sender.send((
+ event,
+ Box::new(move |e| {
+ let response = generator(e);
+ response_sender.send(response).unwrap();
+ tx.send(()).unwrap();
+ }),
+ ));
- rx.await.unwrap();
- } else {
- response_sender.send(generator(None)).unwrap();
+ rx.await.unwrap();
+ }
+ ExecutionContext::Now(generator) => {
+ response_sender.send(generator()).unwrap();
+ }
+ ExecutionContext::Ping(generator, converter) => {
+ match generator() {
+ Ok(addr) => {
+ let res = ping_request_sender.send((0, addr, Box::new(move |packet| {
+ response_sender.send(converter(packet)).unwrap();
+ })));
+ if res.is_err() {
+ panic!();
+ }
+ },
+ Err(e) => {
+ response_sender.send(Err(e)).unwrap();
+ }
+ };
+ }
}
}
}
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index 37ff0dd..70cc21b 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -36,11 +36,12 @@ async fn main() {
let (connection_info_sender, connection_info_receiver) =
watch::channel::<Option<ConnectionInfo>>(None);
let (response_sender, response_receiver) = mpsc::unbounded_channel();
+ let (ping_request_sender, ping_request_receiver) = mpsc::unbounded_channel();
let state = State::new(packet_sender, connection_info_sender);
let state = Arc::new(Mutex::new(state));
- let (_, _, _, e) = join!(
+ let (_, _, _, e, _) = join!(
network::tcp::handle(
Arc::clone(&state),
connection_info_receiver.clone(),
@@ -53,11 +54,19 @@ async fn main() {
connection_info_receiver.clone(),
crypt_state_receiver,
),
- command::handle(state, command_receiver, response_sender),
+ command::handle(
+ state,
+ command_receiver,
+ response_sender,
+ ping_request_sender,
+ ),
spawn_blocking(move || {
// IpcSender is blocking
receive_oneshot_commands(command_sender);
}),
+ network::udp::handle_pings(
+ ping_request_receiver
+ ),
);
e.unwrap();
}
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index cd11690..131f066 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -27,7 +27,7 @@ type TcpSender = SplitSink<
type TcpReceiver =
SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>;
-pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>;
+pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData)>;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum TcpEvent {
@@ -228,7 +228,7 @@ async fn listen(
if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) {
let old = std::mem::take(vec);
for handler in old {
- handler(&TcpEventData::Connected(&msg));
+ handler(TcpEventData::Connected(&msg));
}
}
let mut state = state.lock().unwrap();
@@ -282,7 +282,7 @@ async fn listen(
if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) {
let old = std::mem::take(vec);
for handler in old {
- handler(&TcpEventData::Disconnected);
+ handler(TcpEventData::Disconnected);
}
}
},
diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs
index 4f96c4c..febf7f1 100644
--- a/mumd/src/network/udp.rs
+++ b/mumd/src/network/udp.rs
@@ -13,6 +13,10 @@ use std::sync::{Arc, Mutex};
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::udp::UdpFramed;
+use std::collections::HashMap;
+use mumble_protocol::ping::{PingPacket, PongPacket};
+use std::rc::Rc;
+use std::convert::TryFrom;
type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>;
type UdpReceiver = SplitStream<UdpFramed<ClientCryptState>>;
@@ -225,3 +229,44 @@ async fn send_voice(
debug!("UDP sender process killed");
}
+
+pub async fn handle_pings(
+ mut ping_request_receiver: mpsc::UnboundedReceiver<(u64, SocketAddr, Box<dyn FnOnce(PongPacket)>)>,
+) {
+ let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16))
+ .await
+ .expect("Failed to bind UDP socket");
+
+ let (mut receiver, mut sender) = udp_socket.split();
+
+ let pending = Rc::new(Mutex::new(HashMap::new()));
+
+ let sender_handle = async {
+ while let Some((id, socket_addr, handle)) = ping_request_receiver.recv().await {
+ let packet = PingPacket { id };
+ let packet: [u8; 12] = packet.into();
+ sender.send_to(&packet, &socket_addr).await.unwrap();
+ pending.lock().unwrap().insert(id, handle);
+ }
+ };
+
+ let receiver_handle = async {
+ let mut buf = vec![0; 24];
+ while let Ok(read) = receiver.recv(&mut buf).await {
+ assert_eq!(read, 24);
+
+ let packet = match PongPacket::try_from(buf.as_slice()) {
+ Ok(v) => v,
+ Err(_) => panic!(),
+ };
+
+ if let Some(handler) = pending.lock().unwrap().remove(&packet.id) {
+ handler(packet);
+ }
+ }
+ };
+
+ debug!("Waiting for ping requests");
+
+ join!(sender_handle, receiver_handle);
+} \ No newline at end of file
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index 81b6c98..0d0fad8 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -16,21 +16,29 @@ use mumlib::command::{Command, CommandResponse};
use mumlib::config::Config;
use mumlib::error::{ChannelIdentifierError, Error};
use mumlib::state::UserDiff;
-use std::net::ToSocketAddrs;
+use std::net::{ToSocketAddrs, SocketAddr};
use tokio::sync::{mpsc, watch};
+use mumble_protocol::ping::PongPacket;
macro_rules! at {
($event:expr, $generator:expr) => {
- (Some($event), Box::new($generator))
+ ExecutionContext::TcpEvent($event, Box::new($generator))
};
}
macro_rules! now {
($data:expr) => {
- (None, Box::new(move |_| $data))
+ ExecutionContext::Now(Box::new(move || $data))
};
}
+//TODO give me a better name
+pub enum ExecutionContext {
+ TcpEvent(TcpEvent, Box<dyn FnOnce(TcpEventData) -> mumlib::error::Result<Option<CommandResponse>>>),
+ Now(Box<dyn FnOnce() -> mumlib::error::Result<Option<CommandResponse>>>),
+ Ping(Box<dyn FnOnce() -> mumlib::error::Result<SocketAddr>>, Box<dyn FnOnce(PongPacket) -> mumlib::error::Result<Option<CommandResponse>>>),
+}
+
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum StatePhase {
Disconnected,
@@ -71,10 +79,7 @@ impl State {
pub fn handle_command(
&mut self,
command: Command,
- ) -> (
- Option<TcpEvent>,
- Box<dyn FnOnce(Option<&TcpEventData>) -> mumlib::error::Result<Option<CommandResponse>>>,
- ) {
+ ) -> ExecutionContext {
match command {
Command::ChannelJoin { channel_identifier } => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
@@ -128,7 +133,7 @@ impl State {
}
Command::ChannelList => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (None, Box::new(|_| Err(Error::DisconnectedError)));
+ return now!(Err(Error::DisconnectedError));
}
let list = channel::into_channel(
self.server.as_ref().unwrap().channels(),
@@ -173,7 +178,7 @@ impl State {
.unwrap();
at!(TcpEvent::Connected, |e| {
//runs the closure when the client is connected
- if let Some(TcpEventData::Connected(msg)) = e {
+ if let TcpEventData::Connected(msg) = e {
Ok(Some(CommandResponse::ServerConnect {
welcome_message: if msg.has_welcome_text() {
Some(msg.get_welcome_text().to_string())
@@ -217,6 +222,21 @@ impl State {
self.reload_config();
now!(Ok(None))
}
+ Command::ServerStatus { host, port } => {
+ ExecutionContext::Ping(Box::new(move || {
+ match (host.as_str(), port).to_socket_addrs().map(|mut e| e.next()) {
+ Ok(Some(v)) => Ok(v),
+ _ => Err(mumlib::error::Error::InvalidServerAddrError(host, port)),
+ }
+ }), Box::new(move |pong| {
+ Ok(Some(CommandResponse::ServerStatus {
+ version: pong.version,
+ users: pong.users,
+ max_users: pong.max_users,
+ bandwidth: pong.bandwidth,
+ }))
+ }))
+ }
}
}
@@ -229,9 +249,9 @@ impl State {
// check if this is initial state
if !self.server().unwrap().users().contains_key(&session) {
self.parse_initial_user_state(session, msg);
- return None;
+ None
} else {
- return Some(self.parse_updated_user_state(session, msg));
+ Some(self.parse_updated_user_state(session, msg))
}
}