From 577f25642e3fe1fcdc43ac24a8d68d6fd226302f Mon Sep 17 00:00:00 2001
From: Paul Makles <paulmakles@gmail.com>
Date: Mon, 13 Apr 2020 16:04:41 +0100
Subject: [PATCH] Re-write notifications system.

---
 Rocket.toml                         |   2 +-
 src/database/message.rs             |  34 +++--
 src/main.rs                         |   8 +-
 src/notifications/events/message.rs |  10 ++
 src/notifications/events/mod.rs     |   8 ++
 src/notifications/mod.rs            |  21 +++
 src/notifications/pubsub.rs         | 112 +++++++++++++++
 src/notifications/state.rs          | 210 ++++++++++++++++++++++++++++
 src/notifications/ws.rs             | 119 ++++++++++++++++
 src/websocket/mod.rs                | 193 -------------------------
 10 files changed, 506 insertions(+), 211 deletions(-)
 create mode 100644 src/notifications/events/message.rs
 create mode 100644 src/notifications/events/mod.rs
 create mode 100644 src/notifications/mod.rs
 create mode 100644 src/notifications/pubsub.rs
 create mode 100644 src/notifications/state.rs
 create mode 100644 src/notifications/ws.rs
 delete mode 100644 src/websocket/mod.rs

diff --git a/Rocket.toml b/Rocket.toml
index bc6d28f..a493d7f 100644
--- a/Rocket.toml
+++ b/Rocket.toml
@@ -4,4 +4,4 @@ port = 5500
 
 [production]
 address = "192.168.0.10"
-port = 5500
+port = 3000
diff --git a/src/database/message.rs b/src/database/message.rs
index e47be11..f87bb54 100644
--- a/src/database/message.rs
+++ b/src/database/message.rs
@@ -1,6 +1,9 @@
 use super::get_collection;
 use crate::guards::channel::ChannelRef;
 use crate::routes::channel::ChannelType;
+use crate::notifications;
+use crate::notifications::events::Notification::MessageCreate;
+use crate::notifications::events::message::Create;
 
 use bson::{doc, to_bson, UtcDateTime};
 use serde::{Deserialize, Serialize};
