diff options
Diffstat (limited to 'mumd/src')
| -rw-r--r-- | mumd/src/audio.rs | 71 | ||||
| -rw-r--r-- | mumd/src/command.rs | 42 | ||||
| -rw-r--r-- | mumd/src/main.rs | 103 | ||||
| -rw-r--r-- | mumd/src/network/mod.rs | 19 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 370 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 236 | ||||
| -rw-r--r-- | mumd/src/state.rs | 259 |
7 files changed, 829 insertions, 271 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 9b794a6..1445415 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -1,6 +1,5 @@ use bytes::Bytes; -use cpal::traits::DeviceTrait; -use cpal::traits::HostTrait; +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::{ InputCallbackInfo, OutputCallbackInfo, Sample, SampleFormat, SampleRate, Stream, StreamConfig, }; @@ -28,9 +27,9 @@ pub struct Audio { pub input_config: StreamConfig, pub input_stream: Stream, pub input_buffer: Arc<Mutex<VecDeque<f32>>>, - input_channel_receiver: Option<Receiver<VoicePacketPayload>>, + input_channel_receiver: Option<Receiver<VoicePacketPayload>>, //TODO unbounded? mbe ring buffer and drop the first packet - client_streams: Arc<Mutex<HashMap<u32, ClientStream>>>, + client_streams: Arc<Mutex<HashMap<u32, ClientStream>>>, //TODO move to user state } //TODO split into input/output @@ -104,31 +103,39 @@ impl Audio { let input_stream = match input_supported_sample_format { SampleFormat::F32 => input_device.build_input_stream( &input_config, - input_callback::<f32>(input_encoder, - input_sender, - input_config.sample_rate.0, - 10.0), + input_callback::<f32>( + input_encoder, + input_sender, + input_config.sample_rate.0, + 4, // 10 ms + ), err_fn, ), SampleFormat::I16 => input_device.build_input_stream( &input_config, - input_callback::<i16>(input_encoder, - input_sender, - input_config.sample_rate.0, - 10.0), + input_callback::<i16>( + input_encoder, + input_sender, + input_config.sample_rate.0, + 4, // 10 ms + ), err_fn, ), SampleFormat::U16 => input_device.build_input_stream( &input_config, - input_callback::<u16>(input_encoder, - input_sender, - input_config.sample_rate.0, - 10.0), + input_callback::<u16>( + input_encoder, + input_sender, + input_config.sample_rate.0, + 4, // 10 ms + ), err_fn, ), } .unwrap(); + output_stream.play().unwrap(); + Self { output_config, output_stream, @@ -206,7 +213,8 @@ impl ClientStream { match payload { VoicePacketPayload::Opus(bytes, _eot) => { let mut out: Vec<f32> = vec![0.0; 720 * channels * 4]; //720 is because that is the max size of packet we can get that we want to decode - let parsed = self.opus_decoder + let parsed = self + .opus_decoder .decode_float(&bytes, &mut out, false) .expect("Error decoding"); out.truncate(parsed); @@ -268,16 +276,19 @@ fn input_callback<T: Sample>( mut opus_encoder: opus::Encoder, mut input_sender: Sender<VoicePacketPayload>, sample_rate: u32, - opus_frame_size_ms: f32, + opus_frame_size_blocks: u32, // blocks of 2.5ms ) -> impl FnMut(&[T], &InputCallbackInfo) + Send + 'static { - if ! ( opus_frame_size_ms == 2.5 - || opus_frame_size_ms == 5.0 - || opus_frame_size_ms == 10.0 - || opus_frame_size_ms == 20.0) { - panic!("Unsupported opus frame size {}", opus_frame_size_ms); + if !(opus_frame_size_blocks == 1 + || opus_frame_size_blocks == 2 + || opus_frame_size_blocks == 4 + || opus_frame_size_blocks == 8) + { + panic!( + "Unsupported amount of opus frame blocks {}", + opus_frame_size_blocks + ); } - let opus_frame_size = (opus_frame_size_ms * sample_rate as f32) as u32 / 1000; - + let opus_frame_size = opus_frame_size_blocks * sample_rate / 400; let buf = Arc::new(Mutex::new(VecDeque::new())); move |data: &[T], _info: &InputCallbackInfo| { @@ -292,9 +303,13 @@ fn input_callback<T: Sample>( .unwrap(); opus_buf.truncate(result); let bytes = Bytes::copy_from_slice(&opus_buf); - input_sender - .try_send(VoicePacketPayload::Opus(bytes, false)) - .unwrap(); //TODO handle full buffer / disconnect + match input_sender.try_send(VoicePacketPayload::Opus(bytes, false)) { + //TODO handle full buffer / disconnect + Ok(_) => {} + Err(_e) => { + //warn!("Error sending audio packet: {:?}", e); + } + } *buf = tail; } } diff --git a/mumd/src/command.rs b/mumd/src/command.rs index 5d6cca4..b4bd1b7 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -1,4 +1,12 @@ -enum Command { +use crate::state::{Channel, Server, State, StatePhase}; + +use log::*; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use tokio::sync::mpsc; + +#[derive(Clone, Debug)] +pub enum Command { ChannelJoin { channel_id: u32, }, @@ -12,3 +20,35 @@ enum Command { ServerDisconnect, Status, } + +#[derive(Debug)] +pub enum CommandResponse { + ChannelList { + channels: HashMap<u32, Channel>, + }, + Status { + username: Option<String>, + server_state: Server, + }, +} + +pub async fn handle( + state: Arc<Mutex<State>>, + mut command_receiver: mpsc::UnboundedReceiver<Command>, + command_response_sender: mpsc::UnboundedSender<Result<Option<CommandResponse>, ()>>, +) { + //TODO err if not connected + while let Some(command) = command_receiver.recv().await { + debug!("Parsing command {:?}", command); + let mut state = state.lock().unwrap(); + let (wait_for_connected, command_response) = state.handle_command(command).await; + if wait_for_connected { + let mut watcher = state.phase_receiver(); + drop(state); + while !matches!(watcher.recv().await.unwrap(), StatePhase::Connected) {} + } + command_response_sender.send(command_response).unwrap(); + } + + debug!("Finished handling commands"); +} diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 2a0fcbd..f837a52 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -1,46 +1,55 @@ mod audio; -mod network; mod command; +mod network; mod state; -use crate::audio::Audio; -use crate::state::Server; + +use crate::command::{Command, CommandResponse}; +use crate::network::ConnectionInfo; +use crate::state::State; use argparse::ArgumentParser; use argparse::Store; use argparse::StoreTrue; use colored::*; -use cpal::traits::StreamTrait; -use futures::channel::oneshot; use futures::join; use log::*; +use mumble_protocol::control::ControlPacket; use mumble_protocol::crypt::ClientCryptState; -use std::net::ToSocketAddrs; -use std::sync::Arc; -use std::sync::Mutex; +use mumble_protocol::voice::Serverbound; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::sync::{mpsc, watch}; #[tokio::main] async fn main() { // setup logger fern::Dispatch::new() .format(|out, message, record| { + let message = message.to_string(); out.finish(format_args!( - "{} {}:{} {}", + "{} {}:{}{}{}", //TODO runtime flag that disables color match record.level() { Level::Error => "ERROR".red(), - Level::Warn => "WARN ".yellow(), - Level::Info => "INFO ".normal(), + Level::Warn => "WARN ".yellow(), + Level::Info => "INFO ".normal(), Level::Debug => "DEBUG".green(), Level::Trace => "TRACE".normal(), }, record.file().unwrap(), record.line().unwrap(), + if message.chars().any(|e| e == '\n') { + "\n" + } else { + " " + }, message )) }) .level(log::LevelFilter::Debug) .chain(std::io::stderr()) - .apply().unwrap(); + .apply() + .unwrap(); // Handle command line arguments let mut server_host = "".to_string(); @@ -64,37 +73,69 @@ async fn main() { ); ap.parse_args_or_exit(); } - let server_addr = (server_host.as_ref(), server_port) - .to_socket_addrs() - .expect("Failed to parse server address") - .next() - .expect("Failed to resolve server address"); // Oneshot channel for setting UDP CryptState from control task // For simplicity we don't deal with re-syncing, real applications would have to. - let (crypt_state_sender, crypt_state_receiver) = oneshot::channel::<ClientCryptState>(); + let (crypt_state_sender, crypt_state_receiver) = mpsc::channel::<ClientCryptState>(1); // crypt state should always be consumed before sending a new one + let (packet_sender, packet_receiver) = mpsc::unbounded_channel::<ControlPacket<Serverbound>>(); + let (command_sender, command_receiver) = mpsc::unbounded_channel::<Command>(); + let (command_response_sender, command_response_receiver) = + mpsc::unbounded_channel::<Result<Option<CommandResponse>, ()>>(); + let (connection_info_sender, connection_info_receiver) = + watch::channel::<Option<ConnectionInfo>>(None); - let audio = Audio::new(); - audio.output_stream.play().unwrap(); - let audio = Arc::new(Mutex::new(audio)); - - let server_state = Arc::new(Mutex::new(Server::new())); + let state = State::new( + packet_sender, + command_sender.clone(), + connection_info_sender, + ); + let state = Arc::new(Mutex::new(state)); // Run it join!( network::tcp::handle( - server_state, - server_addr, - server_host, - username, - accept_invalid_cert, + Arc::clone(&state), + connection_info_receiver.clone(), crypt_state_sender, - Arc::clone(&audio), + packet_receiver, ), network::udp::handle( - server_addr, + Arc::clone(&state), + connection_info_receiver.clone(), crypt_state_receiver, - audio, ), + command::handle(state, command_receiver, command_response_sender,), + send_commands( + command_sender, + Command::ServerConnect { + host: server_host, + port: server_port, + username: username.clone(), + accept_invalid_cert + } + ), + receive_command_responses(command_response_receiver,), ); } + +async fn send_commands(command_sender: mpsc::UnboundedSender<Command>, connect_command: Command) { + command_sender.send(connect_command.clone()).unwrap(); + tokio::time::delay_for(Duration::from_secs(2)).await; + command_sender.send(Command::ServerDisconnect).unwrap(); + tokio::time::delay_for(Duration::from_secs(2)).await; + command_sender.send(connect_command.clone()).unwrap(); + tokio::time::delay_for(Duration::from_secs(2)).await; + command_sender.send(Command::ServerDisconnect).unwrap(); + + debug!("Finished sending commands"); +} + +async fn receive_command_responses( + mut command_response_receiver: mpsc::UnboundedReceiver<Result<Option<CommandResponse>, ()>>, +) { + while let Some(command_response) = command_response_receiver.recv().await { + debug!("{:?}", command_response); + } + + debug!("Finished receiving commands"); +} diff --git a/mumd/src/network/mod.rs b/mumd/src/network/mod.rs index f7a6a76..1a31ee2 100644 --- a/mumd/src/network/mod.rs +++ b/mumd/src/network/mod.rs @@ -1,2 +1,21 @@ pub mod tcp; pub mod udp; + +use std::net::SocketAddr; + +#[derive(Clone, Debug)] +pub struct ConnectionInfo { + socket_addr: SocketAddr, + hostname: String, + accept_invalid_cert: bool, +} + +impl ConnectionInfo { + pub fn new(socket_addr: SocketAddr, hostname: String, accept_invalid_cert: bool) -> Self { + Self { + socket_addr, + hostname, + accept_invalid_cert, + } + } +} diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index dde98aa..6a369e5 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -1,17 +1,17 @@ -use crate::audio::Audio; -use crate::state::Server; +use crate::network::ConnectionInfo; +use crate::state::{State, StatePhase}; use log::*; -use futures::channel::oneshot; -use futures::{join, SinkExt, StreamExt}; +use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::{Clientbound, Serverbound}; use std::convert::{Into, TryInto}; -use std::net::{SocketAddr}; +use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use tokio::net::TcpStream; +use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; @@ -24,26 +24,52 @@ type TcpReceiver = SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>; pub async fn handle( - server: Arc<Mutex<Server>>, - server_addr: SocketAddr, - server_host: String, - username: String, - accept_invalid_cert: bool, - crypt_state_sender: oneshot::Sender<ClientCryptState>, - audio: Arc<Mutex<Audio>>, + state: Arc<Mutex<State>>, + mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, + crypt_state_sender: mpsc::Sender<ClientCryptState>, + mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, ) { - let (sink, stream) = connect(server_addr, server_host, accept_invalid_cert).await; - let sink = Arc::new(Mutex::new(sink)); + loop { + let connection_info = loop { + match connection_info_receiver.recv().await { + None => { + return; + } + Some(None) => {} + Some(Some(connection_info)) => { + break connection_info; + } + } + }; + let (mut sink, stream) = connect( + connection_info.socket_addr, + connection_info.hostname, + connection_info.accept_invalid_cert, + ) + .await; + + // Handshake (omitting `Version` message for brevity) + let state_lock = state.lock().unwrap(); + authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; + let phase_watcher = state_lock.phase_receiver(); + let packet_sender = state_lock.packet_sender(); + drop(state_lock); - // Handshake (omitting `Version` message for brevity) - authenticate(Arc::clone(&sink), username).await; + info!("Logging in..."); - info!("Logging in..."); + join!( + send_pings(packet_sender, 10, phase_watcher.clone()), + listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + phase_watcher.clone() + ), + send_packets(sink, &mut packet_receiver, phase_watcher), + ); - join!( - send_pings(Arc::clone(&sink), 10), - listen(server, sink, stream, crypt_state_sender, audio), - ); + debug!("Fully disconnected TCP stream, waiting for new connection info"); + } } async fn connect( @@ -72,109 +98,239 @@ async fn connect( ClientControlCodec::new().framed(tls_stream).split() } -async fn authenticate(sink: Arc<Mutex<TcpSender>>, username: String) { +async fn authenticate(sink: &mut TcpSender, username: String) { let mut msg = msgs::Authenticate::new(); msg.set_username(username); msg.set_opus(true); - sink.lock().unwrap().send(msg.into()).await.unwrap(); + sink.send(msg.into()).await.unwrap(); } -async fn send_pings(sink: Arc<Mutex<TcpSender>>, delay_seconds: u64) { +async fn send_pings( + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + delay_seconds: u64, + mut phase_watcher: watch::Receiver<StatePhase>, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; + let mut interval = time::interval(Duration::from_secs(delay_seconds)); - loop { - interval.tick().await; - trace!("Sending ping"); - let msg = msgs::Ping::new(); - sink.lock().unwrap().send(msg.into()).await.unwrap(); - } + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let interval_waiter = interval.tick().fuse(); + pin_mut!(interval_waiter); + let exitor = select! { + data = interval_waiter => Some(data), + _ = rx => None + }; + + match exitor { + Some(_) => { + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.send(msg.into()).unwrap(); + } + None => break, + } + } + }; + + join!(main_block, phase_transition_block); + + debug!("Ping sender process killed"); +} + +async fn send_packets( + mut sink: TcpSender, + packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, + mut phase_watcher: watch::Receiver<StatePhase>, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = packet_receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some(packet)) => { + sink.send(packet).await.unwrap(); + } + } + } + + //clears queue of remaining packets + while packet_receiver.try_recv().is_ok() {} + + sink.close().await.unwrap(); + }; + + join!(main_block, phase_transition_block); + + debug!("TCP packet sender killed"); } async fn listen( - server: Arc<Mutex<Server>>, - sink: Arc<Mutex<TcpSender>>, + state: Arc<Mutex<State>>, mut stream: TcpReceiver, - crypt_state_sender: oneshot::Sender<ClientCryptState>, - audio: Arc<Mutex<Audio>>, + crypt_state_sender: mpsc::Sender<ClientCryptState>, + mut phase_watcher: watch::Receiver<StatePhase>, ) { let mut crypt_state = None; let mut crypt_state_sender = Some(crypt_state_sender); - while let Some(packet) = stream.next().await { - //TODO handle types separately - match packet.unwrap() { - ControlPacket::TextMessage(mut msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - // Send reply back to server - let mut response = msgs::TextMessage::new(); - response.mut_session().push(msg.get_actor()); - response.set_message(msg.take_message()); - let mut lock = sink.lock().unwrap(); - lock.send(response.into()).await.unwrap(); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - crypt_state = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.take() { - let _ = sender.send( - crypt_state - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ); + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; + + let listener_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = stream.next().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; } - let mut server = server.lock().unwrap(); - server.parse_server_sync(msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), + Some(None) => { + warn!("Channel closed before disconnect command"); + break; } - for (_, channel) in server.channels() { - info!("Found channel {}", channel.name()); + Some(Some(packet)) => { + //TODO handle types separately + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + crypt_state = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(mut sender) = crypt_state_sender.take() { + let _ = sender + .send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + let mut state = state.lock().unwrap(); + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + let mut state = state.lock().unwrap(); + let session = msg.get_session(); + state.audio_mut().add_client(msg.get_session()); //TODO + if *state.phase_receiver().borrow() == StatePhase::Connecting { + state.parse_initial_user_state(*msg); + } else { + state.server_mut().unwrap().parse_user_state(*msg); + } + let server = state.server_mut().unwrap(); + let user = server.users().get(&session).unwrap(); + info!("User {} connected to {}", user.name(), user.channel()); + } + ControlPacket::UserRemove(msg) => { + info!("User {} left", msg.get_session()); + state + .lock() + .unwrap() + .audio_mut() + .remove_client(msg.get_session()); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + _ => {} + } } - sink.lock().unwrap().send(msgs::UserList::new().into()).await.unwrap(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - audio.lock().unwrap().add_client(msg.get_session()); - let mut server = server.lock().unwrap(); - let session = msg.get_session(); - server.parse_user_state(msg); - let user = server.users().get(&session).unwrap(); - info!("User {} connected to {}", - user.name(), - user.channel()); - } - ControlPacket::UserRemove(msg) => { - info!("User {} left", msg.get_session()); - audio.lock().unwrap().remove_client(msg.get_session()); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - server.lock().unwrap().parse_channel_state(msg); } - ControlPacket::ChannelRemove(msg) => { - server.lock().unwrap().parse_channel_remove(msg); - } - _ => {} } - } + + //TODO? clean up stream + }; + + join!(phase_transition_block, listener_block); + + debug!("Killing TCP listener block"); } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 39f16b6..4f96c4c 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -1,9 +1,9 @@ -use crate::audio::Audio; +use crate::network::ConnectionInfo; +use crate::state::{State, StatePhase}; use log::*; use bytes::Bytes; -use futures::channel::oneshot; -use futures::{join, SinkExt, StreamExt}; +use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::{VoicePacket, VoicePacketPayload}; @@ -11,13 +11,57 @@ use mumble_protocol::Serverbound; use std::net::{Ipv6Addr, SocketAddr}; use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; +use tokio::sync::{mpsc, oneshot, watch}; use tokio_util::udp::UdpFramed; type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>; type UdpReceiver = SplitStream<UdpFramed<ClientCryptState>>; +pub async fn handle( + state: Arc<Mutex<State>>, + mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, + mut crypt_state: mpsc::Receiver<ClientCryptState>, +) { + let mut receiver = state.lock().unwrap().audio_mut().take_receiver().unwrap(); + + loop { + let connection_info = loop { + match connection_info_receiver.recv().await { + None => { + return; + } + Some(None) => {} + Some(Some(connection_info)) => { + break connection_info; + } + } + }; + let (mut sink, source) = connect(&mut crypt_state).await; + + // Note: A normal application would also send periodic Ping packets, and its own audio + // via UDP. We instead trick the server into accepting us by sending it one + // dummy voice packet. + send_ping(&mut sink, connection_info.socket_addr).await; + + let sink = Arc::new(Mutex::new(sink)); + + let phase_watcher = state.lock().unwrap().phase_receiver(); + join!( + listen(Arc::clone(&state), source, phase_watcher.clone()), + send_voice( + sink, + connection_info.socket_addr, + phase_watcher, + &mut receiver + ), + ); + + debug!("Fully disconnected UDP stream, waiting for new connection info"); + } +} + pub async fn connect( - crypt_state: oneshot::Receiver<ClientCryptState>, + crypt_state: &mut mpsc::Receiver<ClientCryptState>, ) -> (UdpSender, UdpReceiver) { // Bind UDP socket let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)) @@ -25,10 +69,10 @@ pub async fn connect( .expect("Failed to bind UDP socket"); // Wait for initial CryptState - let crypt_state = match crypt_state.await { - Ok(crypt_state) => crypt_state, + let crypt_state = match crypt_state.recv().await { + Some(crypt_state) => crypt_state, // disconnected before we received the CryptSetup packet, oh well - Err(_) => panic!("disconnect before crypt packet received"), //TODO exit gracefully + None => panic!("Disconnect before crypt packet received"), //TODO exit gracefully }; debug!("UDP connected"); @@ -37,36 +81,74 @@ pub async fn connect( } async fn listen( - _sink: Arc<Mutex<UdpSender>>, + state: Arc<Mutex<State>>, mut source: UdpReceiver, - audio: Arc<Mutex<Audio>>, + mut phase_watcher: watch::Receiver<StatePhase>, ) { - while let Some(packet) = source.next().await { - let (packet, _src_addr) = match packet { - Ok(packet) => packet, - Err(err) => { - warn!("Got an invalid UDP packet: {}", err); - // To be expected, considering this is the internet, just ignore it - continue; - } - }; - match packet { - VoicePacket::Ping { .. } => { - // Note: A normal application would handle these and only use UDP for voice - // once it has received one. - continue; - } - VoicePacket::Audio { - session_id, - // seq_num, - payload, - // position_info, - .. - } => { - audio.lock().unwrap().decode_packet(session_id, payload); + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = source.next().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some(packet)) => { + let (packet, _src_addr) = match packet { + Ok(packet) => packet, + Err(err) => { + warn!("Got an invalid UDP packet: {}", err); + // To be expected, considering this is the internet, just ignore it + continue; + } + }; + match packet { + VoicePacket::Ping { .. } => { + // Note: A normal application would handle these and only use UDP for voice + // once it has received one. + continue; + } + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() + .unwrap() + .audio() + .decode_packet(session_id, payload); + } + } + } } } - } + }; + + join!(main_block, phase_transition_block); + + debug!("UDP listener process killed"); } async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { @@ -88,44 +170,58 @@ async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) { async fn send_voice( sink: Arc<Mutex<UdpSender>>, server_addr: SocketAddr, - audio: Arc<Mutex<Audio>>, + mut phase_watcher: watch::Receiver<StatePhase>, + receiver: &mut mpsc::Receiver<VoicePacketPayload>, ) { - let mut receiver = audio.lock().unwrap().take_receiver().unwrap(); + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} + tx.send(true).unwrap(); + }; - let mut count = 0; - while let Some(payload) = receiver.recv().await { - let reply = VoicePacket::Audio { - _dst: std::marker::PhantomData, - target: 0, // normal speech - session_id: (), // unused for server-bound packets - seq_num: count, - payload, - position_info: None, - }; - count += 1; - sink.lock() - .unwrap() - .send((reply, server_addr)) - .await - .unwrap(); - } -} + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + let mut count = 0; + loop { + let packet_recv = receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some(payload)) => { + let reply = VoicePacket::Audio { + _dst: std::marker::PhantomData, + target: 0, // normal speech + session_id: (), // unused for server-bound packets + seq_num: count, + payload, + position_info: None, + }; + count += 1; + sink.lock() + .unwrap() + .send((reply, server_addr)) + .await + .unwrap(); + } + } + } + }; -pub async fn handle( - server_addr: SocketAddr, - crypt_state: oneshot::Receiver<ClientCryptState>, - audio: Arc<Mutex<Audio>>, -) { - let (mut sink, source) = connect(crypt_state).await; - - // Note: A normal application would also send periodic Ping packets, and its own audio - // via UDP. We instead trick the server into accepting us by sending it one - // dummy voice packet. - send_ping(&mut sink, server_addr).await; - - let sink = Arc::new(Mutex::new(sink)); - join!( - listen(Arc::clone(&sink), source, Arc::clone(&audio)), - send_voice(sink, server_addr, audio) - ); + join!(main_block, phase_transition_block); + + debug!("UDP sender process killed"); } diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 1ef8467..b6fe780 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -1,8 +1,197 @@ +use crate::audio::Audio; +use crate::command::{Command, CommandResponse}; +use crate::network::ConnectionInfo; use log::*; use mumble_protocol::control::msgs; -use std::collections::HashMap; +use mumble_protocol::control::ControlPacket; +use mumble_protocol::voice::Serverbound; use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::net::ToSocketAddrs; +use tokio::sync::{mpsc, watch}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum StatePhase { + Disconnected, + Connecting, + Connected, +} + +pub struct State { + server: Option<Server>, + audio: Audio, + + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + command_sender: mpsc::UnboundedSender<Command>, + connection_info_sender: watch::Sender<Option<ConnectionInfo>>, + + phase_watcher: (watch::Sender<StatePhase>, watch::Receiver<StatePhase>), + + username: Option<String>, + session_id: Option<u32>, +} + +impl State { + pub fn new( + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + command_sender: mpsc::UnboundedSender<Command>, + connection_info_sender: watch::Sender<Option<ConnectionInfo>>, + ) -> Self { + Self { + server: None, + audio: Audio::new(), + packet_sender, + command_sender, + connection_info_sender, + phase_watcher: watch::channel(StatePhase::Disconnected), + username: None, + session_id: None, + } + } + + //TODO? move bool inside Result + pub async fn handle_command( + &mut self, + command: Command, + ) -> (bool, Result<Option<CommandResponse>, ()>) { + match command { + Command::ChannelJoin { channel_id } => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + warn!("Not connected"); + return (false, Err(())); + } + let mut msg = msgs::UserState::new(); + msg.set_session(self.session_id.unwrap()); + msg.set_channel_id(channel_id); + self.packet_sender.send(msg.into()).unwrap(); + (false, Ok(None)) + } + Command::ChannelList => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + warn!("Not connected"); + return (false, Err(())); + } + ( + false, + Ok(Some(CommandResponse::ChannelList { + channels: self.server.as_ref().unwrap().channels.clone(), + })), + ) + } + Command::ServerConnect { + host, + port, + username, + accept_invalid_cert, + } => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { + warn!("Tried to connect to a server while already connected"); + return (false, Err(())); + } + self.server = Some(Server::new()); + self.username = Some(username); + self.phase_watcher + .0 + .broadcast(StatePhase::Connecting) + .unwrap(); + let socket_addr = (host.as_ref(), port) + .to_socket_addrs() + .expect("Failed to parse server address") + .next() + .expect("Failed to resolve server address"); + self.connection_info_sender + .broadcast(Some(ConnectionInfo::new( + socket_addr, + host, + accept_invalid_cert, + ))) + .unwrap(); + (true, Ok(None)) + } + Command::Status => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + warn!("Not connected"); + return (false, Err(())); + } + ( + false, + Ok(Some(CommandResponse::Status { + username: self.username.clone(), + server_state: self.server.clone().unwrap(), + })), + ) + } + Command::ServerDisconnect => { + self.session_id = None; + self.username = None; + self.server = None; + + self.phase_watcher + .0 + .broadcast(StatePhase::Disconnected) + .unwrap(); + (false, Ok(None)) + } + } + } + + pub fn parse_initial_user_state(&mut self, msg: msgs::UserState) { + if !msg.has_session() { + warn!("Can't parse user state without session"); + return; + } + if !msg.has_name() { + warn!("Missing name in initial user state"); + } else if msg.get_name() == self.username.as_ref().unwrap() { + match self.session_id { + None => { + debug!("Found our session id: {}", msg.get_session()); + self.session_id = Some(msg.get_session()); + } + Some(session) => { + if session != msg.get_session() { + error!( + "Got two different session IDs ({} and {}) for ourselves", + session, + msg.get_session() + ); + } else { + debug!("Got our session ID twice"); + } + } + } + } + self.server.as_mut().unwrap().parse_user_state(msg); + } + pub fn initialized(&self) { + self.phase_watcher + .0 + .broadcast(StatePhase::Connected) + .unwrap(); + } + + pub fn audio(&self) -> &Audio { + &self.audio + } + pub fn audio_mut(&mut self) -> &mut Audio { + &mut self.audio + } + pub fn packet_sender(&self) -> mpsc::UnboundedSender<ControlPacket<Serverbound>> { + self.packet_sender.clone() + } + pub fn phase_receiver(&self) -> watch::Receiver<StatePhase> { + self.phase_watcher.1.clone() + } + pub fn server_mut(&mut self) -> Option<&mut Server> { + self.server.as_mut() + } + pub fn username(&self) -> Option<&String> { + self.username.as_ref() + } +} + +#[derive(Clone, Debug)] pub struct Server { channels: HashMap<u32, Channel>, users: HashMap<u32, User>, @@ -18,41 +207,49 @@ impl Server { } } - pub fn parse_server_sync(&mut self, mut msg: Box<msgs::ServerSync>) { + pub fn parse_server_sync(&mut self, mut msg: msgs::ServerSync) { if msg.has_welcome_text() { self.welcome_text = Some(msg.take_welcome_text()); } } - pub fn parse_channel_state(&mut self, msg: Box<msgs::ChannelState>) { + pub fn parse_channel_state(&mut self, msg: msgs::ChannelState) { if !msg.has_channel_id() { warn!("Can't parse channel state without channel id"); return; } match self.channels.entry(msg.get_channel_id()) { - Entry::Vacant(e) => { e.insert(Channel::new(msg)); }, + Entry::Vacant(e) => { + e.insert(Channel::new(msg)); + } Entry::Occupied(mut e) => e.get_mut().parse_channel_state(msg), } } - pub fn parse_channel_remove(&mut self, msg: Box<msgs::ChannelRemove>) { + pub fn parse_channel_remove(&mut self, msg: msgs::ChannelRemove) { if !msg.has_channel_id() { warn!("Can't parse channel remove without channel id"); return; } match self.channels.entry(msg.get_channel_id()) { - Entry::Vacant(_) => { warn!("Attempted to remove channel that doesn't exist"); } - Entry::Occupied(e) => { e.remove(); } + Entry::Vacant(_) => { + warn!("Attempted to remove channel that doesn't exist"); + } + Entry::Occupied(e) => { + e.remove(); + } } } - pub fn parse_user_state(&mut self, msg: Box<msgs::UserState>) { + pub fn parse_user_state(&mut self, msg: msgs::UserState) { if !msg.has_session() { warn!("Can't parse user state without session"); return; } match self.users.entry(msg.get_session()) { - Entry::Vacant(e) => { e.insert(User::new(msg)); }, + Entry::Vacant(e) => { + e.insert(User::new(msg)); + } Entry::Occupied(mut e) => e.get_mut().parse_user_state(msg), } } @@ -66,7 +263,7 @@ impl Server { } } - +#[derive(Clone, Debug)] pub struct Channel { description: Option<String>, links: Vec<u32>, @@ -77,7 +274,7 @@ pub struct Channel { } impl Channel { - pub fn new(mut msg: Box<msgs::ChannelState>) -> Self { + pub fn new(mut msg: msgs::ChannelState) -> Self { Self { description: if msg.has_description() { Some(msg.take_description()) @@ -96,7 +293,7 @@ impl Channel { } } - pub fn parse_channel_state(&mut self, mut msg: Box<msgs::ChannelState>) { + pub fn parse_channel_state(&mut self, mut msg: msgs::ChannelState) { if msg.has_description() { self.description = Some(msg.take_description()); } @@ -120,6 +317,7 @@ impl Channel { } } +#[derive(Clone, Debug)] pub struct User { channel: u32, comment: Option<String>, @@ -128,15 +326,15 @@ pub struct User { priority_speaker: bool, recording: bool, - suppress: bool, // by me + suppress: bool, // by me self_mute: bool, // by self self_deaf: bool, // by self - mute: bool, // by admin - deaf: bool, // by admin + mute: bool, // by admin + deaf: bool, // by admin } impl User { - pub fn new(mut msg: Box<msgs::UserState>) -> Self { + pub fn new(mut msg: msgs::UserState) -> Self { Self { channel: msg.get_channel_id(), comment: if msg.has_comment() { @@ -150,24 +348,17 @@ impl User { None }, name: msg.take_name(), - priority_speaker: msg.has_priority_speaker() - && msg.get_priority_speaker(), - recording: msg.has_recording() - && msg.get_recording(), - suppress: msg.has_suppress() - && msg.get_suppress(), - self_mute: msg.has_self_mute() - && msg.get_self_mute(), - self_deaf: msg.has_self_deaf() - && msg.get_self_deaf(), - mute: msg.has_mute() - && msg.get_mute(), - deaf: msg.has_deaf() - && msg.get_deaf(), - } - } - - pub fn parse_user_state(&mut self, mut msg: Box<msgs::UserState>) { + priority_speaker: msg.has_priority_speaker() && msg.get_priority_speaker(), + recording: msg.has_recording() && msg.get_recording(), + suppress: msg.has_suppress() && msg.get_suppress(), + self_mute: msg.has_self_mute() && msg.get_self_mute(), + self_deaf: msg.has_self_deaf() && msg.get_self_deaf(), + mute: msg.has_mute() && msg.get_mute(), + deaf: msg.has_deaf() && msg.get_deaf(), + } + } + + pub fn parse_user_state(&mut self, mut msg: msgs::UserState) { if msg.has_channel_id() { self.channel = msg.get_channel_id(); } |
