diff options
| author | Eskil Queseth <eskilq@kth.se> | 2020-10-14 16:54:27 +0200 |
|---|---|---|
| committer | Eskil Queseth <eskilq@kth.se> | 2020-10-14 16:54:27 +0200 |
| commit | 7fb14d648aacd398f720f60236020dab6bf9fd35 (patch) | |
| tree | 52f4515aba225c25b006bdda82bf971a9a00f4bb /mumd/src/network/tcp.rs | |
| parent | dcb71982eab550535298b2d879a3a83820a0798a (diff) | |
| download | mum-7fb14d648aacd398f720f60236020dab6bf9fd35.tar.gz | |
add support for disconnect command
Diffstat (limited to 'mumd/src/network/tcp.rs')
| -rw-r--r-- | mumd/src/network/tcp.rs | 291 |
1 files changed, 199 insertions, 92 deletions
diff --git a/mumd/src/network/tcp.rs b/mumd/src/network/tcp.rs index 9fb5ae4..0aca19e 100644 --- a/mumd/src/network/tcp.rs +++ b/mumd/src/network/tcp.rs @@ -2,8 +2,8 @@ use crate::network::ConnectionInfo; use crate::state::{State, StatePhase}; use log::*; -use futures::channel::oneshot; -use futures::{join, SinkExt, StreamExt}; +use tokio::sync::oneshot; +use futures::{join, select, pin_mut, SinkExt, StreamExt, FutureExt}; use futures_util::stream::{SplitSink, SplitStream}; use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket}; use mumble_protocol::crypt::ClientCryptState; @@ -16,6 +16,7 @@ use tokio::sync::{mpsc, watch}; use tokio::time::{self, Duration}; use tokio_tls::{TlsConnector, TlsStream}; use tokio_util::codec::{Decoder, Framed}; +use futures_util::core_reexport::cell::RefCell; type TcpSender = SplitSink< Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>, @@ -43,15 +44,21 @@ pub async fn handle( .await; // Handshake (omitting `Version` message for brevity) - authenticate(&mut sink, state.lock().unwrap().username().unwrap().to_string()).await; + let mut state_lock = state.lock().unwrap(); + authenticate(&mut sink, state_lock.username().unwrap().to_string()).await; + let phase_watcher = state_lock.phase_receiver(); + let packet_sender = state_lock.packet_sender(); + drop(state_lock); info!("Logging in..."); join!( - send_pings(state.lock().unwrap().packet_sender(), 10), - listen(state, stream, crypt_state_sender), - send_packets(sink, packet_receiver), + send_pings(packet_sender, 10, phase_watcher.clone()), + listen(state, stream, crypt_state_sender, phase_watcher.clone()), + send_packets(sink, packet_receiver, phase_watcher), ); + + debug!("Fully disconnected TCP stream"); } async fn connect( @@ -87,109 +94,209 @@ async fn authenticate(sink: &mut TcpSender, username: String) { sink.send(msg.into()).await.unwrap(); } -async fn send_pings(packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, - delay_seconds: u64) { +async fn send_pings( + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + delay_seconds: u64, + mut phase_watcher: watch::Receiver<StatePhase>, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + let mut interval = time::interval(Duration::from_secs(delay_seconds)); - loop { - interval.tick().await; - trace!("Sending ping"); - let msg = msgs::Ping::new(); - packet_sender.send(msg.into()).unwrap(); - } + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let interval_waiter = interval.tick().fuse(); + pin_mut!(interval_waiter); + let exitor = select! { + data = interval_waiter => Some(data), + _ = rx => None + }; + + match exitor { + Some(_) => { + trace!("Sending ping"); + let msg = msgs::Ping::new(); + packet_sender.send(msg.into()).unwrap(); + } + None => break, + } + } + }; + + join!(main_block, phase_transition_block); + + debug!("Ping sender process killed"); } -async fn send_packets(mut sink: TcpSender, - mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>) { +async fn send_packets( + mut sink: TcpSender, + mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>, + mut phase_watcher: watch::Receiver<StatePhase>, +) { + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + + let main_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = packet_receiver.recv().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; + } + Some(None) => { + warn!("Channel closed before disconnect command"); + break; + } + Some(Some(packet)) => { + sink.send(packet).await.unwrap(); + } + } + } - while let Some(packet) = packet_receiver.recv().await { - sink.send(packet).await.unwrap(); - } + //clears queue of remaining packets + while let Ok(_) = packet_receiver.try_recv() {} + + sink.close().await.unwrap(); + }; + + join!(main_block, phase_transition_block); + + debug!("TCP packet sender killed"); } async fn listen( state: Arc<Mutex<State>>, mut stream: TcpReceiver, crypt_state_sender: oneshot::Sender<ClientCryptState>, + mut phase_watcher: watch::Receiver<StatePhase>, ) { let mut crypt_state = None; let mut crypt_state_sender = Some(crypt_state_sender); - while let Some(packet) = stream.next().await { - //TODO handle types separately - match packet.unwrap() { - ControlPacket::TextMessage(msg) => { - info!( - "Got message from user with session ID {}: {}", - msg.get_actor(), - msg.get_message() - ); - } - ControlPacket::CryptSetup(msg) => { - debug!("Crypt setup"); - // Wait until we're fully connected before initiating UDP voice - crypt_state = Some(ClientCryptState::new_from( - msg.get_key() - .try_into() - .expect("Server sent private key with incorrect size"), - msg.get_client_nonce() - .try_into() - .expect("Server sent client_nonce with incorrect size"), - msg.get_server_nonce() - .try_into() - .expect("Server sent server_nonce with incorrect size"), - )); - } - ControlPacket::ServerSync(msg) => { - info!("Logged in"); - if let Some(sender) = crypt_state_sender.take() { - let _ = sender.send( - crypt_state - .take() - .expect("Server didn't send us any CryptSetup packet!"), - ); - } - let mut state = state.lock().unwrap(); - let server = state.server_mut(); - server.parse_server_sync(msg); - match &server.welcome_text { - Some(s) => info!("Welcome: {}", s), - None => info!("No welcome received"), + let (tx, rx) = oneshot::channel(); + let phase_transition_block = async { + while !matches!(phase_watcher.recv().await.unwrap(), StatePhase::Disconnected) {} + tx.send(true); + }; + + let listener_block = async { + let rx = rx.fuse(); + pin_mut!(rx); + loop { + let packet_recv = stream.next().fuse(); + pin_mut!(packet_recv); + let exitor = select! { + data = packet_recv => Some(data), + _ = rx => None + }; + match exitor { + None => { + break; } - for (_, channel) in server.channels() { - info!("Found channel {}", channel.name()); + Some(None) => { + warn!("Channel closed before disconnect command"); + break; } - state.initialized(); - } - ControlPacket::Reject(msg) => { - warn!("Login rejected: {:?}", msg); - } - ControlPacket::UserState(msg) => { - let mut state = state.lock().unwrap(); - let session = msg.get_session(); - state.audio_mut().add_client(msg.get_session()); //TODO - if *state.phase_receiver().borrow() == StatePhase::Connecting { - state.parse_initial_user_state(msg); - } else { - state.server_mut().parse_user_state(msg); + Some(Some(packet)) => { + //TODO handle types separately + match packet.unwrap() { + ControlPacket::TextMessage(msg) => { + info!( + "Got message from user with session ID {}: {}", + msg.get_actor(), + msg.get_message() + ); + } + ControlPacket::CryptSetup(msg) => { + debug!("Crypt setup"); + // Wait until we're fully connected before initiating UDP voice + crypt_state = Some(ClientCryptState::new_from( + msg.get_key() + .try_into() + .expect("Server sent private key with incorrect size"), + msg.get_client_nonce() + .try_into() + .expect("Server sent client_nonce with incorrect size"), + msg.get_server_nonce() + .try_into() + .expect("Server sent server_nonce with incorrect size"), + )); + } + ControlPacket::ServerSync(msg) => { + info!("Logged in"); + if let Some(sender) = crypt_state_sender.take() { + let _ = sender.send( + crypt_state + .take() + .expect("Server didn't send us any CryptSetup packet!"), + ); + } + let mut state = state.lock().unwrap(); + let server = state.server_mut(); + server.parse_server_sync(msg); + match &server.welcome_text { + Some(s) => info!("Welcome: {}", s), + None => info!("No welcome received"), + } + for (_, channel) in server.channels() { + info!("Found channel {}", channel.name()); + } + state.initialized(); + } + ControlPacket::Reject(msg) => { + warn!("Login rejected: {:?}", msg); + } + ControlPacket::UserState(msg) => { + let mut state = state.lock().unwrap(); + let session = msg.get_session(); + state.audio_mut().add_client(msg.get_session()); //TODO + if *state.phase_receiver().borrow() == StatePhase::Connecting { + state.parse_initial_user_state(msg); + } else { + state.server_mut().parse_user_state(msg); + } + let server = state.server_mut(); + let user = server.users().get(&session).unwrap(); + info!("User {} connected to {}", + user.name(), + user.channel()); + } + ControlPacket::UserRemove(msg) => { + info!("User {} left", msg.get_session()); + state.lock().unwrap().audio_mut().remove_client(msg.get_session()); + } + ControlPacket::ChannelState(msg) => { + debug!("Channel state received"); + state.lock().unwrap().server_mut().parse_channel_state(msg); //TODO parse initial if initial + } + ControlPacket::ChannelRemove(msg) => { + state.lock().unwrap().server_mut().parse_channel_remove(msg); + } + _ => {} + } } - let server = state.server_mut(); - let user = server.users().get(&session).unwrap(); - info!("User {} connected to {}", - user.name(), - user.channel()); - } - ControlPacket::UserRemove(msg) => { - info!("User {} left", msg.get_session()); - state.lock().unwrap().audio_mut().remove_client(msg.get_session()); - } - ControlPacket::ChannelState(msg) => { - debug!("Channel state received"); - state.lock().unwrap().server_mut().parse_channel_state(msg); //TODO parse initial if initial - } - ControlPacket::ChannelRemove(msg) => { - state.lock().unwrap().server_mut().parse_channel_remove(msg); } - _ => {} } - } + + //TODO? clean up stream + }; + + join!(phase_transition_block, listener_block); + + debug!("Killing TCP listener block"); } |
