aboutsummaryrefslogtreecommitdiffstats
path: root/mumd
diff options
context:
space:
mode:
Diffstat (limited to 'mumd')
-rw-r--r--mumd/src/audio.rs10
-rw-r--r--mumd/src/client.rs4
-rw-r--r--mumd/src/command.rs8
-rw-r--r--mumd/src/main.rs2
-rw-r--r--mumd/src/network/tcp.rs100
-rw-r--r--mumd/src/network/udp.rs69
6 files changed, 118 insertions, 75 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs
index 3f03e61..bdc8377 100644
--- a/mumd/src/audio.rs
+++ b/mumd/src/audio.rs
@@ -31,7 +31,7 @@ use std::{
};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
-use tokio::sync::watch;
+use tokio::sync::{watch};
const SAMPLE_RATE: u32 = 48000;
@@ -132,11 +132,11 @@ impl Audio {
let err_fn = |err| error!("An error occurred on the output audio stream: {}", err);
- let user_volumes = Arc::new(Mutex::new(HashMap::new()));
+ let user_volumes = Arc::new(std::sync::Mutex::new(HashMap::new()));
let (output_volume_sender, output_volume_receiver) = watch::channel::<f32>(output_volume);
- let play_sounds = Arc::new(Mutex::new(VecDeque::new()));
+ let play_sounds = Arc::new(std::sync::Mutex::new(VecDeque::new()));
- let client_streams = Arc::new(Mutex::new(HashMap::new()));
+ let client_streams = Arc::new(std::sync::Mutex::new(HashMap::new()));
let output_stream = match output_supported_sample_format {
SampleFormat::F32 => output_device.build_output_stream(
&output_config,
@@ -292,7 +292,7 @@ impl Audio {
.collect();
}
- pub fn decode_packet(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) {
+ pub fn decode_packet_payload(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) {
match self.client_streams.lock().unwrap().entry((stream_type, session_id)) {
Entry::Occupied(mut entry) => {
entry
diff --git a/mumd/src/client.rs b/mumd/src/client.rs
index 3613061..222e2a7 100644
--- a/mumd/src/client.rs
+++ b/mumd/src/client.rs
@@ -6,8 +6,8 @@ use futures_util::join;
use ipc_channel::ipc::IpcSender;
use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState};
use mumlib::command::{Command, CommandResponse};
-use std::sync::{Arc, Mutex};
-use tokio::sync::{mpsc, watch};
+use std::sync::Arc;
+use tokio::sync::{mpsc, watch, Mutex};
pub async fn handle(
command_receiver: mpsc::UnboundedReceiver<(
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index e77b34b..b099ae1 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -9,8 +9,8 @@ use ipc_channel::ipc::IpcSender;
use log::*;
use mumble_protocol::{Serverbound, control::ControlPacket};
use mumlib::command::{Command, CommandResponse};
-use std::sync::{Arc, Mutex};
-use tokio::sync::{mpsc, oneshot, watch};
+use std::sync::Arc;
+use tokio::sync::{mpsc, oneshot, watch, Mutex};
pub async fn handle(
state: Arc<Mutex<State>>,
@@ -26,9 +26,11 @@ pub async fn handle(
debug!("Begin listening for commands");
while let Some((command, response_sender)) = command_receiver.recv().await {
debug!("Received command {:?}", command);
- let mut state = state.lock().unwrap();
+ debug!("locking state");
+ let mut state = state.lock().await;
let event = state.handle_command(command, &mut packet_sender, &mut connection_info_sender);
drop(state);
+ debug!("unlocking state");
match event {
ExecutionContext::TcpEvent(event, generator) => {
let (tx, rx) = oneshot::channel();
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index 67481f9..a8cb230 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -14,7 +14,7 @@ use std::fs;
use tokio::sync::mpsc;
use tokio::task::spawn_blocking;
-#[tokio::main]
+#[tokio::main(worker_threads = 4)]
async fn main() {
setup_logger(std::io::stderr(), true);
notify::init();
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 982e747..6f18473 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -14,9 +14,9 @@ use std::collections::HashMap;
use std::convert::{Into, TryInto};
use std::net::SocketAddr;
use std::rc::Rc;
-use std::sync::{Arc, Mutex};
+use std::sync::Arc;
use tokio::net::TcpStream;
-use tokio::sync::{mpsc, watch};
+use tokio::sync::{mpsc, watch, Mutex};
use tokio::time::{self, Duration};
use tokio_native_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
@@ -68,7 +68,7 @@ pub async fn handle(
.await;
// Handshake (omitting `Version` message for brevity)
- let state_lock = state.lock().unwrap();
+ let state_lock = state.lock().await;
authenticate(&mut sink, state_lock.username().unwrap().to_string()).await;
let phase_watcher = state_lock.phase_receiver();
let input_receiver = state_lock.audio().input_receiver();
@@ -138,19 +138,16 @@ async fn send_pings(
delay_seconds: u64,
phase_watcher: watch::Receiver<StatePhase>,
) {
- let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs(
- delay_seconds,
- ))));
- let packet_sender = Rc::new(RefCell::new(packet_sender));
+ let mut interval = time::interval(Duration::from_secs(delay_seconds));
run_until(
|phase| matches!(phase, StatePhase::Disconnected),
async {
loop {
- interval.borrow_mut().tick().await;
+ interval.tick().await;
trace!("Sending ping");
let msg = msgs::Ping::new();
- packet_sender.borrow_mut().send(msg.into()).unwrap();
+ packet_sender.send(msg.into()).unwrap();
}
},
|| async {},
@@ -167,12 +164,11 @@ async fn send_packets(
phase_watcher: watch::Receiver<StatePhase>,
) {
let sink = Rc::new(RefCell::new(sink));
- let packet_receiver = Rc::new(RefCell::new(packet_receiver));
run_until(
|phase| matches!(phase, StatePhase::Disconnected),
async {
loop {
- let packet = packet_receiver.borrow_mut().recv().await.unwrap();
+ let packet = packet_receiver.recv().await.unwrap();
sink.borrow_mut().send(packet).await.unwrap();
}
},
@@ -188,31 +184,40 @@ async fn send_packets(
async fn send_voice(
packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
- receiver: Arc<tokio::sync::Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
+ receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
phase_watcher: watch::Receiver<StatePhase>,
) {
let inner_phase_watcher = phase_watcher.clone();
run_until(
|phase| matches!(phase, StatePhase::Disconnected),
async {
- run_until(
- |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)),
- async {
- loop {
- packet_sender.send(receiver
- .lock()
- .await
- .next()
- .await
- .unwrap()
- .into())
- .unwrap();
+ loop {
+ let mut inner_phase_watcher_2 = inner_phase_watcher.clone();
+ loop {
+ inner_phase_watcher_2.changed().await.unwrap();
+ if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) {
+ break;
}
- },
- || async {},
- inner_phase_watcher.clone(),
- ).await;
- debug!("Stopped sending TCP voice");
+ }
+ run_until(
+ |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)),
+ async {
+ loop {
+ packet_sender.send(
+ receiver
+ .lock()
+ .await
+ .next()
+ .await
+ .unwrap()
+ .into())
+ .unwrap();
+ }
+ },
+ || async {},
+ inner_phase_watcher.clone(),
+ ).await;
+ }
},
|| async {},
phase_watcher,
@@ -269,13 +274,13 @@ async fn listen(
)
.await;
}
- if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) {
+ if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) {
let old = std::mem::take(vec);
for handler in old {
handler(TcpEventData::Connected(&msg));
}
}
- let mut state = state.lock().unwrap();
+ let mut state = state.lock().await;
let server = state.server_mut().unwrap();
server.parse_server_sync(*msg);
match &server.welcome_text {
@@ -291,16 +296,16 @@ async fn listen(
warn!("Login rejected: {:?}", msg);
}
ControlPacket::UserState(msg) => {
- state.lock().unwrap().parse_user_state(*msg);
+ state.lock().await.parse_user_state(*msg);
}
ControlPacket::UserRemove(msg) => {
- state.lock().unwrap().remove_client(*msg);
+ state.lock().await.remove_client(*msg);
}
ControlPacket::ChannelState(msg) => {
debug!("Channel state received");
state
.lock()
- .unwrap()
+ .await
.server_mut()
.unwrap()
.parse_channel_state(*msg); //TODO parse initial if initial
@@ -308,11 +313,32 @@ async fn listen(
ControlPacket::ChannelRemove(msg) => {
state
.lock()
- .unwrap()
+ .await
.server_mut()
.unwrap()
.parse_channel_remove(*msg);
}
+ ControlPacket::UDPTunnel(msg) => {
+ match *msg {
+ VoicePacket::Ping { .. } => {}
+ VoicePacket::Audio {
+ session_id,
+ // seq_num,
+ payload,
+ // position_info,
+ ..
+ } => {
+ state
+ .lock()
+ .await
+ .audio()
+ .decode_packet_payload(
+ VoiceStreamType::TCP,
+ session_id,
+ payload);
+ }
+ }
+ }
packet => {
debug!("Received unhandled ControlPacket {:#?}", packet);
}
@@ -320,7 +346,7 @@ async fn listen(
}
},
|| async {
- if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) {
+ if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) {
let old = std::mem::take(vec);
for handler in old {
handler(TcpEventData::Disconnected);
@@ -347,7 +373,7 @@ async fn register_events(
let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap();
event_data
.lock()
- .unwrap()
+ .await
.entry(event)
.or_default()
.push(handler);
diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs
index d35a255..25ec8d5 100644
--- a/mumd/src/network/udp.rs
+++ b/mumd/src/network/udp.rs
@@ -13,9 +13,9 @@ use std::convert::TryFrom;
use std::net::{Ipv6Addr, SocketAddr};
use std::rc::Rc;
use std::sync::atomic::{AtomicU64, Ordering};
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc};
use tokio::net::UdpSocket;
-use tokio::sync::{mpsc, oneshot, watch};
+use tokio::sync::{mpsc, oneshot, watch, Mutex};
use tokio::time::{interval, Duration};
use tokio_util::udp::UdpFramed;
@@ -31,7 +31,7 @@ pub async fn handle(
mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>,
) {
- let receiver = state.lock().unwrap().audio().input_receiver();
+ let receiver = state.lock().await.audio().input_receiver();
loop {
let connection_info = 'data: loop {
@@ -47,7 +47,7 @@ pub async fn handle(
let sink = Arc::new(Mutex::new(sink));
let source = Arc::new(Mutex::new(source));
- let phase_watcher = state.lock().unwrap().phase_receiver();
+ let phase_watcher = state.lock().await.phase_receiver();
let last_ping_recv = AtomicU64::new(0);
join!(
listen(
@@ -107,8 +107,8 @@ async fn new_crypt_state(
.await
.expect("Failed to bind UDP socket");
let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split();
- *sink.lock().unwrap() = new_sink;
- *source.lock().unwrap() = new_source;
+ *sink.lock().await = new_sink;
+ *source.lock().await = new_source;
}
}
}
@@ -134,13 +134,14 @@ async fn listen(
let rx = rx.fuse();
pin_mut!(rx);
loop {
- let mut source = source.lock().unwrap();
+ let mut source = source.lock().await;
let packet_recv = source.next().fuse();
pin_mut!(packet_recv);
let exitor = select! {
data = packet_recv => Some(data),
_ = rx => None
};
+ drop(source);
match exitor {
None => {
break;
@@ -160,9 +161,10 @@ async fn listen(
};
match packet {
VoicePacket::Ping { timestamp } => {
+ // debug!("Sending UDP voice");
state
- .lock()
- .unwrap()
+ .lock() //TODO clean up unnecessary lock by only updating phase if it should change
+ .await
.broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP));
last_ping_recv.store(timestamp, Ordering::Relaxed);
}
@@ -175,9 +177,9 @@ async fn listen(
} => {
state
.lock()
- .unwrap()
+ .await
.audio()
- .decode_packet(VoiceStreamType::UDP, session_id, payload);
+ .decode_packet_payload(VoiceStreamType::UDP, session_id, payload);
}
}
}
@@ -198,19 +200,21 @@ async fn send_pings(
) {
let mut last_send = None;
let mut interval = interval(Duration::from_millis(1000));
+ interval.tick().await; //this is so we get rid of the first instant resolve
loop {
let last_recv = last_ping_recv.load(Ordering::Relaxed);
if last_send.is_some() && last_send.unwrap() != last_recv {
+ debug!("Sending TCP voice");
state
.lock()
- .unwrap()
+ .await
.broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP));
}
match sink
.lock()
- .unwrap()
- .send((VoicePacket::Ping { timestamp: 0 }, server_addr))
+ .await
+ .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr))
.await
{
Ok(_) => {
@@ -228,22 +232,33 @@ async fn send_voice(
sink: Arc<Mutex<UdpSender>>,
server_addr: SocketAddr,
phase_watcher: watch::Receiver<StatePhase>,
- receiver: Arc<tokio::sync::Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
+ receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
) {
let inner_phase_watcher = phase_watcher.clone();
run_until(
|phase| matches!(phase, StatePhase::Disconnected),
async {
- run_until(
- |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)),
- async {
- debug!("Sending UDP audio");
- sink.lock().unwrap().send((receiver.lock().await.next().await.unwrap(), server_addr)).await.unwrap();
- debug!("Sent UDP audio");
- },
- || async {},
- inner_phase_watcher.clone(),
- ).await;
+ loop {
+ let mut inner_phase_watcher_2 = inner_phase_watcher.clone();
+ loop {
+ inner_phase_watcher_2.changed().await.unwrap();
+ if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) {
+ break;
+ }
+ }
+ run_until(
+ |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)),
+ async {
+ let mut receiver = receiver.lock().await;
+ loop {
+ let sending = (receiver.next().await.unwrap(), server_addr);
+ sink.lock().await.send(sending).await.unwrap();
+ }
+ },
+ || async {},
+ inner_phase_watcher.clone(),
+ ).await;
+ }
},
|| async {},
phase_watcher,
@@ -266,7 +281,7 @@ pub async fn handle_pings(
let packet = PingPacket { id };
let packet: [u8; 12] = packet.into();
udp_socket.send_to(&packet, &socket_addr).await.unwrap();
- pending.lock().unwrap().insert(id, handle);
+ pending.lock().await.insert(id, handle);
}
};
@@ -277,7 +292,7 @@ pub async fn handle_pings(
let packet = PongPacket::try_from(buf.as_slice()).unwrap();
- if let Some(handler) = pending.lock().unwrap().remove(&packet.id) {
+ if let Some(handler) = pending.lock().await.remove(&packet.id) {
handler(packet);
}
}