@@ -34,6 +37,22 @@ impl Message {
             .insert_one(to_bson(&self).unwrap().as_document().unwrap().clone(), None)
             .is_ok()
         {
+            let data = MessageCreate(
+                Create {
+                    id: self.id.clone(),
+                    nonce: self.nonce.clone(),
+                    channel: self.channel.clone(),
+                    author: self.author.clone(),
+                    content: self.content.clone(),
+                }
+            );
+
+            match target.channel_type {
+                0..=1 => notifications::send_message(target.recipients.clone(), None, data),
+                2 => notifications::send_message(target.recipients.clone(), None, data),
+                _ => unreachable!()
+            };
+
             let short_content: String = self.content.chars().take(24).collect();
 
             // !! this stuff can be async
@@ -68,21 +87,6 @@ impl Message {
             } else {
                 true
             }
-
-        /*websocket::queue_message(
-            get_recipients(&target),
-            json!({
-                "type": "message",
-                "data": {
-                    "id": id.clone(),
-                    "nonce": nonce,
-                    "channel": target.id,
-                    "author": user.id,
-                    "content": content,
-                },
-            })
-            .to_string(),
-        );*/
         } else {
             false
         }
diff --git a/src/main.rs b/src/main.rs
index 876bb51..691f57d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -9,9 +9,9 @@ extern crate bitfield;
 pub mod database;
 pub mod email;
 pub mod guards;
+pub mod notifications;
 pub mod routes;
 pub mod util;
-pub mod websocket;
 
 use dotenv;
 use rocket_cors::AllowedOrigins;
@@ -22,7 +22,11 @@ fn main() {
     database::connect();
 
     thread::spawn(|| {
-        websocket::launch_server();
+        notifications::pubsub::launch_subscriber();
+    });
+
+    thread::spawn(|| {
+        notifications::ws::launch_server();
     });
 
     let cors = rocket_cors::CorsOptions {
diff --git a/src/notifications/events/message.rs b/src/notifications/events/message.rs
new file mode 100644
index 0000000..5729488
--- /dev/null
+++ b/src/notifications/events/message.rs
@@ -0,0 +1,10 @@
+use serde::{Deserialize, Serialize};
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Create {
+    pub id: String,
+    pub nonce: Option<String>,
+    pub channel: String,
+    pub author: String,
+    pub content: String,
+}
diff --git a/src/notifications/events/mod.rs b/src/notifications/events/mod.rs
new file mode 100644
index 0000000..2143edd
--- /dev/null
+++ b/src/notifications/events/mod.rs
@@ -0,0 +1,8 @@
+use serde::{Deserialize, Serialize};
+
+pub mod message;
+
+#[derive(Serialize, Deserialize, Debug)]
+pub enum Notification {
+    MessageCreate(message::Create),
+}
diff --git a/src/notifications/mod.rs b/src/notifications/mod.rs
new file mode 100644
index 0000000..e5a04c4
--- /dev/null
+++ b/src/notifications/mod.rs
@@ -0,0 +1,21 @@
+pub mod events;
+pub mod pubsub;
+pub mod state;
+pub mod ws;
+
+pub fn send_message<U: Into<Option<Vec<String>>>, G: Into<Option<String>>>(
+    users: U,
+    guild: G,
+    data: events::Notification,
+) -> bool {
+    let users = users.into();
+    let guild = guild.into();
+
+    if pubsub::send_message(users.clone(), guild.clone(), data) {
+        state::send_message(users, guild, "bruh".to_string());
+
+        true
+    } else {
+        false
+    }
+}
diff --git a/src/notifications/pubsub.rs b/src/notifications/pubsub.rs
new file mode 100644
index 0000000..757fa10
--- /dev/null
+++ b/src/notifications/pubsub.rs
@@ -0,0 +1,112 @@
+use super::events::Notification;
+use crate::database::get_collection;
+
+use bson::{doc, from_bson, to_bson, Bson};
+use mongodb::options::{CursorType, FindOneOptions, FindOptions};
+use serde::{Deserialize, Serialize};
+use std::time::Duration;
+use ulid::Ulid;
+
+use once_cell::sync::OnceCell;
+static SOURCEID: OnceCell<String> = OnceCell::new();
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct PubSubMessage {
+    #[serde(rename = "_id")]
+    id: String,
+    source: String,
+
+    user_recipients: Option<Vec<String>>,
+    target_guild: Option<String>,
+
+    notification_type: String,
+    data: Notification,
+}
+
+pub fn send_message(
+    users: Option<Vec<String>>,
+    guild: Option<String>,
+    data: Notification,
+) -> bool {
+    let message = PubSubMessage {
+        id: Ulid::new().to_string(),
+        source: SOURCEID.get().unwrap().to_string(),
+        user_recipients: users.into(),
+        target_guild: guild.into(),
+        notification_type: match data {
+            Notification::MessageCreate(_) => "message_create",
+        }
+        .to_string(),
+        data,
+    };
+
+    if get_collection("pubsub")
+        .insert_one(
+            to_bson(&message)
+                .expect("Failed to serialize pubsub message.")
+                .as_document()
+                .expect("Failed to convert to a document.")
+                .clone(),
+            None,
+        )
+        .is_ok()
+    {
+        true
+    } else {
+        false
+    }
+}
+
+pub fn launch_subscriber() {
+    let source = Ulid::new().to_string();
+    SOURCEID
+        .set(source.clone())
+        .expect("Failed to create and set source ID.");
+
+    let pubsub = get_collection("pubsub");
+    if let Ok(result) = pubsub.find_one(
+        doc! {},
+        FindOneOptions::builder().sort(doc! { "_id": -1 }).build(),
+    ) {
+        let query = if let Some(doc) = result {
+            doc! { "_id": { "$gt": doc.get_str("_id").unwrap() } }
+        } else {
+            doc! {}
+        };
+
+        if let Ok(mut cursor) = pubsub.find(
+            query,
+            FindOptions::builder()
+                .cursor_type(CursorType::TailableAwait)
+                .no_cursor_timeout(true)
+                .max_await_time(Duration::from_secs(1200))
+                .build(),
+        ) {
+            loop {
+                while let Some(item) = cursor.next() {
+                    if let Ok(doc) = item {
+                        if let Ok(message) =
+                            from_bson(Bson::Document(doc)) as Result<PubSubMessage, _>
+                        {
+                            if &message.source != &source {
+                                super::state::send_message(
+                                    message.user_recipients,
+                                    message.target_guild,
+                                    json!(message.data).to_string(),
+                                );
+                            }
+                        } else {
+                            eprintln!("Failed to deserialize pubsub message.");
+                        }
+                    } else {
+                        eprintln!("Failed to unwrap a document from pubsub.");
+                    }
+                }
+            }
+        } else {
+            eprintln!("Failed to open subscriber cursor.");
+        }
+    } else {
+        eprintln!("Failed to fetch latest document from pubsub collection.");
+    }
+}
diff --git a/src/notifications/state.rs b/src/notifications/state.rs
new file mode 100644
index 0000000..8f4127f
--- /dev/null
+++ b/src/notifications/state.rs
@@ -0,0 +1,210 @@
+use super::events::Notification;
+use crate::database;
+use crate::util::vec_to_set;
+
+use bson::doc;
+use hashbrown::{HashMap, HashSet};
+use mongodb::options::FindOneOptions;
+use once_cell::sync::OnceCell;
+use std::sync::RwLock;
+use ws::Sender;
+
+pub enum StateResult {
+    DatabaseError,
+    InvalidToken,
+    Success(String),
+}
+
+static mut CONNECTIONS: OnceCell<RwLock<HashMap<String, Sender>>> = OnceCell::new();
+
+pub fn add_connection(id: String, sender: Sender) {
+    unsafe {
+        CONNECTIONS
+            .get()
+            .unwrap()
+            .write()
+            .unwrap()
+            .insert(id, sender);
+    }
+}
+
+pub struct User {
+    connections: HashSet<String>,
+    guilds: HashSet<String>,
+}
+
+impl User {
+    pub fn new() -> User {
+        User {
+            connections: HashSet::new(),
+            guilds: HashSet::new(),
+        }
+    }
+}
+
+pub struct Guild {
+    users: HashSet<String>,
+}
+
+impl Guild {
+    pub fn new() -> Guild {
+        Guild {
+            users: HashSet::new(),
+        }
+    }
+}
+
+pub struct GlobalState {
+    users: HashMap<String, User>,
+    guilds: HashMap<String, Guild>,
+}
+
+impl GlobalState {
+    pub fn new() -> GlobalState {
+        GlobalState {
+            users: HashMap::new(),
+            guilds: HashMap::new(),
+        }
+    }
+
+    pub fn push_to_guild(&mut self, guild: String, user: String) {
+        if !self.guilds.contains_key(&guild) {
+            self.guilds.insert(guild.clone(), Guild::new());
+        }
+
+        self.guilds.get_mut(&guild).unwrap().users.insert(user);
+    }
+
+    pub fn try_authenticate(&mut self, connection: String, access_token: String) -> StateResult {
+        if let Ok(result) = database::get_collection("users").find_one(
+            doc! {
+                "access_token": access_token,
+            },
+            FindOneOptions::builder()
+                .projection(doc! { "_id": 1 })
+                .build(),
+        ) {
+            if let Some(user) = result {
+                let user_id = user.get_str("_id").unwrap();
+
+                if self.users.contains_key(user_id) {
+                    self.users
+                        .get_mut(user_id)
+                        .unwrap()
+                        .connections
+                        .insert(connection);
+
+                    return StateResult::Success(user_id.to_string());
+                }
+
+                if let Ok(results) =
+                    database::get_collection("members").find(doc! { "_id.user": &user_id }, None)
+                {
+                    let mut guilds = vec![];
+                    for result in results {
+                        if let Ok(entry) = result {
+                            guilds.push(
+                                entry
+                                    .get_document("_id")
+                                    .unwrap()
+                                    .get_str("guild")
+                                    .unwrap()
+                                    .to_string(),
+                            );
+                        }
+                    }
+
+                    let mut user = User::new();
+                    for guild in guilds {
+                        user.guilds.insert(guild.clone());
+                        self.push_to_guild(guild, user_id.to_string());
+                    }
+
+                    user.connections.insert(connection);
+                    self.users.insert(user_id.to_string(), user);
+
+                    StateResult::Success(user_id.to_string())
+                } else {
+                    StateResult::DatabaseError
+                }
+            } else {
+                StateResult::InvalidToken
+            }
+        } else {
+            StateResult::DatabaseError
+        }
+    }
+
+    pub fn disconnect<U: Into<Option<String>>>(&mut self, user_id: U, connection: String) {
+        if let Some(user_id) = user_id.into() {
+            let user = self.users.get_mut(&user_id).unwrap();
+            user.connections.remove(&connection);
+
+            if user.connections.len() == 0 {
+                for guild in &user.guilds {
+                    self.guilds.get_mut(guild).unwrap().users.remove(&user_id);
+                }
+
+                self.users.remove(&user_id);
+            }
+        }
+
+        unsafe {
+            CONNECTIONS
+                .get()
+                .unwrap()
+                .write()
+                .unwrap()
+                .remove(&connection);
+        }
+    }
+}
+
+pub static mut DATA: OnceCell<RwLock<GlobalState>> = OnceCell::new();
+
+pub fn init() {
+    unsafe {
+        if CONNECTIONS.set(RwLock::new(HashMap::new())).is_err() {
+            panic!("Failed to set global connections map.");
+        }
+
+        if DATA.set(RwLock::new(GlobalState::new())).is_err() {
+            panic!("Failed to set global state.");
+        }
+    }
+}
+
+pub fn send_message(
+    users: Option<Vec<String>>,
+    guild: Option<String>,
+    data: String,
+) {
+    let state = unsafe { DATA.get().unwrap().read().unwrap() };
+    let mut connections = HashSet::new();
+    
+    let mut users = vec_to_set(&users.unwrap_or(vec![]));
+    if let Some(guild) = guild {
+        if let Some(entry) = state.guilds.get(&guild) {
+            for user in &entry.users {
+                users.insert(user.to_string());
+            }
+        }
+    }
+
+    for user in users {
+        if let Some(entry) = state.users.get(&user) {
+            for connection in &entry.connections {
+                connections.insert(connection.clone());
+            }
+        }
+    }
+
+    let targets = unsafe { CONNECTIONS.get().unwrap().read().unwrap() };
+    for conn in connections {
+        if let Some(sender) = targets.get(&conn) {
+            if sender.send(data.clone()).is_err() {
+                eprintln!("Failed to send a notification to a websocket. [{}]", &conn);
+            }
+        }
+    }
+}
diff --git a/src/notifications/ws.rs b/src/notifications/ws.rs
new file mode 100644
index 0000000..c853aea
--- /dev/null
+++ b/src/notifications/ws.rs
@@ -0,0 +1,119 @@
+use super::state::{self, StateResult};
+
+use serde_json::{from_str, json, Value};
+use std::env;
+use ulid::Ulid;
+use ws::{listen, CloseCode, Error, Handler, Handshake, Message, Result, Sender};
+
+struct Server {
+    sender: Sender,
+    user_id: Option<String>,
+    id: String,
+}
+
+impl Handler for Server {
+    fn on_open(&mut self, _: Handshake) -> Result<()> {
+        state::add_connection(self.id.clone(), self.sender.clone());
+        Ok(())
+    }
+
+    fn on_message(&mut self, msg: Message) -> Result<()> {
+        if let Message::Text(text) = msg {
+            if let Ok(data) = from_str(&text) as std::result::Result<Value, _> {
+                if let Value::String(packet_type) = &data["type"] {
+                    if packet_type == "authenticate" {
+                        if self.user_id.is_some() {
+                            self.sender.send(
+                                json!({
+                                    "type": "authenticate",
+                                    "success": false,
+                                    "error": "Already authenticated!"
+                                })
+                                .to_string(),
+                            )
+                        } else if let Value::String(token) = &data["token"] {
+                            let mut state = unsafe { state::DATA.get().unwrap().write().unwrap() };
+
+                            match state.try_authenticate(self.id.clone(), token.to_string()) {
+                                StateResult::Success(user_id) => {
+                                    self.user_id = Some(user_id);
+                                    self.sender.send(
+                                        json!({
+                                            "type": "authenticate",
+                                            "success": true,
+                                        })
+                                        .to_string(),
+                                    )
+                                }
+                                StateResult::DatabaseError => self.sender.send(
+                                    json!({
+                                        "type": "authenticate",
+                                        "success": false,
+                                        "error": "Had database error."
+                                    })
+                                    .to_string(),
+                                ),
+                                StateResult::InvalidToken => self.sender.send(
+                                    json!({
+                                        "type": "authenticate",
+                                        "success": false,
+                                        "error": "Invalid token."
+                                    })
+                                    .to_string(),
+                                ),
+                            }
+                        } else {
+                            self.sender.send(
+                                json!({
+                                    "type": "authenticate",
+                                    "success": false,
+                                    "error": "Token not present."
+                                })
+                                .to_string(),
+                            )
+                        }
+                    } else {
+                        Ok(())
+                    }
+                } else {
+                    Ok(())
+                }
+            } else {
+                Ok(())
+            }
+        } else {
+            Ok(())
+        }
+    }
+
+    fn on_close(&mut self, _code: CloseCode, _reason: &str) {
+        unsafe {
+            state::DATA
+                .get()
+                .unwrap()
+                .write()
+                .unwrap()
+                .disconnect(self.user_id.clone(), self.id.clone());
+        }
+
+        println!("User disconnected. [{}]", self.id);
+    }
+
+    fn on_error(&mut self, err: Error) {
+        println!("The server encountered an error: {:?}", err);
+    }
+}
+
+pub fn launch_server() {
+    state::init();
+
+    listen(
+        env::var("WS_HOST").unwrap_or("0.0.0.0:9999".to_string()),
+        |sender| Server {
+            sender,
+            user_id: None,
+            id: Ulid::new().to_string(),
+        },
+    )
+    .unwrap()
+}
diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs
deleted file mode 100644
index 6a7f521..0000000
--- a/src/websocket/mod.rs
+++ /dev/null
@@ -1,193 +0,0 @@
-extern crate ws;
-
-use crate::database;
-
-use hashbrown::HashMap;
-use std::sync::RwLock;
-use ulid::Ulid;
-
-use bson::doc;
-use serde_json::{from_str, json, Value};
-
-use ws::{listen, CloseCode, Error, Handler, Handshake, Message, Result, Sender};
-
-struct Cell {
-    id: String,
-    out: Sender,
-}
-
-use once_cell::sync::OnceCell;
-static mut CLIENTS: OnceCell<RwLock<HashMap<String, Vec<Cell>>>> = OnceCell::new();
-
-struct Server {
-    out: Sender,
-    id: Option<String>,
-    internal: String,
-}
-
-impl Handler for Server {
-    fn on_open(&mut self, _: Handshake) -> Result<()> {
-        Ok(())
-    }
-
-    fn on_message(&mut self, msg: Message) -> Result<()> {
-        if let Message::Text(text) = msg {
-            let data: Value = from_str(&text).unwrap();
-
-            if let Value::String(packet_type) = &data["type"] {
-                match packet_type.as_str() {
-                    "authenticate" => {
-                        if self.id.is_some() {
-                            self.out.send(
-                                json!({
-                                    "type": "authenticate",
-                                    "success": false,
-                                    "error": "Already authenticated!"
-                                })
-                                .to_string(),
-                            )
-                        } else if let Value::String(token) = &data["token"] {
-                            let col = database::get_collection("users");
-
-                            match col.find_one(doc! { "access_token": token }, None).unwrap() {
-                                Some(u) => {
-                                    let id = u.get_str("_id").expect("Missing id.");
-
-                                    unsafe {
-                                        let mut map = CLIENTS.get_mut().unwrap().write().unwrap();
-                                        let cell = Cell {
-                                            id: self.internal.clone(),
-                                            out: self.out.clone(),
-                                        };
-                                        if map.contains_key(&id.to_string()) {
-                                            map.get_mut(&id.to_string()).unwrap().push(cell);
-                                        } else {
-                                            map.insert(id.to_string(), vec![cell]);
-                                        }
-                                    }
-
-                                    println!(
-                                        "Websocket client connected. [ID: {} // {}]",
-                                        id.to_string(),
-                                        self.internal
-                                    );
-
-                                    self.id = Some(id.to_string());
-                                    self.out.send(
-                                        json!({
-                                            "type": "authenticate",
-                                            "success": true
-                                        })
-                                        .to_string(),
-                                    )
-                                }
-                                None => self.out.send(
-                                    json!({
-                                        "type": "authenticate",
-                                        "success": false,
-                                        "error": "Invalid authentication token."
-                                    })
-                                    .to_string(),
-                                ),
-                            }
-                        } else {
-                            self.out.send(
-                                json!({
-                                    "type": "authenticate",
-                                    "success": false,
-                                    "error": "Missing authentication token."
-                                })
-                                .to_string(),
-                            )
-                        }
-                    }
-                    _ => Ok(()),
-                }
-            } else {
-                Ok(())
-            }
-        } else {
-            Ok(())
-        }
-    }
-
-    fn on_close(&mut self, code: CloseCode, reason: &str) {
-        match code {
-            CloseCode::Normal => println!("The client is done with the connection."),
-            CloseCode::Away => println!("The client is leaving the site."),
-            CloseCode::Abnormal => {
-                println!("Closing handshake failed! Unable to obtain closing status from client.")
-            }
-            _ => println!("The client encountered an error: {}", reason),
-        }
-
-        if let Some(id) = &self.id {
-            println!(
-                "Websocket client disconnected. [ID: {} // {}]",
-                id, self.internal
-            );
-            unsafe {
-                let mut map = CLIENTS.get_mut().unwrap().write().unwrap();
-                let arr = map.get_mut(&id.clone()).unwrap();
-
-                if arr.len() == 1 {
-                    map.remove(&id.clone());
-                } else {
-                    let index = arr.iter().position(|x| x.id == self.internal).unwrap();
-                    arr.remove(index);
-                    println!(
-                        "User [{}] is still connected {} times",
-                        self.id.as_ref().unwrap(),
-                        arr.len()
-                    );
-                }
-            }
-        }
-    }
-
-    fn on_error(&mut self, err: Error) {
-        println!("The server encountered an error: {:?}", err);
-    }
-}
-
-pub fn launch_server() {
-    unsafe {
-        if CLIENTS.set(RwLock::new(HashMap::new())).is_err() {
-            panic!("Failed to set CLIENTS map!");
-        }
-    }
-
-    listen("192.168.0.10:9999", |out| Server {
-        out: out,
-        id: None,
-        internal: Ulid::new().to_string(),
-    })
-    .unwrap()
-}
-
-pub fn send_message(id: String, message: String) -> std::result::Result<(), ()> {
-    unsafe {
-        let map = CLIENTS.get().unwrap().read().unwrap();
-        if map.contains_key(&id) {
-            let arr = map.get(&id).unwrap();
-
-            for item in arr {
-                if item.out.send(message.clone()).is_err() {
-                    return Err(());
-                }
-            }
-        }
-
-        Ok(())
-    }
-}
-
-// ! TODO: WRITE THREADED QUEUE SYSTEM
-// ! FETCH RECIPIENTS HERE INSTEAD OF IN METHOD
-
-pub fn queue_message(ids: Vec<String>, message: String) {
-    for id in ids {
-        send_message(id, message.clone())
-            .expect("uhhhhhhhhhh can i get uhhhhhhhhhhhhhhhhhh mcdonald cheese burger with fries");
-    }
-}
-- 
GitLab