aboutsummaryrefslogtreecommitdiffstats
path: root/mumd
diff options
context:
space:
mode:
Diffstat (limited to 'mumd')
-rw-r--r--mumd/src/audio.rs166
-rw-r--r--mumd/src/audio/input.rs52
-rw-r--r--mumd/src/audio/output.rs90
-rw-r--r--mumd/src/command.rs36
-rw-r--r--mumd/src/main.rs4
-rw-r--r--mumd/src/network/tcp.rs391
-rw-r--r--mumd/src/state.rs120
-rw-r--r--mumd/src/state/server.rs4
-rw-r--r--mumd/src/state/user.rs40
9 files changed, 509 insertions, 394 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs
index bbde547..8609a91 100644
--- a/mumd/src/audio.rs
+++ b/mumd/src/audio.rs
@@ -1,34 +1,25 @@
-use bytes::Bytes;
+pub mod input;
+pub mod output;
+
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
-use cpal::{
- InputCallbackInfo, OutputCallbackInfo, Sample, SampleFormat, SampleRate, Stream, StreamConfig,
-};
+use cpal::{SampleFormat, SampleRate, Stream, StreamConfig};
use log::*;
use mumble_protocol::voice::VoicePacketPayload;
use opus::Channels;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
-use std::collections::VecDeque;
-use std::ops::AddAssign;
-use std::sync::Arc;
-use std::sync::Mutex;
-use tokio::sync::mpsc::{self, Receiver, Sender};
-use tokio::sync::watch;
-
-struct ClientStream {
- buffer: VecDeque<f32>, //TODO ring buffer?
- opus_decoder: opus::Decoder,
-}
+use std::sync::{Arc, Mutex};
+use tokio::sync::{mpsc, watch};
pub struct Audio {
output_config: StreamConfig,
_output_stream: Stream,
_input_stream: Stream,
- input_channel_receiver: Option<Receiver<VoicePacketPayload>>,
+ input_channel_receiver: Option<mpsc::Receiver<VoicePacketPayload>>,
input_volume_sender: watch::Sender<f32>,
- client_streams: Arc<Mutex<HashMap<u32, ClientStream>>>,
+ client_streams: Arc<Mutex<HashMap<u32, output::ClientStream>>>,
}
impl Audio {
@@ -66,17 +57,17 @@ impl Audio {
let output_stream = match output_supported_sample_format {
SampleFormat::F32 => output_device.build_output_stream(
&output_config,
- output_curry_callback::<f32>(Arc::clone(&client_streams)),
+ output::curry_callback::<f32>(Arc::clone(&client_streams)),
err_fn,
),
SampleFormat::I16 => output_device.build_output_stream(
&output_config,
- output_curry_callback::<i16>(Arc::clone(&client_streams)),
+ output::curry_callback::<i16>(Arc::clone(&client_streams)),
err_fn,
),
SampleFormat::U16 => output_device.build_output_stream(
&output_config,
- output_curry_callback::<u16>(Arc::clone(&client_streams)),
+ output::curry_callback::<u16>(Arc::clone(&client_streams)),
err_fn,
),
}
@@ -102,7 +93,7 @@ impl Audio {
let input_stream = match input_supported_sample_format {
SampleFormat::F32 => input_device.build_input_stream(
&input_config,
- input_callback::<f32>(
+ input::callback::<f32>(
input_encoder,
input_sender,
input_config.sample_rate.0,
@@ -113,7 +104,7 @@ impl Audio {
),
SampleFormat::I16 => input_device.build_input_stream(
&input_config,
- input_callback::<i16>(
+ input::callback::<i16>(
input_encoder,
input_sender,
input_config.sample_rate.0,
@@ -124,7 +115,7 @@ impl Audio {
),
SampleFormat::U16 => input_device.build_input_stream(
&input_config,
- input_callback::<u16>(
+ input::callback::<u16>(
input_encoder,
input_sender,
input_config.sample_rate.0,
@@ -167,7 +158,7 @@ impl Audio {
warn!("Session id {} already exists", session_id);
}
Entry::Vacant(entry) => {
- entry.insert(ClientStream::new(
+ entry.insert(output::ClientStream::new(
self.output_config.sample_rate.0,
self.output_config.channels,
));
@@ -189,7 +180,7 @@ impl Audio {
}
}
- pub fn take_receiver(&mut self) -> Option<Receiver<VoicePacketPayload>> {
+ pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<VoicePacketPayload>> {
self.input_channel_receiver.take()
}
@@ -201,128 +192,3 @@ impl Audio {
self.input_volume_sender.broadcast(input_volume).unwrap();
}
}
-
-impl ClientStream {
- fn new(sample_rate: u32, channels: u16) -> Self {
- Self {
- buffer: VecDeque::new(),
- opus_decoder: opus::Decoder::new(
- sample_rate,
- match channels {
- 1 => Channels::Mono,
- 2 => Channels::Stereo,
- _ => unimplemented!("Only 1 or 2 channels supported, got {}", channels),
- },
- )
- .unwrap(),
- }
- }
-
- fn decode_packet(&mut self, payload: VoicePacketPayload, channels: usize) {
- match payload {
- VoicePacketPayload::Opus(bytes, _eot) => {
- let mut out: Vec<f32> = vec![0.0; 720 * channels * 4]; //720 is because that is the max size of packet we can get that we want to decode
- let parsed = self
- .opus_decoder
- .decode_float(&bytes, &mut out, false)
- .expect("Error decoding");
- out.truncate(parsed);
- self.buffer.extend(out);
- }
- _ => {
- unimplemented!("Payload type not supported");
- }
- }
- }
-}
-
-trait SaturatingAdd {
- fn saturating_add(self, rhs: Self) -> Self;
-}
-
-impl SaturatingAdd for f32 {
- fn saturating_add(self, rhs: Self) -> Self {
- match self + rhs {
- a if a < -1.0 => -1.0,
- a if a > 1.0 => 1.0,
- a => a,
- }
- }
-}
-
-impl SaturatingAdd for i16 {
- fn saturating_add(self, rhs: Self) -> Self {
- i16::saturating_add(self, rhs)
- }
-}
-
-impl SaturatingAdd for u16 {
- fn saturating_add(self, rhs: Self) -> Self {
- u16::saturating_add(self, rhs)
- }
-}
-
-fn output_curry_callback<T: Sample + AddAssign + SaturatingAdd>(
- buf: Arc<Mutex<HashMap<u32, ClientStream>>>,
-) -> impl FnMut(&mut [T], &OutputCallbackInfo) + Send + 'static {
- move |data: &mut [T], _info: &OutputCallbackInfo| {
- for sample in data.iter_mut() {
- *sample = Sample::from(&0.0);
- }
-
- let mut lock = buf.lock().unwrap();
- for client_stream in lock.values_mut() {
- for sample in data.iter_mut() {
- *sample = sample.saturating_add(Sample::from(
- &client_stream.buffer.pop_front().unwrap_or(0.0),
- ));
- }
- }
- }
-}
-
-fn input_callback<T: Sample>(
- mut opus_encoder: opus::Encoder,
- mut input_sender: Sender<VoicePacketPayload>,
- sample_rate: u32,
- input_volume_receiver: watch::Receiver<f32>,
- opus_frame_size_blocks: u32, // blocks of 2.5ms
-) -> impl FnMut(&[T], &InputCallbackInfo) + Send + 'static {
- if !(opus_frame_size_blocks == 1
- || opus_frame_size_blocks == 2
- || opus_frame_size_blocks == 4
- || opus_frame_size_blocks == 8)
- {
- panic!(
- "Unsupported amount of opus frame blocks {}",
- opus_frame_size_blocks
- );
- }
- let opus_frame_size = opus_frame_size_blocks * sample_rate / 400;
-
- let buf = Arc::new(Mutex::new(VecDeque::new()));
- move |data: &[T], _info: &InputCallbackInfo| {
- let mut buf = buf.lock().unwrap();
- let input_volume = *input_volume_receiver.borrow();
- let out: Vec<f32> = data.iter().map(|e| e.to_f32())
- .map(|e| e * input_volume)
- .collect();
- buf.extend(out);
- while buf.len() >= opus_frame_size as usize {
- let tail = buf.split_off(opus_frame_size as usize);
- let mut opus_buf: Vec<u8> = vec![0; opus_frame_size as usize];
- let result = opus_encoder
- .encode_float(&Vec::from(buf.clone()), &mut opus_buf)
- .unwrap();
- opus_buf.truncate(result);
- let bytes = Bytes::copy_from_slice(&opus_buf);
- match input_sender.try_send(VoicePacketPayload::Opus(bytes, false)) {
- Ok(_) => {}
- Err(_e) => {
- //warn!("Error sending audio packet: {:?}", e);
- }
- }
- *buf = tail;
- }
- }
-}
diff --git a/mumd/src/audio/input.rs b/mumd/src/audio/input.rs
new file mode 100644
index 0000000..4e95360
--- /dev/null
+++ b/mumd/src/audio/input.rs
@@ -0,0 +1,52 @@
+use bytes::Bytes;
+use cpal::{InputCallbackInfo, Sample};
+use mumble_protocol::voice::VoicePacketPayload;
+use std::collections::VecDeque;
+use std::sync::{Arc, Mutex};
+use tokio::sync::{mpsc, watch};
+
+pub fn callback<T: Sample>(
+ mut opus_encoder: opus::Encoder,
+ mut input_sender: mpsc::Sender<VoicePacketPayload>,
+ sample_rate: u32,
+ input_volume_receiver: watch::Receiver<f32>,
+ opus_frame_size_blocks: u32, // blocks of 2.5ms
+) -> impl FnMut(&[T], &InputCallbackInfo) + Send + 'static {
+ if !(opus_frame_size_blocks == 1
+ || opus_frame_size_blocks == 2
+ || opus_frame_size_blocks == 4
+ || opus_frame_size_blocks == 8)
+ {
+ panic!(
+ "Unsupported amount of opus frame blocks {}",
+ opus_frame_size_blocks
+ );
+ }
+ let opus_frame_size = opus_frame_size_blocks * sample_rate / 400;
+
+ let buf = Arc::new(Mutex::new(VecDeque::new()));
+ move |data: &[T], _info: &InputCallbackInfo| {
+ let mut buf = buf.lock().unwrap();
+ let input_volume = *input_volume_receiver.borrow();
+ let out: Vec<f32> = data.iter().map(|e| e.to_f32())
+ .map(|e| e * input_volume)
+ .collect();
+ buf.extend(out);
+ while buf.len() >= opus_frame_size as usize {
+ let tail = buf.split_off(opus_frame_size as usize);
+ let mut opus_buf: Vec<u8> = vec![0; opus_frame_size as usize];
+ let result = opus_encoder
+ .encode_float(&Vec::from(buf.clone()), &mut opus_buf)
+ .unwrap();
+ opus_buf.truncate(result);
+ let bytes = Bytes::copy_from_slice(&opus_buf);
+ match input_sender.try_send(VoicePacketPayload::Opus(bytes, false)) {
+ Ok(_) => {}
+ Err(_e) => {
+ //warn!("Error sending audio packet: {:?}", e);
+ }
+ }
+ *buf = tail;
+ }
+ }
+}
diff --git a/mumd/src/audio/output.rs b/mumd/src/audio/output.rs
new file mode 100644
index 0000000..94e4b21
--- /dev/null
+++ b/mumd/src/audio/output.rs
@@ -0,0 +1,90 @@
+use cpal::{OutputCallbackInfo, Sample};
+use mumble_protocol::voice::VoicePacketPayload;
+use opus::Channels;
+use std::collections::{HashMap, VecDeque};
+use std::ops::AddAssign;
+use std::sync::{Arc, Mutex};
+
+pub struct ClientStream {
+ buffer: VecDeque<f32>, //TODO ring buffer?
+ opus_decoder: opus::Decoder,
+}
+
+impl ClientStream {
+ pub fn new(sample_rate: u32, channels: u16) -> Self {
+ Self {
+ buffer: VecDeque::new(),
+ opus_decoder: opus::Decoder::new(
+ sample_rate,
+ match channels {
+ 1 => Channels::Mono,
+ 2 => Channels::Stereo,
+ _ => unimplemented!("Only 1 or 2 channels supported, got {}", channels),
+ },
+ )
+ .unwrap(),
+ }
+ }
+
+ pub fn decode_packet(&mut self, payload: VoicePacketPayload, channels: usize) {
+ match payload {
+ VoicePacketPayload::Opus(bytes, _eot) => {
+ let mut out: Vec<f32> = vec![0.0; 720 * channels * 4]; //720 is because that is the max size of packet we can get that we want to decode
+ let parsed = self
+ .opus_decoder
+ .decode_float(&bytes, &mut out, false)
+ .expect("Error decoding");
+ out.truncate(parsed);
+ self.buffer.extend(out);
+ }
+ _ => {
+ unimplemented!("Payload type not supported");
+ }
+ }
+ }
+}
+
+pub trait SaturatingAdd {
+ fn saturating_add(self, rhs: Self) -> Self;
+}
+
+impl SaturatingAdd for f32 {
+ fn saturating_add(self, rhs: Self) -> Self {
+ match self + rhs {
+ a if a < -1.0 => -1.0,
+ a if a > 1.0 => 1.0,
+ a => a,
+ }
+ }
+}
+
+impl SaturatingAdd for i16 {
+ fn saturating_add(self, rhs: Self) -> Self {
+ i16::saturating_add(self, rhs)
+ }
+}
+
+impl SaturatingAdd for u16 {
+ fn saturating_add(self, rhs: Self) -> Self {
+ u16::saturating_add(self, rhs)
+ }
+}
+
+pub fn curry_callback<T: Sample + AddAssign + SaturatingAdd>(
+ buf: Arc<Mutex<HashMap<u32, ClientStream>>>,
+) -> impl FnMut(&mut [T], &OutputCallbackInfo) + Send + 'static {
+ move |data: &mut [T], _info: &OutputCallbackInfo| {
+ for sample in data.iter_mut() {
+ *sample = Sample::from(&0.0);
+ }
+
+ let mut lock = buf.lock().unwrap();
+ for client_stream in lock.values_mut() {
+ for sample in data.iter_mut() {
+ *sample = sample.saturating_add(Sample::from(
+ &client_stream.buffer.pop_front().unwrap_or(0.0),
+ ));
+ }
+ }
+ }
+}
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index a035a26..075bfaf 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -1,10 +1,11 @@
-use crate::state::{State, StatePhase};
+use crate::state::State;
use ipc_channel::ipc::IpcSender;
use log::*;
use mumlib::command::{Command, CommandResponse};
use std::sync::{Arc, Mutex};
-use tokio::sync::mpsc;
+use tokio::sync::{mpsc, oneshot};
+use crate::network::tcp::{TcpEvent, TcpEventCallback};
pub async fn handle(
state: Arc<Mutex<State>>,
@@ -12,23 +13,26 @@ pub async fn handle(
Command,
IpcSender<mumlib::error::Result<Option<CommandResponse>>>,
)>,
+ tcp_event_register_sender: mpsc::UnboundedSender<(TcpEvent, TcpEventCallback)>,
) {
debug!("Begin listening for commands");
- while let Some(command) = command_receiver.recv().await {
- debug!("Received command {:?}", command.0);
+ while let Some((command, response_sender)) = command_receiver.recv().await {
+ debug!("Received command {:?}", command);
let mut state = state.lock().unwrap();
- let (wait_for_connected, command_response) = state.handle_command(command.0).await;
- if wait_for_connected {
- let mut watcher = state.phase_receiver();
- drop(state);
- while !matches!(watcher.recv().await.unwrap(), StatePhase::Connected) {}
+ let (event, generator) = state.handle_command(command).await;
+ drop(state);
+ if let Some(event) = event {
+ let (tx, rx) = oneshot::channel();
+ //TODO handle this error
+ let _ = tcp_event_register_sender.send((event, Box::new(move |e| {
+ let response = generator(Some(e));
+ response_sender.send(response).unwrap();
+ tx.send(()).unwrap();
+ })));
+
+ rx.await.unwrap();
+ } else {
+ response_sender.send(generator(None)).unwrap();
}
- command.1.send(command_response).unwrap();
}
- //TODO err if not connected
- //while let Some(command) = command_receiver.recv().await {
- // debug!("Parsing command {:?}", command);
- //}
-
- //debug!("Finished handling commands");
}
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index 75726f8..e88eede 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -33,6 +33,7 @@ async fn main() {
)>();
let (connection_info_sender, connection_info_receiver) =
watch::channel::<Option<ConnectionInfo>>(None);
+ let (response_sender, response_receiver) = mpsc::unbounded_channel();
let state = State::new(packet_sender, connection_info_sender);
let state = Arc::new(Mutex::new(state));
@@ -43,13 +44,14 @@ async fn main() {
connection_info_receiver.clone(),
crypt_state_sender,
packet_receiver,
+ response_receiver,
),
network::udp::handle(
Arc::clone(&state),
connection_info_receiver.clone(),
crypt_state_receiver,
),
- command::handle(state, command_receiver,),
+ command::handle(state, command_receiver, response_sender),
spawn_blocking(move || {
// IpcSender is blocking
receive_oneshot_commands(command_sender);
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 88d2b59..c2cb234 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -15,6 +15,10 @@ use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::{self, Duration};
use tokio_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
+use std::collections::HashMap;
+use std::future::Future;
+use std::rc::Rc;
+use std::cell::RefCell;
type TcpSender = SplitSink<
Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
@@ -23,11 +27,25 @@ type TcpSender = SplitSink<
type TcpReceiver =
SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>;
+pub(crate) type TcpEventCallback = Box<dyn FnOnce(&TcpEventData)>;
+
+#[derive(Debug, Clone, Hash, Eq, PartialEq)]
+pub enum TcpEvent {
+ Connected, //fires when the client has connected to a server
+ Disconnected, //fires when the client has disconnected from a server
+}
+
+pub enum TcpEventData<'a> {
+ Connected(&'a msgs::ServerSync),
+ Disconnected,
+}
+
pub async fn handle(
state: Arc<Mutex<State>>,
mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
+ mut tcp_event_register_receiver: mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
) {
loop {
let connection_info = loop {
@@ -54,6 +72,7 @@ pub async fn handle(
let phase_watcher = state_lock.phase_receiver();
let packet_sender = state_lock.packet_sender();
drop(state_lock);
+ let event_queue = Arc::new(Mutex::new(HashMap::new()));
info!("Logging in...");
@@ -63,9 +82,11 @@ pub async fn handle(
Arc::clone(&state),
stream,
crypt_state_sender.clone(),
- phase_watcher.clone()
+ Arc::clone(&event_queue),
+ phase_watcher.clone(),
),
- send_packets(sink, &mut packet_receiver, phase_watcher),
+ send_packets(sink, &mut packet_receiver, phase_watcher.clone()),
+ register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher),
);
debug!("Fully disconnected TCP stream, waiting for new connection info");
@@ -108,103 +129,207 @@ async fn authenticate(sink: &mut TcpSender, username: String) {
async fn send_pings(
packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
delay_seconds: u64,
- mut phase_watcher: watch::Receiver<StatePhase>,
+ phase_watcher: watch::Receiver<StatePhase>,
) {
- let (tx, rx) = oneshot::channel();
- let phase_transition_block = async {
- while !matches!(
- phase_watcher.recv().await.unwrap(),
- StatePhase::Disconnected
- ) {}
- tx.send(true).unwrap();
- };
+ let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs(delay_seconds))));
+ let packet_sender = Rc::new(RefCell::new(packet_sender));
- let mut interval = time::interval(Duration::from_secs(delay_seconds));
- let main_block = async {
- let rx = rx.fuse();
- pin_mut!(rx);
- loop {
- let interval_waiter = interval.tick().fuse();
- pin_mut!(interval_waiter);
- let exitor = select! {
- data = interval_waiter => Some(data),
- _ = rx => None
- };
-
- match exitor {
- Some(_) => {
- trace!("Sending ping");
- let msg = msgs::Ping::new();
- packet_sender.send(msg.into()).unwrap();
- }
- None => break,
- }
- }
- };
-
- join!(main_block, phase_transition_block);
+ run_until_disconnection(
+ || async {
+ Some(interval.borrow_mut().tick().await)
+ },
+ |_| async {
+ trace!("Sending ping");
+ let msg = msgs::Ping::new();
+ packet_sender.borrow_mut().send(msg.into()).unwrap();
+ },
+ || async {},
+ phase_watcher,
+ ).await;
debug!("Ping sender process killed");
}
async fn send_packets(
- mut sink: TcpSender,
+ sink: TcpSender,
packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
- mut phase_watcher: watch::Receiver<StatePhase>,
+ phase_watcher: watch::Receiver<StatePhase>,
) {
- let (tx, rx) = oneshot::channel();
- let phase_transition_block = async {
- while !matches!(
- phase_watcher.recv().await.unwrap(),
- StatePhase::Disconnected
- ) {}
- tx.send(true).unwrap();
- };
-
- let main_block = async {
- let rx = rx.fuse();
- pin_mut!(rx);
- loop {
- let packet_recv = packet_receiver.recv().fuse();
- pin_mut!(packet_recv);
- let exitor = select! {
- data = packet_recv => Some(data),
- _ = rx => None
- };
- match exitor {
- None => {
- break;
- }
- Some(None) => {
- warn!("Channel closed before disconnect command");
- break;
- }
- Some(Some(packet)) => {
- sink.send(packet).await.unwrap();
- }
- }
- }
+ let sink = Rc::new(RefCell::new(sink));
+ let packet_receiver = Rc::new(RefCell::new(packet_receiver));
+ run_until_disconnection(
+ || async {
+ packet_receiver.borrow_mut().recv().await
+ },
+ |packet| async {
+ sink.borrow_mut().send(packet).await.unwrap();
+ },
+ || async {
+ //clears queue of remaining packets
+ while packet_receiver.borrow_mut().try_recv().is_ok() {}
- //clears queue of remaining packets
- while packet_receiver.try_recv().is_ok() {}
-
- sink.close().await.unwrap();
- };
-
- join!(main_block, phase_transition_block);
+ sink.borrow_mut().close().await.unwrap();
+ },
+ phase_watcher,
+ ).await;
debug!("TCP packet sender killed");
}
async fn listen(
state: Arc<Mutex<State>>,
- mut stream: TcpReceiver,
+ stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
- mut phase_watcher: watch::Receiver<StatePhase>,
+ event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ phase_watcher: watch::Receiver<StatePhase>,
) {
- let mut crypt_state = None;
- let mut crypt_state_sender = Some(crypt_state_sender);
+ let crypt_state = Rc::new(RefCell::new(None));
+ let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender)));
+ let stream = Rc::new(RefCell::new(stream));
+ run_until_disconnection(
+ || async {
+ stream.borrow_mut().next().await
+ },
+ |packet| async {
+ match packet.unwrap() {
+ ControlPacket::TextMessage(msg) => {
+ info!(
+ "Got message from user with session ID {}: {}",
+ msg.get_actor(),
+ msg.get_message()
+ );
+ }
+ ControlPacket::CryptSetup(msg) => {
+ debug!("Crypt setup");
+ // Wait until we're fully connected before initiating UDP voice
+ *crypt_state.borrow_mut() = Some(ClientCryptState::new_from(
+ msg.get_key()
+ .try_into()
+ .expect("Server sent private key with incorrect size"),
+ msg.get_client_nonce()
+ .try_into()
+ .expect("Server sent client_nonce with incorrect size"),
+ msg.get_server_nonce()
+ .try_into()
+ .expect("Server sent server_nonce with incorrect size"),
+ ));
+ }
+ ControlPacket::ServerSync(msg) => {
+ info!("Logged in");
+ if let Some(mut sender) = crypt_state_sender.borrow_mut().take() {
+ let _ = sender
+ .send(
+ crypt_state
+ .borrow_mut()
+ .take()
+ .expect("Server didn't send us any CryptSetup packet!"),
+ )
+ .await;
+ }
+ if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(&TcpEventData::Connected(&msg));
+ }
+ }
+ let mut state = state.lock().unwrap();
+ let server = state.server_mut().unwrap();
+ server.parse_server_sync(*msg);
+ match &server.welcome_text {
+ Some(s) => info!("Welcome: {}", s),
+ None => info!("No welcome received"),
+ }
+ for channel in server.channels().values() {
+ info!("Found channel {}", channel.name());
+ }
+ state.initialized();
+ }
+ ControlPacket::Reject(msg) => {
+ warn!("Login rejected: {:?}", msg);
+ }
+ ControlPacket::UserState(msg) => {
+ let mut state = state.lock().unwrap();
+ let session = msg.get_session();
+ if *state.phase_receiver().borrow() == StatePhase::Connecting {
+ state.audio_mut().add_client(msg.get_session());
+ state.parse_user_state(*msg);
+ } else {
+ state.parse_user_state(*msg);
+ }
+ let server = state.server_mut().unwrap();
+ let user = server.users().get(&session).unwrap();
+ info!("User {} connected to {}", user.name(), user.channel());
+ }
+ ControlPacket::UserRemove(msg) => {
+ info!("User {} left", msg.get_session());
+ state
+ .lock()
+ .unwrap()
+ .audio_mut()
+ .remove_client(msg.get_session());
+ }
+ ControlPacket::ChannelState(msg) => {
+ debug!("Channel state received");
+ state
+ .lock()
+ .unwrap()
+ .server_mut()
+ .unwrap()
+ .parse_channel_state(*msg); //TODO parse initial if initial
+ }
+ ControlPacket::ChannelRemove(msg) => {
+ state
+ .lock()
+ .unwrap()
+ .server_mut()
+ .unwrap()
+ .parse_channel_remove(*msg);
+ }
+ _ => {}
+ }
+ },
+ || async {
+ if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(&TcpEventData::Disconnected);
+ }
+ }
+ },
+ phase_watcher,
+ ).await;
+
+ debug!("Killing TCP listener block");
+}
+
+async fn register_events(
+ tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
+ event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
+ phase_watcher: watch::Receiver<StatePhase>,
+) {
+ let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver));
+ run_until_disconnection(
+ || async {
+ tcp_event_register_receiver.borrow_mut().recv().await
+ },
+ |(event, handler)| async { event_data.lock().unwrap().entry(event).or_default().push(handler); },
+ || async {},
+ phase_watcher,
+ ).await;
+}
+
+async fn run_until_disconnection<T, F, G, H>(
+ mut generator: impl FnMut() -> F,
+ mut handler: impl FnMut(T) -> G,
+ mut shutdown: impl FnMut() -> H,
+ mut phase_watcher: watch::Receiver<StatePhase>,
+)
+ where
+ F: Future<Output = Option<T>>,
+ G: Future<Output = ()>,
+ H: Future<Output = ()>,
+{
let (tx, rx) = oneshot::channel();
let phase_transition_block = async {
while !matches!(
@@ -214,11 +339,11 @@ async fn listen(
tx.send(true).unwrap();
};
- let listener_block = async {
+ let main_block = async {
let rx = rx.fuse();
pin_mut!(rx);
loop {
- let packet_recv = stream.next().fuse();
+ let packet_recv = generator().fuse();
pin_mut!(packet_recv);
let exitor = select! {
data = packet_recv => Some(data),
@@ -229,107 +354,17 @@ async fn listen(
break;
}
Some(None) => {
- warn!("Channel closed before disconnect command");
+ //warn!("Channel closed before disconnect command"); //TODO make me informative
break;
}
- Some(Some(packet)) => {
- match packet.unwrap() {
- ControlPacket::TextMessage(msg) => {
- info!(
- "Got message from user with session ID {}: {}",
- msg.get_actor(),
- msg.get_message()
- );
- }
- ControlPacket::CryptSetup(msg) => {
- debug!("Crypt setup");
- // Wait until we're fully connected before initiating UDP voice
- crypt_state = Some(ClientCryptState::new_from(
- msg.get_key()
- .try_into()
- .expect("Server sent private key with incorrect size"),
- msg.get_client_nonce()
- .try_into()
- .expect("Server sent client_nonce with incorrect size"),
- msg.get_server_nonce()
- .try_into()
- .expect("Server sent server_nonce with incorrect size"),
- ));
- }
- ControlPacket::ServerSync(msg) => {
- info!("Logged in");
- if let Some(mut sender) = crypt_state_sender.take() {
- let _ = sender
- .send(
- crypt_state
- .take()
- .expect("Server didn't send us any CryptSetup packet!"),
- )
- .await;
- }
- let mut state = state.lock().unwrap();
- let server = state.server_mut().unwrap();
- server.parse_server_sync(*msg);
- match &server.welcome_text {
- Some(s) => info!("Welcome: {}", s),
- None => info!("No welcome received"),
- }
- for channel in server.channels().values() {
- info!("Found channel {}", channel.name());
- }
- state.initialized();
- }
- ControlPacket::Reject(msg) => {
- warn!("Login rejected: {:?}", msg);
- }
- ControlPacket::UserState(msg) => {
- let mut state = state.lock().unwrap();
- let session = msg.get_session();
- if *state.phase_receiver().borrow() == StatePhase::Connecting {
- state.audio_mut().add_client(msg.get_session());
- state.parse_initial_user_state(*msg);
- } else {
- state.server_mut().unwrap().parse_user_state(*msg);
- }
- let server = state.server_mut().unwrap();
- let user = server.users().get(&session).unwrap();
- info!("User {} connected to {}", user.name(), user.channel());
- }
- ControlPacket::UserRemove(msg) => {
- info!("User {} left", msg.get_session());
- state
- .lock()
- .unwrap()
- .audio_mut()
- .remove_client(msg.get_session());
- }
- ControlPacket::ChannelState(msg) => {
- debug!("Channel state received");
- state
- .lock()
- .unwrap()
- .server_mut()
- .unwrap()
- .parse_channel_state(*msg); //TODO parse initial if initial
- }
- ControlPacket::ChannelRemove(msg) => {
- state
- .lock()
- .unwrap()
- .server_mut()
- .unwrap()
- .parse_channel_remove(*msg);
- }
- _ => {}
- }
+ Some(Some(data)) => {
+ handler(data).await;
}
}
}
- //TODO? clean up stream
+ shutdown().await;
};
- join!(phase_transition_block, listener_block);
-
- debug!("Killing TCP listener block");
-}
+ join!(main_block, phase_transition_block);
+} \ No newline at end of file
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index d355ef5..f9ed077 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -15,6 +15,7 @@ use mumlib::config::Config;
use mumlib::error::{ChannelIdentifierError, Error};
use std::net::ToSocketAddrs;
use tokio::sync::{mpsc, watch};
+use crate::network::tcp::{TcpEvent, TcpEventData};
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum StatePhase {
@@ -56,11 +57,11 @@ impl State {
pub async fn handle_command(
&mut self,
command: Command,
- ) -> (bool, mumlib::error::Result<Option<CommandResponse>>) {
+ ) -> (Option<TcpEvent>, Box<dyn FnOnce(Option<&TcpEventData>) -> mumlib::error::Result<Option<CommandResponse>>>) {
match command {
Command::ChannelJoin { channel_identifier } => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Box::new(|_| Err(Error::DisconnectedError)));
}
let channels = self.server()
@@ -78,33 +79,34 @@ impl State {
.filter(|e| e.1.ends_with(&channel_identifier.to_lowercase()))
.collect::<Vec<_>>();
match soft_matches.len() {
- 0 => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))),
+ 0 => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))),
1 => *soft_matches.get(0).unwrap().0,
- _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid))),
+ _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Invalid)))),
}
},
1 => *matches.get(0).unwrap().0,
- _ => return (false, Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous))),
+ _ => return (None, Box::new(|_| Err(Error::ChannelIdentifierError(channel_identifier, ChannelIdentifierError::Ambiguous)))),
};
let mut msg = msgs::UserState::new();
msg.set_session(self.server.as_ref().unwrap().session_id().unwrap());
msg.set_channel_id(id);
self.packet_sender.send(msg.into()).unwrap();
- (false, Ok(None))
+ (None, Box::new(|_| Ok(None)))
}
Command::ChannelList => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Box::new(|_| Err(Error::DisconnectedError)));
}
+ let list = channel::into_channel(
+ self.server.as_ref().unwrap().channels(),
+ self.server.as_ref().unwrap().users(),
+ );
(
- false,
- Ok(Some(CommandResponse::ChannelList {
- channels: channel::into_channel(
- self.server.as_ref().unwrap().channels(),
- self.server.as_ref().unwrap().users(),
- ),
- })),
+ None,
+ Box::new(move |_| Ok(Some(CommandResponse::ChannelList {
+ channels: list,
+ }))),
)
}
Command::ServerConnect {
@@ -114,7 +116,7 @@ impl State {
accept_invalid_cert,
} => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) {
- return (false, Err(Error::AlreadyConnectedError));
+ return (None, Box::new(|_| Err(Error::AlreadyConnectedError)));
}
let mut server = Server::new();
*server.username_mut() = Some(username);
@@ -132,7 +134,7 @@ impl State {
Ok(Some(v)) => v,
_ => {
warn!("Error parsing server addr");
- return (false, Err(Error::InvalidServerAddrError(host, port)));
+ return (None, Box::new(move |_| Err(Error::InvalidServerAddrError(host, port))));
}
};
self.connection_info_sender
@@ -142,22 +144,35 @@ impl State {
accept_invalid_cert,
)))
.unwrap();
- (true, Ok(None))
+ (Some(TcpEvent::Connected), Box::new(|e| { //runs the closure when the client is connected
+ if let Some(TcpEventData::Connected(msg)) = e {
+ Ok(Some(CommandResponse::ServerConnect {
+ welcome_message: if msg.has_welcome_text() {
+ Some(msg.get_welcome_text().to_string())
+ } else {
+ None
+ }
+ }))
+ } else {
+ unreachable!("callback should be provided with a TcpEventData::Connected");
+ }
+ }))
}
Command::Status => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Box::new(|_| Err(Error::DisconnectedError)));
}
+ let state = self.server.as_ref().unwrap().into();
(
- false,
- Ok(Some(CommandResponse::Status {
- server_state: self.server.as_ref().unwrap().into(), //guaranteed not to panic because if we are connected, server is guaranteed to be Some
- })),
+ None,
+ Box::new(move |_| Ok(Some(CommandResponse::Status {
+ server_state: state, //guaranteed not to panic because if we are connected, server is guaranteed to be Some
+ }))),
)
}
Command::ServerDisconnect => {
if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
- return (false, Err(Error::DisconnectedError));
+ return (None, Box::new(|_| Err(Error::DisconnectedError)));
}
self.server = None;
@@ -167,46 +182,54 @@ impl State {
.0
.broadcast(StatePhase::Disconnected)
.unwrap();
- (false, Ok(None))
+ (None, Box::new(|_| Ok(None)))
}
Command::InputVolumeSet(volume) => {
self.audio.set_input_volume(volume);
- (false, Ok(None))
+ (None, Box::new(|_| Ok(None)))
}
Command::ConfigReload => {
self.reload_config();
- (false, Ok(None))
+ (None, Box::new(|_| Ok(None)))
}
}
}
- pub fn parse_initial_user_state(&mut self, msg: msgs::UserState) {
+ pub fn parse_user_state(&mut self, msg: msgs::UserState) -> Option<mumlib::state::UserDiff> {
if !msg.has_session() {
warn!("Can't parse user state without session");
- return;
+ return None;
}
- if !msg.has_name() {
- warn!("Missing name in initial user state");
- } else if msg.get_name() == self.server.as_ref().unwrap().username().unwrap() {
- match self.server.as_ref().unwrap().session_id() {
- None => {
- debug!("Found our session id: {}", msg.get_session());
- *self.server_mut().unwrap().session_id_mut() = Some(msg.get_session());
- }
- Some(session) => {
- if session != msg.get_session() {
- error!(
- "Got two different session IDs ({} and {}) for ourselves",
- session,
- msg.get_session()
- );
- } else {
- debug!("Got our session ID twice");
- }
- }
+ let sess = msg.get_session();
+ // check if this is initial state
+ if !self.server().unwrap().users().contains_key(&sess) {
+ if !msg.has_name() {
+ warn!("Missing name in initial user state");
+ } else if msg.get_name() == self.server().unwrap().username().unwrap() {
+ // this is us
+ *self.server_mut().unwrap().session_id_mut() = Some(sess);
+ } else {
+ // this is someone else
+ self.audio_mut().add_client(sess);
}
+ self.server_mut().unwrap().users_mut().insert(sess, user::User::new(msg));
+ None
+ } else {
+ let user = self.server_mut().unwrap().users_mut().get_mut(&sess).unwrap();
+ let diff = mumlib::state::UserDiff::from(msg);
+ user.apply_user_diff(&diff);
+ Some(diff)
}
- self.server.as_mut().unwrap().parse_user_state(msg);
+ }
+
+ pub fn remove_client(&mut self, msg: msgs::UserRemove) {
+ if !msg.has_session() {
+ warn!("Tried to remove user state without session");
+ return;
+ }
+ self.audio().remove_client(msg.get_session());
+ self.server_mut().unwrap().users_mut().remove(&msg.get_session());
+ info!("User {} disconnected", msg.get_session());
}
pub fn reload_config(&mut self) {
@@ -252,4 +275,3 @@ impl State {
self.server.as_ref().map(|e| e.username()).flatten()
}
}
-
diff --git a/mumd/src/state/server.rs b/mumd/src/state/server.rs
index b7cabb7..b99c7e6 100644
--- a/mumd/src/state/server.rs
+++ b/mumd/src/state/server.rs
@@ -98,6 +98,10 @@ impl Server {
&self.users
}
+ pub fn users_mut(&mut self) -> &mut HashMap<u32, User> {
+ &mut self.users
+ }
+
pub fn username(&self) -> Option<&str> {
self.username.as_ref().map(|e| e.as_str())
}
diff --git a/mumd/src/state/user.rs b/mumd/src/state/user.rs
index bb4e101..679d0ff 100644
--- a/mumd/src/state/user.rs
+++ b/mumd/src/state/user.rs
@@ -1,3 +1,4 @@
+use log::*;
use mumble_protocol::control::msgs;
use serde::{Deserialize, Serialize};
@@ -78,6 +79,45 @@ impl User {
}
}
+ pub fn apply_user_diff(&mut self, diff: &mumlib::state::UserDiff) {
+ debug!("applying user diff\n{:#?}", diff);
+ if let Some(comment) = diff.comment.clone() {
+ self.comment = Some(comment);
+ }
+ if let Some(hash) = diff.hash.clone() {
+ self.hash = Some(hash);
+ }
+ if let Some(name) = diff.name.clone() {
+ self.name = name;
+ }
+ if let Some(priority_speaker) = diff.priority_speaker {
+ self.priority_speaker = priority_speaker;
+ }
+ if let Some(recording) = diff.recording {
+ self.recording = recording;
+ }
+ if let Some(suppress) = diff.suppress {
+ self.suppress = suppress;
+ }
+ if let Some(self_mute) = diff.self_mute {
+ self.self_mute = self_mute;
+ }
+ if let Some(self_deaf) = diff.self_deaf {
+ self.self_deaf = self_deaf;
+ }
+ if let Some(mute) = diff.mute {
+ self.mute = mute;
+ }
+ if let Some(deaf) = diff.deaf {
+ self.deaf = deaf;
+ }
+
+ if let Some(channel_id) = diff.channel_id {
+ self.channel = channel_id;
+ }
+ }
+
+
pub fn name(&self) -> &str {
&self.name
}