From 577f25642e3fe1fcdc43ac24a8d68d6fd226302f Mon Sep 17 00:00:00 2001 From: Paul Makles <paulmakles@gmail.com> Date: Mon, 13 Apr 2020 16:04:41 +0100 Subject: [PATCH] Re-write notifications system. --- Rocket.toml | 2 +- src/database/message.rs | 34 +++-- src/main.rs | 8 +- src/notifications/events/message.rs | 10 ++ src/notifications/events/mod.rs | 8 ++ src/notifications/mod.rs | 21 +++ src/notifications/pubsub.rs | 112 +++++++++++++++ src/notifications/state.rs | 210 ++++++++++++++++++++++++++++ src/notifications/ws.rs | 119 ++++++++++++++++ src/websocket/mod.rs | 193 ------------------------- 10 files changed, 506 insertions(+), 211 deletions(-) create mode 100644 src/notifications/events/message.rs create mode 100644 src/notifications/events/mod.rs create mode 100644 src/notifications/mod.rs create mode 100644 src/notifications/pubsub.rs create mode 100644 src/notifications/state.rs create mode 100644 src/notifications/ws.rs delete mode 100644 src/websocket/mod.rs diff --git a/Rocket.toml b/Rocket.toml index bc6d28f..a493d7f 100644 --- a/Rocket.toml +++ b/Rocket.toml @@ -4,4 +4,4 @@ port = 5500 [production] address = "192.168.0.10" -port = 5500 +port = 3000 diff --git a/src/database/message.rs b/src/database/message.rs index e47be11..f87bb54 100644 --- a/src/database/message.rs +++ b/src/database/message.rs @@ -1,6 +1,9 @@ use super::get_collection; use crate::guards::channel::ChannelRef; use crate::routes::channel::ChannelType; +use crate::notifications; +use crate::notifications::events::Notification::MessageCreate; +use crate::notifications::events::message::Create; use bson::{doc, to_bson, UtcDateTime}; use serde::{Deserialize, Serialize}; @@ -34,6 +37,22 @@ impl Message { .insert_one(to_bson(&self).unwrap().as_document().unwrap().clone(), None) .is_ok() { + let data = MessageCreate( + Create { + id: self.id.clone(), + nonce: self.nonce.clone(), + channel: self.channel.clone(), + author: self.author.clone(), + content: self.content.clone(), + } + ); + + match target.channel_type { + 0..=1 => notifications::send_message(target.recipients.clone(), None, data), + 2 => notifications::send_message(target.recipients.clone(), None, data), + _ => unreachable!() + }; + let short_content: String = self.content.chars().take(24).collect(); // !! this stuff can be async @@ -68,21 +87,6 @@ impl Message { } else { true } - - /*websocket::queue_message( - get_recipients(&target), - json!({ - "type": "message", - "data": { - "id": id.clone(), - "nonce": nonce, - "channel": target.id, - "author": user.id, - "content": content, - }, - }) - .to_string(), - );*/ } else { false } diff --git a/src/main.rs b/src/main.rs index 876bb51..691f57d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,9 +9,9 @@ extern crate bitfield; pub mod database; pub mod email; pub mod guards; +pub mod notifications; pub mod routes; pub mod util; -pub mod websocket; use dotenv; use rocket_cors::AllowedOrigins; @@ -22,7 +22,11 @@ fn main() { database::connect(); thread::spawn(|| { - websocket::launch_server(); + notifications::pubsub::launch_subscriber(); + }); + + thread::spawn(|| { + notifications::ws::launch_server(); }); let cors = rocket_cors::CorsOptions { diff --git a/src/notifications/events/message.rs b/src/notifications/events/message.rs new file mode 100644 index 0000000..5729488 --- /dev/null +++ b/src/notifications/events/message.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct Create { + pub id: String, + pub nonce: Option<String>, + pub channel: String, + pub author: String, + pub content: String, +} diff --git a/src/notifications/events/mod.rs b/src/notifications/events/mod.rs new file mode 100644 index 0000000..2143edd --- /dev/null +++ b/src/notifications/events/mod.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +pub mod message; + +#[derive(Serialize, Deserialize, Debug)] +pub enum Notification { + MessageCreate(message::Create), +} diff --git a/src/notifications/mod.rs b/src/notifications/mod.rs new file mode 100644 index 0000000..e5a04c4 --- /dev/null +++ b/src/notifications/mod.rs @@ -0,0 +1,21 @@ +pub mod events; +pub mod pubsub; +pub mod state; +pub mod ws; + +pub fn send_message<U: Into<Option<Vec<String>>>, G: Into<Option<String>>>( + users: U, + guild: G, + data: events::Notification, +) -> bool { + let users = users.into(); + let guild = guild.into(); + + if pubsub::send_message(users.clone(), guild.clone(), data) { + state::send_message(users, guild, "bruh".to_string()); + + true + } else { + false + } +} diff --git a/src/notifications/pubsub.rs b/src/notifications/pubsub.rs new file mode 100644 index 0000000..757fa10 --- /dev/null +++ b/src/notifications/pubsub.rs @@ -0,0 +1,112 @@ +use super::events::Notification; +use crate::database::get_collection; + +use bson::{doc, from_bson, to_bson, Bson}; +use mongodb::options::{CursorType, FindOneOptions, FindOptions}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use ulid::Ulid; + +use once_cell::sync::OnceCell; +static SOURCEID: OnceCell<String> = OnceCell::new(); + +#[derive(Serialize, Deserialize, Debug)] +pub struct PubSubMessage { + #[serde(rename = "_id")] + id: String, + source: String, + + user_recipients: Option<Vec<String>>, + target_guild: Option<String>, + + notification_type: String, + data: Notification, +} + +pub fn send_message( + users: Option<Vec<String>>, + guild: Option<String>, + data: Notification, +) -> bool { + let message = PubSubMessage { + id: Ulid::new().to_string(), + source: SOURCEID.get().unwrap().to_string(), + user_recipients: users.into(), + target_guild: guild.into(), + notification_type: match data { + Notification::MessageCreate(_) => "message_create", + } + .to_string(), + data, + }; + + if get_collection("pubsub") + .insert_one( + to_bson(&message) + .expect("Failed to serialize pubsub message.") + .as_document() + .expect("Failed to convert to a document.") + .clone(), + None, + ) + .is_ok() + { + true + } else { + false + } +} + +pub fn launch_subscriber() { + let source = Ulid::new().to_string(); + SOURCEID + .set(source.clone()) + .expect("Failed to create and set source ID."); + + let pubsub = get_collection("pubsub"); + if let Ok(result) = pubsub.find_one( + doc! {}, + FindOneOptions::builder().sort(doc! { "_id": -1 }).build(), + ) { + let query = if let Some(doc) = result { + doc! { "_id": { "$gt": doc.get_str("_id").unwrap() } } + } else { + doc! {} + }; + + if let Ok(mut cursor) = pubsub.find( + query, + FindOptions::builder() + .cursor_type(CursorType::TailableAwait) + .no_cursor_timeout(true) + .max_await_time(Duration::from_secs(1200)) + .build(), + ) { + loop { + while let Some(item) = cursor.next() { + if let Ok(doc) = item { + if let Ok(message) = + from_bson(Bson::Document(doc)) as Result<PubSubMessage, _> + { + if &message.source != &source { + super::state::send_message( + message.user_recipients, + message.target_guild, + json!(message.data).to_string(), + ); + } + } else { + eprintln!("Failed to deserialize pubsub message."); + } + } else { + eprintln!("Failed to unwrap a document from pubsub."); + } + } + } + } else { + eprintln!("Failed to open subscriber cursor."); + } + } else { + eprintln!("Failed to fetch latest document from pubsub collection."); + } +} diff --git a/src/notifications/state.rs b/src/notifications/state.rs new file mode 100644 index 0000000..8f4127f --- /dev/null +++ b/src/notifications/state.rs @@ -0,0 +1,210 @@ +use super::events::Notification; +use crate::database; +use crate::util::vec_to_set; + +use bson::doc; +use hashbrown::{HashMap, HashSet}; +use mongodb::options::FindOneOptions; +use once_cell::sync::OnceCell; +use std::sync::RwLock; +use ws::Sender; + +pub enum StateResult { + DatabaseError, + InvalidToken, + Success(String), +} + +static mut CONNECTIONS: OnceCell<RwLock<HashMap<String, Sender>>> = OnceCell::new(); + +pub fn add_connection(id: String, sender: Sender) { + unsafe { + CONNECTIONS + .get() + .unwrap() + .write() + .unwrap() + .insert(id, sender); + } +} + +pub struct User { + connections: HashSet<String>, + guilds: HashSet<String>, +} + +impl User { + pub fn new() -> User { + User { + connections: HashSet::new(), + guilds: HashSet::new(), + } + } +} + +pub struct Guild { + users: HashSet<String>, +} + +impl Guild { + pub fn new() -> Guild { + Guild { + users: HashSet::new(), + } + } +} + +pub struct GlobalState { + users: HashMap<String, User>, + guilds: HashMap<String, Guild>, +} + +impl GlobalState { + pub fn new() -> GlobalState { + GlobalState { + users: HashMap::new(), + guilds: HashMap::new(), + } + } + + pub fn push_to_guild(&mut self, guild: String, user: String) { + if !self.guilds.contains_key(&guild) { + self.guilds.insert(guild.clone(), Guild::new()); + } + + self.guilds.get_mut(&guild).unwrap().users.insert(user); + } + + pub fn try_authenticate(&mut self, connection: String, access_token: String) -> StateResult { + if let Ok(result) = database::get_collection("users").find_one( + doc! { + "access_token": access_token, + }, + FindOneOptions::builder() + .projection(doc! { "_id": 1 }) + .build(), + ) { + if let Some(user) = result { + let user_id = user.get_str("_id").unwrap(); + + if self.users.contains_key(user_id) { + self.users + .get_mut(user_id) + .unwrap() + .connections + .insert(connection); + + return StateResult::Success(user_id.to_string()); + } + + if let Ok(results) = + database::get_collection("members").find(doc! { "_id.user": &user_id }, None) + { + let mut guilds = vec![]; + for result in results { + if let Ok(entry) = result { + guilds.push( + entry + .get_document("_id") + .unwrap() + .get_str("guild") + .unwrap() + .to_string(), + ); + } + } + + let mut user = User::new(); + for guild in guilds { + user.guilds.insert(guild.clone()); + self.push_to_guild(guild, user_id.to_string()); + } + + user.connections.insert(connection); + self.users.insert(user_id.to_string(), user); + + StateResult::Success(user_id.to_string()) + } else { + StateResult::DatabaseError + } + } else { + StateResult::InvalidToken + } + } else { + StateResult::DatabaseError + } + } + + pub fn disconnect<U: Into<Option<String>>>(&mut self, user_id: U, connection: String) { + if let Some(user_id) = user_id.into() { + let user = self.users.get_mut(&user_id).unwrap(); + user.connections.remove(&connection); + + if user.connections.len() == 0 { + for guild in &user.guilds { + self.guilds.get_mut(guild).unwrap().users.remove(&user_id); + } + + self.users.remove(&user_id); + } + } + + unsafe { + CONNECTIONS + .get() + .unwrap() + .write() + .unwrap() + .remove(&connection); + } + } +} + +pub static mut DATA: OnceCell<RwLock<GlobalState>> = OnceCell::new(); + +pub fn init() { + unsafe { + if CONNECTIONS.set(RwLock::new(HashMap::new())).is_err() { + panic!("Failed to set global connections map."); + } + + if DATA.set(RwLock::new(GlobalState::new())).is_err() { + panic!("Failed to set global state."); + } + } +} + +pub fn send_message( + users: Option<Vec<String>>, + guild: Option<String>, + data: String, +) { + let state = unsafe { DATA.get().unwrap().read().unwrap() }; + let mut connections = HashSet::new(); + + let mut users = vec_to_set(&users.unwrap_or(vec![])); + if let Some(guild) = guild { + if let Some(entry) = state.guilds.get(&guild) { + for user in &entry.users { + users.insert(user.to_string()); + } + } + } + + for user in users { + if let Some(entry) = state.users.get(&user) { + for connection in &entry.connections { + connections.insert(connection.clone()); + } + } + } + + let targets = unsafe { CONNECTIONS.get().unwrap().read().unwrap() }; + for conn in connections { + if let Some(sender) = targets.get(&conn) { + if sender.send(data.clone()).is_err() { + eprintln!("Failed to send a notification to a websocket. [{}]", &conn); + } + } + } +} diff --git a/src/notifications/ws.rs b/src/notifications/ws.rs new file mode 100644 index 0000000..c853aea --- /dev/null +++ b/src/notifications/ws.rs @@ -0,0 +1,119 @@ +use super::state::{self, StateResult}; + +use serde_json::{from_str, json, Value}; +use std::env; +use ulid::Ulid; +use ws::{listen, CloseCode, Error, Handler, Handshake, Message, Result, Sender}; + +struct Server { + sender: Sender, + user_id: Option<String>, + id: String, +} + +impl Handler for Server { + fn on_open(&mut self, _: Handshake) -> Result<()> { + state::add_connection(self.id.clone(), self.sender.clone()); + Ok(()) + } + + fn on_message(&mut self, msg: Message) -> Result<()> { + if let Message::Text(text) = msg { + if let Ok(data) = from_str(&text) as std::result::Result<Value, _> { + if let Value::String(packet_type) = &data["type"] { + if packet_type == "authenticate" { + if self.user_id.is_some() { + self.sender.send( + json!({ + "type": "authenticate", + "success": false, + "error": "Already authenticated!" + }) + .to_string(), + ) + } else if let Value::String(token) = &data["token"] { + let mut state = unsafe { state::DATA.get().unwrap().write().unwrap() }; + + match state.try_authenticate(self.id.clone(), token.to_string()) { + StateResult::Success(user_id) => { + self.user_id = Some(user_id); + self.sender.send( + json!({ + "type": "authenticate", + "success": true, + }) + .to_string(), + ) + } + StateResult::DatabaseError => self.sender.send( + json!({ + "type": "authenticate", + "success": false, + "error": "Had database error." + }) + .to_string(), + ), + StateResult::InvalidToken => self.sender.send( + json!({ + "type": "authenticate", + "success": false, + "error": "Invalid token." + }) + .to_string(), + ), + } + } else { + self.sender.send( + json!({ + "type": "authenticate", + "success": false, + "error": "Token not present." + }) + .to_string(), + ) + } + } else { + Ok(()) + } + } else { + Ok(()) + } + } else { + Ok(()) + } + } else { + Ok(()) + } + } + + fn on_close(&mut self, _code: CloseCode, _reason: &str) { + unsafe { + state::DATA + .get() + .unwrap() + .write() + .unwrap() + .disconnect(self.user_id.clone(), self.id.clone()); + } + + println!("User disconnected. [{}]", self.id); + } + + fn on_error(&mut self, err: Error) { + println!("The server encountered an error: {:?}", err); + } +} + +pub fn launch_server() { + state::init(); + + listen( + env::var("WS_HOST").unwrap_or("0.0.0.0:9999".to_string()), + |sender| Server { + sender, + user_id: None, + id: Ulid::new().to_string(), + }, + ) + .unwrap() +} diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs deleted file mode 100644 index 6a7f521..0000000 --- a/src/websocket/mod.rs +++ /dev/null @@ -1,193 +0,0 @@ -extern crate ws; - -use crate::database; - -use hashbrown::HashMap; -use std::sync::RwLock; -use ulid::Ulid; - -use bson::doc; -use serde_json::{from_str, json, Value}; - -use ws::{listen, CloseCode, Error, Handler, Handshake, Message, Result, Sender}; - -struct Cell { - id: String, - out: Sender, -} - -use once_cell::sync::OnceCell; -static mut CLIENTS: OnceCell<RwLock<HashMap<String, Vec<Cell>>>> = OnceCell::new(); - -struct Server { - out: Sender, - id: Option<String>, - internal: String, -} - -impl Handler for Server { - fn on_open(&mut self, _: Handshake) -> Result<()> { - Ok(()) - } - - fn on_message(&mut self, msg: Message) -> Result<()> { - if let Message::Text(text) = msg { - let data: Value = from_str(&text).unwrap(); - - if let Value::String(packet_type) = &data["type"] { - match packet_type.as_str() { - "authenticate" => { - if self.id.is_some() { - self.out.send( - json!({ - "type": "authenticate", - "success": false, - "error": "Already authenticated!" - }) - .to_string(), - ) - } else if let Value::String(token) = &data["token"] { - let col = database::get_collection("users"); - - match col.find_one(doc! { "access_token": token }, None).unwrap() { - Some(u) => { - let id = u.get_str("_id").expect("Missing id."); - - unsafe { - let mut map = CLIENTS.get_mut().unwrap().write().unwrap(); - let cell = Cell { - id: self.internal.clone(), - out: self.out.clone(), - }; - if map.contains_key(&id.to_string()) { - map.get_mut(&id.to_string()).unwrap().push(cell); - } else { - map.insert(id.to_string(), vec![cell]); - } - } - - println!( - "Websocket client connected. [ID: {} // {}]", - id.to_string(), - self.internal - ); - - self.id = Some(id.to_string()); - self.out.send( - json!({ - "type": "authenticate", - "success": true - }) - .to_string(), - ) - } - None => self.out.send( - json!({ - "type": "authenticate", - "success": false, - "error": "Invalid authentication token." - }) - .to_string(), - ), - } - } else { - self.out.send( - json!({ - "type": "authenticate", - "success": false, - "error": "Missing authentication token." - }) - .to_string(), - ) - } - } - _ => Ok(()), - } - } else { - Ok(()) - } - } else { - Ok(()) - } - } - - fn on_close(&mut self, code: CloseCode, reason: &str) { - match code { - CloseCode::Normal => println!("The client is done with the connection."), - CloseCode::Away => println!("The client is leaving the site."), - CloseCode::Abnormal => { - println!("Closing handshake failed! Unable to obtain closing status from client.") - } - _ => println!("The client encountered an error: {}", reason), - } - - if let Some(id) = &self.id { - println!( - "Websocket client disconnected. [ID: {} // {}]", - id, self.internal - ); - unsafe { - let mut map = CLIENTS.get_mut().unwrap().write().unwrap(); - let arr = map.get_mut(&id.clone()).unwrap(); - - if arr.len() == 1 { - map.remove(&id.clone()); - } else { - let index = arr.iter().position(|x| x.id == self.internal).unwrap(); - arr.remove(index); - println!( - "User [{}] is still connected {} times", - self.id.as_ref().unwrap(), - arr.len() - ); - } - } - } - } - - fn on_error(&mut self, err: Error) { - println!("The server encountered an error: {:?}", err); - } -} - -pub fn launch_server() { - unsafe { - if CLIENTS.set(RwLock::new(HashMap::new())).is_err() { - panic!("Failed to set CLIENTS map!"); - } - } - - listen("192.168.0.10:9999", |out| Server { - out: out, - id: None, - internal: Ulid::new().to_string(), - }) - .unwrap() -} - -pub fn send_message(id: String, message: String) -> std::result::Result<(), ()> { - unsafe { - let map = CLIENTS.get().unwrap().read().unwrap(); - if map.contains_key(&id) { - let arr = map.get(&id).unwrap(); - - for item in arr { - if item.out.send(message.clone()).is_err() { - return Err(()); - } - } - } - - Ok(()) - } -} - -// ! TODO: WRITE THREADED QUEUE SYSTEM -// ! FETCH RECIPIENTS HERE INSTEAD OF IN METHOD - -pub fn queue_message(ids: Vec<String>, message: String) { - for id in ids { - send_message(id, message.clone()) - .expect("uhhhhhhhhhh can i get uhhhhhhhhhhhhhhhhhh mcdonald cheese burger with fries"); - } -} -- GitLab