aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src
diff options
context:
space:
mode:
Diffstat (limited to 'mumd/src')
-rw-r--r--mumd/src/command.rs10
-rw-r--r--mumd/src/main.rs8
-rw-r--r--mumd/src/network/udp.rs69
-rw-r--r--mumd/src/state.rs4
4 files changed, 68 insertions, 23 deletions
diff --git a/mumd/src/command.rs b/mumd/src/command.rs
index 7eec388..1337dce 100644
--- a/mumd/src/command.rs
+++ b/mumd/src/command.rs
@@ -8,7 +8,7 @@ use crate::state::{ExecutionContext, State};
use log::*;
use mumble_protocol::{Serverbound, control::ControlPacket};
use mumlib::command::{Command, CommandResponse};
-use std::sync::{Arc, RwLock};
+use std::sync::{atomic::{AtomicU64, Ordering}, Arc, RwLock};
use tokio::sync::{mpsc, oneshot, watch};
pub async fn handle(
@@ -23,6 +23,7 @@ pub async fn handle(
mut connection_info_sender: watch::Sender<Option<ConnectionInfo>>,
) {
debug!("Begin listening for commands");
+ let ping_count = AtomicU64::new(0);
while let Some((command, response_sender)) = command_receiver.recv().await {
debug!("Received command {:?}", command);
let mut state = state.write().unwrap();
@@ -47,10 +48,13 @@ pub async fn handle(
response_sender.send(generator()).unwrap();
}
ExecutionContext::Ping(generator, converter) => {
- match generator() {
+ let ret = generator();
+ debug!("Ping generated: {:?}", ret);
+ match ret {
Ok(addr) => {
+ let id = ping_count.fetch_add(1, Ordering::Relaxed);
let res = ping_request_sender.send((
- 0,
+ id,
addr,
Box::new(move |packet| {
response_sender.send(converter(packet)).unwrap();
diff --git a/mumd/src/main.rs b/mumd/src/main.rs
index d7bc2c0..f298070 100644
--- a/mumd/src/main.rs
+++ b/mumd/src/main.rs
@@ -109,7 +109,13 @@ async fn receive_commands(
sender.send((command, tx)).unwrap();
- let response = rx.await.unwrap();
+ let response = match rx.await {
+ Ok(r) => r,
+ Err(_) => {
+ error!("Internal command response sender dropped");
+ Ok(None)
+ }
+ };
let mut serialized = BytesMut::new();
bincode::serialize_into((&mut serialized).writer(), &response).unwrap();
diff --git a/mumd/src/network/udp.rs b/mumd/src/network/udp.rs
index 3ca77af..0958912 100644
--- a/mumd/src/network/udp.rs
+++ b/mumd/src/network/udp.rs
@@ -3,27 +3,25 @@ use crate::network::ConnectionInfo;
use crate::state::{State, StatePhase};
use futures_util::{FutureExt, SinkExt, StreamExt};
+use futures_util::future::join4;
use futures_util::stream::{SplitSink, SplitStream, Stream};
use log::*;
use mumble_protocol::crypt::ClientCryptState;
use mumble_protocol::ping::{PingPacket, PongPacket};
use mumble_protocol::voice::VoicePacket;
use mumble_protocol::Serverbound;
-use std::collections::HashMap;
+use std::collections::{hash_map::Entry, HashMap};
use std::convert::TryFrom;
use std::net::{Ipv6Addr, SocketAddr};
-use std::rc::Rc;
-use std::sync::atomic::{AtomicU64, Ordering};
-use std::sync::{Arc, RwLock};
+use std::sync::{atomic::{AtomicU64, Ordering}, Arc, RwLock};
use tokio::{join, net::UdpSocket};
-use tokio::sync::{mpsc, watch, Mutex};
-use tokio::time::{interval, Duration};
+use tokio::sync::{mpsc, oneshot, watch, Mutex};
+use tokio::time::{interval, timeout, Duration};
use tokio_util::udp::UdpFramed;
use super::{run_until, VoiceStreamType};
-use futures_util::future::join4;
-pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(PongPacket)>);
+pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(Option<PongPacket>) + Send>);
type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>;
type UdpReceiver = SplitStream<UdpFramed<ClientCryptState>>;
@@ -228,31 +226,68 @@ pub async fn handle_pings(
.await
.expect("Failed to bind UDP socket");
- let pending = Rc::new(Mutex::new(HashMap::new()));
+ let pending = Mutex::new(HashMap::new());
- let sender_handle = async {
+ let sender = async {
while let Some((id, socket_addr, handle)) = ping_request_receiver.recv().await {
+ debug!("Sending ping with id {} to {}", id, socket_addr);
let packet = PingPacket { id };
let packet: [u8; 12] = packet.into();
udp_socket.send_to(&packet, &socket_addr).await.unwrap();
- pending.lock().await.insert(id, handle);
+ let (tx, rx) = oneshot::channel();
+ match pending.lock().await.entry(id) {
+ Entry::Occupied(_) => {
+ warn!("Tried to send duplicate ping with id {}", id);
+ continue;
+ }
+ Entry::Vacant(v) => {
+ v.insert(tx);
+ }
+ }
+
+ tokio::spawn(async move {
+ handle(
+ match timeout(Duration::from_secs(1), rx).await {
+ Ok(Ok(r)) => Some(r),
+ Ok(Err(_)) => {
+ warn!("Ping response sender for server {}, ping id {} dropped", socket_addr, id);
+ None
+ }
+ Err(_) => {
+ debug!("Server {} timed out when sending ping id {}", socket_addr, id);
+ None
+ }
+ }
+ );
+ });
}
};
- let receiver_handle = async {
+ let receiver = async {
let mut buf = vec![0; 24];
+
while let Ok(read) = udp_socket.recv(&mut buf).await {
- assert_eq!(read, 24);
+ if read != 24 {
+ warn!("Ping response had length {}, expected 24", read);
+ continue;
+ }
let packet = PongPacket::try_from(buf.as_slice()).unwrap();
- if let Some(handler) = pending.lock().await.remove(&packet.id) {
- handler(packet);
+ match pending.lock().await.entry(packet.id) {
+ Entry::Occupied(o) => {
+ let id = *o.key();
+ if o.remove().send(packet).is_err() {
+ debug!("Received response to ping with id {} too late", id);
+ }
+ }
+ Entry::Vacant(v) => {
+ warn!("Received ping with id {} that we didn't send", v.key());
+ }
}
}
};
debug!("Waiting for ping requests");
-
- join!(sender_handle, receiver_handle);
+ join!(sender, receiver);
}
diff --git a/mumd/src/state.rs b/mumd/src/state.rs
index 132da74..45e7301 100644
--- a/mumd/src/state.rs
+++ b/mumd/src/state.rs
@@ -43,7 +43,7 @@ pub enum ExecutionContext {
Now(Box<dyn FnOnce() -> mumlib::error::Result<Option<CommandResponse>>>),
Ping(
Box<dyn FnOnce() -> mumlib::error::Result<SocketAddr>>,
- Box<dyn FnOnce(PongPacket) -> mumlib::error::Result<Option<CommandResponse>>>,
+ Box<dyn FnOnce(Option<PongPacket>) -> mumlib::error::Result<Option<CommandResponse>> + Send>,
),
}
@@ -390,7 +390,7 @@ impl State {
}
}),
Box::new(move |pong| {
- Ok(Some(CommandResponse::ServerStatus {
+ Ok(pong.map(|pong| CommandResponse::ServerStatus {
version: pong.version,
users: pong.users,
max_users: pong.max_users,