diff --git a/src/main.rs b/src/main.rs index 46ed96fcbb5e8a88e5237346586119bae7e2463a..f20e9f4de488d2c551b6a9dd52b266e67d24303a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,13 +36,18 @@ async fn entry() { util::variables::preflight_checks(); database::connect().await; + notifications::hive::init_hive().await; ctrlc::set_handler(move || { // Force ungraceful exit to avoid hang. std::process::exit(0); }).expect("Error setting Ctrl-C handler"); - - join!(launch_web(), notifications::websocket::launch_server()); + + join!( + launch_web(), + notifications::websocket::launch_server(), + notifications::hive::listen(), + ); } async fn launch_web() { diff --git a/src/notifications/events.rs b/src/notifications/events.rs index 5ed7960eae45b565b8a5711a7a4a810a80695b42..8814ada3cc87d8d19137a8ae9a9ec6b221b0209d 100644 --- a/src/notifications/events.rs +++ b/src/notifications/events.rs @@ -1,9 +1,29 @@ use serde::{Deserialize, Serialize}; +use rauth::auth::Session; +use snafu::Snafu; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Snafu)] #[serde(tag = "type")] -pub enum Notification { - MessageCreate { +pub enum WebSocketError { + #[snafu(display("This error has not been labelled."))] + LabelMe, + + #[snafu(display("Internal server error."))] + InternalError, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type")] +pub enum ServerboundNotification { + Authenticate(Session) +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum ClientboundNotification { + Error(WebSocketError), + + /*MessageCreate { id: String, nonce: Option<String>, channel: String, @@ -57,7 +77,7 @@ pub enum Notification { GuildDelete { id: String, - }, + },*/ UserRelationship { id: String, diff --git a/src/notifications/hive.rs b/src/notifications/hive.rs index ff2f04a35f887f2314fe8b03229eb86416a34743..ec23101a88b6472b71291fcaab3d42935083cc4c 100644 --- a/src/notifications/hive.rs +++ b/src/notifications/hive.rs @@ -1,14 +1,14 @@ -use super::events::Notification; -// use super::websocket; +use super::events::ClientboundNotification; use crate::database::get_collection; use hive_pubsub::backend::mongo::MongodbPubSub; use hive_pubsub::PubSub; use once_cell::sync::OnceCell; use serde_json::to_string; +use futures::FutureExt; use log::{error, debug}; -static HIVE: OnceCell<MongodbPubSub<String, String, Notification>> = OnceCell::new(); +static HIVE: OnceCell<MongodbPubSub<String, String, ClientboundNotification>> = OnceCell::new(); pub async fn init_hive() { let hive = MongodbPubSub::new( @@ -20,7 +20,7 @@ pub async fn init_hive() { error!("Failed to serialise notification."); } }, - get_collection("hive"), + get_collection("pubsub"), ); if HIVE.set(hive).is_err() { @@ -28,7 +28,18 @@ pub async fn init_hive() { } } -pub fn publish(topic: &String, data: Notification) -> Result<(), String> { +pub async fn listen() { + HIVE.get() + .unwrap() + .listen() + .fuse() + .await + .expect("Hive hit an error"); + + dbg!("a"); +} + +pub fn publish(topic: &String, data: ClientboundNotification) -> Result<(), String> { let hive = HIVE.get().unwrap(); hive.publish(topic, data) } diff --git a/src/notifications/websocket.rs b/src/notifications/websocket.rs index dd98bd3834eb243a49a9f1a01a6ca404c23a5f53..1b1593af8b97184f78eda625c7f04c18fc1b1bd1 100644 --- a/src/notifications/websocket.rs +++ b/src/notifications/websocket.rs @@ -1,10 +1,25 @@ use crate::util::variables::WS_HOST; use log::info; +use ulid::Ulid; use async_std::task; use futures::prelude::*; +use std::str::from_utf8; +use std::sync::{Arc, RwLock}; +use many_to_many::ManyToMany; +use std::collections::HashMap; +use futures::stream::SplitSink; +use async_tungstenite::WebSocketStream; +use async_tungstenite::tungstenite::Message; use async_std::net::{TcpListener, TcpStream}; +lazy_static! { + static ref CONNECTIONS: Arc<RwLock<HashMap<String, SplitSink<WebSocketStream<TcpStream>, Message>>>> = + Arc::new(RwLock::new(HashMap::new())); + static ref USERS: Arc<RwLock<ManyToMany<String, String>>> = + Arc::new(RwLock::new(ManyToMany::new())); +} + pub async fn launch_server() { let try_socket = TcpListener::bind(WS_HOST.to_string()).await; let listener = try_socket.expect("Failed to bind"); @@ -23,6 +38,20 @@ async fn accept(stream: TcpStream) { info!("User established WebSocket connection from {}.", addr); + let id = Ulid::new().to_string(); let (write, read) = ws_stream.split(); - read.forward(write).await.expect("Failed to forward message") + + CONNECTIONS + .write() + .unwrap() + .insert(id, write); + + read + .for_each(|message| async { + let data = message.unwrap().into_data(); + // if you mess with the data, you get the bazooki + let string = from_utf8(&data).unwrap(); + println!("{}", string); + }) + .await; } diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 35f354fe2d5ec5258d91bba6f8b109c29d53b1e4..27894102b1d62f5d3adbc2ba6aecb4247ea12ef5 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -10,7 +10,7 @@ mod channels; pub fn mount(rocket: Rocket) -> Rocket { rocket - .mount("/", routes![root::root, root::teapot]) + .mount("/", routes![root::root]) .mount("/onboard", onboard::routes()) .mount("/users", users::routes()) .mount("/channels", channels::routes()) diff --git a/src/routes/root.rs b/src/routes/root.rs index bcede3cb346436ad0201fa1f8e567a886b27de09..e3f9fc77807e07633a0f40ae04ac1dccae7eda27 100644 --- a/src/routes/root.rs +++ b/src/routes/root.rs @@ -3,7 +3,6 @@ use crate::util::variables::{DISABLE_REGISTRATION, HCAPTCHA_SITEKEY, USE_EMAIL, use rocket_contrib::json::JsonValue; use mongodb::bson::doc; -/// root #[get("/")] pub async fn root() -> JsonValue { json!({ @@ -18,12 +17,3 @@ pub async fn root() -> JsonValue { } }) } - -/// I'm a teapot. -#[delete("/")] -pub async fn teapot() -> JsonValue { - json!({ - "teapot": true, - "can_delete": false - }) -}