Skip to content
Snippets Groups Projects
Commit 1aaa6f3c authored by insert's avatar insert
Browse files

Get notifications working properly.

parent f39bc07b
Branches
Tags
No related merge requests found
......@@ -983,9 +983,8 @@ checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35"
[[package]]
name = "hive_pubsub"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10a9804dd748a82752283c672f906611ae273b59386d67dd349decf1b382fdb4"
version = "0.4.2"
source = "git+https://gitlab.insrt.uk/insert/hive#7ab66da23bc86b6fa6497aac928d29d4b885a878"
dependencies = [
"futures",
"many-to-many",
......@@ -1502,9 +1501,9 @@ dependencies = [
[[package]]
name = "native-tls"
version = "0.2.6"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fcc7939b5edc4e4f86b1b4a04bb1498afaaf871b1a6691838ed06fcb48d3a3f"
checksum = "b8d96b2e1c8da3957d58100b09f102c6d9cfdfced01b7ec5a8974044bb09dbd4"
dependencies = [
"lazy_static",
"libc",
......@@ -2131,7 +2130,7 @@ dependencies = [
[[package]]
name = "rauth"
version = "0.1.0"
source = "git+https://gitlab.insrt.uk/insert/rauth#7e6366cc0c49445355e1f176d23ad244ade6a34b"
source = "git+https://gitlab.insrt.uk/insert/rauth#e11fe0fe429f7df68194891f3f2d5fc3df8e370c"
dependencies = [
"json",
"lazy_static",
......@@ -2165,18 +2164,18 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce"
[[package]]
name = "ref-cast"
version = "1.0.5"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e84b8a3c77dd38893c11b59284a40f304a1346d4da020e603fab3671727df95d"
checksum = "300f2a835d808734ee295d45007adacb9ebb29dd3ae2424acfa17930cae541da"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.5"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d5173fc07aa6595363a38ca7d69d438cc32cca4216ccd1a3a8f2d4b10bbcd0"
checksum = "4c38e3aecd2b21cb3959637b883bb3714bc7e43f0268b9a29d3743ee3e55cdd2"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.8",
......@@ -2573,9 +2572,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.60"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1500e84d27fe482ed1dc791a56eddc2f230046a040fa908c08bda1d9fb615779"
checksum = "4fceb2595057b6891a4ee808f70054bd2d12f0e97f1cbb78689b59f676df325a"
dependencies = [
"indexmap",
"itoa",
......@@ -2687,7 +2686,6 @@ version = "0.6.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eab12d3c261b2308b0d80c26fffb58d17eba81a4be97890101f416b478c79ca7"
dependencies = [
"backtrace",
"doc-comment",
"snafu-derive",
]
......@@ -2722,9 +2720,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "standback"
version = "0.2.13"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf906c8b8fc3f6ecd1046e01da1d8ddec83e48c8b08b84dcc02b585a6bedf5a8"
checksum = "c66a8cff4fa24853fdf6b51f75c6d7f8206d7c75cab4e467bcd7f25c2b1febe0"
dependencies = [
"version_check",
]
......
......@@ -15,7 +15,7 @@ async-tungstenite = { version = "0.10.0", features = ["async-std-runtime"] }
rauth = { git = "https://gitlab.insrt.uk/insert/rauth" }
async-std = { version = "1.8.0", features = ["tokio02"] }
hive_pubsub = { version = "0.4.1", features = ["mongo"] }
hive_pubsub = { git = "https://gitlab.insrt.uk/insert/hive", features = ["mongo"] }
rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" }
rocket_contrib = { git = "https://github.com/SergioBenitez/Rocket", branch = "master" }
rocket = { git = "https://github.com/SergioBenitez/Rocket", branch = "master", default-features = false }
......
use rauth::auth::Session;
use serde::{Deserialize, Serialize};
use snafu::Snafu;
use hive_pubsub::PubSub;
use crate::database::entities::RelationshipStatus;
use super::hive::get_hive;
#[derive(Serialize, Deserialize, Debug, Snafu)]
#[serde(tag = "type")]
pub enum WebSocketError {
#[snafu(display("This error has not been labelled."))]
LabelMe,
#[snafu(display("Internal server error."))]
InternalError,
#[snafu(display("Invalid session."))]
InvalidSession,
#[snafu(display("User hasn't completed onboarding."))]
OnboardingNotFinished,
#[snafu(display("Already authenticated with server."))]
AlreadyAuthenticated,
}
#[derive(Deserialize, Debug)]
......@@ -22,6 +32,7 @@ pub enum ServerboundNotification {
#[serde(tag = "type")]
pub enum ClientboundNotification {
Error(WebSocketError),
Authenticated,
/*MessageCreate {
id: String,
......@@ -78,9 +89,16 @@ pub enum ClientboundNotification {
GuildDelete {
id: String,
},*/
UserRelationship {
id: String,
user: String,
status: i32,
status: RelationshipStatus,
},
}
impl ClientboundNotification {
pub async fn publish(self, topic: String) -> Result<(), String> {
hive_pubsub::backend::mongo::publish(get_hive(), &topic, self).await
}
}
use super::events::ClientboundNotification;
use super::{events::ClientboundNotification, websocket};
use crate::database::get_collection;
use futures::FutureExt;
......@@ -8,14 +8,15 @@ use log::{debug, error};
use once_cell::sync::OnceCell;
use serde_json::to_string;
static HIVE: OnceCell<MongodbPubSub<String, String, ClientboundNotification>> = OnceCell::new();
type Hive = MongodbPubSub<String, String, ClientboundNotification>;
static HIVE: OnceCell<Hive> = OnceCell::new();
pub async fn init_hive() {
let hive = MongodbPubSub::new(
|_ids, notification| {
|ids, notification| {
if let Ok(data) = to_string(&notification) {
debug!("Pushing out notification. {}", data);
// ! FIXME: push to websocket
websocket::publish(ids, notification);
} else {
error!("Failed to serialise notification.");
}
......@@ -39,12 +40,7 @@ pub async fn listen() {
dbg!("a");
}
pub fn publish(topic: &String, data: ClientboundNotification) -> Result<(), String> {
let hive = HIVE.get().unwrap();
hive.publish(topic, data)
}
pub fn subscribe(user: String, topics: Vec<String>) -> Result<(), String> {
pub fn subscribe_multiple(user: String, topics: Vec<String>) -> Result<(), String> {
let hive = HIVE.get().unwrap();
for topic in topics {
hive.subscribe(user.clone(), topic)?;
......@@ -53,16 +49,15 @@ pub fn subscribe(user: String, topics: Vec<String>) -> Result<(), String> {
Ok(())
}
pub fn drop_user(user: &String) -> Result<(), String> {
pub fn subscribe_if_exists(user: String, topic: String) -> Result<(), String> {
let hive = HIVE.get().unwrap();
hive.drop_client(user)?;
if hive.hive.map.lock().unwrap().get_left(&user).is_some() {
hive.subscribe(user, topic)?;
}
Ok(())
}
pub fn drop_topic(topic: &String) -> Result<(), String> {
let hive = HIVE.get().unwrap();
hive.drop_topic(topic)?;
Ok(())
pub fn get_hive() -> &'static Hive {
HIVE.get().unwrap()
}
pub mod events;
pub mod hive;
pub mod websocket;
pub mod subscriptions;
use crate::database::entities::User;
use super::hive::get_hive;
use hive_pubsub::PubSub;
pub async fn generate_subscriptions(user: &User) -> Result<(), String> {
let hive = get_hive();
hive.subscribe(user.id.clone(), user.id.clone())?;
if let Some(relations) = &user.relations {
for relation in relations {
hive.subscribe(user.id.clone(), relation.id.clone())?;
}
}
Ok(())
}
use crate::database::get_collection;
use crate::{database::entities::User, util::variables::WS_HOST};
use crate::database::guards::reference::Ref;
use super::subscriptions;
use async_std::net::{TcpListener, TcpStream};
use async_std::task;
use async_tungstenite::tungstenite::Message;
use futures::channel::mpsc::{unbounded, UnboundedSender};
use futures::stream::TryStreamExt;
use futures::{pin_mut, prelude::*};
use log::info;
use log::{debug, info};
use many_to_many::ManyToMany;
use rauth::auth::Session;
use rauth::auth::{Auth, Session};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::str::from_utf8;
use std::sync::{Arc, Mutex, RwLock};
use ulid::Ulid;
use hive_pubsub::PubSub;
use super::events::ServerboundNotification;
use super::{events::{ClientboundNotification, ServerboundNotification, WebSocketError}, hive::get_hive};
type Tx = UnboundedSender<Message>;
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
lazy_static! {
static ref CONNECTIONS: PeerMap = Arc::new(Mutex::new(HashMap::new()));
static ref USERS: Arc<RwLock<ManyToMany<String, String>>> =
static ref USERS: Arc<RwLock<ManyToMany<String, SocketAddr>>> =
Arc::new(RwLock::new(ManyToMany::new()));
}
......@@ -45,32 +49,111 @@ async fn accept(stream: TcpStream) {
info!("User established WebSocket connection from {}.", &addr);
let id = Ulid::new().to_string();
let (write, read) = ws_stream.split();
let (tx, rx) = unbounded();
CONNECTIONS.lock().unwrap().insert(addr, tx);
CONNECTIONS.lock().unwrap().insert(addr, tx.clone());
let session: Option<Session> = None;
let user: Option<User> = None;
let send = |notification: ClientboundNotification| {
if let Ok(response) = serde_json::to_string(
&notification,
) {
if let Err(_) = tx.unbounded_send(Message::Text(response)) {
debug!("Failed unbounded_send to websocket stream.");
}
}
};
let mut session: Option<Session> = None;
let fwd = rx.map(Ok).forward(write);
let reading = read.for_each(|message| async {
let data = message.unwrap().into_data();
// if you mess with the data, you get the bazooki
let string = from_utf8(&data).unwrap();
if let Ok(notification) = serde_json::from_str::<ServerboundNotification>(string) {
match notification {
ServerboundNotification::Authenticate(a) => {
dbg!(a);
let incoming = read.try_for_each(|msg| {
if let Message::Text(text) = msg {
if let Ok(notification) = serde_json::from_str::<ServerboundNotification>(&text) {
match notification {
ServerboundNotification::Authenticate(new_session) => {
if session.is_some() {
send(ClientboundNotification::Error(WebSocketError::AlreadyAuthenticated));
return future::ok(())
}
match task::block_on(
Auth::new(get_collection("accounts")).verify_session(new_session),
) {
Ok(validated_session) => {
match task::block_on(
Ref { id: validated_session.user_id.clone() }
.fetch_user()
) {
Ok(user) => {
if let Ok(mut map) = USERS.write() {
map.insert(validated_session.user_id.clone(), addr);
session = Some(validated_session);
if let Ok(_) = task::block_on(subscriptions::generate_subscriptions(&user)) {
send(ClientboundNotification::Authenticated);
} else {
send(ClientboundNotification::Error(WebSocketError::InternalError));
}
} else {
send(ClientboundNotification::Error(WebSocketError::InternalError));
}
},
Err(_) => {
send(ClientboundNotification::Error(WebSocketError::OnboardingNotFinished));
}
}
}
Err(_) => {
send(ClientboundNotification::Error(WebSocketError::InvalidSession));
}
}
}
}
}
}
future::ok(())
});
pin_mut!(fwd, reading);
future::select(fwd, reading).await;
pin_mut!(fwd, incoming);
future::select(fwd, incoming).await;
info!("User {} disconnected.", &addr);
CONNECTIONS.lock().unwrap().remove(&addr);
if let Some(session) = session {
let mut users = USERS.write().unwrap();
users.remove(&session.user_id, &addr);
if users.get_left(&session.user_id).is_none() {
get_hive().drop_client(&session.user_id).unwrap();
}
}
}
pub fn publish(ids: Vec<String>, notification: ClientboundNotification) {
let mut targets = vec![];
{
let users = USERS.read().unwrap();
for id in ids {
// Block certain notifications from reaching users that aren't meant to see them.
if let ClientboundNotification::UserRelationship { id: user_id, .. } = &notification {
if &id != user_id {
continue;
}
}
println!("User {} disconnected.", &addr);
if let Some(mut arr) = users.get_left(&id) {
targets.append(&mut arr);
}
}
}
let msg = Message::Text(serde_json::to_string(&notification).unwrap());
let connections = CONNECTIONS.lock().unwrap();
for target in targets {
if let Some(conn) = connections.get(&target) {
if let Err(_) = conn.unbounded_send(msg.clone()) {
debug!("Failed unbounded_send.");
}
}
}
}
use crate::util::result::Result;
use crate::{notifications::{events::ClientboundNotification, hive}, util::result::Result};
use crate::{
database::{
entities::{RelationshipStatus, User},
......@@ -49,7 +49,22 @@ pub async fn req(user: User, target: Ref) -> Result<JsonValue> {
None
)
) {
Ok(_) => Ok(json!({ "status": "Friend" })),
Ok(_) => {
try_join!(
ClientboundNotification::UserRelationship {
id: user.id.clone(),
user: target.id.clone(),
status: RelationshipStatus::Friend
}.publish(user.id.clone()),
ClientboundNotification::UserRelationship {
id: target.id.clone(),
user: user.id.clone(),
status: RelationshipStatus::Friend
}.publish(target.id.clone())
).ok();
Ok(json!({ "status": "Friend" }))
},
Err(_) => Err(Error::DatabaseError {
operation: "update_one",
with: "user",
......@@ -87,7 +102,25 @@ pub async fn req(user: User, target: Ref) -> Result<JsonValue> {
None
)
) {
Ok(_) => Ok(json!({ "status": "Outgoing" })),
Ok(_) => {
try_join!(
ClientboundNotification::UserRelationship {
id: user.id.clone(),
user: target.id.clone(),
status: RelationshipStatus::Outgoing
}.publish(user.id.clone()),
ClientboundNotification::UserRelationship {
id: target.id.clone(),
user: user.id.clone(),
status: RelationshipStatus::Incoming
}.publish(target.id.clone())
).ok();
hive::subscribe_if_exists(user.id.clone(), target.id.clone()).ok();
hive::subscribe_if_exists(target.id.clone(), user.id.clone()).ok();
Ok(json!({ "status": "Outgoing" }))
},
Err(_) => Err(Error::DatabaseError {
operation: "update_one",
with: "user",
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment