aboutsummaryrefslogtreecommitdiffstats
path: root/mumd/src/network/tcp.rs
diff options
context:
space:
mode:
authorEskil Q <eskilq@kth.se>2021-01-07 12:39:49 +0100
committerEskil Q <eskilq@kth.se>2021-01-07 12:39:49 +0100
commitab407d694e5a8ce6f831f8a84fc32dbdf6685aac (patch)
tree7566a607574c71a8c30e62088412c0b2f5f2ae64 /mumd/src/network/tcp.rs
parentf6a8a126e67ff1a89dcbdb35033e1f324add50dc (diff)
downloadmum-ab407d694e5a8ce6f831f8a84fc32dbdf6685aac.tar.gz
lift waiting for disconnection
Diffstat (limited to 'mumd/src/network/tcp.rs')
-rw-r--r--mumd/src/network/tcp.rs402
1 files changed, 174 insertions, 228 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs
index 3e4cbf3..e639dd0 100644
--- a/mumd/src/network/tcp.rs
+++ b/mumd/src/network/tcp.rs
@@ -2,17 +2,15 @@ use crate::network::ConnectionInfo;
use crate::state::{State, StatePhase};
use log::*;
-use futures::{join, SinkExt, Stream, StreamExt};
+use futures::{FutureExt, SinkExt, Stream, StreamExt};
use futures_util::stream::{SplitSink, SplitStream};
use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket};
use mumble_protocol::crypt::ClientCryptState;
use mumble_protocol::voice::VoicePacket;
use mumble_protocol::{Clientbound, Serverbound};
-use std::cell::RefCell;
use std::collections::HashMap;
use std::convert::{Into, TryInto};
use std::net::SocketAddr;
-use std::rc::Rc;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, watch, Mutex};
@@ -21,6 +19,7 @@ use tokio_native_tls::{TlsConnector, TlsStream};
use tokio_util::codec::{Decoder, Framed};
use super::{run_until, VoiceStreamType};
+use futures_util::future::join5;
type TcpSender = SplitSink<
Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
@@ -76,24 +75,34 @@ pub async fn handle(
info!("Logging in...");
- //TODO force exit all futures on disconnection
- join!(
- send_pings(packet_sender.clone(), 10, phase_watcher.clone()),
- listen(
- Arc::clone(&state),
- stream,
- crypt_state_sender.clone(),
- Arc::clone(&event_queue),
- phase_watcher.clone(),
- ),
- send_voice(
- packet_sender.clone(),
- Arc::clone(&input_receiver),
- phase_watcher.clone(),
- ),
- send_packets(sink, &mut packet_receiver, phase_watcher.clone()),
- register_events(&mut tcp_event_register_receiver, event_queue, phase_watcher),
- );
+ run_until(
+ |phase| matches!(phase, StatePhase::Disconnected),
+ join5(
+ send_pings(packet_sender.clone(), 10),
+ listen(
+ Arc::clone(&state),
+ stream,
+ crypt_state_sender.clone(),
+ Arc::clone(&event_queue),
+ ),
+ send_voice(
+ packet_sender.clone(),
+ Arc::clone(&input_receiver),
+ phase_watcher.clone(),
+ ),
+ send_packets(sink, &mut packet_receiver),
+ register_events(&mut tcp_event_register_receiver, Arc::clone(&event_queue)),
+ ).map(|_| ()),
+ || async {},
+ phase_watcher,
+ ).await;
+
+ if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(TcpEventData::Disconnected);
+ }
+ }
debug!("Fully disconnected TCP stream, waiting for new connection info");
}
@@ -135,50 +144,24 @@ async fn authenticate(sink: &mut TcpSender, username: String) {
async fn send_pings(
packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
delay_seconds: u64,
- phase_watcher: watch::Receiver<StatePhase>,
) {
let mut interval = time::interval(Duration::from_secs(delay_seconds));
-
- run_until(
- |phase| matches!(phase, StatePhase::Disconnected),
- async {
- loop {
- interval.tick().await;
- trace!("Sending TCP ping");
- let msg = msgs::Ping::new();
- packet_sender.send(msg.into()).unwrap();
- }
- },
- || async {},
- phase_watcher,
- )
- .await;
-
- debug!("Ping sender process killed");
+ loop {
+ interval.tick().await;
+ trace!("Sending TCP ping");
+ let msg = msgs::Ping::new();
+ packet_sender.send(msg.into()).unwrap();
+ }
}
async fn send_packets(
- sink: TcpSender,
+ mut sink: TcpSender,
packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
- phase_watcher: watch::Receiver<StatePhase>,
) {
- let sink = Rc::new(RefCell::new(sink));
- run_until(
- |phase| matches!(phase, StatePhase::Disconnected),
- async {
- loop {
- let packet = packet_receiver.recv().await.unwrap();
- sink.borrow_mut().send(packet).await.unwrap();
- }
- },
- || async {
- sink.borrow_mut().close().await.unwrap();
- },
- phase_watcher,
- )
- .await;
-
- debug!("TCP packet sender killed");
+ loop {
+ let packet = packet_receiver.recv().await.unwrap();
+ sink.send(packet).await.unwrap();
+ }
}
async fn send_voice(
@@ -186,41 +169,33 @@ async fn send_voice(
receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
phase_watcher: watch::Receiver<StatePhase>,
) {
- let inner_phase_watcher = phase_watcher.clone();
- run_until(
- |phase| matches!(phase, StatePhase::Disconnected),
- async {
- loop {
- let mut inner_phase_watcher_2 = inner_phase_watcher.clone();
+ loop {
+ let mut inner_phase_watcher = phase_watcher.clone();
+ loop {
+ inner_phase_watcher.changed().await.unwrap();
+ if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) {
+ break;
+ }
+ }
+ run_until(
+ |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)),
+ async {
loop {
- inner_phase_watcher_2.changed().await.unwrap();
- if matches!(*inner_phase_watcher.borrow(), StatePhase::Connected(VoiceStreamType::TCP)) {
- break;
- }
+ packet_sender.send(
+ receiver
+ .lock()
+ .await
+ .next()
+ .await
+ .unwrap()
+ .into())
+ .unwrap();
}
- run_until(
- |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::TCP)),
- async {
- loop {
- packet_sender.send(
- receiver
- .lock()
- .await
- .next()
- .await
- .unwrap()
- .into())
- .unwrap();
- }
- },
- || async {},
- inner_phase_watcher.clone(),
- ).await;
- }
- },
- || async {},
- phase_watcher,
- ).await;
+ },
+ || async {},
+ inner_phase_watcher.clone(),
+ ).await;
+ }
}
async fn listen(
@@ -228,158 +203,129 @@ async fn listen(
mut stream: TcpReceiver,
crypt_state_sender: mpsc::Sender<ClientCryptState>,
event_queue: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
- phase_watcher: watch::Receiver<StatePhase>,
) {
- let crypt_state = Rc::new(RefCell::new(None));
- let crypt_state_sender = Rc::new(RefCell::new(Some(crypt_state_sender)));
+ let mut crypt_state = None;
+ let mut crypt_state_sender = Some(crypt_state_sender);
- run_until(
- |phase| matches!(phase, StatePhase::Disconnected),
- async {
- loop {
- let packet = stream.next().await.unwrap();
- match packet.unwrap() {
- ControlPacket::TextMessage(msg) => {
- info!(
- "Got message from user with session ID {}: {}",
- msg.get_actor(),
- msg.get_message()
- );
- }
- ControlPacket::CryptSetup(msg) => {
- debug!("Crypt setup");
- // Wait until we're fully connected before initiating UDP voice
- *crypt_state.borrow_mut() = Some(ClientCryptState::new_from(
- msg.get_key()
- .try_into()
- .expect("Server sent private key with incorrect size"),
- msg.get_client_nonce()
- .try_into()
- .expect("Server sent client_nonce with incorrect size"),
- msg.get_server_nonce()
- .try_into()
- .expect("Server sent server_nonce with incorrect size"),
- ));
- }
- ControlPacket::ServerSync(msg) => {
- info!("Logged in");
- if let Some(sender) = crypt_state_sender.borrow_mut().take() {
- let _ = sender
- .send(
- crypt_state
- .borrow_mut()
- .take()
- .expect("Server didn't send us any CryptSetup packet!"),
- )
- .await;
- }
- if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Connected(&msg));
- }
- }
- let mut state = state.lock().await;
- let server = state.server_mut().unwrap();
- server.parse_server_sync(*msg);
- match &server.welcome_text {
- Some(s) => info!("Welcome: {}", s),
- None => info!("No welcome received"),
- }
- for channel in server.channels().values() {
- info!("Found channel {}", channel.name());
- }
- state.initialized();
- }
- ControlPacket::Reject(msg) => {
- warn!("Login rejected: {:?}", msg);
- }
- ControlPacket::UserState(msg) => {
- state.lock().await.parse_user_state(*msg);
- }
- ControlPacket::UserRemove(msg) => {
- state.lock().await.remove_client(*msg);
- }
- ControlPacket::ChannelState(msg) => {
- debug!("Channel state received");
- state
- .lock()
- .await
- .server_mut()
- .unwrap()
- .parse_channel_state(*msg); //TODO parse initial if initial
+ loop {
+ let packet = stream.next().await.unwrap();
+ match packet.unwrap() {
+ ControlPacket::TextMessage(msg) => {
+ info!(
+ "Got message from user with session ID {}: {}",
+ msg.get_actor(),
+ msg.get_message()
+ );
+ }
+ ControlPacket::CryptSetup(msg) => {
+ debug!("Crypt setup");
+ // Wait until we're fully connected before initiating UDP voice
+ crypt_state = Some(ClientCryptState::new_from(
+ msg.get_key()
+ .try_into()
+ .expect("Server sent private key with incorrect size"),
+ msg.get_client_nonce()
+ .try_into()
+ .expect("Server sent client_nonce with incorrect size"),
+ msg.get_server_nonce()
+ .try_into()
+ .expect("Server sent server_nonce with incorrect size"),
+ ));
+ }
+ ControlPacket::ServerSync(msg) => {
+ info!("Logged in");
+ if let Some(sender) = crypt_state_sender.take() {
+ let _ = sender
+ .send(
+ crypt_state
+ .take()
+ .expect("Server didn't send us any CryptSetup packet!"),
+ )
+ .await;
+ }
+ if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Connected) {
+ let old = std::mem::take(vec);
+ for handler in old {
+ handler(TcpEventData::Connected(&msg));
}
- ControlPacket::ChannelRemove(msg) => {
+ }
+ let mut state = state.lock().await;
+ let server = state.server_mut().unwrap();
+ server.parse_server_sync(*msg);
+ match &server.welcome_text {
+ Some(s) => info!("Welcome: {}", s),
+ None => info!("No welcome received"),
+ }
+ for channel in server.channels().values() {
+ info!("Found channel {}", channel.name());
+ }
+ state.initialized();
+ }
+ ControlPacket::Reject(msg) => {
+ warn!("Login rejected: {:?}", msg);
+ }
+ ControlPacket::UserState(msg) => {
+ state.lock().await.parse_user_state(*msg);
+ }
+ ControlPacket::UserRemove(msg) => {
+ state.lock().await.remove_client(*msg);
+ }
+ ControlPacket::ChannelState(msg) => {
+ debug!("Channel state received");
+ state
+ .lock()
+ .await
+ .server_mut()
+ .unwrap()
+ .parse_channel_state(*msg); //TODO parse initial if initial
+ }
+ ControlPacket::ChannelRemove(msg) => {
+ state
+ .lock()
+ .await
+ .server_mut()
+ .unwrap()
+ .parse_channel_remove(*msg);
+ }
+ ControlPacket::UDPTunnel(msg) => {
+ match *msg {
+ VoicePacket::Ping { .. } => {}
+ VoicePacket::Audio {
+ session_id,
+ // seq_num,
+ payload,
+ // position_info,
+ ..
+ } => {
state
.lock()
.await
- .server_mut()
- .unwrap()
- .parse_channel_remove(*msg);
- }
- ControlPacket::UDPTunnel(msg) => {
- match *msg {
- VoicePacket::Ping { .. } => {}
- VoicePacket::Audio {
+ .audio()
+ .decode_packet_payload(
+ VoiceStreamType::TCP,
session_id,
- // seq_num,
- payload,
- // position_info,
- ..
- } => {
- state
- .lock()
- .await
- .audio()
- .decode_packet_payload(
- VoiceStreamType::TCP,
- session_id,
- payload);
- }
- }
- }
- packet => {
- debug!("Received unhandled ControlPacket {:#?}", packet);
+ payload);
}
}
}
- },
- || async {
- if let Some(vec) = event_queue.lock().await.get_mut(&TcpEvent::Disconnected) {
- let old = std::mem::take(vec);
- for handler in old {
- handler(TcpEventData::Disconnected);
- }
+ packet => {
+ debug!("Received unhandled ControlPacket {:#?}", packet);
}
- },
- phase_watcher,
- )
- .await;
-
- debug!("Killing TCP listener block");
+ }
+ }
}
async fn register_events(
tcp_event_register_receiver: &mut mpsc::UnboundedReceiver<(TcpEvent, TcpEventCallback)>,
event_data: Arc<Mutex<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
- phase_watcher: watch::Receiver<StatePhase>,
) {
- let tcp_event_register_receiver = Rc::new(RefCell::new(tcp_event_register_receiver));
- run_until(
- |phase| matches!(phase, StatePhase::Disconnected),
- async {
- loop {
- let (event, handler) = tcp_event_register_receiver.borrow_mut().recv().await.unwrap();
- event_data
- .lock()
- .await
- .entry(event)
- .or_default()
- .push(handler);
- }
- },
- || async {},
- phase_watcher,
- )
- .await;
+ loop {
+ let (event, handler) = tcp_event_register_receiver.recv().await.unwrap();
+ event_data
+ .lock()
+ .await
+ .entry(event)
+ .or_default()
+ .push(handler);
+ }
}