diff options
| author | Eskil Q <eskilq@kth.se> | 2021-01-06 23:50:09 +0100 |
|---|---|---|
| committer | Eskil Q <eskilq@kth.se> | 2021-01-06 23:50:09 +0100 |
| commit | 92d5b21bf0f910f219c473002f83ba93ddcbee6d (patch) | |
| tree | 5280eb78c1e75e711ba5091c4ddeb6fa0ac79f69 | |
| parent | 02e6f2b84d72294b29a1698c1b73fbb5697815da (diff) | |
| download | mum-92d5b21bf0f910f219c473002f83ba93ddcbee6d.tar.gz | |
fix deadlock
| -rw-r--r-- | mumd/src/audio.rs | 10 | ||||
| -rw-r--r-- | mumd/src/client.rs | 4 | ||||
| -rw-r--r-- | mumd/src/command.rs | 8 | ||||
| -rw-r--r-- | mumd/src/main.rs | 2 | ||||
| -rw-r--r-- | mumd/src/network/tcp.rs | 100 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 69 |
6 files changed, 118 insertions, 75 deletions
diff --git a/mumd/src/audio.rs b/mumd/src/audio.rs index 3f03e61..bdc8377 100644 --- a/mumd/src/audio.rs +++ b/mumd/src/audio.rs @@ -31,7 +31,7 @@ use std::{ }; use strum::IntoEnumIterator; use strum_macros::EnumIter; -use tokio::sync::watch; +use tokio::sync::{watch}; const SAMPLE_RATE: u32 = 48000; @@ -132,11 +132,11 @@ impl Audio { let err_fn = |err| error!("An error occurred on the output audio stream: {}", err); - let user_volumes = Arc::new(Mutex::new(HashMap::new())); + let user_volumes = Arc::new(std::sync::Mutex::new(HashMap::new())); let (output_volume_sender, output_volume_receiver) = watch::channel::<f32>(output_volume); - let play_sounds = Arc::new(Mutex::new(VecDeque::new())); + let play_sounds = Arc::new(std::sync::Mutex::new(VecDeque::new())); - let client_streams = Arc::new(Mutex::new(HashMap::new())); + let client_streams = Arc::new(std::sync::Mutex::new(HashMap::new())); let output_stream = match output_supported_sample_format { SampleFormat::F32 => output_device.build_output_stream( &output_config, @@ -292,7 +292,7 @@ impl Audio { .collect(); } - pub fn decode_packet(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) { + pub fn decode_packet_payload(&self, stream_type: VoiceStreamType, session_id: u32, payload: VoicePacketPayload) { match self.client_streams.lock().unwrap().entry((stream_type, session_id)) { Entry::Occupied(mut entry) => { entry diff --git a/mumd/src/client.rs b/mumd/src/client.rs index 3613061..222e2a7 100644 --- a/mumd/src/client.rs +++ b/mumd/src/client.rs @@ -6,8 +6,8 @@ use futures_util::join; use ipc_channel::ipc::IpcSender; use mumble_protocol::{Serverbound, control::ControlPacket, crypt::ClientCryptState}; use mumlib::command::{Command, CommandResponse}; -use std::sync::{Arc, Mutex}; -use tokio::sync::{mpsc, watch}; +use std::sync::Arc; +use tokio::sync::{mpsc, watch, Mutex}; pub async fn handle( command_receiver: mpsc::UnboundedReceiver<( diff --git a/mumd/src/command.rs b/mumd/src/command.rs index e77b34b..b099ae1 100644 --- a/mumd/src/command.rs +++ b/mumd/src/command.rs @@ -9,8 +9,8 @@ use ipc_channel::ipc::IpcSender; use log::*; use mumble_protocol::{Serverbound, control::ControlPacket}; use mumlib::command::{Command, CommandResponse}; -use std::sync::{Arc, Mutex}; -use tokio::sync::{mpsc, oneshot, watch}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; pub async fn handle( state: Arc<Mutex<State>>, @@ -26,9 +26,11 @@ pub async fn handle( debug!("Begin listening for commands"); while let Some((command, response_sender)) = command_receiver.recv().await { debug!("Received command {:?}", command); - let mut state = state.lock().unwrap(); + debug!("locking state"); + let mut state = state.lock().await; let event = state.handle_command(command, &mut packet_sender, &mut connection_info_sender); drop(state); + debug!("unlocking state"); match event { ExecutionContext::TcpEvent(event, generator) => { let (tx, rx) = oneshot::channel(); diff --git a/mumd/src/main.rs b/mumd/src/main.rs index 67481f9..a8cb230 100644 --- a/mumd/src/main.rs +++ b/mumd/src/main.rs @@ -14,7 +14,7 @@ use std::fs; use tokio::sync::mpsc; use tokio::task::spawn_blocking; -#[tokio::main] +#[tokio::main(worker_threads = 4)] async fn main() { setup_logger(std::io::stderr(), true); notify::init(); diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 982e747..6f18473 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -14,9 +14,9 @@ use std::collections::HashMap; use std::convert::{Into, TryInto}; use std::net::SocketAddr; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use tokio::net::TcpStream; -use tokio::sync::{mpsc, watch}; +use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::{self, Duration}; use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; @@ -68,7 +68,7 @@ pub async fn handle( .await; // Handshake (omitting `Version` message for brevity) - let state_lock = state.lock().unwrap(); + let state_lock = state.lock().await; authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; let phase_watcher = state_lock.phase_receiver(); let input_receiver = state_lock.audio().input_receiver(); @@ -138,19 +138,16 @@ async fn send_pings( delay_seconds: u64, phase_watcher: watch::Receiver<StatePhase>, ) { - let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs( - delay_seconds, - )))); - let packet_sender = Rc::new(RefCell::new(packet_sender)); + let mut interval = time::interval(Duration::from_secs(delay_seconds)); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { loop { - interval.borrow_mut().tick().await; + interval.tick().await; trace!("Sending ping"); let msg = msgs::Ping::new(); - packet_sender.borrow_mut().send(msg.into()).unwrap(); + packet_sender.send(msg.into()).unwrap(); } }, || async {}, @@ -167,12 +164,11 @@ async fn send_packets( phase_watcher: watch::Receiver<StatePhase>, ) { let sink = Rc::new(RefCell::new(sink)); - let packet_receiver = Rc::new(RefCell::new(packet_receiver)); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { loop { - let packet = packet_receiver.borrow_mut().recv().await.unwrap(); + let packet = packet_receiver.recv().await.unwrap(); sink.borrow_mut().send(packet).await.unwrap(); } }, @@ -188,31 +184,40 @@ async fn send_packets( async fn send_voice( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, - receiver: Arc<tokio::sync::Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, phase_watcher: watch::Receiver<StatePhase>, ) { let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), - async { - loop { - packet_sender.send(receiver - .lock() - .await - .next() - .await - .unwrap() - .into()) - .unwrap(); + loop { + let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + inner_phase_watcher_2.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { + break; } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - debug!("Stopped sending TCP voice"); + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + async { + loop { + packet_sender.send( + receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); + } + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } }, || async {}, phase_watcher, @@ -269,13 +274,13 @@ async fn listen( ) .await; } - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { let old = std::mem::take(vec); for handler in old { handler(TcpEventData::Connected(&msg)); } } - let mut state = state.lock().unwrap(); + let mut state = state.lock().await; let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); match &server.welcome_text { @@ -291,16 +296,16 @@ async fn listen( warn!("Login rejected: {:?}", msg); } ControlPacket::UserState(msg) => { - state.lock().unwrap().parse_user_state(*msg); + state.lock().await.parse_user_state(*msg); } ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); + state.lock().await.remove_client(*msg); } ControlPacket::ChannelState(msg) => { debug!("Channel state received"); state .lock() - .unwrap() + .await .server_mut() .unwrap() .parse_channel_state(*msg); //TODO parse initial if initial @@ -308,11 +313,32 @@ async fn listen( ControlPacket::ChannelRemove(msg) => { state .lock() - .unwrap() + .await .server_mut() .unwrap() .parse_channel_remove(*msg); } + ControlPacket::UDPTunnel(msg) => { + match *msg { + VoicePacket::Ping { .. } => {} + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() + .await + .audio() + .decode_packet_payload( + VoiceStreamType::TCP, + session_id, + payload); + } + } + } packet => { debug!("Received unhandled ControlPacket {:#?}", packet); } @@ -320,7 +346,7 @@ async fn listen( } }, || async { - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { let old = std::mem::take(vec); for handler in old { handler(TcpEventData::Disconnected); @@ -347,7 +373,7 @@ async fn register_events( let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap(); event_data .lock() - .unwrap() + .await .entry(event) .or_default() .push(handler); diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index d35a255..25ec8d5 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -13,9 +13,9 @@ use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::rc::Rc; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc}; use tokio::net::UdpSocket; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; @@ -31,7 +31,7 @@ pub async fn handle( mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>, ) { - let receiver = state.lock().unwrap().audio().input_receiver(); + let receiver = state.lock().await.audio().input_receiver(); loop { let connection_info = 'data: loop { @@ -47,7 +47,7 @@ pub async fn handle( let sink = Arc::new(Mutex::new(sink)); let source = Arc::new(Mutex::new(source)); - let phase_watcher = state.lock().unwrap().phase_receiver(); + let phase_watcher = state.lock().await.phase_receiver(); let last_ping_recv = AtomicU64::new(0); join!( listen( @@ -107,8 +107,8 @@ async fn new_crypt_state( .await .expect("Failed to bind UDP socket"); let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().unwrap() = new_sink; - *source.lock().unwrap() = new_source; + *sink.lock().await = new_sink; + *source.lock().await = new_source; } } } @@ -134,13 +134,14 @@ async fn listen( let rx = rx.fuse(); pin_mut!(rx); loop { - let mut source = source.lock().unwrap(); + let mut source = source.lock().await; let packet_recv = source.next().fuse(); pin_mut!(packet_recv); let exitor = select! { data = packet_recv => Some(data), _ = rx => None }; + drop(source); match exitor { None => { break; @@ -160,9 +161,10 @@ async fn listen( }; match packet { VoicePacket::Ping { timestamp } => { + // debug!("Sending UDP voice"); state - .lock() - .unwrap() + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); last_ping_recv.store(timestamp, Ordering::Relaxed); } @@ -175,9 +177,9 @@ async fn listen( } => { state .lock() - .unwrap() + .await .audio() - .decode_packet(VoiceStreamType::UDP, session_id, payload); + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); } } } @@ -198,19 +200,21 @@ async fn send_pings( ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); + interval.tick().await; //this is so we get rid of the first instant resolve loop { let last_recv = last_ping_recv.load(Ordering::Relaxed); if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); state .lock() - .unwrap() + .await .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); } match sink .lock() - .unwrap() - .send((VoicePacket::Ping { timestamp: 0 }, server_addr)) + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) .await { Ok(_) => { @@ -228,22 +232,33 @@ async fn send_voice( sink: Arc<Mutex<UdpSender>>, server_addr: SocketAddr, phase_watcher: watch::Receiver<StatePhase>, - receiver: Arc<tokio::sync::Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, ) { let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - async { - debug!("Sending UDP audio"); - sink.lock().unwrap().send((receiver.lock().await.next().await.unwrap(), server_addr)).await.unwrap(); - debug!("Sent UDP audio"); - }, - || async {}, - inner_phase_watcher.clone(), - ).await; + loop { + let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + inner_phase_watcher_2.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; + loop { + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); + } + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } }, || async {}, phase_watcher, @@ -266,7 +281,7 @@ pub async fn handle_pings( let packet = PingPacket { id }; let packet: [u8; 12] = packet.into(); udp_socket.send_to(&packet, &socket_addr).await.unwrap(); - pending.lock().unwrap().insert(id, handle); + pending.lock().await.insert(id, handle); } }; @@ -277,7 +292,7 @@ pub async fn handle_pings( let packet = PongPacket::try_from(buf.as_slice()).unwrap(); - if let Some(handler) = pending.lock().unwrap().remove(&packet.id) { + if let Some(handler) = pending.lock().await.remove(&packet.id) { handler(packet); } } |
