From af56f5e2d8ac0a2e253f776822b9a4e13419d2c6 Mon Sep 17 00:00:00 2001
From: Paul Makles <paulmakles@gmail.com>
Date: Tue, 29 Dec 2020 23:25:52 +0000
Subject: [PATCH] Add hive to main join!().

---
 src/main.rs                    |  9 +++++++--
 src/notifications/events.rs    | 28 ++++++++++++++++++++++++----
 src/notifications/hive.rs      | 21 ++++++++++++++++-----
 src/notifications/websocket.rs | 31 ++++++++++++++++++++++++++++++-
 src/routes/mod.rs              |  2 +-
 src/routes/root.rs             | 10 ----------
 6 files changed, 78 insertions(+), 23 deletions(-)

diff --git a/src/main.rs b/src/main.rs
index 46ed96f..f20e9f4 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -36,13 +36,18 @@ async fn entry() {
 
     util::variables::preflight_checks();
     database::connect().await;
+    notifications::hive::init_hive().await;
 
     ctrlc::set_handler(move || {
         // Force ungraceful exit to avoid hang.
         std::process::exit(0);
     }).expect("Error setting Ctrl-C handler");
-    
-    join!(launch_web(), notifications::websocket::launch_server());
+
+    join!(
+        launch_web(),
+        notifications::websocket::launch_server(),
+        notifications::hive::listen(),
+    );
 }
 
 async fn launch_web() {
diff --git a/src/notifications/events.rs b/src/notifications/events.rs
index 5ed7960..8814ada 100644
--- a/src/notifications/events.rs
+++ b/src/notifications/events.rs
@@ -1,9 +1,29 @@
 use serde::{Deserialize, Serialize};
+use rauth::auth::Session;
+use snafu::Snafu;
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Serialize, Deserialize, Debug, Snafu)]
 #[serde(tag = "type")]
-pub enum Notification {
-    MessageCreate {
+pub enum WebSocketError {
+    #[snafu(display("This error has not been labelled."))]
+    LabelMe,
+
+    #[snafu(display("Internal server error."))]
+    InternalError,
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type")]
+pub enum ServerboundNotification {
+    Authenticate(Session)
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+#[serde(tag = "type")]
+pub enum ClientboundNotification {
+    Error(WebSocketError),
+
+    /*MessageCreate {
         id: String,
         nonce: Option<String>,
         channel: String,
@@ -57,7 +77,7 @@ pub enum Notification {
 
     GuildDelete {
         id: String,
-    },
+    },*/
 
     UserRelationship {
         id: String,
diff --git a/src/notifications/hive.rs b/src/notifications/hive.rs
index ff2f04a..ec23101 100644
--- a/src/notifications/hive.rs
+++ b/src/notifications/hive.rs
@@ -1,14 +1,14 @@
-use super::events::Notification;
-// use super::websocket;
+use super::events::ClientboundNotification;
 use crate::database::get_collection;
 
 use hive_pubsub::backend::mongo::MongodbPubSub;
 use hive_pubsub::PubSub;
 use once_cell::sync::OnceCell;
 use serde_json::to_string;
+use futures::FutureExt;
 use log::{error, debug};
 
-static HIVE: OnceCell<MongodbPubSub<String, String, Notification>> = OnceCell::new();
+static HIVE: OnceCell<MongodbPubSub<String, String, ClientboundNotification>> = OnceCell::new();
 
 pub async fn init_hive() {
     let hive = MongodbPubSub::new(
@@ -20,7 +20,7 @@ pub async fn init_hive() {
                 error!("Failed to serialise notification.");
             }
         },
-        get_collection("hive"),
+        get_collection("pubsub"),
     );
 
     if HIVE.set(hive).is_err() {
@@ -28,7 +28,18 @@ pub async fn init_hive() {
     }
 }
 
-pub fn publish(topic: &String, data: Notification) -> Result<(), String> {
+pub async fn listen() {
+    HIVE.get()
+        .unwrap()
+        .listen()
+        .fuse()
+        .await
+        .expect("Hive hit an error");
+    
+    dbg!("a");
+}
+
+pub fn publish(topic: &String, data: ClientboundNotification) -> Result<(), String> {
     let hive = HIVE.get().unwrap();
     hive.publish(topic, data)
 }
diff --git a/src/notifications/websocket.rs b/src/notifications/websocket.rs
index dd98bd3..1b1593a 100644
--- a/src/notifications/websocket.rs
+++ b/src/notifications/websocket.rs
@@ -1,10 +1,25 @@
 use crate::util::variables::WS_HOST;
 
 use log::info;
+use ulid::Ulid;
 use async_std::task;
 use futures::prelude::*;
+use std::str::from_utf8;
+use std::sync::{Arc, RwLock};
+use many_to_many::ManyToMany;
+use std::collections::HashMap;
+use futures::stream::SplitSink;
+use async_tungstenite::WebSocketStream;
+use async_tungstenite::tungstenite::Message;
 use async_std::net::{TcpListener, TcpStream};
 
+lazy_static! {
+    static ref CONNECTIONS: Arc<RwLock<HashMap<String, SplitSink<WebSocketStream<TcpStream>, Message>>>> =
+        Arc::new(RwLock::new(HashMap::new()));
+    static ref USERS: Arc<RwLock<ManyToMany<String, String>>> =
+        Arc::new(RwLock::new(ManyToMany::new()));
+}
+
 pub async fn launch_server() {
     let try_socket = TcpListener::bind(WS_HOST.to_string()).await;
     let listener = try_socket.expect("Failed to bind");
@@ -23,6 +38,20 @@ async fn accept(stream: TcpStream) {
 
     info!("User established WebSocket connection from {}.", addr);
 
+    let id = Ulid::new().to_string();
     let (write, read) = ws_stream.split();
-    read.forward(write).await.expect("Failed to forward message")
+
+    CONNECTIONS
+        .write()
+        .unwrap()
+        .insert(id, write);
+
+    read
+        .for_each(|message| async {
+            let data = message.unwrap().into_data();
+            // if you mess with the data, you get the bazooki
+            let string = from_utf8(&data).unwrap();
+            println!("{}", string);
+        })
+        .await;
 }
diff --git a/src/routes/mod.rs b/src/routes/mod.rs
index 35f354f..2789410 100644
--- a/src/routes/mod.rs
+++ b/src/routes/mod.rs
@@ -10,7 +10,7 @@ mod channels;
 
 pub fn mount(rocket: Rocket) -> Rocket {
     rocket
-        .mount("/", routes![root::root, root::teapot])
+        .mount("/", routes![root::root])
         .mount("/onboard", onboard::routes())
         .mount("/users", users::routes())
         .mount("/channels", channels::routes())
diff --git a/src/routes/root.rs b/src/routes/root.rs
index bcede3c..e3f9fc7 100644
--- a/src/routes/root.rs
+++ b/src/routes/root.rs
@@ -3,7 +3,6 @@ use crate::util::variables::{DISABLE_REGISTRATION, HCAPTCHA_SITEKEY, USE_EMAIL,
 use rocket_contrib::json::JsonValue;
 use mongodb::bson::doc;
 
-/// root
 #[get("/")]
 pub async fn root() -> JsonValue {
     json!({
@@ -18,12 +17,3 @@ pub async fn root() -> JsonValue {
         }
     })
 }
-
-/// I'm a teapot.
-#[delete("/")]
-pub async fn teapot() -> JsonValue {
-    json!({
-        "teapot": true,
-        "can_delete": false
-    })
-}
-- 
GitLab