aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG1
-rw-r--r--mumd/src/audio.rs59
-rw-r--r--mumd/src/audio/output.rs6
-rw-r--r--mumd/src/client.rs4
-rw-r--r--mumd/src/command.rs6
-rw-r--r--mumd/src/network.rs50
-rw-r--r--mumd/src/network/tcp.rs406
-rw-r--r--mumd/src/network/udp.rs271
-rw-r--r--mumd/src/state.rs34
9 files changed, 423 insertions, 414 deletions
diff --git a/CHANGELOG b/CHANGELOG
index 5d6d64b..468d9a6 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -19,6 +19,7 @@ Added
~~~~~
* Added a noise gate
+* Added tunneling audio through TCP if UDP connection goes down
// Changed
// ~~~~~~~
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs
index 680433c..598dde6 100644
--- a/mumd/src/audio.rs
+++ b/mumd/src/audio.rs
@@ -2,6 +2,7 @@ pub mod input;
pub mod output;
use crate::audio::output::SaturatingAdd;
+use crate::network::VoiceStreamType;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{SampleFormat, SampleRate, StreamConfig};
@@ -75,14 +76,14 @@ pub struct Audio {
_output_stream: cpal::Stream,
_input_stream: cpal::Stream,
- input_channel_receiver: Arc<Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>>,
+ input_channel_receiver: Arc<tokio::sync::Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>>,
input_volume_sender: watch::Sender<f32>,
output_volume_sender: watch::Sender<f32>,
user_volumes: Arc<Mutex<HashMap<u32, (f32, bool)>>>,
- client_streams: Arc<Mutex<HashMap<u32, output::ClientStream>>>,
+ client_streams: Arc<Mutex<HashMap<(VoiceStreamType, u32), output::ClientStream>>>,
sounds: HashMap<NotificationEvents, Vec<f32>>,
play_sounds: Arc<Mutex<VecDeque<f32>>>,
@@ -131,11 +132,11 @@ impl Audio {
let err_fn = |err| error!("An error occurred on the output audio stream: {}", err);
- let user_volumes = Arc::new(Mutex::new(HashMap::new()));
+ let user_volumes = Arc::new(std::sync::Mutex::new(HashMap::new()));
let (output_volume_sender, output_volume_receiver) = watch::channel::<f32>(output_volume);
- let play_sounds = Arc::new(Mutex::new(VecDeque::new()));
+ let play_sounds = Arc::new(std::sync::Mutex::new(VecDeque::new()));
- let client_streams = Arc::new(Mutex::new(HashMap::new()));
+ let client_streams = Arc::new(std::sync::Mutex::new(HashMap::new()));
let output_stream = match output_supported_sample_format {
SampleFormat::F32 => output_device.build_output_stream(
&output_config,
@@ -226,7 +227,7 @@ impl Audio {
_output_stream: output_stream,
_input_stream: input_stream,
input_volume_sender,
- input_channel_receiver: Arc::new(Mutex::new(Box::new(opus_stream))),
+ input_channel_receiver: Arc::new(tokio::sync::Mutex::new(Box::new(opus_stream))),
client_streams,
sounds: HashMap::new(),
output_volume_sender,
@@ -291,8 +292,8 @@ impl Audio {
.collect();
}
- pub fn decode_packet(&self, session_id: u32, payload: VoicePacketPayload) {
- match self.client_streams.lock().unwrap().entry(session_id) {
+ pub fn decode_packet_payload(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) {
+ match self.client_streams.lock().unwrap().entry((stream_type, session_id)) {
Entry::Occupied(mut entry) => {
entry
.get_mut()
@@ -305,34 +306,38 @@ impl Audio {
}
pub fn add_client(&self, session_id: u32) {
- match self.client_streams.lock().unwrap().entry(session_id) {
- Entry::Occupied(_) => {
- warn!("Session id {} already exists", session_id);
- }
- Entry::Vacant(entry) => {
- entry.insert(output::ClientStream::new(
- self.output_config.sample_rate.0,
- self.output_config.channels,
- ));
+ for stream_type in [VoiceStreamType::TCP, VoiceStreamType::UDP].iter() {
+ match self.client_streams.lock().unwrap().entry((*stream_type, session_id)) {
+ Entry::Occupied(_) => {
+ warn!("Session id {} already exists", session_id);
+ }
+ Entry::Vacant(entry) => {
+ entry.insert(output::ClientStream::new(
+ self.output_config.sample_rate.0,
+ self.output_config.channels,
+ ));
+ }
}
}
}
pub fn remove_client(&self, session_id: u32) {
- match self.client_streams.lock().unwrap().entry(session_id) {
- Entry::Occupied(entry) => {
- entry.remove();
- }
- Entry::Vacant(_) => {
- warn!(
- "Tried to remove session id {} that doesn't exist",
- session_id
- );
+ for stream_type in [VoiceStreamType::TCP, VoiceStreamType::UDP].iter() {
+ match self.client_streams.lock().unwrap().entry((*stream_type, session_id)) {
+ Entry::Occupied(entry) => {
+ entry.remove();
+ }
+ Entry::Vacant(_) => {
+ warn!(
+ "Tried to remove session id {} that doesn't exist",
+ session_id
+ );
+ }
}
}
}
- pub fn take_receiver(&mut self) -> Arc<Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>> {
+ pub fn input_receiver(&self) -> Arc<tokio::sync::Mutex<Box<dyn Stream<Item = VoicePacket<Serverbound>> + Unpin>>> {
Arc::clone(&self.input_channel_receiver)
}
diff --git a/mumd/src/audio/output.rs b/mumd/src/audio/output.rs
index 5e0cb8d..421d395 100644
--- a/mumd/src/audio/output.rs
+++ b/mumd/src/audio/output.rs
@@ -1,3 +1,5 @@
+use crate::network::VoiceStreamType;
+
use cpal::{OutputCallbackInfo, Sample};
use mumble_protocol::voice::VoicePacketPayload;
use opus::Channels;
@@ -73,7 +75,7 @@ impl SaturatingAdd for u16 {
pub fn curry_callback<T: Sample + AddAssign + SaturatingAdd + std::fmt::Display>(
effect_sound: Arc<Mutex<VecDeque<f32>>>,
- user_bufs: Arc<Mutex<HashMap<u32, ClientStream>>>,
+ user_bufs: Arc<Mutex<HashMap<(VoiceStreamType, u32), ClientStream>>>,
output_volume_receiver: watch::Receiver<f32>,
user_volumes: Arc<Mutex<HashMap<u32, (f32, bool)>>>,
) -> impl FnMut(&mut [T], &OutputCallbackInfo) + Send + 'static {
@@ -86,7 +88,7 @@ pub fn curry_callback<T: Sample + AddAssign + SaturatingAdd + std::fmt::Display>
let mut effects_sound = effect_sound.lock().unwrap();
let mut user_bufs = user_bufs.lock().unwrap();
- for (id, client_stream) in &mut *user_bufs {
+ for ((_, id), client_stream) in &mut *user_bufs {
let (user_volume, muted) = user_volumes
.lock()
.unwrap()
diff --git a/mumd/src/client.rs b/mumd/src/client.rs
index 3613061..222e2a7 100644
--- a/mumd/src/client.rs
+++ b/mumd/src/client.rs
@@ -6,8 +6,8 @@ use futures_util::join;
use ipc_channel::ipc::IpcSender;
use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState};
use mumlib::command::{Command, CommandResponse};
-use std::sync::{Arc, Mutex};
-use tokio::sync::{mpsc, watch};
+use std::sync::Arc;
+use tokio::sync::{mpsc, watch, Mutex};
pub async fn handle(
command_receiver: mpsc::UnboundedReceiver<(
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index e77b34b..653d1fa 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -9,8 +9,8 @@ use ipc_channel::ipc::IpcSender;
use log::*;
use mumble_protocol::{Serverbound, control::ControlPacket};
use mumlib::command::{Command, CommandResponse};
-use std::sync::{Arc, Mutex};
-use tokio::sync::{mpsc, oneshot, watch};
+use std::sync::Arc;
+use tokio::sync::{mpsc, oneshot, watch, Mutex};
pub async fn handle(
state: Arc<Mutex<State>>,
@@ -26,7 +26,7 @@ pub async fn handle(
debug!("Begin listening for commands");
while let Some((command, response_sender)) = command_receiver.recv().await {
debug!("Received command {:?}", command);
- let mut state = state.lock().unwrap();
+ let mut state = state.lock().await;
let event = state.handle_command(command, &mut packet_sender, &mut connection_info_sender);
drop(state);
match event {
diff --git a/mumd/src/network.rs b/mumd/src/network.rs
index 1a31ee2..6c67b3a 100644
--- a/mumd/src/network.rs
+++ b/mumd/src/network.rs
@@ -1,7 +1,17 @@
pub mod tcp;
pub mod udp;
+use futures::Future;
+use futures::FutureExt;
+use futures::channel::oneshot;
+use futures::join;
+use futures::pin_mut;
+use futures::select;
+use log::*;
use std::net::SocketAddr;
+use tokio::sync::watch;
+
+use crate::state::StatePhase;
#[derive(Clone, Debug)]
pub struct ConnectionInfo {
@@ -19,3 +29,43 @@ impl ConnectionInfo {
}
}
}
+
+#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
+pub enum VoiceStreamType {
+ TCP,
+ UDP,
+}
+
+async fn run_until<F>(
+ phase_checker: impl Fn(StatePhase) -> bool,
+ fut: F,
+ mut phase_watcher: watch::Receiver<StatePhase>,
+) where
+ F: Future<Output = ()>,
+{
+ let (tx, rx) = oneshot::channel();
+ let phase_transition_block = async {
+ loop {
+ phase_watcher.changed().await.unwrap();
+ if phase_checker(*phase_watcher.borrow()) {
+ break;
+ }
+ }
+ if tx.send(true).is_err() {
+ warn!("future resolved before it could be cancelled");
+ }
+ };
+
+ let main_block = async {
+ let rx = rx.fuse();
+ pin_mut!(rx);
+ let fut = fut.fuse();
+ pin_mut!(fut);
+ select! {
+ _ = fut => (),
+ _ = rx => (),
+ };
+ };
+
+ join!(main_block, phase_transition_block);
+}
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 47ea311..3a32b9f 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -2,24 +2,25 @@ use crate::network::ConnectionInfo;
use crate::state::{State, StatePhase};
use log::*;
-use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt};
+use futures::{FutureExt, SinkExt, Stream, StreamExt};
use futures_util::stream::{SplitSink, SplitStream};
use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket};
use mumble_protocol::crypt::ClientCryptState;
+use mumble_protocol::voice::VoicePacket;
use mumble_protocol::{Clientbound, Serverbound};
-use std::cell::RefCell;
use std::collections::HashMap;
use std::convert::{Into, TryInto};
-use std::future::Future;
use std::net::SocketAddr;
-use std::rc::Rc;
-use std::sync::{Arc, Mutex};
+use std::sync::Arc;
use tokio::net::TcpStream;
-use tokio::sync::{mpsc, oneshot, watch};
+use tokio::sync::{mpsc, watch, Mutex};
use tokio::time::{self, Duration};
use tokio_native_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
+use super::{run_until, VoiceStreamType};
+use futures_util::future::join5;
+
type TcpSender = SplitSink<
Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
ControlPacket<Serverbound>,
@@ -65,26 +66,42 @@ pub async fn handle(
.await;
// Handshake (omitting `Version` message for brevity)
- let state_lock = state.lock().unwrap();
+ let state_lock = state.lock().await;
authenticate(&mut sink, state_lock.username().unwrap().to_string()).await;
let phase_watcher = state_lock.phase_receiver();
+ let input_receiver = state_lock.audio().input_receiver();
drop(state_lock);
let event_queue = Arc::new(Mutex::new(HashMap::new()));
info!("Logging in...");
- join!(
- send_pings(packet_sender.clone(), 10, phase_watcher.clone()),
- listen(
- Arc::clone(&state),
- stream,
- crypt_state_sender.clone(),
- Arc::clone(&event_queue),
- phase_watcher.clone(),
- ),
- send_packets(sink, &mut packet_receiver, phase_watcher.clone()),
- register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher),
- );
+ run_until(
+ |phase| matches!(phase, StatePhase::Disconnected),
+ join5(
+ send_pings(packet_sender.clone(), 10),
+ listen(
+ Arc::clone(&state),
+ stream,
+ crypt_state_sender.clone(),
+ Arc::clone(&event_queue),
+ ),
+ send_voice(
+ packet_sender.clone(),
+ Arc::clone(&input_receiver),
+ phase_watcher.clone(),
+ ),
+ send_packets(sink, &mut packet_receiver),
+ register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)),
+ ).map(|_| ()),
+ phase_watcher,
+ ).await;
+
+ if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(TcpEventData::Disconnected);
+ }
+ }
debug!("Fully disconnected TCP stream, waiting for new connection info");
}
@@ -126,232 +143,187 @@ async fn authenticate(sink: &mut TcpSender, username: String) {
async fn send_pings(
packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
delay_seconds: u64,
- phase_watcher: watch::Receiver<StatePhase>,
) {
- let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs(
- delay_seconds,
- ))));
- let packet_sender = Rc::new(RefCell::new(packet_sender));
-
- run_until_disconnection(
- || async { Some(interval.borrow_mut().tick().await) },
- |_| async {
- trace!("Sending ping");
+ let mut interval = time::interval(Duration::from_secs(delay_seconds));
+ loop {
+ interval.tick().await;
+ trace!("Sending TCP ping");
let msg = msgs::Ping::new();
- packet_sender.borrow_mut().send(msg.into()).unwrap();
- },
- || async {},
- phase_watcher,
- )
- .await;
-
- debug!("Ping sender process killed");
+ packet_sender.send(msg.into()).unwrap();
+ }
}
async fn send_packets(
- sink: TcpSender,
+ mut sink: TcpSender,
packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
- phase_watcher: watch::Receiver<StatePhase>,
) {
- let sink = Rc::new(RefCell::new(sink));
- let packet_receiver = Rc::new(RefCell::new(packet_receiver));
- run_until_disconnection(
- || async { packet_receiver.borrow_mut().recv().await },
- |packet| async {
- sink.borrow_mut().send(packet).await.unwrap();
- },
- || async {
- sink.borrow_mut().close().await.unwrap();
- },
- phase_watcher,
- )
- .await;
+ loop {
+ let packet = packet_receiver.recv().await.unwrap();
+ sink.send(packet).await.unwrap();
+ }
+}
- debug!("TCP packet sender killed");
+async fn send_voice(
+ packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
+ receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
+ phase_watcher: watch::Receiver<StatePhase>,
+) {
+ loop {
+ let mut inner_phase_watcher = phase_watcher.clone();
+ loop {
+ inner_phase_watcher.changed().await.unwrap();
+ if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) {
+ break;
+ }
+ }
+ run_until(
+ |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)),
+ async {
+ loop {
+ packet_sender.send(
+ receiver
+ .lock()
+ .await
+ .next()
+ .await
+ .unwrap()
+ .into())
+ .unwrap();
+ }
+ },
+ inner_phase_watcher.clone(),
+ ).await;
+ }
}
async fn listen(
state: Arc<Mutex<State>>,
- stream: TcpReceiver,
+ mut stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
- phase_watcher: watch::Receiver<StatePhase>,
) {
- let crypt_state = Rc::new(RefCell::new(None));
- let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender)));
+ let mut crypt_state = None;
+ let mut crypt_state_sender = Some(crypt_state_sender);
- let stream = Rc::new(RefCell::new(stream));
- run_until_disconnection(
- || async { stream.borrow_mut().next().await },
- |packet| async {
- match packet.unwrap() {
- ControlPacket::TextMessage(msg) => {
- info!(
- "Got message from user with session ID {}: {}",
- msg.get_actor(),
- msg.get_message()
- );
- }
- ControlPacket::CryptSetup(msg) => {
- debug!("Crypt setup");
- // Wait until we're fully connected before initiating UDP voice
- *crypt_state.borrow_mut() = Some(ClientCryptState::new_from(
- msg.get_key()
- .try_into()
- .expect("Server sent private key with incorrect size"),
- msg.get_client_nonce()
- .try_into()
- .expect("Server sent client_nonce with incorrect size"),
- msg.get_server_nonce()
- .try_into()
- .expect("Server sent server_nonce with incorrect size"),
- ));
+ loop {
+ let packet = stream.next().await.unwrap();
+ match packet.unwrap() {
+ ControlPacket::TextMessage(msg) => {
+ info!(
+ "Got message from user with session ID {}: {}",
+ msg.get_actor(),
+ msg.get_message()
+ );
+ }
+ ControlPacket::CryptSetup(msg) => {
+ debug!("Crypt setup");
+ // Wait until we're fully connected before initiating UDP voice
+ crypt_state = Some(ClientCryptState::new_from(
+ msg.get_key()
+ .try_into()
+ .expect("Server sent private key with incorrect size"),
+ msg.get_client_nonce()
+ .try_into()
+ .expect("Server sent client_nonce with incorrect size"),
+ msg.get_server_nonce()
+ .try_into()
+ .expect("Server sent server_nonce with incorrect size"),
+ ));
+ }
+ ControlPacket::ServerSync(msg) => {
+ info!("Logged in");
+ if let Some(sender) = crypt_state_sender.take() {
+ let _ = sender
+ .send(
+ crypt_state
+ .take()
+ .expect("Server didn't send us any CryptSetup packet!"),
+ )
+ .await;
}
- ControlPacket::ServerSync(msg) => {
- info!("Logged in");
- if let Some(sender) = crypt_state_sender.borrow_mut().take() {
- let _ = sender
- .send(
- crypt_state
- .borrow_mut()
- .take()
- .expect("Server didn't send us any CryptSetup packet!"),
- )
- .await;
+ if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(TcpEventData::Connected(&msg));
}
- if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Connected(&msg));
- }
- }
- let mut state = state.lock().unwrap();
- let server = state.server_mut().unwrap();
- server.parse_server_sync(*msg);
- match &server.welcome_text {
- Some(s) => info!("Welcome: {}", s),
- None => info!("No welcome received"),
- }
- for channel in server.channels().values() {
- info!("Found channel {}", channel.name());
- }
- state.initialized();
- }
- ControlPacket::Reject(msg) => {
- warn!("Login rejected: {:?}", msg);
}
- ControlPacket::UserState(msg) => {
- state.lock().unwrap().parse_user_state(*msg);
+ let mut state = state.lock().await;
+ let server = state.server_mut().unwrap();
+ server.parse_server_sync(*msg);
+ match &server.welcome_text {
+ Some(s) => info!("Welcome: {}", s),
+ None => info!("No welcome received"),
}
- ControlPacket::UserRemove(msg) => {
- state.lock().unwrap().remove_client(*msg);
- }
- ControlPacket::ChannelState(msg) => {
- debug!("Channel state received");
- state
- .lock()
- .unwrap()
- .server_mut()
- .unwrap()
- .parse_channel_state(*msg); //TODO parse initial if initial
- }
- ControlPacket::ChannelRemove(msg) => {
- state
- .lock()
- .unwrap()
- .server_mut()
- .unwrap()
- .parse_channel_remove(*msg);
- }
- packet => {
- debug!("Received unhandled ControlPacket {:#?}", packet);
+ for channel in server.channels().values() {
+ info!("Found channel {}", channel.name());
}
+ state.initialized();
+ }
+ ControlPacket::Reject(msg) => {
+ warn!("Login rejected: {:?}", msg);
+ }
+ ControlPacket::UserState(msg) => {
+ state.lock().await.parse_user_state(*msg);
+ }
+ ControlPacket::UserRemove(msg) => {
+ state.lock().await.remove_client(*msg);
+ }
+ ControlPacket::ChannelState(msg) => {
+ debug!("Channel state received");
+ state
+ .lock()
+ .await
+ .server_mut()
+ .unwrap()
+ .parse_channel_state(*msg); //TODO parse initial if initial
+ }
+ ControlPacket::ChannelRemove(msg) => {
+ state
+ .lock()
+ .await
+ .server_mut()
+ .unwrap()
+ .parse_channel_remove(*msg);
}
- },
- || async {
- if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Disconnected);
+ ControlPacket::UDPTunnel(msg) => {
+ match *msg {
+ VoicePacket::Ping { .. } => {}
+ VoicePacket::Audio {
+ session_id,
+ // seq_num,
+ payload,
+ // position_info,
+ ..
+ } => {
+ state
+ .lock()
+ .await
+ .audio()
+ .decode_packet_payload(
+ VoiceStreamType::TCP,
+ session_id,
+ payload);
+ }
}
}
- },
- phase_watcher,
- )
- .await;
-
- debug!("Killing TCP listener block");
+ packet => {
+ debug!("Received unhandled ControlPacket {:#?}", packet);
+ }
+ }
+ }
}
async fn register_events(
tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
- phase_watcher: watch::Receiver<StatePhase>,
) {
- let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver));
- run_until_disconnection(
- || async { tcp_event_register_receiver.borrow_mut().recv().await },
- |(event, handler)| async {
- event_data
- .lock()
- .unwrap()
- .entry(event)
- .or_default()
- .push(handler);
- },
- || async {},
- phase_watcher,
- )
- .await;
-}
-
-async fn run_until_disconnection<T, F, G, H>(
- mut generator: impl FnMut() -> F,
- mut handler: impl FnMut(T) -> G,
- mut shutdown: impl FnMut() -> H,
- mut phase_watcher: watch::Receiver<StatePhase>,
-) where
- F: Future<Output = Option<T>>,
- G: Future<Output = ()>,
- H: Future<Output = ()>,
-{
- let (tx, rx) = oneshot::channel();
- let phase_transition_block = async {
- loop {
- phase_watcher.changed().await.unwrap();
- if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) {
- break;
- }
- }
- tx.send(true).unwrap();
- };
-
- let main_block = async {
- let rx = rx.fuse();
- pin_mut!(rx);
- loop {
- let packet_recv = generator().fuse();
- pin_mut!(packet_recv);
- let exitor = select! {
- data = packet_recv => Some(data),
- _ = rx => None
- };
- match exitor {
- None => {
- break;
- }
- Some(None) => {
- //warn!("Channel closed before disconnect command"); //TODO make me informative
- break;
- }
- Some(Some(data)) => {
- handler(data).await;
- }
- }
- }
-
- shutdown().await;
- };
-
- join!(main_block, phase_transition_block);
+ loop {
+ let (event, handler) = tcp_event_register_receiver.recv().await.unwrap();
+ event_data
+ .lock()
+ .await
+ .entry(event)
+ .or_default()
+ .push(handler);
+ }
}
diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs
index 0c00029..5f24b51 100644
--- a/mumd/src/network/udp.rs
+++ b/mumd/src/network/udp.rs
@@ -1,23 +1,27 @@
use crate::network::ConnectionInfo;
use crate::state::{State, StatePhase};
-use bytes::Bytes;
-use futures::{join, pin_mut, select, FutureExt, SinkExt, StreamExt, Stream};
+use futures::{join, FutureExt, SinkExt, StreamExt, Stream};
use futures_util::stream::{SplitSink, SplitStream};
use log::*;
use mumble_protocol::crypt::ClientCryptState;
use mumble_protocol::ping::{PingPacket, PongPacket};
-use mumble_protocol::voice::{VoicePacket, VoicePacketPayload};
+use mumble_protocol::voice::VoicePacket;
use mumble_protocol::Serverbound;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::net::{Ipv6Addr, SocketAddr};
use std::rc::Rc;
-use std::sync::{Arc, Mutex};
+use std::sync::atomic::{AtomicU64, Ordering};
+use std::sync::Arc;
use tokio::net::UdpSocket;
-use tokio::sync::{mpsc, oneshot, watch};
+use tokio::sync::{mpsc, watch, Mutex};
+use tokio::time::{interval, Duration};
use tokio_util::udp::UdpFramed;
+use super::{run_until, VoiceStreamType};
+use futures_util::future::join4;
+
pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(PongPacket)>);
type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>;
@@ -28,7 +32,7 @@ pub async fn handle(
mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>,
) {
- let receiver = state.lock().unwrap().audio_mut().take_receiver();
+ let receiver = state.lock().await.audio().input_receiver();
loop {
let connection_info = 'data: loop {
@@ -39,28 +43,38 @@ pub async fn handle(
}
return;
};
- let (mut sink, source) = connect(&mut crypt_state_receiver).await;
-
- // Note: A normal application would also send periodic Ping packets, and its own audio
- // via UDP. We instead trick the server into accepting us by sending it one
- // dummy voice packet.
- send_ping(&mut sink, connection_info.socket_addr).await;
+ let (sink, source) = connect(&mut crypt_state_receiver).await;
let sink = Arc::new(Mutex::new(sink));
let source = Arc::new(Mutex::new(source));
- let phase_watcher = state.lock().unwrap().phase_receiver();
- let mut audio_receiver_lock = receiver.lock().unwrap();
- join!(
- listen(Arc::clone(&state), Arc::clone(&source), phase_watcher.clone()),
- send_voice(
- Arc::clone(&sink),
- connection_info.socket_addr,
- phase_watcher,
- &mut *audio_receiver_lock
- ),
- new_crypt_state(&mut crypt_state_receiver, sink, source)
- );
+ let phase_watcher = state.lock().await.phase_receiver();
+ let last_ping_recv = AtomicU64::new(0);
+
+ run_until(
+ |phase| matches!(phase, StatePhase::Disconnected),
+ join4(
+ listen(
+ Arc::clone(&state),
+ Arc::clone(&source),
+ &last_ping_recv,
+ ),
+ send_voice(
+ Arc::clone(&sink),
+ connection_info.socket_addr,
+ phase_watcher.clone(),
+ Arc::clone(&receiver),
+ ),
+ send_pings(
+ Arc::clone(&state),
+ Arc::clone(&sink),
+ connection_info.socket_addr,
+ &last_ping_recv,
+ ),
+ new_crypt_state(&mut crypt_state_receiver, sink, source),
+ ).map(|_| ()),
+ phase_watcher,
+ ).await;
debug!("Fully disconnected UDP stream, waiting for new connection info");
}
@@ -98,8 +112,8 @@ async fn new_crypt_state(
.await
.expect("Failed to bind UDP socket");
let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split();
- *sink.lock().unwrap() = new_sink;
- *source.lock().unwrap() = new_source;
+ *sink.lock().await = new_sink;
+ *source.lock().await = new_source;
}
}
}
@@ -107,143 +121,104 @@ async fn new_crypt_state(
async fn listen(
state: Arc<Mutex<State>>,
source: Arc<Mutex<UdpReceiver>>,
- mut phase_watcher: watch::Receiver<StatePhase>,
+ last_ping_recv: &AtomicU64,
) {
- let (tx, rx) = oneshot::channel();
- let phase_transition_block = async {
- loop {
- phase_watcher.changed().await.unwrap();
- if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) {
- break;
+ loop {
+ let packet = source.lock().await.next().await.unwrap();
+ let (packet, _src_addr) = match packet {
+ Ok(packet) => packet,
+ Err(err) => {
+ warn!("Got an invalid UDP packet: {}", err);
+ // To be expected, considering this is the internet, just ignore it
+ continue;
}
- }
- tx.send(true).unwrap();
- };
-
- let main_block = async {
- let rx = rx.fuse();
- pin_mut!(rx);
- loop {
- let mut source = source.lock().unwrap();
- let packet_recv = source.next().fuse();
- pin_mut!(packet_recv);
- let exitor = select! {
- data = packet_recv => Some(data),
- _ = rx => None
- };
- match exitor {
- None => {
- break;
- }
- Some(None) => {
- warn!("Channel closed before disconnect command");
- break;
- }
- Some(Some(packet)) => {
- let (packet, _src_addr) = match packet {
- Ok(packet) => packet,
- Err(err) => {
- warn!("Got an invalid UDP packet: {}", err);
- // To be expected, considering this is the internet, just ignore it
- continue;
- }
- };
- match packet {
- VoicePacket::Ping { .. } => {
- // Note: A normal application would handle these and only use UDP for voice
- // once it has received one.
- continue;
- }
- VoicePacket::Audio {
- session_id,
- // seq_num,
- payload,
- // position_info,
- ..
- } => {
- state
- .lock()
- .unwrap()
- .audio()
- .decode_packet(session_id, payload);
- }
- }
- }
+ };
+ match packet {
+ VoicePacket::Ping { timestamp } => {
+ state
+ .lock() //TODO clean up unnecessary lock by only updating phase if it should change
+ .await
+ .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP));
+ last_ping_recv.store(timestamp, Ordering::Relaxed);
+ }
+ VoicePacket::Audio {
+ session_id,
+ // seq_num,
+ payload,
+ // position_info,
+ ..
+ } => {
+ state
+ .lock() //TODO change so that we only have to lock audio and not the whole state
+ .await
+ .audio()
+ .decode_packet_payload(VoiceStreamType::UDP, session_id, payload);
}
}
- };
-
- join!(main_block, phase_transition_block);
-
- debug!("UDP listener process killed");
+ }
}
-async fn send_ping(sink: &mut UdpSender, server_addr: SocketAddr) {
- sink.send((
- VoicePacket::Audio {
- _dst: std::marker::PhantomData,
- target: 0,
- session_id: (),
- seq_num: 0,
- payload: VoicePacketPayload::Opus(Bytes::from([0u8; 128].as_ref()), true),
- position_info: None,
- },
- server_addr,
- ))
- .await
- .unwrap();
+async fn send_pings(
+ state: Arc<Mutex<State>>,
+ sink: Arc<Mutex<UdpSender>>,
+ server_addr: SocketAddr,
+ last_ping_recv: &AtomicU64,
+) {
+ let mut last_send = None;
+ let mut interval = interval(Duration::from_millis(1000));
+
+ loop {
+ interval.tick().await;
+ let last_recv = last_ping_recv.load(Ordering::Relaxed);
+ if last_send.is_some() && last_send.unwrap() != last_recv {
+ debug!("Sending TCP voice");
+ state
+ .lock()
+ .await
+ .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP));
+ }
+ match sink
+ .lock()
+ .await
+ .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr))
+ .await
+ {
+ Ok(_) => {
+ last_send = Some(last_recv + 1);
+ },
+ Err(e) => {
+ debug!("Error sending UDP ping: {}", e);
+ }
+ }
+ }
}
async fn send_voice(
sink: Arc<Mutex<UdpSender>>,
server_addr: SocketAddr,
- mut phase_watcher: watch::Receiver<StatePhase>,
- receiver: &mut (dyn Stream<Item = VoicePacket<Serverbound>> + Unpin),
+ phase_watcher: watch::Receiver<StatePhase>,
+ receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
) {
- pin_mut!(receiver);
- let (tx, rx) = oneshot::channel();
- let phase_transition_block = async {
+ loop {
+ let mut inner_phase_watcher = phase_watcher.clone();
loop {
- phase_watcher.changed().await.unwrap();
- if matches!(*phase_watcher.borrow(), StatePhase::Disconnected) {
+ inner_phase_watcher.changed().await.unwrap();
+ if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) {
break;
}
}
- tx.send(true).unwrap();
- };
-
- let main_block = async {
- let rx = rx.fuse();
- pin_mut!(rx);
- loop {
- let packet_recv = receiver.next().fuse();
- pin_mut!(packet_recv);
- let exitor = select! {
- data = packet_recv => Some(data),
- _ = rx => None
- };
- match exitor {
- None => {
- break;
- }
- Some(None) => {
- warn!("Channel closed before disconnect command");
- break;
+ run_until(
+ |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)),
+ async {
+ let mut receiver = receiver.lock().await;
+ loop {
+ let sending = (receiver.next().await.unwrap(), server_addr);
+ sink.lock().await.send(sending).await.unwrap();
}
- Some(Some(reply)) => {
- sink.lock()
- .unwrap()
- .send((reply, server_addr))
- .await
- .unwrap();
- }
- }
- }
- };
-
- join!(main_block, phase_transition_block);
-
- debug!("UDP sender process killed");
+ },
+ phase_watcher.clone(),
+ ).await;
+ }
}
pub async fn handle_pings(
@@ -260,7 +235,7 @@ pub async fn handle_pings(
let packet = PingPacket { id };
let packet: [u8; 12] = packet.into();
udp_socket.send_to(&packet, &socket_addr).await.unwrap();
- pending.lock().unwrap().insert(id, handle);
+ pending.lock().await.insert(id, handle);
}
};
@@ -271,7 +246,7 @@ pub async fn handle_pings(
let packet = PongPacket::try_from(buf.as_slice()).unwrap();
- if let Some(handler) = pending.lock().unwrap().remove(&packet.id) {
+ if let Some(handler) = pending.lock().await.remove(&packet.id) {
handler(packet);
}
}
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index 84247bc..2ed73b2 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -3,11 +3,11 @@ pub mod server;
pub mod user;
use crate::audio::{Audio, NotificationEvents};
-use crate::network::ConnectionInfo;
+use crate::network::{ConnectionInfo, VoiceStreamType};
+use crate::network::tcp::{TcpEvent, TcpEventData};
use crate::notify;
use crate::state::server::Server;
-use crate::network::tcp::{TcpEvent, TcpEventData};
use log::*;
use mumble_protocol::control::msgs;
use mumble_protocol::control::ControlPacket;
@@ -45,11 +45,11 @@ pub enum ExecutionContext {
),
}
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StatePhase {
Disconnected,
Connecting,
- Connected,
+ Connected(VoiceStreamType),
}
pub struct State {
@@ -85,7 +85,7 @@ impl State {
) -> ExecutionContext {
match command {
Command::ChannelJoin { channel_identifier } => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
@@ -135,7 +135,7 @@ impl State {
now!(Ok(None))
}
Command::ChannelList => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
let list = channel::into_channel(
@@ -149,7 +149,7 @@ impl State {
now!(Ok(None))
}
Command::DeafenSelf(toggle) => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
@@ -207,7 +207,7 @@ impl State {
now!(Ok(None))
}
Command::MuteOther(string, toggle) => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
@@ -242,7 +242,7 @@ impl State {
return now!(Ok(None));
}
Command::MuteSelf(toggle) => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
@@ -354,7 +354,7 @@ impl State {
})
}
Command::ServerDisconnect => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
@@ -388,7 +388,7 @@ impl State {
}),
),
Command::Status => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
let state = self.server.as_ref().unwrap().into();
@@ -397,7 +397,7 @@ impl State {
})))
}
Command::UserVolumeSet(string, volume) => {
- if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) {
+ if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
return now!(Err(Error::DisconnectedError));
}
let user_id = match self
@@ -448,7 +448,7 @@ impl State {
self.audio_mut().add_client(session);
// send notification only if we've passed the connecting phase
- if *self.phase_receiver().borrow() == StatePhase::Connected {
+ if matches!(*self.phase_receiver().borrow(), StatePhase::Connected(_)) {
let channel_id = msg.get_channel_id();
if channel_id
@@ -578,11 +578,15 @@ impl State {
}
}
- pub fn initialized(&self) {
+ pub fn broadcast_phase(&self, phase: StatePhase) {
self.phase_watcher
.0
- .send(StatePhase::Connected)
+ .send(phase)
.unwrap();
+ }
+
+ pub fn initialized(&self) {
+ self.broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP));
self.audio.play_effect(NotificationEvents::ServerConnect);
}