diff --git a/Cargo.lock b/Cargo.lock index e10a154cd0fdff4194e09e64a348ccb96877e938..7ea9f63dcc1f7d08b85c945e90a163e3d3bf6acd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,9 +983,8 @@ checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35" [[package]] name = "hive_pubsub" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a9804dd748a82752283c672f906611ae273b59386d67dd349decf1b382fdb4" +version = "0.4.2" +source = "git+https://gitlab.insrt.uk/insert/hive#7ab66da23bc86b6fa6497aac928d29d4b885a878" dependencies = [ "futures", "many-to-many", @@ -1502,9 +1501,9 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fcc7939b5edc4e4f86b1b4a04bb1498afaaf871b1a6691838ed06fcb48d3a3f" +checksum = "b8d96b2e1c8da3957d58100b09f102c6d9cfdfced01b7ec5a8974044bb09dbd4" dependencies = [ "lazy_static", "libc", @@ -2131,7 +2130,7 @@ dependencies = [ [[package]] name = "rauth" version = "0.1.0" -source = "git+https://gitlab.insrt.uk/insert/rauth#7e6366cc0c49445355e1f176d23ad244ade6a34b" +source = "git+https://gitlab.insrt.uk/insert/rauth#e11fe0fe429f7df68194891f3f2d5fc3df8e370c" dependencies = [ "json", "lazy_static", @@ -2165,18 +2164,18 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] name = "ref-cast" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e84b8a3c77dd38893c11b59284a40f304a1346d4da020e603fab3671727df95d" +checksum = "300f2a835d808734ee295d45007adacb9ebb29dd3ae2424acfa17930cae541da" dependencies = [ "ref-cast-impl", ] [[package]] name = "ref-cast-impl" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d5173fc07aa6595363a38ca7d69d438cc32cca4216ccd1a3a8f2d4b10bbcd0" +checksum = "4c38e3aecd2b21cb3959637b883bb3714bc7e43f0268b9a29d3743ee3e55cdd2" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.8", @@ -2573,9 +2572,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1500e84d27fe482ed1dc791a56eddc2f230046a040fa908c08bda1d9fb615779" +checksum = "4fceb2595057b6891a4ee808f70054bd2d12f0e97f1cbb78689b59f676df325a" dependencies = [ "indexmap", "itoa", @@ -2687,7 +2686,6 @@ version = "0.6.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eab12d3c261b2308b0d80c26fffb58d17eba81a4be97890101f416b478c79ca7" dependencies = [ - "backtrace", "doc-comment", "snafu-derive", ] @@ -2722,9 +2720,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "standback" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf906c8b8fc3f6ecd1046e01da1d8ddec83e48c8b08b84dcc02b585a6bedf5a8" +checksum = "c66a8cff4fa24853fdf6b51f75c6d7f8206d7c75cab4e467bcd7f25c2b1febe0" dependencies = [ "version_check", ] diff --git a/Cargo.toml b/Cargo.toml index 102ae23c0893baf883273848fda6a103c6609611..06e7c9373aa169ecc4d6abf7a8308b7fcf98002f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ async-tungstenite = { version = "0.10.0", features = ["async-std-runtime"] } rauth = { git = "https://gitlab.insrt.uk/insert/rauth" } async-std = { version = "1.8.0", features = ["tokio02"] } -hive_pubsub = { version = "0.4.1", features = ["mongo"] } +hive_pubsub = { git = "https://gitlab.insrt.uk/insert/hive", features = ["mongo"] } rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" } rocket_contrib = { git = "https://github.com/SergioBenitez/Rocket", branch = "master" } rocket = { git = "https://github.com/SergioBenitez/Rocket", branch = "master", default-features = false } diff --git a/src/notifications/events.rs b/src/notifications/events.rs index b04673745b1b51f2ecd4714b984658f508bd3089..3011da7592e7e8c139fc615baa40f1d5c77df9e2 100644 --- a/src/notifications/events.rs +++ b/src/notifications/events.rs @@ -1,15 +1,25 @@ use rauth::auth::Session; use serde::{Deserialize, Serialize}; use snafu::Snafu; +use hive_pubsub::PubSub; + +use crate::database::entities::RelationshipStatus; + +use super::hive::get_hive; #[derive(Serialize, Deserialize, Debug, Snafu)] #[serde(tag = "type")] pub enum WebSocketError { #[snafu(display("This error has not been labelled."))] LabelMe, - #[snafu(display("Internal server error."))] InternalError, + #[snafu(display("Invalid session."))] + InvalidSession, + #[snafu(display("User hasn't completed onboarding."))] + OnboardingNotFinished, + #[snafu(display("Already authenticated with server."))] + AlreadyAuthenticated, } #[derive(Deserialize, Debug)] @@ -22,6 +32,7 @@ pub enum ServerboundNotification { #[serde(tag = "type")] pub enum ClientboundNotification { Error(WebSocketError), + Authenticated, /*MessageCreate { id: String, @@ -78,9 +89,16 @@ pub enum ClientboundNotification { GuildDelete { id: String, },*/ + UserRelationship { id: String, user: String, - status: i32, + status: RelationshipStatus, }, } + +impl ClientboundNotification { + pub async fn publish(self, topic: String) -> Result<(), String> { + hive_pubsub::backend::mongo::publish(get_hive(), &topic, self).await + } +} diff --git a/src/notifications/hive.rs b/src/notifications/hive.rs index 98c8e4b74b211619f521ef2fa3de5652560689b9..96d56ba2f0f3cd7c99bd14562d36be2e81c6f62b 100644 --- a/src/notifications/hive.rs +++ b/src/notifications/hive.rs @@ -1,4 +1,4 @@ -use super::events::ClientboundNotification; +use super::{events::ClientboundNotification, websocket}; use crate::database::get_collection; use futures::FutureExt; @@ -8,14 +8,15 @@ use log::{debug, error}; use once_cell::sync::OnceCell; use serde_json::to_string; -static HIVE: OnceCell<MongodbPubSub<String, String, ClientboundNotification>> = OnceCell::new(); +type Hive = MongodbPubSub<String, String, ClientboundNotification>; +static HIVE: OnceCell<Hive> = OnceCell::new(); pub async fn init_hive() { let hive = MongodbPubSub::new( - |_ids, notification| { + |ids, notification| { if let Ok(data) = to_string(¬ification) { debug!("Pushing out notification. {}", data); - // ! FIXME: push to websocket + websocket::publish(ids, notification); } else { error!("Failed to serialise notification."); } @@ -39,12 +40,7 @@ pub async fn listen() { dbg!("a"); } -pub fn publish(topic: &String, data: ClientboundNotification) -> Result<(), String> { - let hive = HIVE.get().unwrap(); - hive.publish(topic, data) -} - -pub fn subscribe(user: String, topics: Vec<String>) -> Result<(), String> { +pub fn subscribe_multiple(user: String, topics: Vec<String>) -> Result<(), String> { let hive = HIVE.get().unwrap(); for topic in topics { hive.subscribe(user.clone(), topic)?; @@ -53,16 +49,15 @@ pub fn subscribe(user: String, topics: Vec<String>) -> Result<(), String> { Ok(()) } -pub fn drop_user(user: &String) -> Result<(), String> { +pub fn subscribe_if_exists(user: String, topic: String) -> Result<(), String> { let hive = HIVE.get().unwrap(); - hive.drop_client(user)?; + if hive.hive.map.lock().unwrap().get_left(&user).is_some() { + hive.subscribe(user, topic)?; + } Ok(()) } -pub fn drop_topic(topic: &String) -> Result<(), String> { - let hive = HIVE.get().unwrap(); - hive.drop_topic(topic)?; - - Ok(()) +pub fn get_hive() -> &'static Hive { + HIVE.get().unwrap() } diff --git a/src/notifications/mod.rs b/src/notifications/mod.rs index 23abbfd52dd5c810e66c3b4c87e2c6ec3b29d821..c1dd5c6718b79984bdf2792816483a74330dca32 100644 --- a/src/notifications/mod.rs +++ b/src/notifications/mod.rs @@ -1,3 +1,4 @@ pub mod events; pub mod hive; pub mod websocket; +pub mod subscriptions; diff --git a/src/notifications/subscriptions.rs b/src/notifications/subscriptions.rs new file mode 100644 index 0000000000000000000000000000000000000000..2148208cf41af12e872666ed1c1e0e12fa31312f --- /dev/null +++ b/src/notifications/subscriptions.rs @@ -0,0 +1,17 @@ +use crate::database::entities::User; + +use super::hive::get_hive; +use hive_pubsub::PubSub; + +pub async fn generate_subscriptions(user: &User) -> Result<(), String> { + let hive = get_hive(); + hive.subscribe(user.id.clone(), user.id.clone())?; + + if let Some(relations) = &user.relations { + for relation in relations { + hive.subscribe(user.id.clone(), relation.id.clone())?; + } + } + + Ok(()) +} diff --git a/src/notifications/websocket.rs b/src/notifications/websocket.rs index 9d884fbab395a9004a84119be86cdcfc03780226..3fa3fa5f5e1e172358b3dc1537f87aba861f97f4 100644 --- a/src/notifications/websocket.rs +++ b/src/notifications/websocket.rs @@ -1,27 +1,31 @@ +use crate::database::get_collection; use crate::{database::entities::User, util::variables::WS_HOST}; +use crate::database::guards::reference::Ref; + +use super::subscriptions; use async_std::net::{TcpListener, TcpStream}; use async_std::task; use async_tungstenite::tungstenite::Message; use futures::channel::mpsc::{unbounded, UnboundedSender}; +use futures::stream::TryStreamExt; use futures::{pin_mut, prelude::*}; -use log::info; +use log::{debug, info}; use many_to_many::ManyToMany; -use rauth::auth::Session; +use rauth::auth::{Auth, Session}; use std::collections::HashMap; use std::net::SocketAddr; -use std::str::from_utf8; use std::sync::{Arc, Mutex, RwLock}; -use ulid::Ulid; +use hive_pubsub::PubSub; -use super::events::ServerboundNotification; +use super::{events::{ClientboundNotification, ServerboundNotification, WebSocketError}, hive::get_hive}; type Tx = UnboundedSender<Message>; type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>; lazy_static! { static ref CONNECTIONS: PeerMap = Arc::new(Mutex::new(HashMap::new())); - static ref USERS: Arc<RwLock<ManyToMany<String, String>>> = + static ref USERS: Arc<RwLock<ManyToMany<String, SocketAddr>>> = Arc::new(RwLock::new(ManyToMany::new())); } @@ -45,32 +49,111 @@ async fn accept(stream: TcpStream) { info!("User established WebSocket connection from {}.", &addr); - let id = Ulid::new().to_string(); let (write, read) = ws_stream.split(); - let (tx, rx) = unbounded(); - CONNECTIONS.lock().unwrap().insert(addr, tx); + CONNECTIONS.lock().unwrap().insert(addr, tx.clone()); - let session: Option<Session> = None; - let user: Option<User> = None; + let send = |notification: ClientboundNotification| { + if let Ok(response) = serde_json::to_string( + ¬ification, + ) { + if let Err(_) = tx.unbounded_send(Message::Text(response)) { + debug!("Failed unbounded_send to websocket stream."); + } + } + }; + let mut session: Option<Session> = None; let fwd = rx.map(Ok).forward(write); - let reading = read.for_each(|message| async { - let data = message.unwrap().into_data(); - // if you mess with the data, you get the bazooki - let string = from_utf8(&data).unwrap(); - - if let Ok(notification) = serde_json::from_str::<ServerboundNotification>(string) { - match notification { - ServerboundNotification::Authenticate(a) => { - dbg!(a); + let incoming = read.try_for_each(|msg| { + if let Message::Text(text) = msg { + if let Ok(notification) = serde_json::from_str::<ServerboundNotification>(&text) { + match notification { + ServerboundNotification::Authenticate(new_session) => { + if session.is_some() { + send(ClientboundNotification::Error(WebSocketError::AlreadyAuthenticated)); + return future::ok(()) + } + + match task::block_on( + Auth::new(get_collection("accounts")).verify_session(new_session), + ) { + Ok(validated_session) => { + match task::block_on( + Ref { id: validated_session.user_id.clone() } + .fetch_user() + ) { + Ok(user) => { + if let Ok(mut map) = USERS.write() { + map.insert(validated_session.user_id.clone(), addr); + session = Some(validated_session); + if let Ok(_) = task::block_on(subscriptions::generate_subscriptions(&user)) { + send(ClientboundNotification::Authenticated); + } else { + send(ClientboundNotification::Error(WebSocketError::InternalError)); + } + } else { + send(ClientboundNotification::Error(WebSocketError::InternalError)); + } + }, + Err(_) => { + send(ClientboundNotification::Error(WebSocketError::OnboardingNotFinished)); + } + } + } + Err(_) => { + send(ClientboundNotification::Error(WebSocketError::InvalidSession)); + } + } + } } } } + + future::ok(()) }); - pin_mut!(fwd, reading); - future::select(fwd, reading).await; + pin_mut!(fwd, incoming); + future::select(fwd, incoming).await; + + info!("User {} disconnected.", &addr); + CONNECTIONS.lock().unwrap().remove(&addr); + + if let Some(session) = session { + let mut users = USERS.write().unwrap(); + users.remove(&session.user_id, &addr); + if users.get_left(&session.user_id).is_none() { + get_hive().drop_client(&session.user_id).unwrap(); + } + } +} + +pub fn publish(ids: Vec<String>, notification: ClientboundNotification) { + let mut targets = vec![]; + { + let users = USERS.read().unwrap(); + for id in ids { + // Block certain notifications from reaching users that aren't meant to see them. + if let ClientboundNotification::UserRelationship { id: user_id, .. } = ¬ification { + if &id != user_id { + continue; + } + } - println!("User {} disconnected.", &addr); + if let Some(mut arr) = users.get_left(&id) { + targets.append(&mut arr); + } + } + } + + let msg = Message::Text(serde_json::to_string(¬ification).unwrap()); + + let connections = CONNECTIONS.lock().unwrap(); + for target in targets { + if let Some(conn) = connections.get(&target) { + if let Err(_) = conn.unbounded_send(msg.clone()) { + debug!("Failed unbounded_send."); + } + } + } } diff --git a/src/routes/users/add_friend.rs b/src/routes/users/add_friend.rs index a0762f23f4d7f9c520e78721b217feb6ef8da251..514c3ba004b9a4a81bf6c5a0692ed6c8125ff975 100644 --- a/src/routes/users/add_friend.rs +++ b/src/routes/users/add_friend.rs @@ -1,4 +1,4 @@ -use crate::util::result::Result; +use crate::{notifications::{events::ClientboundNotification, hive}, util::result::Result}; use crate::{ database::{ entities::{RelationshipStatus, User}, @@ -49,7 +49,22 @@ pub async fn req(user: User, target: Ref) -> Result<JsonValue> { None ) ) { - Ok(_) => Ok(json!({ "status": "Friend" })), + Ok(_) => { + try_join!( + ClientboundNotification::UserRelationship { + id: user.id.clone(), + user: target.id.clone(), + status: RelationshipStatus::Friend + }.publish(user.id.clone()), + ClientboundNotification::UserRelationship { + id: target.id.clone(), + user: user.id.clone(), + status: RelationshipStatus::Friend + }.publish(target.id.clone()) + ).ok(); + + Ok(json!({ "status": "Friend" })) + }, Err(_) => Err(Error::DatabaseError { operation: "update_one", with: "user", @@ -87,7 +102,25 @@ pub async fn req(user: User, target: Ref) -> Result<JsonValue> { None ) ) { - Ok(_) => Ok(json!({ "status": "Outgoing" })), + Ok(_) => { + try_join!( + ClientboundNotification::UserRelationship { + id: user.id.clone(), + user: target.id.clone(), + status: RelationshipStatus::Outgoing + }.publish(user.id.clone()), + ClientboundNotification::UserRelationship { + id: target.id.clone(), + user: user.id.clone(), + status: RelationshipStatus::Incoming + }.publish(target.id.clone()) + ).ok(); + + hive::subscribe_if_exists(user.id.clone(), target.id.clone()).ok(); + hive::subscribe_if_exists(target.id.clone(), user.id.clone()).ok(); + + Ok(json!({ "status": "Outgoing" })) + }, Err(_) => Err(Error::DatabaseError { operation: "update_one", with: "user",