diff options
Diffstat (limited to 'mumd/src')
| -rw-r--r-- | mumd/src/audio.rs | 49 | ||||
| -rw-r--r-- | mumd/src/command.rs | 4 | ||||
| -rw-r--r-- | mumd/src/main.rs | 48 | ||||
| -rw-r--r-- | mumd/src/network/mod.rs | 6 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 82 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 36 | ||||
| -rw-r--r-- | mumd/src/state.rs | 136 |
7 files changed, 234 insertions, 127 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index aa06a9d..58424b6 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -103,26 +103,32 @@ impl Audio { let input_stream = match input_supported_sample_format { SampleFormat::F32 => input_device.build_input_stream( &input_config, - input_callback::<f32>(input_encoder, - input_sender, - input_config.sample_rate.0, - 10.0), + input_callback::<f32>( + input_encoder, + input_sender, + input_config.sample_rate.0, + 10.0, + ), err_fn, ), SampleFormat::I16 => input_device.build_input_stream( &input_config, - input_callback::<i16>(input_encoder, - input_sender, - input_config.sample_rate.0, - 10.0), + input_callback::<i16>( + input_encoder, + input_sender, + input_config.sample_rate.0, + 10.0, + ), err_fn, ), SampleFormat::U16 => input_device.build_input_stream( &input_config, - input_callback::<u16>(input_encoder, - input_sender, - input_config.sample_rate.0, - 10.0), + input_callback::<u16>( + input_encoder, + input_sender, + input_config.sample_rate.0, + 10.0, + ), err_fn, ), } @@ -207,7 +213,8 @@ impl ClientStream { match payload { VoicePacketPayload::Opus(bytes, _eot) => { let mut out: Vec<f32> = vec![0.0; 720 * channels * 4]; //720 is because that is the max size of packet we can get that we want to decode - let parsed = self.opus_decoder + let parsed = self + .opus_decoder .decode_float(&bytes, &mut out, false) .expect("Error decoding"); out.truncate(parsed); @@ -271,15 +278,15 @@ fn input_callback<T: Sample>( sample_rate: u32, opus_frame_size_ms: f32, ) -> impl FnMut(&[T], &InputCallbackInfo) + Send + 'static { - if ! ( opus_frame_size_ms == 2.5 - || opus_frame_size_ms == 5.0 - || opus_frame_size_ms == 10.0 - || opus_frame_size_ms == 20.0) { + if !(opus_frame_size_ms == 2.5 + || opus_frame_size_ms == 5.0 + || opus_frame_size_ms == 10.0 + || opus_frame_size_ms == 20.0) + { panic!("Unsupported opus frame size {}", opus_frame_size_ms); } let opus_frame_size = (opus_frame_size_ms * sample_rate as f32) as u32 / 1000; - let buf = Arc::new(Mutex::new(VecDeque::new())); move |data: &[T], _info: &InputCallbackInfo| { let mut buf = buf.lock().unwrap(); @@ -293,9 +300,9 @@ fn input_callback<T: Sample>( .unwrap(); opus_buf.truncate(result); let bytes = Bytes::copy_from_slice(&opus_buf); - match input_sender - .try_send(VoicePacketPayload::Opus(bytes, false)) { //TODO handle full buffer / disconnect - Ok(_) => {}, + match input_sender.try_send(VoicePacketPayload::Opus(bytes, false)) { + //TODO handle full buffer / disconnect + Ok(_) => {} Err(_e) => { //warn!("Error sending audio packet: {:?}", e); } diff --git a/mumd/src/command.rs b/mumd/src/command.rs index 1104671..b4bd1b7 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -1,9 +1,9 @@ use crate::state::{Channel, Server, State, StatePhase}; +use log::*; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use tokio::sync::mpsc; -use log::*; #[derive(Clone, Debug)] pub enum Command { @@ -29,7 +29,7 @@ pub enum CommandResponse { Status { username: Option<String>, server_state: Server, - } + }, } pub async fn handle( diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 6d435fa..797b71f 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -1,10 +1,10 @@ mod audio; -mod network; mod command; +mod network; mod state; -use crate::network::ConnectionInfo; use crate::command::{Command, CommandResponse}; +use crate::network::ConnectionInfo; use crate::state::State; use argparse::ArgumentParser; @@ -17,8 +17,8 @@ use mumble_protocol::control::ControlPacket; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::Serverbound; use std::sync::{Arc, Mutex}; -use tokio::sync::{mpsc, watch}; use std::time::Duration; +use tokio::sync::{mpsc, watch}; #[tokio::main] async fn main() { @@ -31,20 +31,25 @@ async fn main() { //TODO runtime flag that disables color match record.level() { Level::Error => "ERROR".red(), - Level::Warn => "WARN ".yellow(), - Level::Info => "INFO ".normal(), + Level::Warn => "WARN ".yellow(), + Level::Info => "INFO ".normal(), Level::Debug => "DEBUG".green(), Level::Trace => "TRACE".normal(), }, record.file().unwrap(), record.line().unwrap(), - if message.chars().any(|e| e == '\n') { "\n" } else { " " }, + if message.chars().any(|e| e == '\n') { + "\n" + } else { + " " + }, message )) }) .level(log::LevelFilter::Debug) .chain(std::io::stderr()) - .apply().unwrap(); + .apply() + .unwrap(); // Handle command line arguments let mut server_host = "".to_string(); @@ -74,10 +79,16 @@ async fn main() { let (crypt_state_sender, crypt_state_receiver) = mpsc::channel::<ClientCryptState>(1); // crypt state should always be consumed before sending a new one let (packet_sender, packet_receiver) = mpsc::unbounded_channel::<ControlPacket<Serverbound>>(); let (command_sender, command_receiver) = mpsc::unbounded_channel::<Command>(); - let (command_response_sender, command_response_receiver) = mpsc::unbounded_channel::<Result<Option<CommandResponse>, ()>>(); - let (connection_info_sender, connection_info_receiver) = watch::channel::<Option<ConnectionInfo>>(None); + let (command_response_sender, command_response_receiver) = + mpsc::unbounded_channel::<Result<Option<CommandResponse>, ()>>(); + let (connection_info_sender, connection_info_receiver) = + watch::channel::<Option<ConnectionInfo>>(None); - let state = State::new(packet_sender, command_sender.clone(), connection_info_sender); + let state = State::new( + packet_sender, + command_sender.clone(), + connection_info_sender, + ); let state = Arc::new(Mutex::new(state)); // Run it @@ -93,18 +104,17 @@ async fn main() { connection_info_receiver.clone(), crypt_state_receiver, ), - command::handle( - state, - command_receiver, - command_response_sender, - ), + command::handle(state, command_receiver, command_response_sender,), send_commands( command_sender, - Command::ServerConnect{host: server_host, port: server_port, username: username.clone(), accept_invalid_cert} - ), - receive_command_responses( - command_response_receiver, + Command::ServerConnect { + host: server_host, + port: server_port, + username: username.clone(), + accept_invalid_cert + } ), + receive_command_responses(command_response_receiver,), ); } diff --git a/mumd/src/network/mod.rs b/mumd/src/network/mod.rs index 777faad..1a31ee2 100644 --- a/mumd/src/network/mod.rs +++ b/mumd/src/network/mod.rs @@ -11,11 +11,7 @@ pub struct ConnectionInfo { } impl ConnectionInfo { - pub fn new( - socket_addr: SocketAddr, - hostname: String, - accept_invalid_cert: bool, - ) -> Self { + pub fn new(socket_addr: SocketAddr, hostname: String, accept_invalid_cert: bool) -> Self { Self { socket_addr, hostname, diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 0a53266..e096843 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,16 +2,16 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use futures::{join, select, pin_mut, SinkExt, StreamExt, FutureExt}; +use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::{Clientbound, Serverbound}; use std::convert::{Into, TryInto}; -use std::net::{SocketAddr}; +use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use tokio::net::TcpStream; -use tokio::sync::{mpsc, watch, oneshot}; +use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; @@ -32,15 +32,21 @@ pub async fn handle( loop { let connection_info = loop { match connection_info_receiver.recv().await { - None => { return; } + None => { + return; + } Some(None) => {} - Some(Some(connection_info)) => { break connection_info; } + Some(Some(connection_info)) => { + break connection_info; + } } }; - let (mut sink, stream) = connect(connection_info.socket_addr, - connection_info.hostname, - connection_info.accept_invalid_cert) - .await; + let (mut sink, stream) = connect( + connection_info.socket_addr, + connection_info.hostname, + connection_info.accept_invalid_cert, + ) + .await; // Handshake (omitting `Version` message for brevity) let state_lock = state.lock().unwrap(); @@ -53,7 +59,12 @@ pub async fn handle( join!( send_pings(packet_sender, 10, phase_watcher.clone()), - listen(Arc::clone(&state), stream, crypt_state_sender.clone(), phase_watcher.clone()), + listen( + Arc::clone(&state), + stream, + crypt_state_sender.clone(), + phase_watcher.clone() + ), send_packets(sink, &mut packet_receiver, phase_watcher), ); @@ -101,7 +112,10 @@ async fn send_pings( ) { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { - while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} tx.send(true).unwrap(); }; @@ -140,7 +154,10 @@ async fn send_packets( ) { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { - while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} tx.send(true).unwrap(); }; @@ -190,7 +207,10 @@ async fn listen( let (tx, rx) = oneshot::channel(); let phase_transition_block = async { - while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} tx.send(true).unwrap(); }; @@ -240,11 +260,13 @@ async fn listen( ControlPacket::ServerSync(msg) => { info!("Logged in"); if let Some(mut sender) = crypt_state_sender.take() { - let _ = sender.send( - crypt_state - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ).await; + let _ = sender + .send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ) + .await; } let mut state = state.lock().unwrap(); let server = state.server_mut().unwrap(); @@ -272,20 +294,32 @@ async fn listen( } let server = state.server_mut().unwrap(); let user = server.users().get(&session).unwrap(); - info!("User {} connected to {}", - user.name(), - user.channel()); + info!("User {} connected to {}", user.name(), user.channel()); } ControlPacket::UserRemove(msg) => { info!("User {} left", msg.get_session()); - state.lock().unwrap().audio_mut().remove_client(msg.get_session()); + state + .lock() + .unwrap() + .audio_mut() + .remove_client(msg.get_session()); } ControlPacket::ChannelState(msg) => { debug!("Channel state received"); - state.lock().unwrap().server_mut().unwrap().parse_channel_state(msg); //TODO parse initial if initial + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_state(msg); //TODO parse initial if initial } ControlPacket::ChannelRemove(msg) => { - state.lock().unwrap().server_mut().unwrap().parse_channel_remove(msg); + state + .lock() + .unwrap() + .server_mut() + .unwrap() + .parse_channel_remove(msg); } _ => {} } diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index 31e33e3..4f96c4c 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -3,7 +3,7 @@ use crate::state::{State, StatePhase}; use log::*; use bytes::Bytes; -use futures::{join, select, pin_mut, SinkExt, StreamExt, FutureExt}; +use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::voice::{VoicePacket, VoicePacketPayload}; @@ -11,7 +11,7 @@ use mumble_protocol::Serverbound; use std::net::{Ipv6Addr, SocketAddr}; use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; -use tokio::sync::{watch, oneshot, mpsc}; +use tokio::sync::{mpsc, oneshot, watch}; use tokio_util::udp::UdpFramed; type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>; @@ -27,9 +27,13 @@ pub async fn handle( loop { let connection_info = loop { match connection_info_receiver.recv().await { - None => { return; } + None => { + return; + } Some(None) => {} - Some(Some(connection_info)) => { break connection_info; } + Some(Some(connection_info)) => { + break connection_info; + } } }; let (mut sink, source) = connect(&mut crypt_state).await; @@ -44,7 +48,12 @@ pub async fn handle( let phase_watcher = state.lock().unwrap().phase_receiver(); join!( listen(Arc::clone(&state), source, phase_watcher.clone()), - send_voice(sink, connection_info.socket_addr, phase_watcher, &mut receiver), + send_voice( + sink, + connection_info.socket_addr, + phase_watcher, + &mut receiver + ), ); debug!("Fully disconnected UDP stream, waiting for new connection info"); @@ -78,7 +87,10 @@ async fn listen( ) { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { - while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} tx.send(true).unwrap(); }; @@ -122,7 +134,11 @@ async fn listen( // position_info, .. } => { - state.lock().unwrap().audio().decode_packet(session_id, payload); + state + .lock() + .unwrap() + .audio() + .decode_packet(session_id, payload); } } } @@ -159,7 +175,10 @@ async fn send_voice( ) { let (tx, rx) = oneshot::channel(); let phase_transition_block = async { - while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + while !matches!( + phase_watcher.recv().await.unwrap(), + StatePhase::Disconnected + ) {} tx.send(true).unwrap(); }; @@ -206,4 +225,3 @@ async fn send_voice( debug!("UDP sender process killed"); } - diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 8371be9..69a462d 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -1,12 +1,12 @@ -use log::*; use crate::audio::Audio; use crate::command::{Command, CommandResponse}; use crate::network::ConnectionInfo; +use log::*; use mumble_protocol::control::msgs; use mumble_protocol::control::ControlPacket; use mumble_protocol::voice::Serverbound; -use std::collections::HashMap; use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::net::ToSocketAddrs; use tokio::sync::{mpsc, watch}; @@ -50,9 +50,12 @@ impl State { } //TODO? move bool inside Result - pub async fn handle_command(&mut self, command: Command) -> (bool, Result<Option<CommandResponse>, ()>) { + pub async fn handle_command( + &mut self, + command: Command, + ) -> (bool, Result<Option<CommandResponse>, ()>) { match command { - Command::ChannelJoin{channel_id} => { + Command::ChannelJoin { channel_id } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { warn!("Not connected"); return (false, Err(())); @@ -68,26 +71,41 @@ impl State { warn!("Not connected"); return (false, Err(())); } - (false, Ok(Some(CommandResponse::ChannelList{channels: self.server.as_ref().unwrap().channels.clone()}))) + ( + false, + Ok(Some(CommandResponse::ChannelList { + channels: self.server.as_ref().unwrap().channels.clone(), + })), + ) } - Command::ServerConnect{host, port, username, accept_invalid_cert} => { + Command::ServerConnect { + host, + port, + username, + accept_invalid_cert, + } => { if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { warn!("Tried to connect to a server while already connected"); return (false, Err(())); } self.server = Some(Server::new()); self.username = Some(username); - self.phase_watcher.0.broadcast(StatePhase::Connecting).unwrap(); + self.phase_watcher + .0 + .broadcast(StatePhase::Connecting) + .unwrap(); let socket_addr = (host.as_ref(), port) .to_socket_addrs() .expect("Failed to parse server address") .next() .expect("Failed to resolve server address"); - self.connection_info_sender.broadcast(Some(ConnectionInfo::new( - socket_addr, - host, - accept_invalid_cert, - ))).unwrap(); + self.connection_info_sender + .broadcast(Some(ConnectionInfo::new( + socket_addr, + host, + accept_invalid_cert, + ))) + .unwrap(); (true, Ok(None)) } Command::Status => { @@ -95,17 +113,23 @@ impl State { warn!("Not connected"); return (false, Err(())); } - (false, Ok(Some(CommandResponse::Status{ - username: self.username.clone(), - server_state: self.server.clone().unwrap(), - }))) + ( + false, + Ok(Some(CommandResponse::Status { + username: self.username.clone(), + server_state: self.server.clone().unwrap(), + })), + ) } Command::ServerDisconnect => { self.session_id = None; self.username = None; self.server = None; - self.phase_watcher.0.broadcast(StatePhase::Disconnected).unwrap(); + self.phase_watcher + .0 + .broadcast(StatePhase::Disconnected) + .unwrap(); (false, Ok(None)) } } @@ -127,9 +151,11 @@ impl State { } Some(session) => { if session != msg.get_session() { - error!("Got two different session IDs ({} and {}) for ourselves", + error!( + "Got two different session IDs ({} and {}) for ourselves", session, - msg.get_session()); + msg.get_session() + ); } else { debug!("Got our session ID twice"); } @@ -141,15 +167,30 @@ impl State { } pub fn initialized(&self) { - self.phase_watcher.0.broadcast(StatePhase::Connected).unwrap(); + self.phase_watcher + .0 + .broadcast(StatePhase::Connected) + .unwrap(); } - pub fn audio(&self) -> &Audio { &self.audio } - pub fn audio_mut(&mut self) -> &mut Audio { &mut self.audio } - pub fn packet_sender(&self) -> mpsc::UnboundedSender<ControlPacket<Serverbound>> { self.packet_sender.clone() } - pub fn phase_receiver(&self) -> watch::Receiver<StatePhase> { self.phase_watcher.1.clone() } - pub fn server_mut(&mut self) -> Option<&mut Server> { self.server.as_mut() } - pub fn username(&self) -> Option<&String> { self.username.as_ref() } + pub fn audio(&self) -> &Audio { + &self.audio + } + pub fn audio_mut(&mut self) -> &mut Audio { + &mut self.audio + } + pub fn packet_sender(&self) -> mpsc::UnboundedSender<ControlPacket<Serverbound>> { + self.packet_sender.clone() + } + pub fn phase_receiver(&self) -> watch::Receiver<StatePhase> { + self.phase_watcher.1.clone() + } + pub fn server_mut(&mut self) -> Option<&mut Server> { + self.server.as_mut() + } + pub fn username(&self) -> Option<&String> { + self.username.as_ref() + } } #[derive(Clone, Debug)] @@ -180,7 +221,9 @@ impl Server { return; } match self.channels.entry(msg.get_channel_id()) { - Entry::Vacant(e) => { e.insert(Channel::new(msg)); }, + Entry::Vacant(e) => { + e.insert(Channel::new(msg)); + } Entry::Occupied(mut e) => e.get_mut().parse_channel_state(msg), } } @@ -191,8 +234,12 @@ impl Server { return; } match self.channels.entry(msg.get_channel_id()) { - Entry::Vacant(_) => { warn!("Attempted to remove channel that doesn't exist"); } - Entry::Occupied(e) => { e.remove(); } + Entry::Vacant(_) => { + warn!("Attempted to remove channel that doesn't exist"); + } + Entry::Occupied(e) => { + e.remove(); + } } } @@ -202,7 +249,9 @@ impl Server { return; } match self.users.entry(msg.get_session()) { - Entry::Vacant(e) => { e.insert(User::new(msg)); }, + Entry::Vacant(e) => { + e.insert(User::new(msg)); + } Entry::Occupied(mut e) => e.get_mut().parse_user_state(msg), } } @@ -279,11 +328,11 @@ pub struct User { priority_speaker: bool, recording: bool, - suppress: bool, // by me + suppress: bool, // by me self_mute: bool, // by self self_deaf: bool, // by self - mute: bool, // by admin - deaf: bool, // by admin + mute: bool, // by admin + deaf: bool, // by admin } impl User { @@ -301,20 +350,13 @@ impl User { None }, name: msg.take_name(), - priority_speaker: msg.has_priority_speaker() - && msg.get_priority_speaker(), - recording: msg.has_recording() - && msg.get_recording(), - suppress: msg.has_suppress() - && msg.get_suppress(), - self_mute: msg.has_self_mute() - && msg.get_self_mute(), - self_deaf: msg.has_self_deaf() - && msg.get_self_deaf(), - mute: msg.has_mute() - && msg.get_mute(), - deaf: msg.has_deaf() - && msg.get_deaf(), + priority_speaker: msg.has_priority_speaker() && msg.get_priority_speaker(), + recording: msg.has_recording() && msg.get_recording(), + suppress: msg.has_suppress() && msg.get_suppress(), + self_mute: msg.has_self_mute() && msg.get_self_mute(), + self_deaf: msg.has_self_deaf() && msg.get_self_deaf(), + mute: msg.has_mute() && msg.get_mute(), + deaf: msg.has_deaf() && msg.get_deaf(), } } |
