diff options
Diffstat (limited to 'mumd')
| -rw-r--r-- | mumd/src/command.rs | 53 | ||||
| -rw-r--r-- | mumd/src/main.rs | 13 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 6 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 45 | ||||
| -rw-r--r-- | mumd/src/state.rs | 42 |
5 files changed, 127 insertions, 32 deletions
diff --git a/mumd/src/command.rs b/mumd/src/command.rs index d4b25d0..ff53dc7 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -1,4 +1,4 @@ -use crate::state::State; +use crate::state::{State, ExecutionContext}; use crate::network::tcp::{TcpEvent, TcpEventCallback}; use ipc_channel::ipc::IpcSender; @@ -6,6 +6,8 @@ use log::*; use mumlib::command::{Command, CommandResponse}; use std::sync::{Arc, Mutex}; use tokio::sync::{mpsc, oneshot}; +use mumble_protocol::ping::PongPacket; +use std::net::SocketAddr; pub async fn handle( state: Arc<Mutex<State>>, @@ -14,28 +16,47 @@ pub async fn handle( IpcSender<mumlib::error::Result<Option<CommandResponse>>>, )>, tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>, + ping_request_sender: mpsc::UnboundedSender<(u64, SocketAddr, Box<dyn FnOnce(PongPacket)>)>, ) { debug!("Begin listening for commands"); while let Some((command, response_sender)) = command_receiver.recv().await { debug!("Received command {:?}", command); let mut state = state.lock().unwrap(); - let (event, generator) = state.handle_command(command); + let event = state.handle_command(command); drop(state); - if let Some(event) = event { - let (tx, rx) = oneshot::channel(); - //TODO handle this error - let _ = tcp_event_register_sender.send(( - event, - Box::new(move |e| { - let response = generator(Some(e)); - response_sender.send(response).unwrap(); - tx.send(()).unwrap(); - }), - )); + match event { + ExecutionContext::TcpEvent(event, generator) => { + let (tx, rx) = oneshot::channel(); + //TODO handle this error + let _ = tcp_event_register_sender.send(( + event, + Box::new(move |e| { + let response = generator(e); + response_sender.send(response).unwrap(); + tx.send(()).unwrap(); + }), + )); - rx.await.unwrap(); - } else { - response_sender.send(generator(None)).unwrap(); + rx.await.unwrap(); + } + ExecutionContext::Now(generator) => { + response_sender.send(generator()).unwrap(); + } + ExecutionContext::Ping(generator, converter) => { + match generator() { + Ok(addr) => { + let res = ping_request_sender.send((0, addr, Box::new(move |packet| { + response_sender.send(converter(packet)).unwrap(); + }))); + if res.is_err() { + panic!(); + } + }, + Err(e) => { + response_sender.send(Err(e)).unwrap(); + } + }; + } } } } diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 37ff0dd..70cc21b 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -36,11 +36,12 @@ async fn main() { let (connection_info_sender, connection_info_receiver) = watch::channel::<Option<ConnectionInfo>>(None); let (response_sender, response_receiver) = mpsc::unbounded_channel(); + let (ping_request_sender, ping_request_receiver) = mpsc::unbounded_channel(); let state = State::new(packet_sender, connection_info_sender); let state = Arc::new(Mutex::new(state)); - let (_, _, _, e) = join!( + let (_, _, _, e, _) = join!( network::tcp::handle( Arc::clone(&state), connection_info_receiver.clone(), @@ -53,11 +54,19 @@ async fn main() { connection_info_receiver.clone(), crypt_state_receiver, ), - command::handle(state, command_receiver, response_sender), + command::handle( + state, + command_receiver, + response_sender, + ping_request_sender, + ), spawn_blocking(move || { // IpcSender is blocking receive_oneshot_commands(command_sender); }), + network::udp::handle_pings( + ping_request_receiver + ), ); e.unwrap(); } diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index cd11690..131f066 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -27,7 +27,7 @@ type TcpSender = SplitSink< type TcpReceiver = SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>; -pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>; +pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData)>; #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum TcpEvent { @@ -228,7 +228,7 @@ async fn listen( if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { let old = std::mem::take(vec); for handler in old { - handler(&TcpEventData::Connected(&msg)); + handler(TcpEventData::Connected(&msg)); } } let mut state = state.lock().unwrap(); @@ -282,7 +282,7 @@ async fn listen( if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { let old = std::mem::take(vec); for handler in old { - handler(&TcpEventData::Disconnected); + handler(TcpEventData::Disconnected); } } }, diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 4f96c4c..febf7f1 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -13,6 +13,10 @@ use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, watch}; use tokio_util::udp::UdpFramed; +use std::collections::HashMap; +use mumble_protocol::ping::{PingPacket, PongPacket}; +use std::rc::Rc; +use std::convert::TryFrom; type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>; type UdpReceiver = SplitStream<UdpFramed<ClientCryptState>>; @@ -225,3 +229,44 @@ async fn send_voice( debug!("UDP sender process killed"); } + +pub async fn handle_pings( + mut ping_request_receiver: mpsc::UnboundedReceiver<(u64, SocketAddr, Box<dyn FnOnce(PongPacket)>)>, +) { + let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) + .await + .expect("Failed to bind UDP socket"); + + let (mut receiver, mut sender) = udp_socket.split(); + + let pending = Rc::new(Mutex::new(HashMap::new())); + + let sender_handle = async { + while let Some((id, socket_addr, handle)) = ping_request_receiver.recv().await { + let packet = PingPacket { id }; + let packet: [u8; 12] = packet.into(); + sender.send_to(&packet, &socket_addr).await.unwrap(); + pending.lock().unwrap().insert(id, handle); + } + }; + + let receiver_handle = async { + let mut buf = vec![0; 24]; + while let Ok(read) = receiver.recv(&mut buf).await { + assert_eq!(read, 24); + + let packet = match PongPacket::try_from(buf.as_slice()) { + Ok(v) => v, + Err(_) => panic!(), + }; + + if let Some(handler) = pending.lock().unwrap().remove(&packet.id) { + handler(packet); + } + } + }; + + debug!("Waiting for ping requests"); + + join!(sender_handle, receiver_handle); +}
\ No newline at end of file diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 81b6c98..0d0fad8 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -16,21 +16,29 @@ use mumlib::command::{Command, CommandResponse}; use mumlib::config::Config; use mumlib::error::{ChannelIdentifierError, Error}; use mumlib::state::UserDiff; -use std::net::ToSocketAddrs; +use std::net::{ToSocketAddrs, SocketAddr}; use tokio::sync::{mpsc, watch}; +use mumble_protocol::ping::PongPacket; macro_rules! at { ($event:expr, $generator:expr) => { - (Some($event), Box::new($generator)) + ExecutionContext::TcpEvent($event, Box::new($generator)) }; } macro_rules! now { ($data:expr) => { - (None, Box::new(move |_| $data)) + ExecutionContext::Now(Box::new(move || $data)) }; } +//TODO give me a better name +pub enum ExecutionContext { + TcpEvent(TcpEvent, Box<dyn FnOnce(TcpEventData) -> mumlib::error::Result<Option<CommandResponse>>>), + Now(Box<dyn FnOnce() -> mumlib::error::Result<Option<CommandResponse>>>), + Ping(Box<dyn FnOnce() -> mumlib::error::Result<SocketAddr>>, Box<dyn FnOnce(PongPacket) -> mumlib::error::Result<Option<CommandResponse>>>), +} + #[derive(Clone, Debug, Eq, PartialEq)] pub enum StatePhase { Disconnected, @@ -71,10 +79,7 @@ impl State { pub fn handle_command( &mut self, command: Command, - ) -> ( - Option<TcpEvent>, - Box<dyn FnOnce(Option<&TcpEventData>) -> mumlib::error::Result<Option<CommandResponse>>>, - ) { + ) -> ExecutionContext { match command { Command::ChannelJoin { channel_identifier } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { @@ -128,7 +133,7 @@ impl State { } Command::ChannelList => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (None, Box::new(|_| Err(Error::DisconnectedError))); + return now!(Err(Error::DisconnectedError)); } let list = channel::into_channel( self.server.as_ref().unwrap().channels(), @@ -173,7 +178,7 @@ impl State { .unwrap(); at!(TcpEvent::Connected, |e| { //runs the closure when the client is connected - if let Some(TcpEventData::Connected(msg)) = e { + if let TcpEventData::Connected(msg) = e { Ok(Some(CommandResponse::ServerConnect { welcome_message: if msg.has_welcome_text() { Some(msg.get_welcome_text().to_string()) @@ -217,6 +222,21 @@ impl State { self.reload_config(); now!(Ok(None)) } + Command::ServerStatus { host, port } => { + ExecutionContext::Ping(Box::new(move || { + match (host.as_str(), port).to_socket_addrs().map(|mut e| e.next()) { + Ok(Some(v)) => Ok(v), + _ => Err(mumlib::error::Error::InvalidServerAddrError(host, port)), + } + }), Box::new(move |pong| { + Ok(Some(CommandResponse::ServerStatus { + version: pong.version, + users: pong.users, + max_users: pong.max_users, + bandwidth: pong.bandwidth, + })) + })) + } } } @@ -229,9 +249,9 @@ impl State { // check if this is initial state if !self.server().unwrap().users().contains_key(&session) { self.parse_initial_user_state(session, msg); - return None; + None } else { - return Some(self.parse_updated_user_state(session, msg)); + Some(self.parse_updated_user_state(session, msg)) } } |
