diff options
Diffstat (limited to 'mumd')
| -rw-r--r-- | mumd/src/audio.rs | 166 | ||||
| -rw-r--r-- | mumd/src/audio/input.rs | 52 | ||||
| -rw-r--r-- | mumd/src/audio/output.rs | 90 | ||||
| -rw-r--r-- | mumd/src/command.rs | 36 | ||||
| -rw-r--r-- | mumd/src/main.rs | 4 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 391 | ||||
| -rw-r--r-- | mumd/src/state.rs | 120 | ||||
| -rw-r--r-- | mumd/src/state/server.rs | 4 | ||||
| -rw-r--r-- | mumd/src/state/user.rs | 40 |
9 files changed, 509 insertions, 394 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index bbde547..8609a91 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -1,34 +1,25 @@ -use bytes::Bytes; +pub mod input; +pub mod output; + use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; -use cpal::{ - InputCallbackInfo, OutputCallbackInfo, Sample, SampleFormat, SampleRate, Stream, StreamConfig, -}; +use cpal::{SampleFormat, SampleRate, Stream, StreamConfig}; use log::*; use mumble_protocol::voice::VoicePacketPayload; use opus::Channels; use std::collections::hash_map::Entry; use std::collections::HashMap; -use std::collections::VecDeque; -use std::ops::AddAssign; -use std::sync::Arc; -use std::sync::Mutex; -use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio::sync::watch; - -struct ClientStream { - buffer: VecDeque<f32>, //TODO ring buffer? - opus_decoder: opus::Decoder, -} +use std::sync::{Arc, Mutex}; +use tokio::sync::{mpsc, watch}; pub struct Audio { output_config: StreamConfig, _output_stream: Stream, _input_stream: Stream, - input_channel_receiver: Option<Receiver<VoicePacketPayload>>, + input_channel_receiver: Option<mpsc::Receiver<VoicePacketPayload>>, input_volume_sender: watch::Sender<f32>, - client_streams: Arc<Mutex<HashMap<u32, ClientStream>>>, + client_streams: Arc<Mutex<HashMap<u32, output::ClientStream>>>, } impl Audio { @@ -66,17 +57,17 @@ impl Audio { let output_stream = match output_supported_sample_format { SampleFormat::F32 => output_device.build_output_stream( &output_config, - output_curry_callback::<f32>(Arc::clone(&client_streams)), + output::curry_callback::<f32>(Arc::clone(&client_streams)), err_fn, ), SampleFormat::I16 => output_device.build_output_stream( &output_config, - output_curry_callback::<i16>(Arc::clone(&client_streams)), + output::curry_callback::<i16>(Arc::clone(&client_streams)), err_fn, ), SampleFormat::U16 => output_device.build_output_stream( &output_config, - output_curry_callback::<u16>(Arc::clone(&client_streams)), + output::curry_callback::<u16>(Arc::clone(&client_streams)), err_fn, ), } @@ -102,7 +93,7 @@ impl Audio { let input_stream = match input_supported_sample_format { SampleFormat::F32 => input_device.build_input_stream( &input_config, - input_callback::<f32>( + input::callback::<f32>( input_encoder, input_sender, input_config.sample_rate.0, @@ -113,7 +104,7 @@ impl Audio { ), SampleFormat::I16 => input_device.build_input_stream( &input_config, - input_callback::<i16>( + input::callback::<i16>( input_encoder, input_sender, input_config.sample_rate.0, @@ -124,7 +115,7 @@ impl Audio { ), SampleFormat::U16 => input_device.build_input_stream( &input_config, - input_callback::<u16>( + input::callback::<u16>( input_encoder, input_sender, input_config.sample_rate.0, @@ -167,7 +158,7 @@ impl Audio { warn!("Session id {} already exists", session_id); } Entry::Vacant(entry) => { - entry.insert(ClientStream::new( + entry.insert(output::ClientStream::new( self.output_config.sample_rate.0, self.output_config.channels, )); @@ -189,7 +180,7 @@ impl Audio { } } - pub fn take_receiver(&mut self) -> Option<Receiver<VoicePacketPayload>> { + pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<VoicePacketPayload>> { self.input_channel_receiver.take() } @@ -201,128 +192,3 @@ impl Audio { self.input_volume_sender.broadcast(input_volume).unwrap(); } } - -impl ClientStream { - fn new(sample_rate: u32, channels: u16) -> Self { - Self { - buffer: VecDeque::new(), - opus_decoder: opus::Decoder::new( - sample_rate, - match channels { - 1 => Channels::Mono, - 2 => Channels::Stereo, - _ => unimplemented!("Only 1 or 2 channels supported, got {}", channels), - }, - ) - .unwrap(), - } - } - - fn decode_packet(&mut self, payload: VoicePacketPayload, channels: usize) { - match payload { - VoicePacketPayload::Opus(bytes, _eot) => { - let mut out: Vec<f32> = vec![0.0; 720 * channels * 4]; //720 is because that is the max size of packet we can get that we want to decode - let parsed = self - .opus_decoder - .decode_float(&bytes, &mut out, false) - .expect("Error decoding"); - out.truncate(parsed); - self.buffer.extend(out); - } - _ => { - unimplemented!("Payload type not supported"); - } - } - } -} - -trait SaturatingAdd { - fn saturating_add(self, rhs: Self) -> Self; -} - -impl SaturatingAdd for f32 { - fn saturating_add(self, rhs: Self) -> Self { - match self + rhs { - a if a < -1.0 => -1.0, - a if a > 1.0 => 1.0, - a => a, - } - } -} - -impl SaturatingAdd for i16 { - fn saturating_add(self, rhs: Self) -> Self { - i16::saturating_add(self, rhs) - } -} - -impl SaturatingAdd for u16 { - fn saturating_add(self, rhs: Self) -> Self { - u16::saturating_add(self, rhs) - } -} - -fn output_curry_callback<T: Sample + AddAssign + SaturatingAdd>( - buf: Arc<Mutex<HashMap<u32, ClientStream>>>, -) -> impl FnMut(&mut [T], &OutputCallbackInfo) + Send + 'static { - move |data: &mut [T], _info: &OutputCallbackInfo| { - for sample in data.iter_mut() { - *sample = Sample::from(&0.0); - } - - let mut lock = buf.lock().unwrap(); - for client_stream in lock.values_mut() { - for sample in data.iter_mut() { - *sample = sample.saturating_add(Sample::from( - &client_stream.buffer.pop_front().unwrap_or(0.0), - )); - } - } - } -} - -fn input_callback<T: Sample>( - mut opus_encoder: opus::Encoder, - mut input_sender: Sender<VoicePacketPayload>, - sample_rate: u32, - input_volume_receiver: watch::Receiver<f32>, - opus_frame_size_blocks: u32, // blocks of 2.5ms -) -> impl FnMut(&[T], &InputCallbackInfo) + Send + 'static { - if !(opus_frame_size_blocks == 1 - || opus_frame_size_blocks == 2 - || opus_frame_size_blocks == 4 - || opus_frame_size_blocks == 8) - { - panic!( - "Unsupported amount of opus frame blocks {}", - opus_frame_size_blocks - ); - } - let opus_frame_size = opus_frame_size_blocks * sample_rate / 400; - - let buf = Arc::new(Mutex::new(VecDeque::new())); - move |data: &[T], _info: &InputCallbackInfo| { - let mut buf = buf.lock().unwrap(); - let input_volume = *input_volume_receiver.borrow(); - let out: Vec<f32> = data.iter().map(|e| e.to_f32()) - .map(|e| e * input_volume) - .collect(); - buf.extend(out); - while buf.len() >= opus_frame_size as usize { - let tail = buf.split_off(opus_frame_size as usize); - let mut opus_buf: Vec<u8> = vec![0; opus_frame_size as usize]; - let result = opus_encoder - .encode_float(&Vec::from(buf.clone()), &mut opus_buf) - .unwrap(); - opus_buf.truncate(result); - let bytes = Bytes::copy_from_slice(&opus_buf); - match input_sender.try_send(VoicePacketPayload::Opus(bytes, false)) { - Ok(_) => {} - Err(_e) => { - //warn!("Error sending audio packet: {:?}", e); - } - } - *buf = tail; - } - } -} diff --git a/mumd/src/audio/input.rs b/mumd/src/audio/input.rs new file mode 100644 index 0000000..4e95360 --- /dev/null +++ b/mumd/src/audio/input.rs @@ -0,0 +1,52 @@ +use bytes::Bytes; +use cpal::{InputCallbackInfo, Sample}; +use mumble_protocol::voice::VoicePacketPayload; +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; +use tokio::sync::{mpsc, watch}; + +pub fn callback<T: Sample>( + mut opus_encoder: opus::Encoder, + mut input_sender: mpsc::Sender<VoicePacketPayload>, + sample_rate: u32, + input_volume_receiver: watch::Receiver<f32>, + opus_frame_size_blocks: u32, // blocks of 2.5ms +) -> impl FnMut(&[T], &InputCallbackInfo) + Send + 'static { + if !(opus_frame_size_blocks == 1 + || opus_frame_size_blocks == 2 + || opus_frame_size_blocks == 4 + || opus_frame_size_blocks == 8) + { + panic!( + "Unsupported amount of opus frame blocks {}", + opus_frame_size_blocks + ); + } + let opus_frame_size = opus_frame_size_blocks * sample_rate / 400; + + let buf = Arc::new(Mutex::new(VecDeque::new())); + move |data: &[T], _info: &InputCallbackInfo| { + let mut buf = buf.lock().unwrap(); + let input_volume = *input_volume_receiver.borrow(); + let out: Vec<f32> = data.iter().map(|e| e.to_f32()) + .map(|e| e * input_volume) + .collect(); + buf.extend(out); + while buf.len() >= opus_frame_size as usize { + let tail = buf.split_off(opus_frame_size as usize); + let mut opus_buf: Vec<u8> = vec![0; opus_frame_size as usize]; + let result = opus_encoder + .encode_float(&Vec::from(buf.clone()), &mut opus_buf) + .unwrap(); + opus_buf.truncate(result); + let bytes = Bytes::copy_from_slice(&opus_buf); + match input_sender.try_send(VoicePacketPayload::Opus(bytes, false)) { + Ok(_) => {} + Err(_e) => { + //warn!("Error sending audio packet: {:?}", e); + } + } + *buf = tail; + } + } +} diff --git a/mumd/src/audio/output.rs b/mumd/src/audio/output.rs new file mode 100644 index 0000000..94e4b21 --- /dev/null +++ b/mumd/src/audio/output.rs @@ -0,0 +1,90 @@ +use cpal::{OutputCallbackInfo, Sample}; +use mumble_protocol::voice::VoicePacketPayload; +use opus::Channels; +use std::collections::{HashMap, VecDeque}; +use std::ops::AddAssign; +use std::sync::{Arc, Mutex}; + +pub struct ClientStream { + buffer: VecDeque<f32>, //TODO ring buffer? + opus_decoder: opus::Decoder, +} + +impl ClientStream { + pub fn new(sample_rate: u32, channels: u16) -> Self { + Self { + buffer: VecDeque::new(), + opus_decoder: opus::Decoder::new( + sample_rate, + match channels { + 1 => Channels::Mono, + 2 => Channels::Stereo, + _ => unimplemented!("Only 1 or 2 channels supported, got {}", channels), + }, + ) + .unwrap(), + } + } + + pub fn decode_packet(&mut self, payload: VoicePacketPayload, channels: usize) { + match payload { + VoicePacketPayload::Opus(bytes, _eot) => { + let mut out: Vec<f32> = vec![0.0; 720 * channels * 4]; //720 is because that is the max size of packet we can get that we want to decode + let parsed = self + .opus_decoder + .decode_float(&bytes, &mut out, false) + .expect("Error decoding"); + out.truncate(parsed); + self.buffer.extend(out); + } + _ => { + unimplemented!("Payload type not supported"); + } + } + } +} + +pub trait SaturatingAdd { + fn saturating_add(self, rhs: Self) -> Self; +} + +impl SaturatingAdd for f32 { + fn saturating_add(self, rhs: Self) -> Self { + match self + rhs { + a if a < -1.0 => -1.0, + a if a > 1.0 => 1.0, + a => a, + } + } +} + +impl SaturatingAdd for i16 { + fn saturating_add(self, rhs: Self) -> Self { + i16::saturating_add(self, rhs) + } +} + +impl SaturatingAdd for u16 { + fn saturating_add(self, rhs: Self) -> Self { + u16::saturating_add(self, rhs) + } +} + +pub fn curry_callback<T: Sample + AddAssign + SaturatingAdd>( + buf: Arc<Mutex<HashMap<u32, ClientStream>>>, +) -> impl FnMut(&mut [T], &OutputCallbackInfo) + Send + 'static { + move |data: &mut [T], _info: &OutputCallbackInfo| { + for sample in data.iter_mut() { + *sample = Sample::from(&0.0); + } + + let mut lock = buf.lock().unwrap(); + for client_stream in lock.values_mut() { + for sample in data.iter_mut() { + *sample = sample.saturating_add(Sample::from( + &client_stream.buffer.pop_front().unwrap_or(0.0), + )); + } + } + } +} diff --git a/mumd/src/command.rs b/mumd/src/command.rs index a035a26..075bfaf 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -1,10 +1,11 @@ -use crate::state::{State, StatePhase}; +use crate::state::State; use ipc_channel::ipc::IpcSender; use log::*; use mumlib::command::{Command, CommandResponse}; use std::sync::{Arc, Mutex}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; +use crate::network::tcp::{TcpEvent, TcpEventCallback}; pub async fn handle( state: Arc<Mutex<State>>, @@ -12,23 +13,26 @@ pub async fn handle( Command, IpcSender<mumlib::error::Result<Option<CommandResponse>>>, )>, + tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>, ) { debug!("Begin listening for commands"); - while let Some(command) = command_receiver.recv().await { - debug!("Received command {:?}", command.0); + while let Some((command, response_sender)) = command_receiver.recv().await { + debug!("Received command {:?}", command); let mut state = state.lock().unwrap(); - let (wait_for_connected, command_response) = state.handle_command(command.0).await; - if wait_for_connected { - let mut watcher = state.phase_receiver(); - drop(state); - while !matches!(watcher.recv().await.unwrap(), StatePhase::Connected) {} + let (event, generator) = state.handle_command(command).await; + drop(state); + if let Some(event) = event { + let (tx, rx) = oneshot::channel(); + //TODO handle this error + let _ = tcp_event_register_sender.send((event, Box::new(move |e| { + let response = generator(Some(e)); + response_sender.send(response).unwrap(); + tx.send(()).unwrap(); + }))); + + rx.await.unwrap(); + } else { + response_sender.send(generator(None)).unwrap(); } - command.1.send(command_response).unwrap(); } - //TODO err if not connected - //while let Some(command) = command_receiver.recv().await { - // debug!("Parsing command {:?}", command); - //} - - //debug!("Finished handling commands"); } diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 75726f8..e88eede 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -33,6 +33,7 @@ async fn main() { )>(); let (connection_info_sender, connection_info_receiver) = watch::channel::<Option<ConnectionInfo>>(None); + let (response_sender, response_receiver) = mpsc::unbounded_channel(); let state = State::new(packet_sender, connection_info_sender); let state = Arc::new(Mutex::new(state)); @@ -43,13 +44,14 @@ async fn main() { connection_info_receiver.clone(), crypt_state_sender, packet_receiver, + response_receiver, ), network::udp::handle( Arc::clone(&state), connection_info_receiver.clone(), crypt_state_receiver, ), - command::handle(state, command_receiver,), + command::handle(state, command_receiver, response_sender), spawn_blocking(move || { // IpcSender is blocking receive_oneshot_commands(command_sender); diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 88d2b59..c2cb234 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -15,6 +15,10 @@ use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use std::collections::HashMap; +use std::future::Future; +use std::rc::Rc; +use std::cell::RefCell; type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, @@ -23,11 +27,25 @@ type TcpSender = SplitSink< type TcpReceiver = SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>; +pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum TcpEvent { + Connected, //fires when the client has connected to a server + Disconnected, //fires when the client has disconnected from a server +} + +pub enum TcpEventData<'a> { + Connected(&'a msgs::ServerSync), + Disconnected, +} + pub async fn handle( state: Arc<Mutex<State>>, mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, crypt_state_sender: mpsc::Sender<ClientCryptState>, mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, + mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, ) { loop { let connection_info = loop { @@ -54,6 +72,7 @@ pub async fn handle( let phase_watcher = state_lock.phase_receiver(); let packet_sender = state_lock.packet_sender(); drop(state_lock); + let event_queue = Arc::new(Mutex::new(HashMap::new())); info!("Logging in..."); @@ -63,9 +82,11 @@ pub async fn handle( Arc::clone(&state), stream, crypt_state_sender.clone(), - phase_watcher.clone() + Arc::clone(&event_queue), + phase_watcher.clone(), ), - send_packets(sink, &mut packet_receiver, phase_watcher), + send_packets(sink, &mut packet_receiver, phase_watcher.clone()), + register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher), ); debug!("Fully disconnected TCP stream, waiting for new connection info"); @@ -108,103 +129,207 @@ async fn authenticate(sink: &mut TcpSender, username: String) { async fn send_pings( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, delay_seconds: u64, - mut phase_watcher: watch::Receiver<StatePhase>, + phase_watcher: watch::Receiver<StatePhase>, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - while !matches!( - phase_watcher.recv().await.unwrap(), - StatePhase::Disconnected - ) {} - tx.send(true).unwrap(); - }; + let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs(delay_seconds)))); + let packet_sender = Rc::new(RefCell::new(packet_sender)); - let mut interval = time::interval(Duration::from_secs(delay_seconds)); - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let interval_waiter = interval.tick().fuse(); - pin_mut!(interval_waiter); - let exitor = select! { - data = interval_waiter => Some(data), - _ = rx => None - }; - - match exitor { - Some(_) => { - trace!("Sending ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } - None => break, - } - } - }; - - join!(main_block, phase_transition_block); + run_until_disconnection( + || async { + Some(interval.borrow_mut().tick().await) + }, + |_| async { + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.borrow_mut().send(msg.into()).unwrap(); + }, + || async {}, + phase_watcher, + ).await; debug!("Ping sender process killed"); } async fn send_packets( - mut sink: TcpSender, + sink: TcpSender, packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, - mut phase_watcher: watch::Receiver<StatePhase>, + phase_watcher: watch::Receiver<StatePhase>, ) { - let (tx, rx) = oneshot::channel(); - let phase_transition_block = async { - while !matches!( - phase_watcher.recv().await.unwrap(), - StatePhase::Disconnected - ) {} - tx.send(true).unwrap(); - }; - - let main_block = async { - let rx = rx.fuse(); - pin_mut!(rx); - loop { - let packet_recv = packet_receiver.recv().fuse(); - pin_mut!(packet_recv); - let exitor = select! { - data = packet_recv => Some(data), - _ = rx => None - }; - match exitor { - None => { - break; - } - Some(None) => { - warn!("Channel closed before disconnect command"); - break; - } - Some(Some(packet)) => { - sink.send(packet).await.unwrap(); - } - } - } + let sink = Rc::new(RefCell::new(sink)); + let packet_receiver = Rc::new(RefCell::new(packet_receiver)); + run_until_disconnection( + || async { + packet_receiver.borrow_mut().recv().await + }, + |packet| async { + sink.borrow_mut().send(packet).await.unwrap(); + }, + || async { + //clears queue of remaining packets + while packet_receiver.borrow_mut().try_recv().is_ok() {} - //clears queue of remaining packets - while packet_receiver.try_recv().is_ok() {} - - sink.close().await.unwrap(); - }; - - join!(main_block, phase_transition_block); + sink.borrow_mut().close().await.unwrap(); + }, + phase_watcher, + ).await; debug!("TCP packet sender killed"); } async fn listen( state: Arc<Mutex<State>>, - mut stream: TcpReceiver, + stream: TcpReceiver, crypt_state_sender: mpsc::Sender<ClientCryptState>, - mut phase_watcher: watch::Receiver<StatePhase>, + event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + phase_watcher: watch::Receiver<StatePhase>, ) { - let mut crypt_state = None; - let mut crypt_state_sender = Some(crypt_state_sender); + let crypt_state = Rc::new(RefCell::new(None)); + let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender))); + let stream = Rc::new(RefCell::new(stream)); + run_until_disconnection( + || async { + stream.borrow_mut().next().await + }, + |packet| async { + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + *crypt_state.borrow_mut() = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(mut sender) = crypt_state_sender.borrow_mut().take() { + let _ = sender + .send( + crypt_state + .borrow_mut() + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; + } + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Connected(&msg)); + } + } + let mut state = state.lock().unwrap(); + let server = state.server_mut().unwrap(); + server.parse_server_sync(*msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for channel in server.channels().values() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + let mut state = state.lock().unwrap(); + let session = msg.get_session(); + if *state.phase_receiver().borrow() == StatePhase::Connecting { + state.audio_mut().add_client(msg.get_session()); + state.parse_user_state(*msg); + } else { + state.parse_user_state(*msg); + } + let server = state.server_mut().unwrap(); + let user = server.users().get(&session).unwrap(); + info!("User {} connected to {}", user.name(), user.channel()); + } + ControlPacket::UserRemove(msg) => { + info!("User {} left", msg.get_session()); + state + .lock() + .unwrap() + .audio_mut() + .remove_client(msg.get_session()); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_state(*msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_remove(*msg); + } + _ => {} + } + }, + || async { + if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + let old = std::mem::take(vec); + for handler in old { + handler(&TcpEventData::Disconnected); + } + } + }, + phase_watcher, + ).await; + + debug!("Killing TCP listener block"); +} + +async fn register_events( + tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>, + event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>, + phase_watcher: watch::Receiver<StatePhase>, +) { + let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver)); + run_until_disconnection( + || async { + tcp_event_register_receiver.borrow_mut().recv().await + }, + |(event, handler)| async { event_data.lock().unwrap().entry(event).or_default().push(handler); }, + || async {}, + phase_watcher, + ).await; +} + +async fn run_until_disconnection<T, F, G, H>( + mut generator: impl FnMut() -> F, + mut handler: impl FnMut(T) -> G, + mut shutdown: impl FnMut() -> H, + mut phase_watcher: watch::Receiver<StatePhase>, +) + where + F: Future<Output = Option<T>>, + G: Future<Output = ()>, + H: Future<Output = ()>, +{ let (tx, rx) = oneshot::channel(); let phase_transition_block = async { while !matches!( @@ -214,11 +339,11 @@ async fn listen( tx.send(true).unwrap(); }; - let listener_block = async { + let main_block = async { let rx = rx.fuse(); pin_mut!(rx); loop { - let packet_recv = stream.next().fuse(); + let packet_recv = generator().fuse(); pin_mut!(packet_recv); let exitor = select! { data = packet_recv => Some(data), @@ -229,107 +354,17 @@ async fn listen( break; } Some(None) => { - warn!("Channel closed before disconnect command"); + //warn!("Channel closed before disconnect command"); //TODO make me informative break; } - Some(Some(packet)) => { - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - crypt_state = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(mut sender) = crypt_state_sender.take() { - let _ = sender - .send( - crypt_state - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ) - .await; - } - let mut state = state.lock().unwrap(); - let server = state.server_mut().unwrap(); - server.parse_server_sync(*msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), - } - for channel in server.channels().values() { - info!("Found channel {}", channel.name()); - } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - let mut state = state.lock().unwrap(); - let session = msg.get_session(); - if *state.phase_receiver().borrow() == StatePhase::Connecting { - state.audio_mut().add_client(msg.get_session()); - state.parse_initial_user_state(*msg); - } else { - state.server_mut().unwrap().parse_user_state(*msg); - } - let server = state.server_mut().unwrap(); - let user = server.users().get(&session).unwrap(); - info!("User {} connected to {}", user.name(), user.channel()); - } - ControlPacket::UserRemove(msg) => { - info!("User {} left", msg.get_session()); - state - .lock() - .unwrap() - .audio_mut() - .remove_client(msg.get_session()); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_state(*msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state - .lock() - .unwrap() - .server_mut() - .unwrap() - .parse_channel_remove(*msg); - } - _ => {} - } + Some(Some(data)) => { + handler(data).await; } } } - //TODO? clean up stream + shutdown().await; }; - join!(phase_transition_block, listener_block); - - debug!("Killing TCP listener block"); -} + join!(main_block, phase_transition_block); +}
\ No newline at end of file diff --git a/mumd/src/state.rs b/mumd/src/state.rs index d355ef5..f9ed077 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -15,6 +15,7 @@ use mumlib::config::Config; use mumlib::error::{ChannelIdentifierError, Error}; use std::net::ToSocketAddrs; use tokio::sync::{mpsc, watch}; +use crate::network::tcp::{TcpEvent, TcpEventData}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum StatePhase { @@ -56,11 +57,11 @@ impl State { pub async fn handle_command( &mut self, command: Command, - ) -> (bool, mumlib::error::Result<Option<CommandResponse>>) { + ) -> (Option<TcpEvent>, Box<dyn FnOnce(Option<&TcpEventData>) -> mumlib::error::Result<Option<CommandResponse>>>) { match command { Command::ChannelJoin { channel_identifier } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } let channels = self.server() @@ -78,33 +79,34 @@ impl State { .filter(|e| e.1.ends_with(&channel_identifier.to_lowercase())) .collect::<Vec<_>>(); match soft_matches.len() { - 0 => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + 0 => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))), 1 => *soft_matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))), + _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))), } }, 1 => *matches.get(0).unwrap().0, - _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))), + _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous)))), }; let mut msg = msgs::UserState::new(); msg.set_session(self.server.as_ref().unwrap().session_id().unwrap()); msg.set_channel_id(id); self.packet_sender.send(msg.into()).unwrap(); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::ChannelList => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } + let list = channel::into_channel( + self.server.as_ref().unwrap().channels(), + self.server.as_ref().unwrap().users(), + ); ( - false, - Ok(Some(CommandResponse::ChannelList { - channels: channel::into_channel( - self.server.as_ref().unwrap().channels(), - self.server.as_ref().unwrap().users(), - ), - })), + None, + Box::new(move |_| Ok(Some(CommandResponse::ChannelList { + channels: list, + }))), ) } Command::ServerConnect { @@ -114,7 +116,7 @@ impl State { accept_invalid_cert, } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { - return (false, Err(Error::AlreadyConnectedError)); + return (None, Box::new(|_| Err(Error::AlreadyConnectedError))); } let mut server = Server::new(); *server.username_mut() = Some(username); @@ -132,7 +134,7 @@ impl State { Ok(Some(v)) => v, _ => { warn!("Error parsing server addr"); - return (false, Err(Error::InvalidServerAddrError(host, port))); + return (None, Box::new(move |_| Err(Error::InvalidServerAddrError(host, port)))); } }; self.connection_info_sender @@ -142,22 +144,35 @@ impl State { accept_invalid_cert, ))) .unwrap(); - (true, Ok(None)) + (Some(TcpEvent::Connected), Box::new(|e| { //runs the closure when the client is connected + if let Some(TcpEventData::Connected(msg)) = e { + Ok(Some(CommandResponse::ServerConnect { + welcome_message: if msg.has_welcome_text() { + Some(msg.get_welcome_text().to_string()) + } else { + None + } + })) + } else { + unreachable!("callback should be provided with a TcpEventData::Connected"); + } + })) } Command::Status => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } + let state = self.server.as_ref().unwrap().into(); ( - false, - Ok(Some(CommandResponse::Status { - server_state: self.server.as_ref().unwrap().into(), //guaranteed not to panic because if we are connected, server is guaranteed to be Some - })), + None, + Box::new(move |_| Ok(Some(CommandResponse::Status { + server_state: state, //guaranteed not to panic because if we are connected, server is guaranteed to be Some + }))), ) } Command::ServerDisconnect => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { - return (false, Err(Error::DisconnectedError)); + return (None, Box::new(|_| Err(Error::DisconnectedError))); } self.server = None; @@ -167,46 +182,54 @@ impl State { .0 .broadcast(StatePhase::Disconnected) .unwrap(); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::InputVolumeSet(volume) => { self.audio.set_input_volume(volume); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } Command::ConfigReload => { self.reload_config(); - (false, Ok(None)) + (None, Box::new(|_| Ok(None))) } } } - pub fn parse_initial_user_state(&mut self, msg: msgs::UserState) { + pub fn parse_user_state(&mut self, msg: msgs::UserState) -> Option<mumlib::state::UserDiff> { if !msg.has_session() { warn!("Can't parse user state without session"); - return; + return None; } - if !msg.has_name() { - warn!("Missing name in initial user state"); - } else if msg.get_name() == self.server.as_ref().unwrap().username().unwrap() { - match self.server.as_ref().unwrap().session_id() { - None => { - debug!("Found our session id: {}", msg.get_session()); - *self.server_mut().unwrap().session_id_mut() = Some(msg.get_session()); - } - Some(session) => { - if session != msg.get_session() { - error!( - "Got two different session IDs ({} and {}) for ourselves", - session, - msg.get_session() - ); - } else { - debug!("Got our session ID twice"); - } - } + let sess = msg.get_session(); + // check if this is initial state + if !self.server().unwrap().users().contains_key(&sess) { + if !msg.has_name() { + warn!("Missing name in initial user state"); + } else if msg.get_name() == self.server().unwrap().username().unwrap() { + // this is us + *self.server_mut().unwrap().session_id_mut() = Some(sess); + } else { + // this is someone else + self.audio_mut().add_client(sess); } + self.server_mut().unwrap().users_mut().insert(sess, user::User::new(msg)); + None + } else { + let user = self.server_mut().unwrap().users_mut().get_mut(&sess).unwrap(); + let diff = mumlib::state::UserDiff::from(msg); + user.apply_user_diff(&diff); + Some(diff) } - self.server.as_mut().unwrap().parse_user_state(msg); + } + + pub fn remove_client(&mut self, msg: msgs::UserRemove) { + if !msg.has_session() { + warn!("Tried to remove user state without session"); + return; + } + self.audio().remove_client(msg.get_session()); + self.server_mut().unwrap().users_mut().remove(&msg.get_session()); + info!("User {} disconnected", msg.get_session()); } pub fn reload_config(&mut self) { @@ -252,4 +275,3 @@ impl State { self.server.as_ref().map(|e| e.username()).flatten() } } - diff --git a/mumd/src/state/server.rs b/mumd/src/state/server.rs index b7cabb7..b99c7e6 100644 --- a/mumd/src/state/server.rs +++ b/mumd/src/state/server.rs @@ -98,6 +98,10 @@ impl Server { &self.users } + pub fn users_mut(&mut self) -> &mut HashMap<u32, User> { + &mut self.users + } + pub fn username(&self) -> Option<&str> { self.username.as_ref().map(|e| e.as_str()) } diff --git a/mumd/src/state/user.rs b/mumd/src/state/user.rs index bb4e101..679d0ff 100644 --- a/mumd/src/state/user.rs +++ b/mumd/src/state/user.rs @@ -1,3 +1,4 @@ +use log::*; use mumble_protocol::control::msgs; use serde::{Deserialize, Serialize}; @@ -78,6 +79,45 @@ impl User { } } + pub fn apply_user_diff(&mut self, diff: &mumlib::state::UserDiff) { + debug!("applying user diff\n{:#?}", diff); + if let Some(comment) = diff.comment.clone() { + self.comment = Some(comment); + } + if let Some(hash) = diff.hash.clone() { + self.hash = Some(hash); + } + if let Some(name) = diff.name.clone() { + self.name = name; + } + if let Some(priority_speaker) = diff.priority_speaker { + self.priority_speaker = priority_speaker; + } + if let Some(recording) = diff.recording { + self.recording = recording; + } + if let Some(suppress) = diff.suppress { + self.suppress = suppress; + } + if let Some(self_mute) = diff.self_mute { + self.self_mute = self_mute; + } + if let Some(self_deaf) = diff.self_deaf { + self.self_deaf = self_deaf; + } + if let Some(mute) = diff.mute { + self.mute = mute; + } + if let Some(deaf) = diff.deaf { + self.deaf = deaf; + } + + if let Some(channel_id) = diff.channel_id { + self.channel = channel_id; + } + } + + pub fn name(&self) -> &str { &self.name } |
