Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 1963 additions and 269 deletions
use crate::util::{
result::{Error, Result},
variables::JANUARY_URL,
};
use linkify::{LinkFinder, LinkKind};
use regex::Regex;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ImageSize {
Large,
Preview,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Image {
pub url: String,
pub width: isize,
pub height: isize,
pub size: ImageSize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Video {
pub url: String,
pub width: isize,
pub height: isize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum TwitchType {
Channel,
Video,
Clip,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum BandcampType {
Album,
Track,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum Special {
None,
YouTube {
id: String,
},
Twitch {
content_type: TwitchType,
id: String,
},
Spotify {
content_type: String,
id: String,
},
Soundcloud,
Bandcamp {
content_type: BandcampType,
id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Metadata {
#[serde(skip_serializing_if = "Option::is_none")]
url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
special: Option<Special>,
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
image: Option<Image>,
#[serde(skip_serializing_if = "Option::is_none")]
video: Option<Video>,
// #[serde(skip_serializing_if = "Option::is_none")]
// opengraph_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
site_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
icon_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
color: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum Embed {
Website(Metadata),
Image(Image),
None,
}
impl Embed {
pub async fn generate(content: String) -> Result<Vec<Embed>> {
lazy_static! {
static ref RE_CODE: Regex = Regex::new("```(?:.|\n)+?```|`(?:.|\n)+?`").unwrap();
}
// Ignore code blocks.
let content = RE_CODE.replace_all(&content, "");
let content = content
// Ignore quoted lines.
.split("\n")
.map(|v| {
if let Some(c) = v.chars().next() {
if c == '>' {
return "";
}
}
v
})
.collect::<Vec<&str>>()
.join("\n");
// ! FIXME: allow multiple links
// ! FIXME: prevent generation if link is surrounded with < >
let mut finder = LinkFinder::new();
finder.kinds(&[LinkKind::Url]);
let links: Vec<_> = finder.links(&content).collect();
if links.len() == 0 {
return Err(Error::LabelMe);
}
let link = &links[0];
let client = reqwest::Client::new();
let result = client
.get(&format!("{}/embed", *JANUARY_URL))
.query(&[("url", link.as_str())])
.send()
.await;
match result {
Err(_) => return Err(Error::LabelMe),
Ok(result) => match result.status() {
reqwest::StatusCode::OK => {
let res: Embed = result.json().await.map_err(|_| Error::InvalidOperation)?;
Ok(vec![res])
}
_ => return Err(Error::LabelMe),
},
}
}
}
pub mod autumn;
pub mod january;
mod channel; mod channel;
mod guild; mod invites;
mod message; mod message;
mod microservice;
mod server;
mod sync;
mod user; mod user;
use microservice::*;
pub use autumn::*;
pub use channel::*; pub use channel::*;
pub use guild::*; pub use invites::*;
pub use january::*;
pub use message::*; pub use message::*;
pub use server::*;
pub use sync::*;
pub use user::*; pub use user::*;
use std::collections::HashMap;
use crate::database::*;
use crate::notifications::events::ClientboundNotification;
use crate::util::result::{Error, Result};
use futures::StreamExt;
use mongodb::bson::{Bson, doc};
use mongodb::bson::from_document;
use mongodb::bson::to_document;
use mongodb::bson::Document;
use rocket_contrib::json::JsonValue;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MemberCompositeKey {
pub server: String,
pub user: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Member {
#[serde(rename = "_id")]
pub id: MemberCompositeKey,
#[serde(skip_serializing_if = "Option::is_none")]
pub nickname: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avatar: Option<File>,
#[serde(skip_serializing_if = "Option::is_none")]
pub roles: Option<Vec<String>>
}
pub type PermissionTuple = (
i32, // server permission
i32 // channel permission
);
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Role {
pub name: String,
pub permissions: PermissionTuple,
#[serde(skip_serializing_if = "Option::is_none")]
pub colour: Option<String>
// Bri'ish API conventions
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Category {
pub id: String,
pub title: String,
pub channels: Vec<String>
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Ban {
#[serde(rename = "_id")]
pub id: MemberCompositeKey,
pub reason: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SystemMessageChannels {
pub user_joined: Option<String>,
pub user_left: Option<String>,
pub user_kicked: Option<String>,
pub user_banned: Option<String>,
}
pub enum RemoveMember {
Leave,
Kick,
Ban,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Server {
#[serde(rename = "_id")]
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
pub owner: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub channels: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub categories: Option<Vec<Category>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_messages: Option<SystemMessageChannels>,
#[serde(default = "HashMap::new", skip_serializing_if = "HashMap::is_empty")]
pub roles: HashMap<String, Role>,
pub default_permissions: PermissionTuple,
#[serde(skip_serializing_if = "Option::is_none")]
pub icon: Option<File>,
#[serde(skip_serializing_if = "Option::is_none")]
pub banner: Option<File>,
}
impl Server {
pub async fn create(self) -> Result<()> {
get_collection("servers")
.insert_one(
to_document(&self).map_err(|_| Error::DatabaseError {
operation: "to_bson",
with: "channel",
})?,
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "insert_one",
with: "server",
})?;
Ok(())
}
pub async fn publish_update(&self, data: JsonValue) -> Result<()> {
ClientboundNotification::ServerUpdate {
id: self.id.clone(),
data,
clear: None,
}
.publish(self.id.clone());
Ok(())
}
pub async fn delete(&self) -> Result<()> {
// Check if there are any attachments we need to delete.
Channel::delete_messages(Bson::Document(doc! { "$in": &self.channels })).await?;
// Delete all channels.
get_collection("channels")
.delete_many(
doc! {
"server": &self.id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "delete_many",
with: "channels",
})?;
// Delete any associated objects, e.g. unreads and invites.
Channel::delete_associated_objects(Bson::Document(doc! { "$in": &self.channels })).await?;
// Delete members and bans.
for with in &["server_members", "server_bans"] {
get_collection(with)
.delete_many(
doc! {
"_id.server": &self.id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "delete_many",
with,
})?;
}
// Delete server icon / banner.
if let Some(attachment) = &self.icon {
attachment.delete().await?;
}
if let Some(attachment) = &self.banner {
attachment.delete().await?;
}
// Delete the server
get_collection("servers")
.delete_one(
doc! {
"_id": &self.id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "delete_one",
with: "server",
})?;
ClientboundNotification::ServerDelete {
id: self.id.clone(),
}
.publish(self.id.clone());
Ok(())
}
pub async fn fetch_members(id: &str) -> Result<Vec<Member>> {
Ok(get_collection("server_members")
.find(
doc! {
"_id.server": id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "server_members",
})?
.filter_map(async move |s| s.ok())
.collect::<Vec<Document>>()
.await
.into_iter()
.filter_map(|x| from_document(x).ok())
.collect::<Vec<Member>>())
}
pub async fn fetch_member_ids(id: &str) -> Result<Vec<String>> {
Ok(get_collection("server_members")
.find(
doc! {
"_id.server": id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "server_members",
})?
.filter_map(async move |s| s.ok())
.collect::<Vec<Document>>()
.await
.into_iter()
.filter_map(|x| {
x.get_document("_id")
.ok()
.map(|i| i.get_str("user").ok().map(|x| x.to_string()))
})
.flatten()
.collect::<Vec<String>>())
}
pub async fn join_member(&self, id: &str) -> Result<()> {
if get_collection("server_bans")
.find_one(
doc! {
"_id.server": &self.id,
"_id.user": &id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "server_bans",
})?
.is_some()
{
return Err(Error::Banned);
}
get_collection("server_members")
.insert_one(
doc! {
"_id": {
"server": &self.id,
"user": &id
}
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "insert_one",
with: "server_members",
})?;
ClientboundNotification::ServerMemberJoin {
id: self.id.clone(),
user: id.to_string(),
}
.publish(self.id.clone());
if let Some(channels) = &self.system_messages {
if let Some(cid) = &channels.user_joined {
let channel = Ref::from_unchecked(cid.clone()).fetch_channel().await?;
Content::SystemMessage(SystemMessage::UserJoined { id: id.to_string() })
.send_as_system(&channel)
.await?;
}
}
Ok(())
}
pub async fn remove_member(&self, id: &str, removal: RemoveMember) -> Result<()> {
let result = get_collection("server_members")
.delete_one(
doc! {
"_id": {
"server": &self.id,
"user": &id
}
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "delete_one",
with: "server_members",
})?;
if result.deleted_count > 0 {
ClientboundNotification::ServerMemberLeave {
id: self.id.clone(),
user: id.to_string(),
}
.publish(self.id.clone());
if let Some(channels) = &self.system_messages {
let message = match removal {
RemoveMember::Leave => {
if let Some(cid) = &channels.user_left {
Some((cid.clone(), SystemMessage::UserLeft { id: id.to_string() }))
} else {
None
}
}
RemoveMember::Kick => {
if let Some(cid) = &channels.user_kicked {
Some((
cid.clone(),
SystemMessage::UserKicked { id: id.to_string() },
))
} else {
None
}
}
RemoveMember::Ban => {
if let Some(cid) = &channels.user_banned {
Some((
cid.clone(),
SystemMessage::UserBanned { id: id.to_string() },
))
} else {
None
}
}
};
if let Some((cid, message)) = message {
let channel = Ref::from_unchecked(cid).fetch_channel().await?;
Content::SystemMessage(message)
.send_as_system(&channel)
.await?;
}
}
}
Ok(())
}
pub async fn get_member_count(id: &str) -> Result<i64> {
Ok(get_collection("server_members")
.count_documents(
doc! {
"_id.server": id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "count_documents",
with: "server_members",
})?)
}
}
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type UserSettings = HashMap<String, (i64, String)>;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChannelCompositeKey {
pub channel: String,
pub user: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChannelUnread {
#[serde(rename = "_id")]
pub id: ChannelCompositeKey,
pub last_id: Option<String>,
pub mentions: Option<Vec<String>>,
}
use futures::StreamExt;
use mongodb::bson::Document;
use mongodb::options::{Collation, FindOneOptions};
use mongodb::{
bson::{doc, from_document},
options::FindOptions,
};
use num_enum::TryFromPrimitive;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::ops;
use ulid::Ulid;
use validator::Validate;
use crate::database::permissions::user::UserPermissions;
use crate::database::*;
use crate::notifications::websocket::is_online;
use crate::util::result::{Error, Result};
use crate::util::variables::EARLY_ADOPTER_BADGE;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub enum RelationshipStatus { pub enum RelationshipStatus {
...@@ -11,24 +28,273 @@ pub enum RelationshipStatus { ...@@ -11,24 +28,273 @@ pub enum RelationshipStatus {
BlockedOther, BlockedOther,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Relationship { pub struct Relationship {
#[serde(rename = "_id")] #[serde(rename = "_id")]
pub id: String, pub id: String,
pub status: RelationshipStatus, pub status: RelationshipStatus,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum Presence {
Online,
Idle,
Busy,
Invisible,
}
#[derive(Validate, Serialize, Deserialize, Debug, Clone)]
pub struct UserStatus {
#[validate(length(min = 1, max = 128))]
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence: Option<Presence>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct UserProfile {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub background: Option<File>,
}
#[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)]
#[repr(i32)]
pub enum Badges {
Developer = 1,
Translator = 2,
Supporter = 4,
ResponsibleDisclosure = 8,
RevoltTeam = 16,
EarlyAdopter = 256,
}
impl_op_ex_commutative!(+ |a: &i32, b: &Badges| -> i32 { *a | *b as i32 });
// When changing this struct, update notifications/payload.rs#80
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct User { pub struct User {
#[serde(rename = "_id")] #[serde(rename = "_id")]
pub id: String, pub id: String,
pub username: String, pub username: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub avatar: Option<File>,
#[serde(skip_serializing_if = "Option::is_none")]
pub relations: Option<Vec<Relationship>>, pub relations: Option<Vec<Relationship>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub badges: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub status: Option<UserStatus>,
#[serde(skip_serializing_if = "Option::is_none")]
pub profile: Option<UserProfile>,
// ? This should never be pushed to the collection. // ? This should never be pushed to the collection.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub relationship: Option<RelationshipStatus>, pub relationship: Option<RelationshipStatus>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub online: Option<bool>, pub online: Option<bool>,
} }
impl User {
/// Mutate the user object to include relationship as seen by user.
pub fn from(mut self, user: &User) -> User {
self.relationship = Some(RelationshipStatus::None);
if self.id == user.id {
self.relationship = Some(RelationshipStatus::User);
return self;
}
self.relations = None;
if let Some(relations) = &user.relations {
if let Some(relationship) = relations.iter().find(|x| self.id == x.id) {
self.relationship = Some(relationship.status.clone());
return self;
}
}
self
}
/// Mutate the user object to appear as seen by user.
pub fn with(mut self, permissions: UserPermissions<[u32; 1]>) -> User {
let mut badges = self.badges.unwrap_or_else(|| 0);
if let Ok(id) = Ulid::from_string(&self.id) {
if id.datetime().timestamp_millis() < *EARLY_ADOPTER_BADGE {
badges = badges + Badges::EarlyAdopter;
}
}
self.badges = Some(badges);
if permissions.get_view_profile() {
self.online = Some(is_online(&self.id));
} else {
self.status = None;
}
self.profile = None;
self
}
/// Mutate the user object to appear as seen by user.
/// Also overrides the relationship status.
pub async fn from_override(
mut self,
user: &User,
relationship: RelationshipStatus,
) -> Result<User> {
let permissions = PermissionCalculator::new(&user)
.with_relationship(&relationship)
.for_user(&self.id)
.await?;
self.relations = None;
self.relationship = Some(relationship);
Ok(self.with(permissions))
}
/// Utility function for checking claimed usernames.
pub async fn is_username_taken(username: &str) -> Result<bool> {
if username.to_lowercase() == "revolt" || username.to_lowercase() == "admin" || username.to_lowercase() == "system" {
return Ok(true);
}
if get_collection("users")
.find_one(
doc! {
"username": username
},
FindOneOptions::builder()
.collation(Collation::builder().locale("en").strength(2).build())
.build(),
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "user",
})?
.is_some()
{
Ok(true)
} else {
Ok(false)
}
}
/// Utility function for fetching multiple users from the perspective of one.
/// Assumes user has a mutual connection with others.
pub async fn fetch_multiple_users(&self, user_ids: Vec<String>) -> Result<Vec<User>> {
let mut users = vec![];
let mut cursor = get_collection("users")
.find(
doc! {
"_id": {
"$in": user_ids
}
},
FindOptions::builder()
.projection(
doc! { "_id": 1, "username": 1, "avatar": 1, "badges": 1, "status": 1 },
)
.build(),
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "users",
})?;
while let Some(result) = cursor.next().await {
if let Ok(doc) = result {
let other: User = from_document(doc).map_err(|_| Error::DatabaseError {
operation: "from_document",
with: "user",
})?;
let permissions = PermissionCalculator::new(&self)
.with_mutual_connection()
.with_user(&other)
.for_user_given()
.await?;
users.push(other.from(&self).with(permissions));
}
}
Ok(users)
}
/// Utility function to get all of a user's memberships.
pub async fn fetch_memberships(id: &str) -> Result<Vec<Member>> {
Ok(get_collection("server_members")
.find(
doc! {
"_id.user": id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "server_members",
})?
.filter_map(async move |s| s.ok())
.collect::<Vec<Document>>()
.await
.into_iter()
.filter_map(|x| {
from_document(x).ok()
})
.collect::<Vec<Member>>())
}
/// Utility function to get all the server IDs the user is in.
pub async fn fetch_server_ids(id: &str) -> Result<Vec<String>> {
Ok(get_collection("server_members")
.find(
doc! {
"_id.user": id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "server_members",
})?
.filter_map(async move |s| s.ok())
.collect::<Vec<Document>>()
.await
.into_iter()
.filter_map(|x| {
x.get_document("_id")
.ok()
.map(|i| i.get_str("server").ok().map(|x| x.to_string()))
})
.flatten()
.collect::<Vec<String>>())
}
/// Utility function to fetch unread objects for user.
pub async fn fetch_unreads(id: &str) -> Result<Vec<Document>> {
Ok(get_collection("channel_unreads")
.find(
doc! {
"_id.user": id
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "user_settings",
})?
.filter_map(async move |s| s.ok())
.collect::<Vec<Document>>()
.await)
}
}
...@@ -9,16 +9,23 @@ use validator::Validate; ...@@ -9,16 +9,23 @@ use validator::Validate;
#[derive(Validate, Serialize, Deserialize)] #[derive(Validate, Serialize, Deserialize)]
pub struct Ref { pub struct Ref {
#[validate(length(min = 26, max = 26))] #[validate(length(min = 1, max = 26))]
pub id: String, pub id: String,
} }
impl Ref { impl Ref {
pub fn from_unchecked(id: String) -> Ref {
Ref { id }
}
pub fn from(id: String) -> Result<Ref> { pub fn from(id: String) -> Result<Ref> {
Ok(Ref { id }) let r = Ref { id };
r.validate()
.map_err(|error| Error::FailedValidation { error })?;
Ok(r)
} }
pub async fn fetch<T: DeserializeOwned>(&self, collection: &'static str) -> Result<T> { async fn fetch<T: DeserializeOwned>(&self, collection: &'static str) -> Result<T> {
let doc = get_collection(&collection) let doc = get_collection(&collection)
.find_one( .find_one(
doc! { doc! {
...@@ -31,7 +38,7 @@ impl Ref { ...@@ -31,7 +38,7 @@ impl Ref {
operation: "find_one", operation: "find_one",
with: &collection, with: &collection,
})? })?
.ok_or_else(|| Error::UnknownUser)?; .ok_or_else(|| Error::NotFound)?;
Ok(from_document::<T>(doc).map_err(|_| Error::DatabaseError { Ok(from_document::<T>(doc).map_err(|_| Error::DatabaseError {
operation: "from_document", operation: "from_document",
...@@ -47,6 +54,60 @@ impl Ref { ...@@ -47,6 +54,60 @@ impl Ref {
self.fetch("channels").await self.fetch("channels").await
} }
pub async fn fetch_server(&self) -> Result<Server> {
self.fetch("servers").await
}
pub async fn fetch_invite(&self) -> Result<Invite> {
self.fetch("channel_invites").await
}
pub async fn fetch_member(&self, server: &str) -> Result<Member> {
let doc = get_collection("server_members")
.find_one(
doc! {
"_id.user": &self.id,
"_id.server": server
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "server_member",
})?
.ok_or_else(|| Error::NotFound)?;
Ok(
from_document::<Member>(doc).map_err(|_| Error::DatabaseError {
operation: "from_document",
with: "server_member",
})?,
)
}
pub async fn fetch_ban(&self, server: &str) -> Result<Ban> {
let doc = get_collection("server_bans")
.find_one(
doc! {
"_id.user": &self.id,
"_id.server": server
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "server_ban",
})?
.ok_or_else(|| Error::NotFound)?;
Ok(from_document::<Ban>(doc).map_err(|_| Error::DatabaseError {
operation: "from_document",
with: "server_ban",
})?)
}
pub async fn fetch_message(&self, channel: &Channel) -> Result<Message> { pub async fn fetch_message(&self, channel: &Channel) -> Result<Message> {
let message: Message = self.fetch("messages").await?; let message: Message = self.fetch("messages").await?;
if &message.channel != channel.id() { if &message.channel != channel.id() {
......
...@@ -29,7 +29,10 @@ impl<'a, 'r> FromRequest<'a, 'r> for User { ...@@ -29,7 +29,10 @@ impl<'a, 'r> FromRequest<'a, 'r> for User {
} else { } else {
Outcome::Failure(( Outcome::Failure((
Status::InternalServerError, Status::InternalServerError,
rauth::util::Error::DatabaseError, rauth::util::Error::DatabaseError {
operation: "find_one",
with: "user",
},
)) ))
} }
} }
......
...@@ -9,6 +9,10 @@ pub async fn create_database() { ...@@ -9,6 +9,10 @@ pub async fn create_database() {
info!("Creating database."); info!("Creating database.");
let db = get_db(); let db = get_db();
db.create_collection("accounts", None)
.await
.expect("Failed to create accounts collection.");
db.create_collection("users", None) db.create_collection("users", None)
.await .await
.expect("Failed to create users collection."); .expect("Failed to create users collection.");
...@@ -17,22 +21,42 @@ pub async fn create_database() { ...@@ -17,22 +21,42 @@ pub async fn create_database() {
.await .await
.expect("Failed to create channels collection."); .expect("Failed to create channels collection.");
db.create_collection("guilds", None) db.create_collection("messages", None)
.await .await
.expect("Failed to create guilds collection."); .expect("Failed to create messages collection.");
db.create_collection("members", None) db.create_collection("servers", None)
.await .await
.expect("Failed to create members collection."); .expect("Failed to create servers collection.");
db.create_collection("messages", None) db.create_collection("server_members", None)
.await .await
.expect("Failed to create messages collection."); .expect("Failed to create server_members collection.");
db.create_collection("server_bans", None)
.await
.expect("Failed to create server_bans collection.");
db.create_collection("channel_invites", None)
.await
.expect("Failed to create channel_invites collection.");
db.create_collection("channel_unreads", None)
.await
.expect("Failed to create channel_unreads collection.");
db.create_collection("migrations", None) db.create_collection("migrations", None)
.await .await
.expect("Failed to create migrations collection."); .expect("Failed to create migrations collection.");
db.create_collection("attachments", None)
.await
.expect("Failed to create attachments collection.");
db.create_collection("user_settings", None)
.await
.expect("Failed to create user_settings collection.");
db.create_collection( db.create_collection(
"pubsub", "pubsub",
CreateCollectionOptions::builder() CreateCollectionOptions::builder()
...@@ -43,6 +67,39 @@ pub async fn create_database() { ...@@ -43,6 +67,39 @@ pub async fn create_database() {
.await .await
.expect("Failed to create pubsub collection."); .expect("Failed to create pubsub collection.");
db.run_command(
doc! {
"createIndexes": "accounts",
"indexes": [
{
"key": {
"email": 1
},
"name": "email",
"unique": true,
"collation": {
"locale": "en",
"strength": 2
}
},
{
"key": {
"email_normalised": 1
},
"name": "email_normalised",
"unique": true,
"collation": {
"locale": "en",
"strength": 2
}
}
]
},
None,
)
.await
.expect("Failed to create account index.");
db.run_command( db.run_command(
doc! { doc! {
"createIndexes": "users", "createIndexes": "users",
...@@ -65,6 +122,23 @@ pub async fn create_database() { ...@@ -65,6 +122,23 @@ pub async fn create_database() {
.await .await
.expect("Failed to create username index."); .expect("Failed to create username index.");
db.run_command(
doc! {
"createIndexes": "messages",
"indexes": [
{
"key": {
"content": "text"
},
"name": "content"
}
]
},
None,
)
.await
.expect("Failed to create message index.");
db.collection("migrations") db.collection("migrations")
.insert_one( .insert_one(
doc! { doc! {
......
use super::super::{get_collection, get_db}; use crate::database::{permissions, get_collection, get_db, PermissionTuple};
use crate::rocket::futures::StreamExt; use futures::StreamExt;
use log::info; use log::info;
use mongodb::bson::{doc, from_document}; use mongodb::{bson::{doc, from_document, to_document}, options::FindOptions};
use mongodb::options::FindOptions;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
...@@ -12,7 +11,7 @@ struct MigrationInfo { ...@@ -12,7 +11,7 @@ struct MigrationInfo {
revision: i32, revision: i32,
} }
pub const LATEST_REVISION: i32 = 3; pub const LATEST_REVISION: i32 = 7;
pub async fn migrate_database() { pub async fn migrate_database() {
let migrations = get_collection("migrations"); let migrations = get_collection("migrations");
...@@ -56,96 +55,152 @@ pub async fn run_migrations(revision: i32) -> i32 { ...@@ -56,96 +55,152 @@ pub async fn run_migrations(revision: i32) -> i32 {
} }
if revision <= 1 { if revision <= 1 {
info!("Running migration [revision 1]: Add channels to guild object."); info!("Running migration [revision 1 / 2021-04-24]: Migrate to Autumn v1.0.0.");
let col = get_collection("guilds"); let messages = get_collection("messages");
let mut guilds = col let attachments = get_collection("attachments");
.find(
messages
.update_many(
doc! { "attachment": { "$exists": 1 } },
doc! { "$set": { "attachment.tag": "attachments", "attachment.size": 0 } },
None, None,
FindOptions::builder().projection(doc! { "_id": 1 }).build(),
) )
.await .await
.expect("Failed to fetch guilds."); .expect("Failed to update messages.");
let mut result = get_collection("channels") attachments
.find( .update_many(
doc! { doc! {},
"type": 2 doc! { "$set": { "tag": "attachments", "size": 0 } },
}, None,
FindOptions::builder()
.projection(doc! { "_id": 1, "guild": 1 })
.build(),
) )
.await .await
.expect("Failed to fetch channels."); .expect("Failed to update attachments.");
}
let mut channels = vec![];
while let Some(doc) = result.next().await {
let channel = doc.expect("Failed to fetch channel.");
let id = channel
.get_str("_id")
.expect("Failed to get channel id.")
.to_string();
let gid = channel
.get_str("guild")
.expect("Failed to get guild id.")
.to_string();
channels.push((id, gid));
}
while let Some(doc) = guilds.next().await { if revision <= 2 {
let guild = doc.expect("Failed to fetch guild."); info!("Running migration [revision 2 / 2021-05-08]: Add servers collection.");
let id = guild.get_str("_id").expect("Failed to get guild id.");
let list: Vec<String> = channels get_db()
.iter() .create_collection("servers", None)
.filter(|x| x.1 == id) .await
.map(|x| x.0.clone()) .expect("Failed to create servers collection.");
.collect(); }
col.update_one( if revision <= 3 {
doc! { info!("Running migration [revision 3 / 2021-05-25]: Support multiple file uploads, add channel_unreads and user_settings.");
"_id": id
}, let messages = get_collection("messages");
let mut cursor = messages
.find(
doc! { doc! {
"$set": { "attachment": {
"channels": list "$exists": 1
} }
}, },
None, FindOptions::builder()
.projection(doc! {
"_id": 1,
"attachments": [ "$attachment" ]
})
.build(),
) )
.await .await
.expect("Failed to update guild."); .expect("Failed to fetch messages.");
while let Some(result) = cursor.next().await {
let doc = result.unwrap();
let id = doc.get_str("_id").unwrap();
let attachments = doc.get_array("attachments").unwrap();
messages
.update_one(
doc! { "_id": id },
doc! { "$unset": { "attachment": 1 }, "$set": { "attachments": attachments } },
None,
)
.await
.unwrap();
} }
get_db()
.create_collection("channel_unreads", None)
.await
.expect("Failed to create channel_unreads collection.");
get_db()
.create_collection("user_settings", None)
.await
.expect("Failed to create user_settings collection.");
} }
if revision <= 2 { if revision <= 4 {
info!("Running migration [revision 2]: Add username index to users."); info!("Running migration [revision 4 / 2021-06-01]: Add more server collections.");
get_db()
.create_collection("server_members", None)
.await
.expect("Failed to create server_members collection.");
get_db() get_db()
.run_command( .create_collection("server_bans", None)
.await
.expect("Failed to create server_bans collection.");
get_db()
.create_collection("channel_invites", None)
.await
.expect("Failed to create channel_invites collection.");
}
if revision <= 5 {
info!("Running migration [revision 5 / 2021-06-26]: Add permissions.");
#[derive(Serialize)]
struct Server {
pub default_permissions: PermissionTuple,
}
let server = Server {
default_permissions: (
*permissions::server::DEFAULT_PERMISSION as i32,
*permissions::channel::DEFAULT_PERMISSION_SERVER as i32
)
};
get_collection("servers")
.update_many(
doc! { },
doc! { doc! {
"createIndexes": "users", "$set": to_document(&server).unwrap()
"indexes": [
{
"key": {
"username": 1
},
"name": "username",
"unique": true,
"collation": {
"locale": "en",
"strength": 2
}
}
]
}, },
None, None
) )
.await .await
.expect("Failed to create username index."); .expect("Failed to migrate servers.");
}
if revision <= 6 {
info!("Running migration [revision 6 / 2021-07-09]: Add message text index.");
get_db()
.run_command(
doc! {
"createIndexes": "messages",
"indexes": [
{
"key": {
"content": "text"
},
"name": "content"
}
]
},
None,
)
.await
.expect("Failed to create message index.");
} }
// Reminder to update LATEST_REVISION when adding new migrations. // Reminder to update LATEST_REVISION when adding new migrations.
......
use crate::database::*; use crate::database::*;
use crate::util::result::{Error, Result};
use super::PermissionCalculator;
use num_enum::TryFromPrimitive; use num_enum::TryFromPrimitive;
use std::ops; use std::ops;
#[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)] #[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)]
#[repr(u32)] #[repr(u32)]
pub enum ChannelPermission { pub enum ChannelPermission {
View = 1, View = 0b00000000000000000000000000000001, // 1
SendMessage = 2, SendMessage = 0b00000000000000000000000000000010, // 2
ManageMessages = 4, ManageMessages = 0b00000000000000000000000000000100, // 4
ManageChannel = 0b00000000000000000000000000001000, // 8
VoiceCall = 0b00000000000000000000000000010000, // 16
InviteOthers = 0b00000000000000000000000000100000, // 32
EmbedLinks = 0b00000000000000000000000001000000, // 64
UploadFiles = 0b00000000000000000000000010000000, // 128
} }
lazy_static! {
pub static ref DEFAULT_PERMISSION_DM: u32 =
ChannelPermission::View
+ ChannelPermission::SendMessage
+ ChannelPermission::ManageChannel
+ ChannelPermission::VoiceCall
+ ChannelPermission::InviteOthers
+ ChannelPermission::EmbedLinks
+ ChannelPermission::UploadFiles;
pub static ref DEFAULT_PERMISSION_SERVER: u32 =
ChannelPermission::View
+ ChannelPermission::SendMessage
+ ChannelPermission::VoiceCall
+ ChannelPermission::InviteOthers
+ ChannelPermission::EmbedLinks
+ ChannelPermission::UploadFiles;
}
impl_op_ex!(+ |a: &ChannelPermission, b: &ChannelPermission| -> u32 { *a as u32 | *b as u32 });
impl_op_ex_commutative!(+ |a: &u32, b: &ChannelPermission| -> u32 { *a | *b as u32 });
bitfield! { bitfield! {
pub struct ChannelPermissions(MSB0 [u32]); pub struct ChannelPermissions(MSB0 [u32]);
u32; u32;
pub get_view, _: 31; pub get_view, _: 31;
pub get_send_message, _: 30; pub get_send_message, _: 30;
pub get_manage_messages, _: 29; pub get_manage_messages, _: 29;
pub get_manage_channel, _: 28;
pub get_voice_call, _: 27;
pub get_invite_others, _: 26;
pub get_embed_links, _: 25;
pub get_upload_files, _: 24;
} }
impl_op_ex!(+ |a: &ChannelPermission, b: &ChannelPermission| -> u32 { *a as u32 | *b as u32 }); impl<'a> PermissionCalculator<'a> {
impl_op_ex_commutative!(+ |a: &u32, b: &ChannelPermission| -> u32 { *a | *b as u32 }); pub async fn calculate_channel(self) -> Result<u32> {
let channel = if let Some(channel) = self.channel {
channel
} else {
unreachable!()
};
pub async fn calculate(user: &User, target: &Channel) -> ChannelPermissions<[u32; 1]> { match channel {
match target { Channel::SavedMessages { user: owner, .. } => {
Channel::SavedMessages { user: owner, .. } => { if &self.perspective.id == owner {
if &user.id == owner { Ok(u32::MAX)
ChannelPermissions([ChannelPermission::View } else {
+ ChannelPermission::SendMessage Ok(0)
+ ChannelPermission::ManageMessages]) }
} else {
ChannelPermissions([0])
} }
} Channel::DirectMessage { recipients, .. } => {
Channel::DirectMessage { recipients, .. } => { if recipients
if recipients.iter().find(|x| *x == &user.id).is_some() { .iter()
if let Some(recipient) = recipients.iter().find(|x| *x != &user.id) { .find(|x| *x == &self.perspective.id)
let perms = super::user::calculate(&user, recipient).await; .is_some()
{
if perms.get_send_message() { if let Some(recipient) = recipients.iter().find(|x| *x != &self.perspective.id)
return ChannelPermissions([ {
ChannelPermission::View + ChannelPermission::SendMessage let perms = self.for_user(recipient).await?;
]);
if perms.get_send_message() {
return Ok(*DEFAULT_PERMISSION_DM);
}
return Ok(ChannelPermission::View as u32);
} }
}
return ChannelPermissions([ChannelPermission::View as u32]); Ok(0)
}
Channel::Group { recipients, permissions, owner, .. } => {
if &self.perspective.id == owner {
return Ok(*DEFAULT_PERMISSION_DM)
}
if recipients
.iter()
.find(|x| *x == &self.perspective.id)
.is_some()
{
if let Some(permissions) = permissions {
Ok(permissions.clone() as u32)
} else {
Ok(*DEFAULT_PERMISSION_DM)
}
} else {
Ok(0)
} }
} }
Channel::TextChannel { server, default_permissions, role_permissions, .. }
| Channel::VoiceChannel { server, default_permissions, role_permissions, .. } => {
let server = Ref::from_unchecked(server.clone()).fetch_server().await?;
ChannelPermissions([0]) if self.perspective.id == server.owner {
} Ok(u32::MAX)
Channel::Group { recipients, .. } => { } else {
if recipients.iter().find(|x| *x == &user.id).is_some() { match Ref::from_unchecked(self.perspective.id.clone()).fetch_member(&server.id).await {
ChannelPermissions([ChannelPermission::View + ChannelPermission::SendMessage]) Ok(member) => {
} else { let mut perm = if let Some(permission) = default_permissions {
ChannelPermissions([0]) *permission as u32
} else {
server.default_permissions.1 as u32
};
if let Some(roles) = member.roles {
for role in roles {
if let Some(permission) = role_permissions.get(&role) {
perm |= *permission as u32;
}
if let Some(server_role) = server.roles.get(&role) {
perm |= server_role.permissions.1 as u32;
}
}
}
Ok(perm)
}
Err(error) => {
match &error {
Error::NotFound => Ok(0),
_ => Err(error)
}
}
}
}
} }
} }
} }
pub async fn for_channel(self) -> Result<ChannelPermissions<[u32; 1]>> {
Ok(ChannelPermissions([self.calculate_channel().await?]))
}
} }
pub use crate::database::*;
pub mod channel; pub mod channel;
pub mod server;
pub mod user; pub mod user;
pub use user::get_relationship; pub use user::get_relationship;
pub struct PermissionCalculator<'a> {
perspective: &'a User,
user: Option<&'a User>,
relationship: Option<&'a RelationshipStatus>,
channel: Option<&'a Channel>,
server: Option<&'a Server>,
// member: Option<&'a Member>,
has_mutual_connection: bool,
}
impl<'a> PermissionCalculator<'a> {
pub fn new(perspective: &'a User) -> PermissionCalculator {
PermissionCalculator {
perspective,
user: None,
relationship: None,
channel: None,
server: None,
// member: None,
has_mutual_connection: false,
}
}
pub fn with_user(self, user: &'a User) -> PermissionCalculator {
PermissionCalculator {
user: Some(&user),
..self
}
}
pub fn with_relationship(self, relationship: &'a RelationshipStatus) -> PermissionCalculator {
PermissionCalculator {
relationship: Some(&relationship),
..self
}
}
pub fn with_channel(self, channel: &'a Channel) -> PermissionCalculator {
PermissionCalculator {
channel: Some(&channel),
..self
}
}
pub fn with_server(self, server: &'a Server) -> PermissionCalculator {
PermissionCalculator {
server: Some(&server),
..self
}
}
/* pub fn with_member(self, member: &'a Member) -> PermissionCalculator {
PermissionCalculator {
member: Some(&member),
..self
}
} */
pub fn with_mutual_connection(self) -> PermissionCalculator<'a> {
PermissionCalculator {
has_mutual_connection: true,
..self
}
}
}
use crate::util::result::{Error, Result};
use super::PermissionCalculator;
use super::Ref;
use num_enum::TryFromPrimitive;
use std::ops;
#[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)]
#[repr(u32)]
pub enum ServerPermission {
View = 0b00000000000000000000000000000001, // 1
ManageRoles = 0b00000000000000000000000000000010, // 2
ManageChannels = 0b00000000000000000000000000000100, // 4
ManageServer = 0b00000000000000000000000000001000, // 8
KickMembers = 0b00000000000000000000000000010000, // 16
BanMembers = 0b00000000000000000000000000100000, // 32
// 6 bits of space
ChangeNickname = 0b00000000000000000001000000000000, // 4096
ManageNicknames = 0b00000000000000000010000000000000, // 8192
ChangeAvatar = 0b00000000000000000100000000000000, // 16382
RemoveAvatars = 0b00000000000000001000000000000000, // 32768
// 16 bits of space
}
lazy_static! {
pub static ref DEFAULT_PERMISSION: u32 =
ServerPermission::View
+ ServerPermission::ChangeNickname
+ ServerPermission::ChangeAvatar;
}
impl_op_ex!(+ |a: &ServerPermission, b: &ServerPermission| -> u32 { *a as u32 | *b as u32 });
impl_op_ex_commutative!(+ |a: &u32, b: &ServerPermission| -> u32 { *a | *b as u32 });
bitfield! {
pub struct ServerPermissions(MSB0 [u32]);
u32;
pub get_view, _: 31;
pub get_manage_roles, _: 30;
pub get_manage_channels, _: 29;
pub get_manage_server, _: 28;
pub get_kick_members, _: 27;
pub get_ban_members, _: 26;
pub get_change_nickname, _: 19;
pub get_manage_nicknames, _: 18;
pub get_change_avatar, _: 17;
pub get_remove_avatars, _: 16;
}
impl<'a> PermissionCalculator<'a> {
pub async fn calculate_server(self) -> Result<u32> {
let server = if let Some(server) = self.server {
server
} else {
unreachable!()
};
if self.perspective.id == server.owner {
Ok(u32::MAX)
} else {
match Ref::from_unchecked(self.perspective.id.clone()).fetch_member(&server.id).await {
Ok(member) => {
let mut perm = server.default_permissions.0 as u32;
if let Some(roles) = member.roles {
for role in roles {
if let Some(server_role) = server.roles.get(&role) {
perm |= server_role.permissions.0 as u32;
}
}
}
Ok(perm)
}
Err(error) => {
match &error {
Error::NotFound => Ok(0),
_ => Err(error)
}
}
}
}
}
pub async fn for_server(self) -> Result<ServerPermissions<[u32; 1]>> {
Ok(ServerPermissions([self.calculate_server().await?]))
}
}
use crate::database::*; use crate::database::*;
use crate::util::result::{Error, Result};
use super::PermissionCalculator;
use mongodb::bson::doc;
use num_enum::TryFromPrimitive; use num_enum::TryFromPrimitive;
use std::ops; use std::ops;
#[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)] #[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)]
#[repr(u32)] #[repr(u32)]
pub enum UserPermission { pub enum UserPermission {
Access = 1, Access = 0b00000000000000000000000000000001, // 1
SendMessage = 2, ViewProfile = 0b00000000000000000000000000000010, // 2
Invite = 4, SendMessage = 0b00000000000000000000000000000100, // 4
Invite = 0b00000000000000000000000000001000, // 8
} }
bitfield! { bitfield! {
pub struct UserPermissions(MSB0 [u32]); pub struct UserPermissions(MSB0 [u32]);
u32; u32;
pub get_access, _: 31; pub get_access, _: 31;
pub get_send_message, _: 30; pub get_view_profile, _: 30;
pub get_invite, _: 29; pub get_send_message, _: 29;
pub get_invite, _: 28;
} }
impl_op_ex!(+ |a: &UserPermission, b: &UserPermission| -> u32 { *a as u32 | *b as u32 }); impl_op_ex!(+ |a: &UserPermission, b: &UserPermission| -> u32 { *a as u32 | *b as u32 });
impl_op_ex_commutative!(+ |a: &u32, b: &UserPermission| -> u32 { *a | *b as u32 }); impl_op_ex_commutative!(+ |a: &u32, b: &UserPermission| -> u32 { *a | *b as u32 });
pub async fn calculate(user: &User, target: &str) -> UserPermissions<[u32; 1]> {
// if friends; Access + Message + Invite
// if mutually know each other:
// and has DMs from users enabled -> Access + Message
// otherwise -> Access
// otherwise; None
let mut permissions: u32 = 0;
match get_relationship(&user, &target) {
RelationshipStatus::Friend => {
return UserPermissions([UserPermission::Access
+ UserPermission::SendMessage
+ UserPermission::Invite])
}
RelationshipStatus::Blocked | RelationshipStatus::BlockedOther => {
return UserPermissions([UserPermission::Access as u32])
}
RelationshipStatus::Incoming | RelationshipStatus::Outgoing => {
permissions = UserPermission::Access as u32;
}
_ => {}
}
UserPermissions([permissions])
}
pub fn get_relationship(a: &User, b: &str) -> RelationshipStatus { pub fn get_relationship(a: &User, b: &str) -> RelationshipStatus {
if a.id == b { if a.id == b {
return RelationshipStatus::Friend; return RelationshipStatus::User;
} }
if let Some(relations) = &a.relations { if let Some(relations) = &a.relations {
...@@ -60,3 +41,91 @@ pub fn get_relationship(a: &User, b: &str) -> RelationshipStatus { ...@@ -60,3 +41,91 @@ pub fn get_relationship(a: &User, b: &str) -> RelationshipStatus {
RelationshipStatus::None RelationshipStatus::None
} }
impl<'a> PermissionCalculator<'a> {
pub async fn calculate_user(self, target: &str) -> Result<u32> {
if &self.perspective.id == target {
return Ok(u32::MAX);
}
let mut permissions: u32 = 0;
match self
.relationship
.clone()
.map(|v| v.to_owned())
.unwrap_or_else(|| get_relationship(&self.perspective, &target))
{
RelationshipStatus::Friend | RelationshipStatus::User => return Ok(u32::MAX),
RelationshipStatus::Blocked | RelationshipStatus::BlockedOther => {
return Ok(UserPermission::Access as u32)
}
RelationshipStatus::Incoming | RelationshipStatus::Outgoing => {
permissions = UserPermission::Access as u32;
// ! INFO: if we add boolean switch for permission to
// ! message people who have mutual, we need to get
// ! rid of this return statement.
// return Ok(permissions);
}
_ => {}
}
let check_server_overlap = async || {
let server_ids = User::fetch_server_ids(&self.perspective.id).await?;
Ok(
get_collection("server_members")
.find_one(
doc! {
"_id.user": &target,
"_id.server": {
"$in": server_ids
}
},
None
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "server_members",
})?
.is_some()
)
};
if self.has_mutual_connection
|| check_server_overlap().await?
|| get_collection("channels")
.find_one(
doc! {
"channel_type": {
"$in": ["Group", "DirectMessage"]
},
"recipients": {
"$all": [ &self.perspective.id, target ]
}
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "channels",
})?
.is_some()
{
// ! FIXME: add privacy settings
return Ok(UserPermission::Access + UserPermission::ViewProfile);
}
Ok(permissions)
}
pub async fn for_user(self, target: &str) -> Result<UserPermissions<[u32; 1]>> {
Ok(UserPermissions([self.calculate_user(&target).await?]))
}
pub async fn for_user_given(self) -> Result<UserPermissions<[u32; 1]>> {
let id = &self.user.unwrap().id;
Ok(UserPermissions([self.calculate_user(&id).await?]))
}
}
...@@ -17,18 +17,33 @@ pub mod database; ...@@ -17,18 +17,33 @@ pub mod database;
pub mod notifications; pub mod notifications;
pub mod routes; pub mod routes;
pub mod util; pub mod util;
pub mod version;
use async_std::task;
use chrono::Duration;
use futures::join; use futures::join;
use log::info; use log::info;
use rauth; use rauth::options::{EmailVerification, Options, SMTP};
use rauth::{
auth::Auth,
options::{Template, Templates},
};
use rocket_cors::AllowedOrigins; use rocket_cors::AllowedOrigins;
use rocket_prometheus::PrometheusMetrics;
use util::variables::{
APP_URL, HCAPTCHA_KEY, INVITE_ONLY, PUBLIC_URL, SMTP_FROM, SMTP_HOST, SMTP_PASSWORD,
SMTP_USERNAME, USE_EMAIL, USE_HCAPTCHA, USE_PROMETHEUS,
};
#[async_std::main] #[async_std::main]
async fn main() { async fn main() {
dotenv::dotenv().ok(); dotenv::dotenv().ok();
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")); env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "info"));
info!("Starting REVOLT server."); info!(
"Starting REVOLT server [version {}].",
crate::version::VERSION
);
util::variables::preflight_checks(); util::variables::preflight_checks();
database::connect().await; database::connect().await;
...@@ -40,10 +55,13 @@ async fn main() { ...@@ -40,10 +55,13 @@ async fn main() {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
let web_task = task::spawn(launch_web());
let hive_task = task::spawn(notifications::hive::listen());
join!( join!(
launch_web(), web_task,
notifications::websocket::launch_server(), hive_task,
notifications::hive::listen(), notifications::websocket::launch_server()
); );
} }
...@@ -55,10 +73,71 @@ async fn launch_web() { ...@@ -55,10 +73,71 @@ async fn launch_web() {
.to_cors() .to_cors()
.expect("Failed to create CORS."); .expect("Failed to create CORS.");
let auth = rauth::auth::Auth::new(database::get_collection("accounts")); let mut options = Options::new()
.base_url(format!("{}/auth", *PUBLIC_URL))
.email_verification(if *USE_EMAIL {
EmailVerification::Enabled {
success_redirect_uri: format!("{}/login", *APP_URL),
welcome_redirect_uri: format!("{}/welcome", *APP_URL),
password_reset_url: Some(format!("{}/login/reset", *APP_URL)),
verification_expiry: Duration::days(1),
password_reset_expiry: Duration::hours(1),
templates: Templates {
verify_email: Template {
title: "Verify your Revolt account.",
text: "You're almost there!
If you did not perform this action you can safely ignore this email.
Please verify your account here: {{url}}",
html: None,
},
reset_password: Template {
title: "Reset your Revolt password.",
text: "You requested a password reset, if you did not perform this action you can safely ignore this email.
Reset your password here: {{url}}",
html: None,
},
welcome: None,
},
smtp: SMTP {
from: (*SMTP_FROM).to_string(),
host: (*SMTP_HOST).to_string(),
username: (*SMTP_USERNAME).to_string(),
password: (*SMTP_PASSWORD).to_string(),
},
}
} else {
EmailVerification::Disabled
});
if *INVITE_ONLY {
options = options.invite_only_collection(database::get_collection("invites"))
}
if *USE_HCAPTCHA {
options = options.hcaptcha_secret(HCAPTCHA_KEY.clone());
}
let auth = Auth::new(database::get_collection("accounts"), options);
let mut rocket = rocket::ignite();
if *USE_PROMETHEUS {
info!("Enabled Prometheus metrics!");
let prometheus = PrometheusMetrics::new();
rocket = rocket
.attach(prometheus.clone())
.mount("/metrics", prometheus);
}
routes::mount(rauth::routes::mount(rocket::ignite(), "/auth", auth)) routes::mount(rocket)
.mount("/", rocket_cors::catch_all_options_routes()) .mount("/", rocket_cors::catch_all_options_routes())
.mount("/auth", rauth::routes::routes())
.manage(auth)
.manage(cors.clone()) .manage(cors.clone())
.attach(cors) .attach(cors)
.launch() .launch()
......
use hive_pubsub::PubSub; use hive_pubsub::PubSub;
use mongodb::bson::doc;
use rauth::auth::Session; use rauth::auth::Session;
use rocket_contrib::json::JsonValue; use rocket_contrib::json::JsonValue;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use snafu::Snafu;
use super::hive::{get_hive, subscribe_if_exists}; use super::hive::{get_hive, subscribe_if_exists};
use crate::database::*; use crate::{database::*, util::result::Result};
#[derive(Serialize, Deserialize, Debug, Snafu)] #[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "error")] #[serde(tag = "error")]
pub enum WebSocketError { pub enum WebSocketError {
#[snafu(display("This error has not been labelled."))]
LabelMe, LabelMe,
#[snafu(display("Internal server error."))]
InternalError { at: String }, InternalError { at: String },
#[snafu(display("Invalid session."))]
InvalidSession, InvalidSession,
#[snafu(display("User hasn't completed onboarding."))]
OnboardingNotFinished, OnboardingNotFinished,
#[snafu(display("Already authenticated with server."))]
AlreadyAuthenticated, AlreadyAuthenticated,
} }
...@@ -26,26 +21,75 @@ pub enum WebSocketError { ...@@ -26,26 +21,75 @@ pub enum WebSocketError {
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum ServerboundNotification { pub enum ServerboundNotification {
Authenticate(Session), Authenticate(Session),
BeginTyping { channel: String },
EndTyping { channel: String },
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveUserField {
ProfileContent,
ProfileBackground,
StatusText,
Avatar,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveChannelField {
Icon,
Description
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveServerField {
Icon,
Banner,
Description,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveRoleField {
Colour,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveMemberField {
Nickname,
Avatar,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum ClientboundNotification { pub enum ClientboundNotification {
Error(WebSocketError), Error(WebSocketError),
Authenticated, Authenticated,
Ready { Ready {
users: Vec<User>, users: Vec<User>,
servers: Vec<Server>,
channels: Vec<Channel>, channels: Vec<Channel>,
members: Vec<Member>
}, },
Message(Message), Message(Message),
MessageUpdate(JsonValue), MessageUpdate {
id: String,
channel: String,
data: JsonValue,
},
MessageDelete { MessageDelete {
id: String, id: String,
channel: String,
}, },
ChannelCreate(Channel), ChannelCreate(Channel),
ChannelUpdate(JsonValue), ChannelUpdate {
id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveChannelField>,
},
ChannelDelete {
id: String,
},
ChannelGroupJoin { ChannelGroupJoin {
id: String, id: String,
user: String, user: String,
...@@ -54,29 +98,96 @@ pub enum ClientboundNotification { ...@@ -54,29 +98,96 @@ pub enum ClientboundNotification {
id: String, id: String,
user: String, user: String,
}, },
ChannelDelete { ChannelStartTyping {
id: String, id: String,
user: String,
},
ChannelStopTyping {
id: String,
user: String,
},
ChannelAck {
id: String,
user: String,
message_id: String,
}, },
UserRelationship { ServerUpdate {
id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveServerField>,
},
ServerDelete {
id: String,
},
ServerMemberUpdate {
id: MemberCompositeKey,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveMemberField>,
},
ServerMemberJoin {
id: String,
user: String,
},
ServerMemberLeave {
id: String, id: String,
user: String, user: String,
},
ServerRoleUpdate {
id: String,
role_id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveRoleField>
},
ServerRoleDelete {
id: String,
role_id: String
},
UserUpdate {
id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveUserField>,
},
UserRelationship {
id: String,
user: User,
status: RelationshipStatus, status: RelationshipStatus,
}, },
UserPresence { UserSettingsUpdate {
id: String, id: String,
online: bool, update: JsonValue,
}, },
} }
impl ClientboundNotification { impl ClientboundNotification {
pub async fn publish(self, topic: String) -> Result<(), String> { pub fn publish(self, topic: String) {
prehandle_hook(&self); // ! TODO: this should be moved to pubsub async_std::task::spawn(async move {
hive_pubsub::backend::mongo::publish(get_hive(), &topic, self).await prehandle_hook(&self).await.ok(); // ! FIXME: this should be moved to pubsub
hive_pubsub::backend::mongo::publish(get_hive(), &topic, self)
.await
.ok();
});
}
pub fn publish_as_user(self, user: String) {
self.clone().publish(user.clone());
async_std::task::spawn(async move {
if let Ok(server_ids) = User::fetch_server_ids(&user).await {
for server in server_ids {
self.clone().publish(server.clone());
}
}
});
} }
} }
pub fn prehandle_hook(notification: &ClientboundNotification) { pub async fn prehandle_hook(notification: &ClientboundNotification) -> Result<()> {
match &notification { match &notification {
ClientboundNotification::ChannelGroupJoin { id, user } => { ClientboundNotification::ChannelGroupJoin { id, user } => {
subscribe_if_exists(user.clone(), id.clone()).ok(); subscribe_if_exists(user.clone(), id.clone()).ok();
...@@ -92,34 +203,59 @@ pub fn prehandle_hook(notification: &ClientboundNotification) { ...@@ -92,34 +203,59 @@ pub fn prehandle_hook(notification: &ClientboundNotification) {
subscribe_if_exists(recipient.clone(), channel_id.to_string()).ok(); subscribe_if_exists(recipient.clone(), channel_id.to_string()).ok();
} }
} }
Channel::TextChannel { server, .. }
| Channel::VoiceChannel { server, .. } => {
// ! FIXME: write a better algorithm?
let members = Server::fetch_member_ids(server).await?;
for member in members {
subscribe_if_exists(member.clone(), channel_id.to_string()).ok();
}
}
} }
} }
ClientboundNotification::ChannelGroupLeave { id, user } => { ClientboundNotification::ServerMemberJoin { id, user } => {
get_hive() let server = Ref::from_unchecked(id.clone()).fetch_server().await?;
.hive
.unsubscribe(&user.to_string(), &id.to_string()) subscribe_if_exists(user.clone(), id.clone()).ok();
.ok();
for channel in server.channels {
subscribe_if_exists(user.clone(), channel).ok();
}
} }
ClientboundNotification::UserRelationship { id, user, status } => { ClientboundNotification::UserRelationship { id, user, status } => {
if status != &RelationshipStatus::None { if status != &RelationshipStatus::None {
subscribe_if_exists(id.clone(), user.clone()).ok(); subscribe_if_exists(id.clone(), user.id.clone()).ok();
} }
} }
_ => {} _ => {}
} }
Ok(())
} }
pub fn posthandle_hook(notification: &ClientboundNotification) { pub async fn posthandle_hook(notification: &ClientboundNotification) {
match &notification { match &notification {
ClientboundNotification::ChannelDelete { id } => { ClientboundNotification::ChannelDelete { id } => {
get_hive().hive.drop_topic(&id).ok(); get_hive().hive.drop_topic(&id).ok();
} }
ClientboundNotification::ChannelGroupLeave { id, user } => {
get_hive().hive.unsubscribe(user, id).ok();
}
ClientboundNotification::ServerDelete { id } => {
get_hive().hive.drop_topic(&id).ok();
}
ClientboundNotification::ServerMemberLeave { id, user } => {
get_hive().hive.unsubscribe(user, id).ok();
if let Ok(server) = Ref::from_unchecked(id.clone()).fetch_server().await {
for channel in server.channels {
get_hive().hive.unsubscribe(user, &channel).ok();
}
}
}
ClientboundNotification::UserRelationship { id, user, status } => { ClientboundNotification::UserRelationship { id, user, status } => {
if status == &RelationshipStatus::None { if status == &RelationshipStatus::None {
get_hive() get_hive().hive.unsubscribe(id, &user.id).ok();
.hive
.unsubscribe(&id.to_string(), &user.to_string())
.ok();
} }
} }
_ => {} _ => {}
......
...@@ -13,8 +13,11 @@ static HIVE: OnceCell<Hive> = OnceCell::new(); ...@@ -13,8 +13,11 @@ static HIVE: OnceCell<Hive> = OnceCell::new();
pub async fn init_hive() { pub async fn init_hive() {
let hive = MongodbPubSub::new( let hive = MongodbPubSub::new(
|ids, notification| { |ids, notification: ClientboundNotification| {
super::events::posthandle_hook(&notification); let notif = notification.clone();
async_std::task::spawn(async move {
super::events::posthandle_hook(&notif).await;
});
if let Ok(data) = to_string(&notification) { if let Ok(data) = to_string(&notification) {
debug!("Pushing out notification. {}", data); debug!("Pushing out notification. {}", data);
......
use crate::notifications::events::ClientboundNotification; use std::collections::HashSet;
use crate::{database::*, notifications::events::ClientboundNotification};
use crate::{ use crate::{
database::{entities::User, get_collection}, database::{entities::User, get_collection},
util::result::{Error, Result}, util::result::{Error, Result},
}; };
use futures::StreamExt; use futures::StreamExt;
use mongodb::{ use mongodb::bson::{doc, from_document};
bson::{doc, from_document},
options::FindOptions,
};
use super::websocket::is_online;
pub async fn generate_ready(mut user: User) -> Result<ClientboundNotification> { pub async fn generate_ready(mut user: User) -> Result<ClientboundNotification> {
let mut users = vec![]; let mut user_ids: HashSet<String> = HashSet::new();
if let Some(relationships) = &user.relations { if let Some(relationships) = &user.relations {
let user_ids: Vec<String> = relationships user_ids.extend(
.iter() relationships
.map(|relationship| relationship.id.clone()) .iter()
.collect(); .map(|relationship| relationship.id.clone()),
);
let mut cursor = get_collection("users") }
.find(
doc! {
"_id": {
"$in": user_ids
}
},
FindOptions::builder()
.projection(doc! { "_id": 1, "username": 1 })
.build(),
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "users",
})?;
while let Some(result) = cursor.next().await {
if let Ok(doc) = result {
let mut user: User = from_document(doc).map_err(|_| Error::DatabaseError {
operation: "from_document",
with: "user",
})?;
user.relationship = Some( let members = User::fetch_memberships(&user.id).await?;
relationships let server_ids: Vec<String> = members.iter()
.iter() .map(|x| x.id.server.clone())
.find(|x| user.id == x.id) .collect();
.ok_or_else(|| Error::InternalError)?
.status let mut cursor = get_collection("servers")
.clone(), .find(
); doc! {
"_id": {
"$in": server_ids
}
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "servers",
})?;
user.online = Some(is_online(&user.id)); let mut servers = vec![];
let mut channel_ids = vec![];
while let Some(result) = cursor.next().await {
if let Ok(doc) = result {
let server: Server = from_document(doc).map_err(|_| Error::DatabaseError {
operation: "from_document",
with: "server",
})?;
users.push(user); channel_ids.extend(server.channels.iter().cloned());
} servers.push(server);
} }
} }
...@@ -65,16 +58,20 @@ pub async fn generate_ready(mut user: User) -> Result<ClientboundNotification> { ...@@ -65,16 +58,20 @@ pub async fn generate_ready(mut user: User) -> Result<ClientboundNotification> {
doc! { doc! {
"$or": [ "$or": [
{ {
"type": "SavedMessages", "_id": {
"$in": channel_ids
}
},
{
"channel_type": "SavedMessages",
"user": &user.id "user": &user.id
}, },
{ {
"type": "DirectMessage", "channel_type": "DirectMessage",
"recipients": &user.id, "recipients": &user.id
"active": true
}, },
{ {
"type": "Group", "channel_type": "Group",
"recipients": &user.id "recipients": &user.id
} }
] ]
...@@ -90,15 +87,37 @@ pub async fn generate_ready(mut user: User) -> Result<ClientboundNotification> { ...@@ -90,15 +87,37 @@ pub async fn generate_ready(mut user: User) -> Result<ClientboundNotification> {
let mut channels = vec![]; let mut channels = vec![];
while let Some(result) = cursor.next().await { while let Some(result) = cursor.next().await {
if let Ok(doc) = result { if let Ok(doc) = result {
channels.push(from_document(doc).map_err(|_| Error::DatabaseError { let channel = from_document(doc).map_err(|_| Error::DatabaseError {
operation: "from_document", operation: "from_document",
with: "channel", with: "channel",
})?); })?;
if let Channel::Group { recipients, .. } = &channel {
user_ids.extend(recipients.iter().cloned());
} else if let Channel::DirectMessage { recipients, .. } = &channel {
user_ids.extend(recipients.iter().cloned());
}
channels.push(channel);
} }
} }
user_ids.remove(&user.id);
let mut users = if user_ids.len() > 0 {
user.fetch_multiple_users(user_ids.into_iter().collect::<Vec<String>>())
.await?
} else {
vec![]
};
user.relationship = Some(RelationshipStatus::User);
user.online = Some(true); user.online = Some(true);
users.push(user); users.push(user);
Ok(ClientboundNotification::Ready { users, channels }) Ok(ClientboundNotification::Ready {
users,
servers,
channels,
members
})
} }
...@@ -4,6 +4,7 @@ use super::hive::get_hive; ...@@ -4,6 +4,7 @@ use super::hive::get_hive;
use futures::StreamExt; use futures::StreamExt;
use hive_pubsub::PubSub; use hive_pubsub::PubSub;
use mongodb::bson::doc; use mongodb::bson::doc;
use mongodb::bson::Document;
use mongodb::options::FindOptions; use mongodb::options::FindOptions;
pub async fn generate_subscriptions(user: &User) -> Result<(), String> { pub async fn generate_subscriptions(user: &User) -> Result<(), String> {
...@@ -16,21 +17,57 @@ pub async fn generate_subscriptions(user: &User) -> Result<(), String> { ...@@ -16,21 +17,57 @@ pub async fn generate_subscriptions(user: &User) -> Result<(), String> {
} }
} }
let server_ids = User::fetch_server_ids(&user.id)
.await
.map_err(|_| "Failed to fetch memberships.".to_string())?;
let channel_ids = get_collection("servers")
.find(
doc! {
"_id": {
"$in": &server_ids
}
},
None,
)
.await
.map_err(|_| "Failed to fetch servers.".to_string())?
.filter_map(async move |s| s.ok())
.collect::<Vec<Document>>()
.await
.into_iter()
.filter_map(|x| {
x.get_array("channels").ok().map(|v| {
v.into_iter()
.filter_map(|x| x.as_str().map(|x| x.to_string()))
.collect::<Vec<String>>()
})
})
.flatten()
.collect::<Vec<String>>();
for id in server_ids {
hive.subscribe(user.id.clone(), id)?;
}
for id in channel_ids {
hive.subscribe(user.id.clone(), id)?;
}
let mut cursor = get_collection("channels") let mut cursor = get_collection("channels")
.find( .find(
doc! { doc! {
"$or": [ "$or": [
{ {
"type": "SavedMessages", "channel_type": "SavedMessages",
"user": &user.id "user": &user.id
}, },
{ {
"type": "DirectMessage", "channel_type": "DirectMessage",
"recipients": &user.id, "recipients": &user.id
"active": true
}, },
{ {
"type": "Group", "channel_type": "Group",
"recipients": &user.id "recipients": &user.id
} }
] ]
......
...@@ -12,7 +12,10 @@ use futures::{pin_mut, prelude::*}; ...@@ -12,7 +12,10 @@ use futures::{pin_mut, prelude::*};
use hive_pubsub::PubSub; use hive_pubsub::PubSub;
use log::{debug, info}; use log::{debug, info};
use many_to_many::ManyToMany; use many_to_many::ManyToMany;
use rauth::auth::{Auth, Session}; use rauth::{
auth::{Auth, Session},
options::Options,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
...@@ -68,8 +71,6 @@ async fn accept(stream: TcpStream) { ...@@ -68,8 +71,6 @@ async fn accept(stream: TcpStream) {
let fwd = rx.map(Ok).forward(write); let fwd = rx.map(Ok).forward(write);
let incoming = read.try_for_each(async move |msg| { let incoming = read.try_for_each(async move |msg| {
let mutex = mutex_generator(); let mutex = mutex_generator();
//dbg!(&mutex.lock().unwrap());
if let Message::Text(text) = msg { if let Message::Text(text) = msg {
if let Ok(notification) = serde_json::from_str::<ServerboundNotification>(&text) { if let Ok(notification) = serde_json::from_str::<ServerboundNotification>(&text) {
match notification { match notification {
...@@ -84,9 +85,10 @@ async fn accept(stream: TcpStream) { ...@@ -84,9 +85,10 @@ async fn accept(stream: TcpStream) {
} }
} }
if let Ok(validated_session) = Auth::new(get_collection("accounts")) if let Ok(validated_session) =
.verify_session(new_session) Auth::new(get_collection("accounts"), Options::new())
.await .verify_session(new_session)
.await
{ {
let id = validated_session.user_id.clone(); let id = validated_session.user_id.clone();
if let Ok(user) = (Ref { id: id.clone() }).fetch_user().await { if let Ok(user) = (Ref { id: id.clone() }).fetch_user().await {
...@@ -127,13 +129,14 @@ async fn accept(stream: TcpStream) { ...@@ -127,13 +129,14 @@ async fn accept(stream: TcpStream) {
send(payload); send(payload);
if !was_online { if !was_online {
ClientboundNotification::UserPresence { ClientboundNotification::UserUpdate {
id: id.clone(), id: id.clone(),
online: true, data: json!({
"online": true
}),
clear: None
} }
.publish(id) .publish_as_user(id);
.await
.ok();
} }
} }
Err(_) => { Err(_) => {
...@@ -157,6 +160,50 @@ async fn accept(stream: TcpStream) { ...@@ -157,6 +160,50 @@ async fn accept(stream: TcpStream) {
)); ));
} }
} }
// ! TEMP: verify user part of channel
// ! Could just run permission check here.
ServerboundNotification::BeginTyping { channel } => {
if mutex.lock().unwrap().is_some() {
let user = {
let mutex = mutex.lock().unwrap();
let session = mutex.as_ref().unwrap();
session.user_id.clone()
};
ClientboundNotification::ChannelStartTyping {
id: channel.clone(),
user,
}
.publish(channel);
} else {
send(ClientboundNotification::Error(
WebSocketError::AlreadyAuthenticated,
));
return Ok(());
}
}
ServerboundNotification::EndTyping { channel } => {
if mutex.lock().unwrap().is_some() {
let user = {
let mutex = mutex.lock().unwrap();
let session = mutex.as_ref().unwrap();
session.user_id.clone()
};
ClientboundNotification::ChannelStopTyping {
id: channel.clone(),
user,
}
.publish(channel);
} else {
send(ClientboundNotification::Error(
WebSocketError::AlreadyAuthenticated,
));
return Ok(());
}
}
} }
} }
} }
...@@ -170,14 +217,29 @@ async fn accept(stream: TcpStream) { ...@@ -170,14 +217,29 @@ async fn accept(stream: TcpStream) {
info!("User {} disconnected.", &addr); info!("User {} disconnected.", &addr);
CONNECTIONS.lock().unwrap().remove(&addr); CONNECTIONS.lock().unwrap().remove(&addr);
let session = session.lock().unwrap(); let mut offline = None;
if let Some(session) = session.as_ref() { {
let mut users = USERS.write().unwrap(); let session = session.lock().unwrap();
users.remove(&session.user_id, &addr); if let Some(session) = session.as_ref() {
if users.get_left(&session.user_id).is_none() { let mut users = USERS.write().unwrap();
get_hive().drop_client(&session.user_id).unwrap(); users.remove(&session.user_id, &addr);
if users.get_left(&session.user_id).is_none() {
get_hive().drop_client(&session.user_id).unwrap();
offline = Some(session.user_id.clone());
}
} }
} }
if let Some(id) = offline {
ClientboundNotification::UserUpdate {
id: id.clone(),
data: json!({
"online": false
}),
clear: None
}
.publish_as_user(id);
}
} }
pub fn publish(ids: Vec<String>, notification: ClientboundNotification) { pub fn publish(ids: Vec<String>, notification: ClientboundNotification) {
...@@ -186,10 +248,15 @@ pub fn publish(ids: Vec<String>, notification: ClientboundNotification) { ...@@ -186,10 +248,15 @@ pub fn publish(ids: Vec<String>, notification: ClientboundNotification) {
let users = USERS.read().unwrap(); let users = USERS.read().unwrap();
for id in ids { for id in ids {
// Block certain notifications from reaching users that aren't meant to see them. // Block certain notifications from reaching users that aren't meant to see them.
if let ClientboundNotification::UserRelationship { id: user_id, .. } = &notification { match &notification {
if &id != user_id { ClientboundNotification::UserRelationship { id: user_id, .. }
continue; | ClientboundNotification::UserSettingsUpdate { id: user_id, .. }
| ClientboundNotification::ChannelAck { user: user_id, .. } => {
if &id != user_id {
continue;
}
} }
_ => {}
} }
if let Some(mut arr) = users.get_left(&id) { if let Some(mut arr) = users.get_left(&id) {
......