aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src
diff options
context:
space:
mode:
authorEskil Queseth <eskilq@kth.se>2020-10-14 19:48:05 +0200
committerEskil Queseth <eskilq@kth.se>2020-10-14 19:48:05 +0200
commita40d365aacf118b33c07f3353f277eb96c4536a8 (patch)
tree1a5e623da01745b3d2a2d1b1d5958a22cd0e382a /mumd/src
parentc0855405832ce47f75fa6e1ff7a33e51a8b36903 (diff)
parent6ac72067a75d5e1904226efb5c45bcf0e54a0ae5 (diff)
downloadmum-a40d365aacf118b33c07f3353f277eb96c4536a8.tar.gz
Merge remote-tracking branch 'origin/commands' into main
Diffstat (limited to 'mumd/src')
-rw-r--r--mumd/src/audio.rs71
-rw-r--r--mumd/src/command.rs42
-rw-r--r--mumd/src/main.rs103
-rw-r--r--mumd/src/network/mod.rs19
-rw-r--r--mumd/src/network/tcp.rs370
-rw-r--r--mumd/src/network/udp.rs236
-rw-r--r--mumd/src/state.rs259
7 files changed, 829 insertions, 271 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs
index 9b794a6..1445415 100644
--- a/mumd/src/audio.rs
+++ b/mumd/src/audio.rs
@@ -1,6 +1,5 @@
use bytes::Bytes;
-use cpal::traits::DeviceTrait;
-use cpal::traits::HostTrait;
+use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{
InputCallbackInfo, OutputCallbackInfo, Sample, SampleFormat, SampleRate, Stream, StreamConfig,
};
@@ -28,9 +27,9 @@ pub struct Audio {
pub input_config: StreamConfig,
pub input_stream: Stream,
pub input_buffer: Arc<Mutex<VecDeque<f32>>>,
- input_channel_receiver: Option<Receiver<VoicePacketPayload>>,
+ input_channel_receiver: Option<Receiver<VoicePacketPayload>>, //TODO unbounded? mbe ring buffer and drop the first packet
- client_streams: Arc<Mutex<HashMap<u32, ClientStream>>>,
+ client_streams: Arc<Mutex<HashMap<u32, ClientStream>>>, //TODO move to user state
}
//TODO split into input/output
@@ -104,31 +103,39 @@ impl Audio {
let input_stream = match input_supported_sample_format {
SampleFormat::F32 => input_device.build_input_stream(
&input_config,
- input_callback::<f32>(input_encoder,
- input_sender,
- input_config.sample_rate.0,
- 10.0),
+ input_callback::<f32>(
+ input_encoder,
+ input_sender,
+ input_config.sample_rate.0,
+ 4, // 10 ms
+ ),
err_fn,
),
SampleFormat::I16 => input_device.build_input_stream(
&input_config,
- input_callback::<i16>(input_encoder,
- input_sender,
- input_config.sample_rate.0,
- 10.0),
+ input_callback::<i16>(
+ input_encoder,
+ input_sender,
+ input_config.sample_rate.0,
+ 4, // 10 ms
+ ),
err_fn,
),
SampleFormat::U16 => input_device.build_input_stream(
&input_config,
- input_callback::<u16>(input_encoder,
- input_sender,
- input_config.sample_rate.0,
- 10.0),
+ input_callback::<u16>(
+ input_encoder,
+ input_sender,
+ input_config.sample_rate.0,
+ 4, // 10 ms
+ ),
err_fn,
),
}
.unwrap();
+ output_stream.play().unwrap();
+
Self {
output_config,
output_stream,
@@ -206,7 +213,8 @@ impl ClientStream {
match payload {
VoicePacketPayload::Opus(bytes, _eot) => {
let mut out: Vec<f32> = vec![0.0; 720 * channels * 4]; //720 is because that is the max size of packet we can get that we want to decode
- let parsed = self.opus_decoder
+ let parsed = self
+ .opus_decoder
.decode_float(&bytes, &mut out, false)
.expect("Error decoding");
out.truncate(parsed);
@@ -268,16 +276,19 @@ fn input_callback<T: Sample>(
mut opus_encoder: opus::Encoder,
mut input_sender: Sender<VoicePacketPayload>,
sample_rate: u32,
- opus_frame_size_ms: f32,
+ opus_frame_size_blocks: u32, // blocks of 2.5ms
) -> impl FnMut(&[T], &InputCallbackInfo) + Send + 'static {
- if ! ( opus_frame_size_ms == 2.5
- || opus_frame_size_ms == 5.0
- || opus_frame_size_ms == 10.0
- || opus_frame_size_ms == 20.0) {
- panic!("Unsupported opus frame size {}", opus_frame_size_ms);
+ if !(opus_frame_size_blocks == 1
+ || opus_frame_size_blocks == 2
+ || opus_frame_size_blocks == 4
+ || opus_frame_size_blocks == 8)
+ {
+ panic!(
+ "Unsupported amount of opus frame blocks {}",
+ opus_frame_size_blocks
+ );
}
- let opus_frame_size = (opus_frame_size_ms * sample_rate as f32) as u32 / 1000;
-
+ let opus_frame_size = opus_frame_size_blocks * sample_rate / 400;
let buf = Arc::new(Mutex::new(VecDeque::new()));
move |data: &[T], _info: &InputCallbackInfo| {
@@ -292,9 +303,13 @@ fn input_callback<T: Sample>(
.unwrap();
opus_buf.truncate(result);
let bytes = Bytes::copy_from_slice(&opus_buf);
- input_sender
- .try_send(VoicePacketPayload::Opus(bytes, false))
- .unwrap(); //TODO handle full buffer / disconnect
+ match input_sender.try_send(VoicePacketPayload::Opus(bytes, false)) {
+ //TODO handle full buffer / disconnect
+ Ok(_) => {}
+ Err(_e) => {
+ //warn!("Error sending audio packet: {:?}", e);
+ }
+ }
*buf = tail;
}
}
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index 5d6cca4..b4bd1b7 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -1,4 +1,12 @@
-enum Command {
+use crate::state::{Channel, Server, State, StatePhase};
+
+use log::*;
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+use tokio::sync::mpsc;
+
+#[derive(Clone, Debug)]
+pub enum Command {
ChannelJoin {
channel_id: u32,
},
@@ -12,3 +20,35 @@ enum Command {
ServerDisconnect,
Status,
}
+
+#[derive(Debug)]
+pub enum CommandResponse {
+ ChannelList {
+ channels: HashMap<u32, Channel>,
+ },
+ Status {
+ username: Option<String>,
+ server_state: Server,
+ },
+}
+
+pub async fn handle(
+ state: Arc<Mutex<State>>,
+ mut command_receiver: mpsc::UnboundedReceiver<Command>,
+ command_response_sender: mpsc::UnboundedSender<Result<Option<CommandResponse>, ()>>,
+) {
+ //TODO err if not connected
+ while let Some(command) = command_receiver.recv().await {
+ debug!("Parsing command {:?}", command);
+ let mut state = state.lock().unwrap();
+ let (wait_for_connected, command_response) = state.handle_command(command).await;
+ if wait_for_connected {
+ let mut watcher = state.phase_receiver();
+ drop(state);
+ while !matches!(watcher.recv().await.unwrap(), StatePhase::Connected) {}
+ }
+ command_response_sender.send(command_response).unwrap();
+ }
+
+ debug!("Finished handling commands");
+}
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index 2a0fcbd..f837a52 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -1,46 +1,55 @@
mod audio;
-mod network;
mod command;
+mod network;
mod state;
-use crate::audio::Audio;
-use crate::state::Server;
+
+use crate::command::{Command, CommandResponse};
+use crate::network::ConnectionInfo;
+use crate::state::State;
use argparse::ArgumentParser;
use argparse::Store;
use argparse::StoreTrue;
use colored::*;
-use cpal::traits::StreamTrait;
-use futures::channel::oneshot;
use futures::join;
use log::*;
+use mumble_protocol::control::ControlPacket;
use mumble_protocol::crypt::ClientCryptState;
-use std::net::ToSocketAddrs;
-use std::sync::Arc;
-use std::sync::Mutex;
+use mumble_protocol::voice::Serverbound;
+use std::sync::{Arc, Mutex};
+use std::time::Duration;
+use tokio::sync::{mpsc, watch};
#[tokio::main]
async fn main() {
// setup logger
fern::Dispatch::new()
.format(|out, message, record| {
+ let message = message.to_string();
out.finish(format_args!(
- "{} {}:{} {}",
+ "{} {}:{}{}{}",
//TODO runtime flag that disables color
match record.level() {
Level::Error => "ERROR".red(),
- Level::Warn => "WARN ".yellow(),
- Level::Info => "INFO ".normal(),
+ Level::Warn => "WARN ".yellow(),
+ Level::Info => "INFO ".normal(),
Level::Debug => "DEBUG".green(),
Level::Trace => "TRACE".normal(),
},
record.file().unwrap(),
record.line().unwrap(),
+ if message.chars().any(|e| e == '\n') {
+ "\n"
+ } else {
+ " "
+ },
message
))
})
.level(log::LevelFilter::Debug)
.chain(std::io::stderr())
- .apply().unwrap();
+ .apply()
+ .unwrap();
// Handle command line arguments
let mut server_host = "".to_string();
@@ -64,37 +73,69 @@ async fn main() {
);
ap.parse_args_or_exit();
}
- let server_addr = (server_host.as_ref(), server_port)
- .to_socket_addrs()
- .expect("Failed to parse server address")
- .next()
- .expect("Failed to resolve server address");
// Oneshot channel for setting UDP CryptState from control task
// For simplicity we don't deal with re-syncing, real applications would have to.
- let (crypt_state_sender, crypt_state_receiver) = oneshot::channel::<ClientCryptState>();
+ let (crypt_state_sender, crypt_state_receiver) = mpsc::channel::<ClientCryptState>(1); // crypt state should always be consumed before sending a new one
+ let (packet_sender, packet_receiver) = mpsc::unbounded_channel::<ControlPacket<Serverbound>>();
+ let (command_sender, command_receiver) = mpsc::unbounded_channel::<Command>();
+ let (command_response_sender, command_response_receiver) =
+ mpsc::unbounded_channel::<Result<Option<CommandResponse>, ()>>();
+ let (connection_info_sender, connection_info_receiver) =
+ watch::channel::<Option<ConnectionInfo>>(None);
- let audio = Audio::new();
- audio.output_stream.play().unwrap();
- let audio = Arc::new(Mutex::new(audio));
-
- let server_state = Arc::new(Mutex::new(Server::new()));
+ let state = State::new(
+ packet_sender,
+ command_sender.clone(),
+ connection_info_sender,
+ );
+ let state = Arc::new(Mutex::new(state));
// Run it
join!(
network::tcp::handle(
- server_state,
- server_addr,
- server_host,
- username,
- accept_invalid_cert,
+ Arc::clone(&state),
+ connection_info_receiver.clone(),
crypt_state_sender,
- Arc::clone(&audio),
+ packet_receiver,
),
network::udp::handle(
- server_addr,
+ Arc::clone(&state),
+ connection_info_receiver.clone(),
crypt_state_receiver,
- audio,
),
+ command::handle(state, command_receiver, command_response_sender,),
+ send_commands(
+ command_sender,
+ Command::ServerConnect {
+ host: server_host,
+ port: server_port,
+ username: username.clone(),
+ accept_invalid_cert
+ }
+ ),
+ receive_command_responses(command_response_receiver,),
);
}
+
+async fn send_commands(command_sender: mpsc::UnboundedSender<Command>, connect_command: Command) {
+ command_sender.send(connect_command.clone()).unwrap();
+ tokio::time::delay_for(Duration::from_secs(2)).await;
+ command_sender.send(Command::ServerDisconnect).unwrap();
+ tokio::time::delay_for(Duration::from_secs(2)).await;
+ command_sender.send(connect_command.clone()).unwrap();
+ tokio::time::delay_for(Duration::from_secs(2)).await;
+ command_sender.send(Command::ServerDisconnect).unwrap();
+
+ debug!("Finished sending commands");
+}
+
+async fn receive_command_responses(
+ mut command_response_receiver: mpsc::UnboundedReceiver<Result<Option<CommandResponse>, ()>>,
+) {
+ while let Some(command_response) = command_response_receiver.recv().await {
+ debug!("{:?}", command_response);
+ }
+
+ debug!("Finished receiving commands");
+}
diff --git a/mumd/src/network/mod.rs b/mumd/src/network/mod.rs
index f7a6a76..1a31ee2 100644
--- a/mumd/src/network/mod.rs
+++ b/mumd/src/network/mod.rs
@@ -1,2 +1,21 @@
pub mod tcp;
pub mod udp;
+
+use std::net::SocketAddr;
+
+#[derive(Clone, Debug)]
+pub struct ConnectionInfo {
+ socket_addr: SocketAddr,
+ hostname: String,
+ accept_invalid_cert: bool,
+}
+
+impl ConnectionInfo {
+ pub fn new(socket_addr: SocketAddr, hostname: String, accept_invalid_cert: bool) -> Self {
+ Self {
+ socket_addr,
+ hostname,
+ accept_invalid_cert,
+ }
+ }
+}
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index dde98aa..6a369e5 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -1,17 +1,17 @@
-use crate::audio::Audio;
-use crate::state::Server;
+use crate::network::ConnectionInfo;
+use crate::state::{State, StatePhase};
use log::*;
-use futures::channel::oneshot;
-use futures::{join, SinkExt, StreamExt};
+use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt};
use futures_util::stream::{SplitSink, SplitStream};
use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket};
use mumble_protocol::crypt::ClientCryptState;
use mumble_protocol::{Clientbound, Serverbound};
use std::convert::{Into, TryInto};
-use std::net::{SocketAddr};
+use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use tokio::net::TcpStream;
+use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::{self, Duration};
use tokio_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
@@ -24,26 +24,52 @@ type TcpReceiver =
SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>;
pub async fn handle(
- server: Arc<Mutex<Server>>,
- server_addr: SocketAddr,
- server_host: String,
- username: String,
- accept_invalid_cert: bool,
- crypt_state_sender: oneshot::Sender<ClientCryptState>,
- audio: Arc<Mutex<Audio>>,
+ state: Arc<Mutex<State>>,
+ mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
+ crypt_state_sender: mpsc::Sender<ClientCryptState>,
+ mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
) {
- let (sink, stream) = connect(server_addr, server_host, accept_invalid_cert).await;
- let sink = Arc::new(Mutex::new(sink));
+ loop {
+ let connection_info = loop {
+ match connection_info_receiver.recv().await {
+ None => {
+ return;
+ }
+ Some(None) => {}
+ Some(Some(connection_info)) => {
+ break connection_info;
+ }
+ }
+ };
+ let (mut sink, stream) = connect(
+ connection_info.socket_addr,
+ connection_info.hostname,
+ connection_info.accept_invalid_cert,
+ )
+ .await;
+
+ // Handshake (omitting `Version` message for brevity)
+ let state_lock = state.lock().unwrap();
+ authenticate(&mut sink, state_lock.username().unwrap().to_string()).await;
+ let phase_watcher = state_lock.phase_receiver();
+ let packet_sender = state_lock.packet_sender();
+ drop(state_lock);
- // Handshake (omitting `Version` message for brevity)
- authenticate(Arc::clone(&sink), username).await;
+ info!("Logging in...");
- info!("Logging in...");
+ join!(
+ send_pings(packet_sender, 10, phase_watcher.clone()),
+ listen(
+ Arc::clone(&state),
+ stream,
+ crypt_state_sender.clone(),
+ phase_watcher.clone()
+ ),
+ send_packets(sink, &mut packet_receiver, phase_watcher),
+ );
- join!(
- send_pings(Arc::clone(&sink), 10),
- listen(server, sink, stream, crypt_state_sender, audio),
- );
+ debug!("Fully disconnected TCP stream, waiting for new connection info");
+ }
}
async fn connect(
@@ -72,109 +98,239 @@ async fn connect(
ClientControlCodec::new().framed(tls_stream).split()
}
-async fn authenticate(sink: Arc<Mutex<TcpSender>>, username: String) {
+async fn authenticate(sink: &mut TcpSender, username: String) {
let mut msg = msgs::Authenticate::new();
msg.set_username(username);
msg.set_opus(true);
- sink.lock().unwrap().send(msg.into()).await.unwrap();
+ sink.send(msg.into()).await.unwrap();
}
-async fn send_pings(sink: Arc<Mutex<TcpSender>>, delay_seconds: u64) {
+async fn send_pings(
+ packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
+ delay_seconds: u64,
+ mut phase_watcher: watch::Receiver<StatePhase>,
+) {
+ let (tx, rx) = oneshot::channel();
+ let phase_transition_block = async {
+ while !matches!(
+ phase_watcher.recv().await.unwrap(),
+ StatePhase::Disconnected
+ ) {}
+ tx.send(true).unwrap();
+ };
+
let mut interval = time::interval(Duration::from_secs(delay_seconds));
- loop {
- interval.tick().await;
- trace!("Sending ping");
- let msg = msgs::Ping::new();
- sink.lock().unwrap().send(msg.into()).await.unwrap();
- }
+ let main_block = async {
+ let rx = rx.fuse();
+ pin_mut!(rx);
+ loop {
+ let interval_waiter = interval.tick().fuse();
+ pin_mut!(interval_waiter);
+ let exitor = select! {
+ data = interval_waiter => Some(data),
+ _ = rx => None
+ };
+
+ match exitor {
+ Some(_) => {
+ trace!("Sending ping");
+ let msg = msgs::Ping::new();
+ packet_sender.send(msg.into()).unwrap();
+ }
+ None => break,
+ }
+ }
+ };
+
+ join!(main_block, phase_transition_block);
+
+ debug!("Ping sender process killed");
+}
+
+async fn send_packets(
+ mut sink: TcpSender,
+ packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
+ mut phase_watcher: watch::Receiver<StatePhase>,
+) {
+ let (tx, rx) = oneshot::channel();
+ let phase_transition_block = async {
+ while !matches!(
+ phase_watcher.recv().await.unwrap(),
+ StatePhase::Disconnected
+ ) {}
+ tx.send(true).unwrap();
+ };
+
+ let main_block = async {
+ let rx = rx.fuse();
+ pin_mut!(rx);
+ loop {
+ let packet_recv = packet_receiver.recv().fuse();
+ pin_mut!(packet_recv);
+ let exitor = select! {
+ data = packet_recv => Some(data),
+ _ = rx => None
+ };
+ match exitor {
+ None => {
+ break;
+ }
+ Some(None) => {
+ warn!("Channel closed before disconnect command");
+ break;
+ }
+ Some(Some(packet)) => {
+ sink.send(packet).await.unwrap();
+ }
+ }
+ }
+
+ //clears queue of remaining packets
+ while packet_receiver.try_recv().is_ok() {}
+
+ sink.close().await.unwrap();
+ };
+
+ join!(main_block, phase_transition_block);
+
+ debug!("TCP packet sender killed");
}
async fn listen(
- server: Arc<Mutex<Server>>,
- sink: Arc<Mutex<TcpSender>>,
+ state: Arc<Mutex<State>>,
mut stream: TcpReceiver,
- crypt_state_sender: oneshot::Sender<ClientCryptState>,
- audio: Arc<Mutex<Audio>>,
+ crypt_state_sender: mpsc::Sender<ClientCryptState>,
+ mut phase_watcher: watch::Receiver<StatePhase>,
) {
let mut crypt_state = None;
let mut crypt_state_sender = Some(crypt_state_sender);
- while let Some(packet) = stream.next().await {
- //TODO handle types separately
- match packet.unwrap() {
- ControlPacket::TextMessage(mut msg) => {
- info!(
- "Got message from user with session ID {}: {}",
- msg.get_actor(),
- msg.get_message()
- );
- // Send reply back to server
- let mut response = msgs::TextMessage::new();
- response.mut_session().push(msg.get_actor());
- response.set_message(msg.take_message());
- let mut lock = sink.lock().unwrap();
- lock.send(response.into()).await.unwrap();
- }
- ControlPacket::CryptSetup(msg) => {
- debug!("Crypt setup");
- // Wait until we're fully connected before initiating UDP voice
- crypt_state = Some(ClientCryptState::new_from(
- msg.get_key()
- .try_into()
- .expect("Server sent private key with incorrect size"),
- msg.get_client_nonce()
- .try_into()
- .expect("Server sent client_nonce with incorrect size"),
- msg.get_server_nonce()
- .try_into()
- .expect("Server sent server_nonce with incorrect size"),
- ));
- }
- ControlPacket::ServerSync(msg) => {
- info!("Logged in");
- if let Some(sender) = crypt_state_sender.take() {
- let _ = sender.send(
- crypt_state
- .take()
- .expect("Server didn't send us any CryptSetup packet!"),
- );
+ let (tx, rx) = oneshot::channel();
+ let phase_transition_block = async {
+ while !matches!(
+ phase_watcher.recv().await.unwrap(),
+ StatePhase::Disconnected
+ ) {}
+ tx.send(true).unwrap();
+ };
+
+ let listener_block = async {
+ let rx = rx.fuse();
+ pin_mut!(rx);
+ loop {
+ let packet_recv = stream.next().fuse();
+ pin_mut!(packet_recv);
+ let exitor = select! {
+ data = packet_recv => Some(data),
+ _ = rx => None
+ };
+ match exitor {
+ None => {
+ break;
}
- let mut server = server.lock().unwrap();
- server.parse_server_sync(msg);
- match &server.welcome_text {
- Some(s) => info!("Welcome: {}", s),
- None => info!("No welcome received"),
+ Some(None) => {
+ warn!("Channel closed before disconnect command");
+ break;
}
- for (_, channel) in server.channels() {
- info!("Found channel {}", channel.name());
+ Some(Some(packet)) => {
+ //TODO handle types separately
+ match packet.unwrap() {
+ ControlPacket::TextMessage(msg) => {
+ info!(
+ "Got message from user with session ID {}: {}",
+ msg.get_actor(),
+ msg.get_message()
+ );
+ }
+ ControlPacket::CryptSetup(msg) => {
+ debug!("Crypt setup");
+ // Wait until we're fully connected before initiating UDP voice
+ crypt_state = Some(ClientCryptState::new_from(
+ msg.get_key()
+ .try_into()
+ .expect("Server sent private key with incorrect size"),
+ msg.get_client_nonce()
+ .try_into()
+ .expect("Server sent client_nonce with incorrect size"),
+ msg.get_server_nonce()
+ .try_into()
+ .expect("Server sent server_nonce with incorrect size"),
+ ));
+ }
+ ControlPacket::ServerSync(msg) => {
+ info!("Logged in");
+ if let Some(mut sender) = crypt_state_sender.take() {
+ let _ = sender
+ .send(
+ crypt_state
+ .take()
+ .expect("Server didn't send us any CryptSetup packet!"),
+ )
+ .await;
+ }
+ let mut state = state.lock().unwrap();
+ let server = state.server_mut().unwrap();
+ server.parse_server_sync(*msg);
+ match &server.welcome_text {
+ Some(s) => info!("Welcome: {}", s),
+ None => info!("No welcome received"),
+ }
+ for channel in server.channels().values() {
+ info!("Found channel {}", channel.name());
+ }
+ state.initialized();
+ }
+ ControlPacket::Reject(msg) => {
+ warn!("Login rejected: {:?}", msg);
+ }
+ ControlPacket::UserState(msg) => {
+ let mut state = state.lock().unwrap();
+ let session = msg.get_session();
+ state.audio_mut().add_client(msg.get_session()); //TODO
+ if *state.phase_receiver().borrow() == StatePhase::Connecting {
+ state.parse_initial_user_state(*msg);
+ } else {
+ state.server_mut().unwrap().parse_user_state(*msg);
+ }
+ let server = state.server_mut().unwrap();
+ let user = server.users().get(&session).unwrap();
+ info!("User {} connected to {}", user.name(), user.channel());
+ }
+ ControlPacket::UserRemove(msg) => {
+ info!("User {} left", msg.get_session());
+ state
+ .lock()
+ .unwrap()
+ .audio_mut()
+ .remove_client(msg.get_session());
+ }
+ ControlPacket::ChannelState(msg) => {
+ debug!("Channel state received");
+ state
+ .lock()
+ .unwrap()
+ .server_mut()
+ .unwrap()
+ .parse_channel_state(*msg); //TODO parse initial if initial
+ }
+ ControlPacket::ChannelRemove(msg) => {
+ state
+ .lock()
+ .unwrap()
+ .server_mut()
+ .unwrap()
+ .parse_channel_remove(*msg);
+ }
+ _ => {}
+ }
}
- sink.lock().unwrap().send(msgs::UserList::new().into()).await.unwrap();
- }
- ControlPacket::Reject(msg) => {
- warn!("Login rejected: {:?}", msg);
- }
- ControlPacket::UserState(msg) => {
- audio.lock().unwrap().add_client(msg.get_session());
- let mut server = server.lock().unwrap();
- let session = msg.get_session();
- server.parse_user_state(msg);
- let user = server.users().get(&session).unwrap();
- info!("User {} connected to {}",
- user.name(),
- user.channel());
- }
- ControlPacket::UserRemove(msg) => {
- info!("User {} left", msg.get_session());
- audio.lock().unwrap().remove_client(msg.get_session());
- }
- ControlPacket::ChannelState(msg) => {
- debug!("Channel state received");
- server.lock().unwrap().parse_channel_state(msg);
}
- ControlPacket::ChannelRemove(msg) => {
- server.lock().unwrap().parse_channel_remove(msg);
- }
- _ => {}
}
- }
+
+ //TODO? clean up stream
+ };
+
+ join!(phase_transition_block, listener_block);
+
+ debug!("Killing TCP listener block");
}
diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs
index 39f16b6..4f96c4c 100644
--- a/mumd/src/network/udp.rs
+++ b/mumd/src/network/udp.rs
@@ -1,9 +1,9 @@
-use crate::audio::Audio;
+use crate::network::ConnectionInfo;
+use crate::state::{State, StatePhase};
use log::*;
use bytes::Bytes;
-use futures::channel::oneshot;
-use futures::{join, SinkExt, StreamExt};
+use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt};
use futures_util::stream::{SplitSink, SplitStream};
use mumble_protocol::crypt::ClientCryptState;
use mumble_protocol::voice::{VoicePacket, VoicePacketPayload};
@@ -11,13 +11,57 @@ use mumble_protocol::Serverbound;
use std::net::{Ipv6Addr, SocketAddr};
use std::sync::{Arc, Mutex};
use tokio::net::UdpSocket;
+use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::udp::UdpFramed;
type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>;
type UdpReceiver = SplitStream<UdpFramed<ClientCryptState>>;
+pub async fn handle(
+ state: Arc<Mutex<State>>,
+ mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
+ mut crypt_state: mpsc::Receiver<ClientCryptState>,
+) {
+ let mut receiver = state.lock().unwrap().audio_mut().take_receiver().unwrap();
+
+ loop {
+ let connection_info = loop {
+ match connection_info_receiver.recv().await {
+ None => {
+ return;
+ }
+ Some(None) => {}
+ Some(Some(connection_info)) => {
+ break connection_info;
+ }
+ }
+ };
+ let (mut sink, source) = connect(&mut crypt_state).await;
+
+ // Note: A normal application would also send periodic Ping packets, and its own audio
+ // via UDP. We instead trick the server into accepting us by sending it one
+ // dummy voice packet.
+ send_ping(&mut sink, connection_info.socket_addr).await;
+
+ let sink = Arc::new(Mutex::new(sink));
+
+ let phase_watcher = state.lock().unwrap().phase_receiver();
+ join!(
+ listen(Arc::clone(&state), source, phase_watcher.clone()),
+ send_voice(
+ sink,
+ connection_info.socket_addr,
+ phase_watcher,
+ &mut receiver
+ ),
+ );
+
+ debug!("Fully disconnected UDP stream, waiting for new connection info");
+ }
+}
+
pub async fn connect(
- crypt_state: oneshot::Receiver<ClientCryptState>,
+ crypt_state: &mut mpsc::Receiver<ClientCryptState>,
) -> (UdpSender, UdpReceiver) {
// Bind UDP socket
let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16))
@@ -25,10 +69,10 @@ pub async fn connect(
.expect("Failed to bind UDP socket");
// Wait for initial CryptState
- let crypt_state = match crypt_state.await {
- Ok(crypt_state) => crypt_state,
+ let crypt_state = match crypt_state.recv().await {
+ Some(crypt_state) => crypt_state,
// disconnected before we received the CryptSetup packet, oh well
- Err(_) => panic!("disconnect before crypt packet received"), //TODO exit gracefully
+ None => panic!("Disconnect before crypt packet received"), //TODO exit gracefully
};
debug!("UDP connected");
@@ -37,36 +81,74 @@ pub async fn connect(
}
async fn listen(
- _sink: Arc<Mutex<UdpSender>>,
+ state: Arc<Mutex<State>>,
mut source: UdpReceiver,
- audio: Arc<Mutex<Audio>>,
+ mut phase_watcher: watch::Receiver<StatePhase>,
) {
- while let Some(packet) = source.next().await {
- let (packet, _src_addr) = match packet {
- Ok(packet) => packet,
- Err(err) => {
- warn!("Got an invalid UDP packet: {}", err);
- // To be expected, considering this is the internet, just ignore it
- continue;
- }
- };
- match packet {
- VoicePacket::Ping { .. } => {
- // Note: A normal application would handle these and only use UDP for voice
- // once it has received one.
- continue;
- }
- VoicePacket::Audio {
- session_id,
- // seq_num,
- payload,
- // position_info,
- ..
- } => {
- audio.lock().unwrap().decode_packet(session_id, payload);
+ let (tx, rx) = oneshot::channel();
+ let phase_transition_block = async {
+ while !matches!(
+ phase_watcher.recv().await.unwrap(),
+ StatePhase::Disconnected
+ ) {}
+ tx.send(true).unwrap();
+ };
+
+ let main_block = async {
+ let rx = rx.fuse();
+ pin_mut!(rx);
+ loop {
+ let packet_recv = source.next().fuse();
+ pin_mut!(packet_recv);
+ let exitor = select! {
+ data = packet_recv => Some(data),
+ _ = rx => None
+ };
+ match exitor {
+ None => {
+ break;
+ }
+ Some(None) => {
+ warn!("Channel closed before disconnect command");
+ break;
+ }
+ Some(Some(packet)) => {
+ let (packet, _src_addr) = match packet {
+ Ok(packet) => packet,
+ Err(err) => {
+ warn!("Got an invalid UDP packet: {}", err);
+ // To be expected, considering this is the internet, just ignore it
+ continue;
+ }
+ };
+ match packet {
+ VoicePacket::Ping { .. } => {
+ // Note: A normal application would handle these and only use UDP for voice
+ // once it has received one.
+ continue;
+ }
+ VoicePacket::Audio {
+ session_id,
+ // seq_num,
+ payload,
+ // position_info,
+ ..
+ } => {
+ state
+ .lock()
+ .unwrap()
+ .audio()
+ .decode_packet(session_id, payload);
+ }
+ }
+ }
}
}
- }
+ };
+
+ join!(main_block, phase_transition_block);
+
+ debug!("UDP listener process killed");
}
async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) {
@@ -88,44 +170,58 @@ async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) {
async fn send_voice(
sink: Arc<Mutex<UdpSender>>,
server_addr: SocketAddr,
- audio: Arc<Mutex<Audio>>,
+ mut phase_watcher: watch::Receiver<StatePhase>,
+ receiver: &mut mpsc::Receiver<VoicePacketPayload>,
) {
- let mut receiver = audio.lock().unwrap().take_receiver().unwrap();
+ let (tx, rx) = oneshot::channel();
+ let phase_transition_block = async {
+ while !matches!(
+ phase_watcher.recv().await.unwrap(),
+ StatePhase::Disconnected
+ ) {}
+ tx.send(true).unwrap();
+ };
- let mut count = 0;
- while let Some(payload) = receiver.recv().await {
- let reply = VoicePacket::Audio {
- _dst: std::marker::PhantomData,
- target: 0, // normal speech
- session_id: (), // unused for server-bound packets
- seq_num: count,
- payload,
- position_info: None,
- };
- count += 1;
- sink.lock()
- .unwrap()
- .send((reply, server_addr))
- .await
- .unwrap();
- }
-}
+ let main_block = async {
+ let rx = rx.fuse();
+ pin_mut!(rx);
+ let mut count = 0;
+ loop {
+ let packet_recv = receiver.recv().fuse();
+ pin_mut!(packet_recv);
+ let exitor = select! {
+ data = packet_recv => Some(data),
+ _ = rx => None
+ };
+ match exitor {
+ None => {
+ break;
+ }
+ Some(None) => {
+ warn!("Channel closed before disconnect command");
+ break;
+ }
+ Some(Some(payload)) => {
+ let reply = VoicePacket::Audio {
+ _dst: std::marker::PhantomData,
+ target: 0, // normal speech
+ session_id: (), // unused for server-bound packets
+ seq_num: count,
+ payload,
+ position_info: None,
+ };
+ count += 1;
+ sink.lock()
+ .unwrap()
+ .send((reply, server_addr))
+ .await
+ .unwrap();
+ }
+ }
+ }
+ };
-pub async fn handle(
- server_addr: SocketAddr,
- crypt_state: oneshot::Receiver<ClientCryptState>,
- audio: Arc<Mutex<Audio>>,
-) {
- let (mut sink, source) = connect(crypt_state).await;
-
- // Note: A normal application would also send periodic Ping packets, and its own audio
- // via UDP. We instead trick the server into accepting us by sending it one
- // dummy voice packet.
- send_ping(&mut sink, server_addr).await;
-
- let sink = Arc::new(Mutex::new(sink));
- join!(
- listen(Arc::clone(&sink), source, Arc::clone(&audio)),
- send_voice(sink, server_addr, audio)
- );
+ join!(main_block, phase_transition_block);
+
+ debug!("UDP sender process killed");
}
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index 1ef8467..b6fe780 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -1,8 +1,197 @@
+use crate::audio::Audio;
+use crate::command::{Command, CommandResponse};
+use crate::network::ConnectionInfo;
use log::*;
use mumble_protocol::control::msgs;
-use std::collections::HashMap;
+use mumble_protocol::control::ControlPacket;
+use mumble_protocol::voice::Serverbound;
use std::collections::hash_map::Entry;
+use std::collections::HashMap;
+use std::net::ToSocketAddrs;
+use tokio::sync::{mpsc, watch};
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum StatePhase {
+ Disconnected,
+ Connecting,
+ Connected,
+}
+
+pub struct State {
+ server: Option<Server>,
+ audio: Audio,
+
+ packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
+ command_sender: mpsc::UnboundedSender<Command>,
+ connection_info_sender: watch::Sender<Option<ConnectionInfo>>,
+
+ phase_watcher: (watch::Sender<StatePhase>, watch::Receiver<StatePhase>),
+
+ username: Option<String>,
+ session_id: Option<u32>,
+}
+
+impl State {
+ pub fn new(
+ packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
+ command_sender: mpsc::UnboundedSender<Command>,
+ connection_info_sender: watch::Sender<Option<ConnectionInfo>>,
+ ) -> Self {
+ Self {
+ server: None,
+ audio: Audio::new(),
+ packet_sender,
+ command_sender,
+ connection_info_sender,
+ phase_watcher: watch::channel(StatePhase::Disconnected),
+ username: None,
+ session_id: None,
+ }
+ }
+
+ //TODO? move bool inside Result
+ pub async fn handle_command(
+ &mut self,
+ command: Command,
+ ) -> (bool, Result<Option<CommandResponse>, ()>) {
+ match command {
+ Command::ChannelJoin { channel_id } => {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ warn!("Not connected");
+ return (false, Err(()));
+ }
+ let mut msg = msgs::UserState::new();
+ msg.set_session(self.session_id.unwrap());
+ msg.set_channel_id(channel_id);
+ self.packet_sender.send(msg.into()).unwrap();
+ (false, Ok(None))
+ }
+ Command::ChannelList => {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ warn!("Not connected");
+ return (false, Err(()));
+ }
+ (
+ false,
+ Ok(Some(CommandResponse::ChannelList {
+ channels: self.server.as_ref().unwrap().channels.clone(),
+ })),
+ )
+ }
+ Command::ServerConnect {
+ host,
+ port,
+ username,
+ accept_invalid_cert,
+ } => {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) {
+ warn!("Tried to connect to a server while already connected");
+ return (false, Err(()));
+ }
+ self.server = Some(Server::new());
+ self.username = Some(username);
+ self.phase_watcher
+ .0
+ .broadcast(StatePhase::Connecting)
+ .unwrap();
+ let socket_addr = (host.as_ref(), port)
+ .to_socket_addrs()
+ .expect("Failed to parse server address")
+ .next()
+ .expect("Failed to resolve server address");
+ self.connection_info_sender
+ .broadcast(Some(ConnectionInfo::new(
+ socket_addr,
+ host,
+ accept_invalid_cert,
+ )))
+ .unwrap();
+ (true, Ok(None))
+ }
+ Command::Status => {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ warn!("Not connected");
+ return (false, Err(()));
+ }
+ (
+ false,
+ Ok(Some(CommandResponse::Status {
+ username: self.username.clone(),
+ server_state: self.server.clone().unwrap(),
+ })),
+ )
+ }
+ Command::ServerDisconnect => {
+ self.session_id = None;
+ self.username = None;
+ self.server = None;
+
+ self.phase_watcher
+ .0
+ .broadcast(StatePhase::Disconnected)
+ .unwrap();
+ (false, Ok(None))
+ }
+ }
+ }
+
+ pub fn parse_initial_user_state(&mut self, msg: msgs::UserState) {
+ if !msg.has_session() {
+ warn!("Can't parse user state without session");
+ return;
+ }
+ if !msg.has_name() {
+ warn!("Missing name in initial user state");
+ } else if msg.get_name() == self.username.as_ref().unwrap() {
+ match self.session_id {
+ None => {
+ debug!("Found our session id: {}", msg.get_session());
+ self.session_id = Some(msg.get_session());
+ }
+ Some(session) => {
+ if session != msg.get_session() {
+ error!(
+ "Got two different session IDs ({} and {}) for ourselves",
+ session,
+ msg.get_session()
+ );
+ } else {
+ debug!("Got our session ID twice");
+ }
+ }
+ }
+ }
+ self.server.as_mut().unwrap().parse_user_state(msg);
+ }
+ pub fn initialized(&self) {
+ self.phase_watcher
+ .0
+ .broadcast(StatePhase::Connected)
+ .unwrap();
+ }
+
+ pub fn audio(&self) -> &Audio {
+ &self.audio
+ }
+ pub fn audio_mut(&mut self) -> &mut Audio {
+ &mut self.audio
+ }
+ pub fn packet_sender(&self) -> mpsc::UnboundedSender<ControlPacket<Serverbound>> {
+ self.packet_sender.clone()
+ }
+ pub fn phase_receiver(&self) -> watch::Receiver<StatePhase> {
+ self.phase_watcher.1.clone()
+ }
+ pub fn server_mut(&mut self) -> Option<&mut Server> {
+ self.server.as_mut()
+ }
+ pub fn username(&self) -> Option<&String> {
+ self.username.as_ref()
+ }
+}
+
+#[derive(Clone, Debug)]
pub struct Server {
channels: HashMap<u32, Channel>,
users: HashMap<u32, User>,
@@ -18,41 +207,49 @@ impl Server {
}
}
- pub fn parse_server_sync(&mut self, mut msg: Box<msgs::ServerSync>) {
+ pub fn parse_server_sync(&mut self, mut msg: msgs::ServerSync) {
if msg.has_welcome_text() {
self.welcome_text = Some(msg.take_welcome_text());
}
}
- pub fn parse_channel_state(&mut self, msg: Box<msgs::ChannelState>) {
+ pub fn parse_channel_state(&mut self, msg: msgs::ChannelState) {
if !msg.has_channel_id() {
warn!("Can't parse channel state without channel id");
return;
}
match self.channels.entry(msg.get_channel_id()) {
- Entry::Vacant(e) => { e.insert(Channel::new(msg)); },
+ Entry::Vacant(e) => {
+ e.insert(Channel::new(msg));
+ }
Entry::Occupied(mut e) => e.get_mut().parse_channel_state(msg),
}
}
- pub fn parse_channel_remove(&mut self, msg: Box<msgs::ChannelRemove>) {
+ pub fn parse_channel_remove(&mut self, msg: msgs::ChannelRemove) {
if !msg.has_channel_id() {
warn!("Can't parse channel remove without channel id");
return;
}
match self.channels.entry(msg.get_channel_id()) {
- Entry::Vacant(_) => { warn!("Attempted to remove channel that doesn't exist"); }
- Entry::Occupied(e) => { e.remove(); }
+ Entry::Vacant(_) => {
+ warn!("Attempted to remove channel that doesn't exist");
+ }
+ Entry::Occupied(e) => {
+ e.remove();
+ }
}
}
- pub fn parse_user_state(&mut self, msg: Box<msgs::UserState>) {
+ pub fn parse_user_state(&mut self, msg: msgs::UserState) {
if !msg.has_session() {
warn!("Can't parse user state without session");
return;
}
match self.users.entry(msg.get_session()) {
- Entry::Vacant(e) => { e.insert(User::new(msg)); },
+ Entry::Vacant(e) => {
+ e.insert(User::new(msg));
+ }
Entry::Occupied(mut e) => e.get_mut().parse_user_state(msg),
}
}
@@ -66,7 +263,7 @@ impl Server {
}
}
-
+#[derive(Clone, Debug)]
pub struct Channel {
description: Option<String>,
links: Vec<u32>,
@@ -77,7 +274,7 @@ pub struct Channel {
}
impl Channel {
- pub fn new(mut msg: Box<msgs::ChannelState>) -> Self {
+ pub fn new(mut msg: msgs::ChannelState) -> Self {
Self {
description: if msg.has_description() {
Some(msg.take_description())
@@ -96,7 +293,7 @@ impl Channel {
}
}
- pub fn parse_channel_state(&mut self, mut msg: Box<msgs::ChannelState>) {
+ pub fn parse_channel_state(&mut self, mut msg: msgs::ChannelState) {
if msg.has_description() {
self.description = Some(msg.take_description());
}
@@ -120,6 +317,7 @@ impl Channel {
}
}
+#[derive(Clone, Debug)]
pub struct User {
channel: u32,
comment: Option<String>,
@@ -128,15 +326,15 @@ pub struct User {
priority_speaker: bool,
recording: bool,
- suppress: bool, // by me
+ suppress: bool, // by me
self_mute: bool, // by self
self_deaf: bool, // by self
- mute: bool, // by admin
- deaf: bool, // by admin
+ mute: bool, // by admin
+ deaf: bool, // by admin
}
impl User {
- pub fn new(mut msg: Box<msgs::UserState>) -> Self {
+ pub fn new(mut msg: msgs::UserState) -> Self {
Self {
channel: msg.get_channel_id(),
comment: if msg.has_comment() {
@@ -150,24 +348,17 @@ impl User {
None
},
name: msg.take_name(),
- priority_speaker: msg.has_priority_speaker()
- && msg.get_priority_speaker(),
- recording: msg.has_recording()
- && msg.get_recording(),
- suppress: msg.has_suppress()
- && msg.get_suppress(),
- self_mute: msg.has_self_mute()
- && msg.get_self_mute(),
- self_deaf: msg.has_self_deaf()
- && msg.get_self_deaf(),
- mute: msg.has_mute()
- && msg.get_mute(),
- deaf: msg.has_deaf()
- && msg.get_deaf(),
- }
- }
-
- pub fn parse_user_state(&mut self, mut msg: Box<msgs::UserState>) {
+ priority_speaker: msg.has_priority_speaker() && msg.get_priority_speaker(),
+ recording: msg.has_recording() && msg.get_recording(),
+ suppress: msg.has_suppress() && msg.get_suppress(),
+ self_mute: msg.has_self_mute() && msg.get_self_mute(),
+ self_deaf: msg.has_self_deaf() && msg.get_self_deaf(),
+ mute: msg.has_mute() && msg.get_mute(),
+ deaf: msg.has_deaf() && msg.get_deaf(),
+ }
+ }
+
+ pub fn parse_user_state(&mut self, mut msg: msgs::UserState) {
if msg.has_channel_id() {
self.channel = msg.get_channel_id();
}