diff options
Diffstat (limited to 'mumd/src')
| -rw-r--r-- | mumd/src/command.rs | 2 | ||||
| -rw-r--r-- | mumd/src/main.rs | 24 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 291 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 144 | ||||
| -rw-r--r-- | mumd/src/state.rs | 5 |
5 files changed, 344 insertions, 122 deletions
diff --git a/mumd/src/command.rs b/mumd/src/command.rs index 0e5bdc7..bfdb7dd 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -49,4 +49,6 @@ pub async fn handle( } command_response_sender.send(command_response).unwrap(); } + + debug!("Finished handling commands"); } diff --git a/mumd/src/main.rs b/mumd/src/main.rs index a2665ba..c923857 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -11,14 +11,18 @@ use argparse::ArgumentParser; use argparse::Store; use argparse::StoreTrue; use colored::*; -use futures::channel::oneshot; -use futures::join; +use tokio::sync::oneshot; +use futures::{join, select}; use log::*; use mumble_protocol::control::ControlPacket; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::Serverbound; use std::sync::{Arc, Mutex}; use tokio::sync::{mpsc, watch}; +use std::thread; +use std::time::Duration; +use tokio::stream::StreamExt; +use futures::FutureExt; #[tokio::main] async fn main() { @@ -79,7 +83,7 @@ async fn main() { command_sender.send(Command::ChannelList).unwrap(); command_sender.send(Command::ServerConnect{host: server_host, port: server_port, username: username.clone(), accept_invalid_cert}); - command_sender.send(Command::ChannelJoin{channel_id: 1}).unwrap(); + //command_sender.send(Command::ChannelJoin{channel_id: 1}).unwrap(); command_sender.send(Command::ChannelList).unwrap(); let state = State::new(packet_sender, command_sender.clone(), connection_info_sender, username); let state = Arc::new(Mutex::new(state)); @@ -102,16 +106,28 @@ async fn main() { command_receiver, command_response_sender, ), + send_commands( + command_sender + ), receive_command_responses( command_response_receiver, ), ); } +async fn send_commands(command_sender: mpsc::UnboundedSender<Command>) { + tokio::time::delay_for(Duration::from_secs(5)).await; + command_sender.send(Command::ServerDisconnect); + + debug!("Finished sending commands"); +} + async fn receive_command_responses( mut command_response_receiver: mpsc::UnboundedReceiver<Result<Option<CommandResponse>, ()>>, ) { while let Some(command_response) = command_response_receiver.recv().await { - debug!("{:#?}", command_response); + debug!("{:?}", command_response); } + + debug!("Finished receiving commands"); } diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 9fb5ae4..0aca19e 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,8 +2,8 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use futures::channel::oneshot; -use futures::{join, SinkExt, StreamExt}; +use tokio::sync::oneshot; +use futures::{join, select, pin_mut, SinkExt, StreamExt, FutureExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; @@ -16,6 +16,7 @@ use tokio::sync::{mpsc, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use futures_util::core_reexport::cell::RefCell; type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, @@ -43,15 +44,21 @@ pub async fn handle( .await; // Handshake (omitting `Version` message for brevity) - authenticate(&mut sink, state.lock().unwrap().username().unwrap().to_string()).await; + let mut state_lock = state.lock().unwrap(); + authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; + let phase_watcher = state_lock.phase_receiver(); + let packet_sender = state_lock.packet_sender(); + drop(state_lock); info!("Logging in..."); join!( - send_pings(state.lock().unwrap().packet_sender(), 10), - listen(state, stream, crypt_state_sender), - send_packets(sink, packet_receiver), + send_pings(packet_sender, 10, phase_watcher.clone()), + listen(state, stream, crypt_state_sender, phase_watcher.clone()), + send_packets(sink, packet_receiver, phase_watcher), ); + + debug!("Fully disconnected TCP stream"); } async fn connect( @@ -87,109 +94,209 @@ async fn authenticate(sink: &mut TcpSender, username: String) { sink.send(msg.into()).await.unwrap(); } -async fn send_pings(packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, - delay_seconds: u64) { +async fn send_pings( + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + delay_seconds: u64, + mut phase_watcher: watch::Receiver<StatePhase>, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + let mut interval = time::interval(Duration::from_secs(delay_seconds)); - loop { - interval.tick().await; - trace!("Sending ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let interval_waiter = interval.tick().fuse(); + pin_mut!(interval_waiter); + let exitor = select! { + data = interval_waiter => Some(data), + _ = rx => None + }; + + match exitor { + Some(_) => { + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.send(msg.into()).unwrap(); + } + None => break, + } + } + }; + + join!(main_block, phase_transition_block); + + debug!("Ping sender process killed"); } -async fn send_packets(mut sink: TcpSender, - mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>) { +async fn send_packets( + mut sink: TcpSender, + mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, + mut phase_watcher: watch::Receiver<StatePhase>, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = packet_receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some(packet)) => { + sink.send(packet).await.unwrap(); + } + } + } - while let Some(packet) = packet_receiver.recv().await { - sink.send(packet).await.unwrap(); - } + //clears queue of remaining packets + while let Ok(_) = packet_receiver.try_recv() {} + + sink.close().await.unwrap(); + }; + + join!(main_block, phase_transition_block); + + debug!("TCP packet sender killed"); } async fn listen( state: Arc<Mutex<State>>, mut stream: TcpReceiver, crypt_state_sender: oneshot::Sender<ClientCryptState>, + mut phase_watcher: watch::Receiver<StatePhase>, ) { let mut crypt_state = None; let mut crypt_state_sender = Some(crypt_state_sender); - while let Some(packet) = stream.next().await { - //TODO handle types separately - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - crypt_state = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.take() { - let _ = sender.send( - crypt_state - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ); - } - let mut state = state.lock().unwrap(); - let server = state.server_mut(); - server.parse_server_sync(msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + + let listener_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = stream.next().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; } - for (_, channel) in server.channels() { - info!("Found channel {}", channel.name()); + Some(None) => { + warn!("Channel closed before disconnect command"); + break; } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - let mut state = state.lock().unwrap(); - let session = msg.get_session(); - state.audio_mut().add_client(msg.get_session()); //TODO - if *state.phase_receiver().borrow() == StatePhase::Connecting { - state.parse_initial_user_state(msg); - } else { - state.server_mut().parse_user_state(msg); + Some(Some(packet)) => { + //TODO handle types separately + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + crypt_state = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.take() { + let _ = sender.send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ); + } + let mut state = state.lock().unwrap(); + let server = state.server_mut(); + server.parse_server_sync(msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for (_, channel) in server.channels() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + let mut state = state.lock().unwrap(); + let session = msg.get_session(); + state.audio_mut().add_client(msg.get_session()); //TODO + if *state.phase_receiver().borrow() == StatePhase::Connecting { + state.parse_initial_user_state(msg); + } else { + state.server_mut().parse_user_state(msg); + } + let server = state.server_mut(); + let user = server.users().get(&session).unwrap(); + info!("User {} connected to {}", + user.name(), + user.channel()); + } + ControlPacket::UserRemove(msg) => { + info!("User {} left", msg.get_session()); + state.lock().unwrap().audio_mut().remove_client(msg.get_session()); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state.lock().unwrap().server_mut().parse_channel_state(msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state.lock().unwrap().server_mut().parse_channel_remove(msg); + } + _ => {} + } } - let server = state.server_mut(); - let user = server.users().get(&session).unwrap(); - info!("User {} connected to {}", - user.name(), - user.channel()); - } - ControlPacket::UserRemove(msg) => { - info!("User {} left", msg.get_session()); - state.lock().unwrap().audio_mut().remove_client(msg.get_session()); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state.lock().unwrap().server_mut().parse_channel_state(msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state.lock().unwrap().server_mut().parse_channel_remove(msg); } - _ => {} } - } + + //TODO? clean up stream + }; + + join!(phase_transition_block, listener_block); + + debug!("Killing TCP listener block"); } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index cf0305b..ab4ca1d 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,10 +1,9 @@ use crate::network::ConnectionInfo; -use crate::state::State; +use crate::state::{State, StatePhase}; use log::*; use bytes::Bytes; -use futures::channel::oneshot; -use futures::{join, SinkExt, StreamExt}; +use futures::{join, select, pin_mut, SinkExt, StreamExt, FutureExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::{VoicePacket, VoicePacketPayload}; @@ -12,7 +11,7 @@ use mumble_protocol::Serverbound; use std::net::{Ipv6Addr, SocketAddr}; use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; -use tokio::sync::watch; +use tokio::sync::{watch, oneshot}; use tokio_util::udp::UdpFramed; type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>; @@ -58,17 +57,80 @@ pub async fn handle( send_ping(&mut sink, connection_info.socket_addr).await; let sink = Arc::new(Mutex::new(sink)); + + let phase_watcher = state.lock().unwrap().phase_receiver(); join!( - listen(Arc::clone(&state), source), - send_voice(state, sink, connection_info.socket_addr), + listen(Arc::clone(&state), source, phase_watcher.clone()), + send_voice(state, sink, connection_info.socket_addr, phase_watcher), ); + + debug!("Fully disconnected UPD stream"); } async fn listen( state: Arc<Mutex<State>>, mut source: UdpReceiver, + mut phase_watcher: watch::Receiver<StatePhase>, ) { - while let Some(packet) = source.next().await { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = source.next().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some(packet)) => { + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; + } + }; + match packet { + VoicePacket::Ping { .. } => { + // Note: A normal application would handle these and only use UDP for voice + // once it has received one. + continue; + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state.lock().unwrap().audio().decode_packet(session_id, payload); + } + } + } + } + } + }; + + join!(main_block, phase_transition_block); + + debug!("UDP listener process killed"); + + /*while let Some(packet) = source.next().await { let (packet, _src_addr) = match packet { Ok(packet) => packet, Err(err) => { @@ -93,7 +155,7 @@ async fn listen( state.lock().unwrap().audio().decode_packet(session_id, payload); } } - } + }*/ } async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { @@ -116,25 +178,57 @@ async fn send_voice( state: Arc<Mutex<State>>, sink: Arc<Mutex<UdpSender>>, server_addr: SocketAddr, + mut phase_watcher: watch::Receiver<StatePhase>, ) { let mut receiver = state.lock().unwrap().audio_mut().take_receiver().unwrap(); - let mut count = 0; - while let Some(payload) = receiver.recv().await { - let reply = VoicePacket::Audio { - _dst: std::marker::PhantomData, - target: 0, // normal speech - session_id: (), // unused for server-bound packets - seq_num: count, - payload, - position_info: None, - }; - count += 1; - sink.lock() - .unwrap() - .send((reply, server_addr)) - .await - .unwrap(); - } + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + let mut count = 0; + loop { + let packet_recv = receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some(payload)) => { + let reply = VoicePacket::Audio { + _dst: std::marker::PhantomData, + target: 0, // normal speech + session_id: (), // unused for server-bound packets + seq_num: count, + payload, + position_info: None, + }; + count += 1; + sink.lock() + .unwrap() + .send((reply, server_addr)) + .await + .unwrap(); + } + } + } + }; + + join!(main_block, phase_transition_block); + + debug!("UDP listener process killed"); } diff --git a/mumd/src/state.rs b/mumd/src/state.rs index ef1cd6d..0de29f1 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -100,7 +100,10 @@ impl State { server_state: self.server.clone(), }))) } - _ => { (false, Ok(None)) } + Command::ServerDisconnect => { + self.phase_watcher.0.broadcast(StatePhase::Disconnected).unwrap(); + (false, Ok(None)) + } } } |
