From 7e848151aea0ad579acbd51125907d96cc67438b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Sat, 10 Apr 2021 19:28:37 +0200 Subject: timeout server pings --- mumd/src/network/udp.rs | 64 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 13 deletions(-) (limited to 'mumd/src/network/udp.rs') diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs index d267007..59620a3 100644 --- a/mumd/src/network/udp.rs +++ b/mumd/src/network/udp.rs @@ -10,18 +10,18 @@ use mumble_protocol::crypt::ClientCryptState; use mumble_protocol::ping::{PingPacket, PongPacket}; use mumble_protocol::voice::VoicePacket; use mumble_protocol::Serverbound; -use std::collections::HashMap; +use std::collections::{hash_map::Entry, HashMap}; use std::convert::TryFrom; use std::net::{Ipv6Addr, SocketAddr}; use std::sync::{atomic::{AtomicU64, Ordering}, Arc, RwLock}; -use tokio::{join, net::UdpSocket}; -use tokio::sync::{mpsc, watch, Mutex}; +use tokio::{join, net::UdpSocket, select}; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; use tokio::time::{interval, Duration}; use tokio_util::udp::UdpFramed; use super::{run_until, VoiceStreamType}; -pub type PingRequest = (u64, SocketAddr, Box); +pub type PingRequest = (u64, SocketAddr, Box) + Send>); type UdpSender = SplitSink, (VoicePacket, SocketAddr)>; type UdpReceiver = SplitStream>; @@ -226,32 +226,70 @@ pub async fn handle_pings( .await .expect("Failed to bind UDP socket"); - let pending = Mutex::new(HashMap::new()); + let pending = Arc::new(Mutex::new(HashMap::new())); - let sender_handle = async { + let sender = async { while let Some((id, socket_addr, handle)) = ping_request_receiver.recv().await { debug!("Sending ping {} to {}", id, socket_addr); let packet = PingPacket { id }; let packet: [u8; 12] = packet.into(); udp_socket.send_to(&packet, &socket_addr).await.unwrap(); - pending.lock().await.insert(id, handle); + let (tx, rx) = oneshot::channel(); + match pending.lock().await.entry(id) { + Entry::Occupied(_) => { + warn!("Tried to send duplicate ping {}", id); + continue; + } + Entry::Vacant(v) => { + v.insert(tx); + } + } + + tokio::spawn(async move { + let rx = rx.fuse(); + let timeout = async { + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + }; + handle(select! { + r = rx => match r { + Ok(r) => Some(r), + Err(_) => { + warn!("Ping response sender dropped"); + None + } + }, + _ = timeout => None, + }); + }); } }; - let receiver_handle = async { + let receiver = async { let mut buf = vec![0; 24]; + while let Ok(read) = udp_socket.recv(&mut buf).await { - assert_eq!(read, 24); + if read != 24 { + warn!("Ping response had length {}, expected 24", read); + continue; + } + assert_eq!(read, 24); // just checked let packet = PongPacket::try_from(buf.as_slice()).unwrap(); - if let Some(handler) = pending.lock().await.remove(&packet.id) { - handler(packet); + match pending.lock().await.entry(packet.id) { + Entry::Occupied(o) => { + let id = *o.key(); + if o.remove().send(packet).is_err() { + debug!("Received response to ping {} too late", id); + } + } + Entry::Vacant(v) => { + warn!("Received ping {} that we didn't send", v.key()); + } } } }; debug!("Waiting for ping requests"); - - join!(sender_handle, receiver_handle); + join!(sender, receiver); } -- cgit v1.2.1