diff options
| author | Eskil Queseth <eskilq@kth.se> | 2020-10-14 19:48:05 +0200 |
|---|---|---|
| committer | Eskil Queseth <eskilq@kth.se> | 2020-10-14 19:48:05 +0200 |
| commit | a40d365aacf118b33c07f3353f277eb96c4536a8 (patch) | |
| tree | 1a5e623da01745b3d2a2d1b1d5958a22cd0e382a /mumd/src/state.rs | |
| parent | c0855405832ce47f75fa6e1ff7a33e51a8b36903 (diff) | |
| parent | 6ac72067a75d5e1904226efb5c45bcf0e54a0ae5 (diff) | |
| download | mum-a40d365aacf118b33c07f3353f277eb96c4536a8.tar.gz | |
Merge remote-tracking branch 'origin/commands' into main
Diffstat (limited to 'mumd/src/state.rs')
| -rw-r--r-- | mumd/src/state.rs | 259 |
1 files changed, 225 insertions, 34 deletions
diff --git a/mumd/src/state.rs b/mumd/src/state.rs index 1ef8467..b6fe780 100644 --- a/mumd/src/state.rs +++ b/mumd/src/state.rs @@ -1,8 +1,197 @@ +use crate::audio::Audio; +use crate::command::{Command, CommandResponse}; +use crate::network::ConnectionInfo; use log::*; use mumble_protocol::control::msgs; -use std::collections::HashMap; +use mumble_protocol::control::ControlPacket; +use mumble_protocol::voice::Serverbound; use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::net::ToSocketAddrs; +use tokio::sync::{mpsc, watch}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum StatePhase { + Disconnected, + Connecting, + Connected, +} + +pub struct State { + server: Option<Server>, + audio: Audio, + + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + command_sender: mpsc::UnboundedSender<Command>, + connection_info_sender: watch::Sender<Option<ConnectionInfo>>, + + phase_watcher: (watch::Sender<StatePhase>, watch::Receiver<StatePhase>), + + username: Option<String>, + session_id: Option<u32>, +} + +impl State { + pub fn new( + packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>, + command_sender: mpsc::UnboundedSender<Command>, + connection_info_sender: watch::Sender<Option<ConnectionInfo>>, + ) -> Self { + Self { + server: None, + audio: Audio::new(), + packet_sender, + command_sender, + connection_info_sender, + phase_watcher: watch::channel(StatePhase::Disconnected), + username: None, + session_id: None, + } + } + + //TODO? move bool inside Result + pub async fn handle_command( + &mut self, + command: Command, + ) -> (bool, Result<Option<CommandResponse>, ()>) { + match command { + Command::ChannelJoin { channel_id } => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + warn!("Not connected"); + return (false, Err(())); + } + let mut msg = msgs::UserState::new(); + msg.set_session(self.session_id.unwrap()); + msg.set_channel_id(channel_id); + self.packet_sender.send(msg.into()).unwrap(); + (false, Ok(None)) + } + Command::ChannelList => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + warn!("Not connected"); + return (false, Err(())); + } + ( + false, + Ok(Some(CommandResponse::ChannelList { + channels: self.server.as_ref().unwrap().channels.clone(), + })), + ) + } + Command::ServerConnect { + host, + port, + username, + accept_invalid_cert, + } => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Disconnected) { + warn!("Tried to connect to a server while already connected"); + return (false, Err(())); + } + self.server = Some(Server::new()); + self.username = Some(username); + self.phase_watcher + .0 + .broadcast(StatePhase::Connecting) + .unwrap(); + let socket_addr = (host.as_ref(), port) + .to_socket_addrs() + .expect("Failed to parse server address") + .next() + .expect("Failed to resolve server address"); + self.connection_info_sender + .broadcast(Some(ConnectionInfo::new( + socket_addr, + host, + accept_invalid_cert, + ))) + .unwrap(); + (true, Ok(None)) + } + Command::Status => { + if !matches!(*self.phase_receiver().borrow(), StatePhase::Connected) { + warn!("Not connected"); + return (false, Err(())); + } + ( + false, + Ok(Some(CommandResponse::Status { + username: self.username.clone(), + server_state: self.server.clone().unwrap(), + })), + ) + } + Command::ServerDisconnect => { + self.session_id = None; + self.username = None; + self.server = None; + + self.phase_watcher + .0 + .broadcast(StatePhase::Disconnected) + .unwrap(); + (false, Ok(None)) + } + } + } + + pub fn parse_initial_user_state(&mut self, msg: msgs::UserState) { + if !msg.has_session() { + warn!("Can't parse user state without session"); + return; + } + if !msg.has_name() { + warn!("Missing name in initial user state"); + } else if msg.get_name() == self.username.as_ref().unwrap() { + match self.session_id { + None => { + debug!("Found our session id: {}", msg.get_session()); + self.session_id = Some(msg.get_session()); + } + Some(session) => { + if session != msg.get_session() { + error!( + "Got two different session IDs ({} and {}) for ourselves", + session, + msg.get_session() + ); + } else { + debug!("Got our session ID twice"); + } + } + } + } + self.server.as_mut().unwrap().parse_user_state(msg); + } + pub fn initialized(&self) { + self.phase_watcher + .0 + .broadcast(StatePhase::Connected) + .unwrap(); + } + + pub fn audio(&self) -> &Audio { + &self.audio + } + pub fn audio_mut(&mut self) -> &mut Audio { + &mut self.audio + } + pub fn packet_sender(&self) -> mpsc::UnboundedSender<ControlPacket<Serverbound>> { + self.packet_sender.clone() + } + pub fn phase_receiver(&self) -> watch::Receiver<StatePhase> { + self.phase_watcher.1.clone() + } + pub fn server_mut(&mut self) -> Option<&mut Server> { + self.server.as_mut() + } + pub fn username(&self) -> Option<&String> { + self.username.as_ref() + } +} + +#[derive(Clone, Debug)] pub struct Server { channels: HashMap<u32, Channel>, users: HashMap<u32, User>, @@ -18,41 +207,49 @@ impl Server { } } - pub fn parse_server_sync(&mut self, mut msg: Box<msgs::ServerSync>) { + pub fn parse_server_sync(&mut self, mut msg: msgs::ServerSync) { if msg.has_welcome_text() { self.welcome_text = Some(msg.take_welcome_text()); } } - pub fn parse_channel_state(&mut self, msg: Box<msgs::ChannelState>) { + pub fn parse_channel_state(&mut self, msg: msgs::ChannelState) { if !msg.has_channel_id() { warn!("Can't parse channel state without channel id"); return; } match self.channels.entry(msg.get_channel_id()) { - Entry::Vacant(e) => { e.insert(Channel::new(msg)); }, + Entry::Vacant(e) => { + e.insert(Channel::new(msg)); + } Entry::Occupied(mut e) => e.get_mut().parse_channel_state(msg), } } - pub fn parse_channel_remove(&mut self, msg: Box<msgs::ChannelRemove>) { + pub fn parse_channel_remove(&mut self, msg: msgs::ChannelRemove) { if !msg.has_channel_id() { warn!("Can't parse channel remove without channel id"); return; } match self.channels.entry(msg.get_channel_id()) { - Entry::Vacant(_) => { warn!("Attempted to remove channel that doesn't exist"); } - Entry::Occupied(e) => { e.remove(); } + Entry::Vacant(_) => { + warn!("Attempted to remove channel that doesn't exist"); + } + Entry::Occupied(e) => { + e.remove(); + } } } - pub fn parse_user_state(&mut self, msg: Box<msgs::UserState>) { + pub fn parse_user_state(&mut self, msg: msgs::UserState) { if !msg.has_session() { warn!("Can't parse user state without session"); return; } match self.users.entry(msg.get_session()) { - Entry::Vacant(e) => { e.insert(User::new(msg)); }, + Entry::Vacant(e) => { + e.insert(User::new(msg)); + } Entry::Occupied(mut e) => e.get_mut().parse_user_state(msg), } } @@ -66,7 +263,7 @@ impl Server { } } - +#[derive(Clone, Debug)] pub struct Channel { description: Option<String>, links: Vec<u32>, @@ -77,7 +274,7 @@ pub struct Channel { } impl Channel { - pub fn new(mut msg: Box<msgs::ChannelState>) -> Self { + pub fn new(mut msg: msgs::ChannelState) -> Self { Self { description: if msg.has_description() { Some(msg.take_description()) @@ -96,7 +293,7 @@ impl Channel { } } - pub fn parse_channel_state(&mut self, mut msg: Box<msgs::ChannelState>) { + pub fn parse_channel_state(&mut self, mut msg: msgs::ChannelState) { if msg.has_description() { self.description = Some(msg.take_description()); } @@ -120,6 +317,7 @@ impl Channel { } } +#[derive(Clone, Debug)] pub struct User { channel: u32, comment: Option<String>, @@ -128,15 +326,15 @@ pub struct User { priority_speaker: bool, recording: bool, - suppress: bool, // by me + suppress: bool, // by me self_mute: bool, // by self self_deaf: bool, // by self - mute: bool, // by admin - deaf: bool, // by admin + mute: bool, // by admin + deaf: bool, // by admin } impl User { - pub fn new(mut msg: Box<msgs::UserState>) -> Self { + pub fn new(mut msg: msgs::UserState) -> Self { Self { channel: msg.get_channel_id(), comment: if msg.has_comment() { @@ -150,24 +348,17 @@ impl User { None }, name: msg.take_name(), - priority_speaker: msg.has_priority_speaker() - && msg.get_priority_speaker(), - recording: msg.has_recording() - && msg.get_recording(), - suppress: msg.has_suppress() - && msg.get_suppress(), - self_mute: msg.has_self_mute() - && msg.get_self_mute(), - self_deaf: msg.has_self_deaf() - && msg.get_self_deaf(), - mute: msg.has_mute() - && msg.get_mute(), - deaf: msg.has_deaf() - && msg.get_deaf(), - } - } - - pub fn parse_user_state(&mut self, mut msg: Box<msgs::UserState>) { + priority_speaker: msg.has_priority_speaker() && msg.get_priority_speaker(), + recording: msg.has_recording() && msg.get_recording(), + suppress: msg.has_suppress() && msg.get_suppress(), + self_mute: msg.has_self_mute() && msg.get_self_mute(), + self_deaf: msg.has_self_deaf() && msg.get_self_deaf(), + mute: msg.has_mute() && msg.get_mute(), + deaf: msg.has_deaf() && msg.get_deaf(), + } + } + + pub fn parse_user_state(&mut self, mut msg: msgs::UserState) { if msg.has_channel_id() { self.channel = msg.get_channel_id(); } |
