diff options
| author | Eskil Q <eskilq@kth.se> | 2021-01-06 23:50:09 +0100 |
|---|---|---|
| committer | Eskil Q <eskilq@kth.se> | 2021-01-06 23:50:09 +0100 |
| commit | 92d5b21bf0f910f219c473002f83ba93ddcbee6d (patch) | |
| tree | 5280eb78c1e75e711ba5091c4ddeb6fa0ac79f69 /mumd/src/network | |
| parent | 02e6f2b84d72294b29a1698c1b73fbb5697815da (diff) | |
| download | mum-92d5b21bf0f910f219c473002f83ba93ddcbee6d.tar.gz | |
fix deadlock
Diffstat (limited to 'mumd/src/network')
| -rw-r--r-- | mumd/src/network/tcp.rs | 100 | ||||
| -rw-r--r-- | mumd/src/network/udp.rs | 69 |
2 files changed, 105 insertions, 64 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 982e747..6f18473 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -14,9 +14,9 @@ use std::collections::HashMap; use std::convert::{Into, TryInto}; use std::net::SocketAddr; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use tokio::net::TcpStream; -use tokio::sync::{mpsc, watch}; +use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::{self, Duration}; use tokio_native_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; @@ -68,7 +68,7 @@ pub async fn handle( .await; // Handshake (omitting `Version` message for brevity) - let state_lock = state.lock().unwrap(); + let state_lock = state.lock().await; authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; let phase_watcher = state_lock.phase_receiver(); let input_receiver = state_lock.audio().input_receiver(); @@ -138,19 +138,16 @@ async fn send_pings( delay_seconds: u64, phase_watcher: watch::Receiver<StatePhase>, ) { - let interval = Rc::new(RefCell::new(time::interval(Duration::from_secs( - delay_seconds, - )))); - let packet_sender = Rc::new(RefCell::new(packet_sender)); + let mut interval = time::interval(Duration::from_secs(delay_seconds)); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { loop { - interval.borrow_mut().tick().await; + interval.tick().await; trace!("Sending ping"); let msg = msgs::Ping::new(); - packet_sender.borrow_mut().send(msg.into()).unwrap(); + packet_sender.send(msg.into()).unwrap(); } }, || async {}, @@ -167,12 +164,11 @@ async fn send_packets( phase_watcher: watch::Receiver<StatePhase>, ) { let sink = Rc::new(RefCell::new(sink)); - let packet_receiver = Rc::new(RefCell::new(packet_receiver)); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { loop { - let packet = packet_receiver.borrow_mut().recv().await.unwrap(); + let packet = packet_receiver.recv().await.unwrap(); sink.borrow_mut().send(packet).await.unwrap(); } }, @@ -188,31 +184,40 @@ async fn send_packets( async fn send_voice( packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, - receiver: Arc<tokio::sync::Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, phase_watcher: watch::Receiver<StatePhase>, ) { let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), - async { - loop { - packet_sender.send(receiver - .lock() - .await - .next() - .await - .unwrap() - .into()) - .unwrap(); + loop { + let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + inner_phase_watcher_2.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) { + break; } - }, - || async {}, - inner_phase_watcher.clone(), - ).await; - debug!("Stopped sending TCP voice"); + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)), + async { + loop { + packet_sender.send( + receiver + .lock() + .await + .next() + .await + .unwrap() + .into()) + .unwrap(); + } + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } }, || async {}, phase_watcher, @@ -269,13 +274,13 @@ async fn listen( ) .await; } - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Connected) { + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) { let old = std::mem::take(vec); for handler in old { handler(TcpEventData::Connected(&msg)); } } - let mut state = state.lock().unwrap(); + let mut state = state.lock().await; let server = state.server_mut().unwrap(); server.parse_server_sync(*msg); match &server.welcome_text { @@ -291,16 +296,16 @@ async fn listen( warn!("Login rejected: {:?}", msg); } ControlPacket::UserState(msg) => { - state.lock().unwrap().parse_user_state(*msg); + state.lock().await.parse_user_state(*msg); } ControlPacket::UserRemove(msg) => { - state.lock().unwrap().remove_client(*msg); + state.lock().await.remove_client(*msg); } ControlPacket::ChannelState(msg) => { debug!("Channel state received"); state .lock() - .unwrap() + .await .server_mut() .unwrap() .parse_channel_state(*msg); //TODO parse initial if initial @@ -308,11 +313,32 @@ async fn listen( ControlPacket::ChannelRemove(msg) => { state .lock() - .unwrap() + .await .server_mut() .unwrap() .parse_channel_remove(*msg); } + ControlPacket::UDPTunnel(msg) => { + match *msg { + VoicePacket::Ping { .. } => {} + VoicePacket::Audio { + session_id, + // seq_num, + payload, + // position_info, + .. + } => { + state + .lock() + .await + .audio() + .decode_packet_payload( + VoiceStreamType::TCP, + session_id, + payload); + } + } + } packet => { debug!("Received unhandled ControlPacket {:#?}", packet); } @@ -320,7 +346,7 @@ async fn listen( } }, || async { - if let Some(vec) = event_queue.lock().unwrap().get_mut(&TcpEvent::Disconnected) { + if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) { let old = std::mem::take(vec); for handler in old { handler(TcpEventData::Disconnected); @@ -347,7 +373,7 @@ async fn register_events( let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap(); event_data .lock() - .unwrap() + .await .entry(event) .or_default() .push(handler); diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index d35a255..25ec8d5 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -13,9 +13,9 @@ use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::rc::Rc; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc}; use tokio::net::UdpSocket; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; @@ -31,7 +31,7 @@ pub async fn handle( mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>, mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>, ) { - let receiver = state.lock().unwrap().audio().input_receiver(); + let receiver = state.lock().await.audio().input_receiver(); loop { let connection_info = 'data: loop { @@ -47,7 +47,7 @@ pub async fn handle( let sink = Arc::new(Mutex::new(sink)); let source = Arc::new(Mutex::new(source)); - let phase_watcher = state.lock().unwrap().phase_receiver(); + let phase_watcher = state.lock().await.phase_receiver(); let last_ping_recv = AtomicU64::new(0); join!( listen( @@ -107,8 +107,8 @@ async fn new_crypt_state( .await .expect("Failed to bind UDP socket"); let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split(); - *sink.lock().unwrap() = new_sink; - *source.lock().unwrap() = new_source; + *sink.lock().await = new_sink; + *source.lock().await = new_source; } } } @@ -134,13 +134,14 @@ async fn listen( let rx = rx.fuse(); pin_mut!(rx); loop { - let mut source = source.lock().unwrap(); + let mut source = source.lock().await; let packet_recv = source.next().fuse(); pin_mut!(packet_recv); let exitor = select! { data = packet_recv => Some(data), _ = rx => None }; + drop(source); match exitor { None => { break; @@ -160,9 +161,10 @@ async fn listen( }; match packet { VoicePacket::Ping { timestamp } => { + // debug!("Sending UDP voice"); state - .lock() - .unwrap() + .lock() //TODO clean up unnecessary lock by only updating phase if it should change + .await .broadcast_phase(StatePhase::Connected(VoiceStreamType::UDP)); last_ping_recv.store(timestamp, Ordering::Relaxed); } @@ -175,9 +177,9 @@ async fn listen( } => { state .lock() - .unwrap() + .await .audio() - .decode_packet(VoiceStreamType::UDP, session_id, payload); + .decode_packet_payload(VoiceStreamType::UDP, session_id, payload); } } } @@ -198,19 +200,21 @@ async fn send_pings( ) { let mut last_send = None; let mut interval = interval(Duration::from_millis(1000)); + interval.tick().await; //this is so we get rid of the first instant resolve loop { let last_recv = last_ping_recv.load(Ordering::Relaxed); if last_send.is_some() && last_send.unwrap() != last_recv { + debug!("Sending TCP voice"); state .lock() - .unwrap() + .await .broadcast_phase(StatePhase::Connected(VoiceStreamType::TCP)); } match sink .lock() - .unwrap() - .send((VoicePacket::Ping { timestamp: 0 }, server_addr)) + .await + .send((VoicePacket::Ping { timestamp: last_recv + 1 }, server_addr)) .await { Ok(_) => { @@ -228,22 +232,33 @@ async fn send_voice( sink: Arc<Mutex<UdpSender>>, server_addr: SocketAddr, phase_watcher: watch::Receiver<StatePhase>, - receiver: Arc<tokio::sync::Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, + receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>, ) { let inner_phase_watcher = phase_watcher.clone(); run_until( |phase| matches!(phase, StatePhase::Disconnected), async { - run_until( - |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), - async { - debug!("Sending UDP audio"); - sink.lock().unwrap().send((receiver.lock().await.next().await.unwrap(), server_addr)).await.unwrap(); - debug!("Sent UDP audio"); - }, - || async {}, - inner_phase_watcher.clone(), - ).await; + loop { + let mut inner_phase_watcher_2 = inner_phase_watcher.clone(); + loop { + inner_phase_watcher_2.changed().await.unwrap(); + if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::UDP)) { + break; + } + } + run_until( + |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::UDP)), + async { + let mut receiver = receiver.lock().await; + loop { + let sending = (receiver.next().await.unwrap(), server_addr); + sink.lock().await.send(sending).await.unwrap(); + } + }, + || async {}, + inner_phase_watcher.clone(), + ).await; + } }, || async {}, phase_watcher, @@ -266,7 +281,7 @@ pub async fn handle_pings( let packet = PingPacket { id }; let packet: [u8; 12] = packet.into(); udp_socket.send_to(&packet, &socket_addr).await.unwrap(); - pending.lock().unwrap().insert(id, handle); + pending.lock().await.insert(id, handle); } }; @@ -277,7 +292,7 @@ pub async fn handle_pings( let packet = PongPacket::try_from(buf.as_slice()).unwrap(); - if let Some(handler) = pending.lock().unwrap().remove(&packet.id) { + if let Some(handler) = pending.lock().await.remove(&packet.id) { handler(packet); } } |
