Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 1645 additions and 714 deletions
use crate::database::{permissions, get_collection, get_db, PermissionTuple};
use futures::StreamExt;
use log::info;
use mongodb::{bson::{doc, from_document, to_document}, options::FindOptions};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct MigrationInfo {
_id: i32,
revision: i32,
}
pub const LATEST_REVISION: i32 = 7;
pub async fn migrate_database() {
let migrations = get_collection("migrations");
let data = migrations
.find_one(None, None)
.await
.expect("Failed to fetch migration data.");
if let Some(doc) = data {
let info: MigrationInfo =
from_document(doc).expect("Failed to read migration information.");
let revision = run_migrations(info.revision).await;
migrations
.update_one(
doc! {
"_id": info._id
},
doc! {
"$set": {
"revision": revision
}
},
None,
)
.await
.expect("Failed to commit migration information.");
info!("Migration complete. Currently at revision {}.", revision);
} else {
panic!("Database was configured incorrectly, possibly because initalization failed.")
}
}
pub async fn run_migrations(revision: i32) -> i32 {
info!("Starting database migration.");
if revision <= 0 {
info!("Running migration [revision 0]: Test migration system.");
}
if revision <= 1 {
info!("Running migration [revision 1 / 2021-04-24]: Migrate to Autumn v1.0.0.");
let messages = get_collection("messages");
let attachments = get_collection("attachments");
messages
.update_many(
doc! { "attachment": { "$exists": 1 } },
doc! { "$set": { "attachment.tag": "attachments", "attachment.size": 0 } },
None,
)
.await
.expect("Failed to update messages.");
attachments
.update_many(
doc! {},
doc! { "$set": { "tag": "attachments", "size": 0 } },
None,
)
.await
.expect("Failed to update attachments.");
}
if revision <= 2 {
info!("Running migration [revision 2 / 2021-05-08]: Add servers collection.");
get_db()
.create_collection("servers", None)
.await
.expect("Failed to create servers collection.");
}
if revision <= 3 {
info!("Running migration [revision 3 / 2021-05-25]: Support multiple file uploads, add channel_unreads and user_settings.");
let messages = get_collection("messages");
let mut cursor = messages
.find(
doc! {
"attachment": {
"$exists": 1
}
},
FindOptions::builder()
.projection(doc! {
"_id": 1,
"attachments": [ "$attachment" ]
})
.build(),
)
.await
.expect("Failed to fetch messages.");
while let Some(result) = cursor.next().await {
let doc = result.unwrap();
let id = doc.get_str("_id").unwrap();
let attachments = doc.get_array("attachments").unwrap();
messages
.update_one(
doc! { "_id": id },
doc! { "$unset": { "attachment": 1 }, "$set": { "attachments": attachments } },
None,
)
.await
.unwrap();
}
get_db()
.create_collection("channel_unreads", None)
.await
.expect("Failed to create channel_unreads collection.");
get_db()
.create_collection("user_settings", None)
.await
.expect("Failed to create user_settings collection.");
}
if revision <= 4 {
info!("Running migration [revision 4 / 2021-06-01]: Add more server collections.");
get_db()
.create_collection("server_members", None)
.await
.expect("Failed to create server_members collection.");
get_db()
.create_collection("server_bans", None)
.await
.expect("Failed to create server_bans collection.");
get_db()
.create_collection("channel_invites", None)
.await
.expect("Failed to create channel_invites collection.");
}
if revision <= 5 {
info!("Running migration [revision 5 / 2021-06-26]: Add permissions.");
#[derive(Serialize)]
struct Server {
pub default_permissions: PermissionTuple,
}
let server = Server {
default_permissions: (
*permissions::server::DEFAULT_PERMISSION as i32,
*permissions::channel::DEFAULT_PERMISSION_SERVER as i32
)
};
get_collection("servers")
.update_many(
doc! { },
doc! {
"$set": to_document(&server).unwrap()
},
None
)
.await
.expect("Failed to migrate servers.");
}
if revision <= 6 {
info!("Running migration [revision 6 / 2021-07-09]: Add message text index.");
get_db()
.run_command(
doc! {
"createIndexes": "messages",
"indexes": [
{
"key": {
"content": "text"
},
"name": "content"
}
]
},
None,
)
.await
.expect("Failed to create message index.");
}
// Reminder to update LATEST_REVISION when adding new migrations.
LATEST_REVISION
}
use mongodb::{ Client, Collection, Database };
use std::env;
use crate::util::variables::MONGO_URI;
use mongodb::{Client, Collection, Database};
use once_cell::sync::OnceCell;
static DBCONN: OnceCell<Client> = OnceCell::new();
pub fn connect() {
let client = Client::with_uri_str(
&env::var("DB_URI").expect("DB_URI not in environment variables!"))
.expect("Failed to init db connection.");
pub async fn connect() {
let client = Client::with_uri_str(&MONGO_URI)
.await
.expect("Failed to init db connection.");
DBCONN.set(client).unwrap();
DBCONN.set(client).unwrap();
migrations::run_migrations().await;
}
pub fn get_connection() -> &'static Client {
DBCONN.get().unwrap()
DBCONN.get().unwrap()
}
pub fn get_db() -> Database {
get_connection().database("revolt")
get_connection().database("revolt")
}
pub fn get_collection(collection: &str) -> Collection {
get_db().collection(collection)
get_db().collection(collection)
}
pub mod user;
pub mod channel;
pub mod message;
pub mod entities;
pub mod guards;
pub mod migrations;
pub mod permissions;
pub use entities::*;
pub use guards::*;
pub use permissions::*;
use crate::database::*;
use crate::util::result::{Error, Result};
use super::PermissionCalculator;
use num_enum::TryFromPrimitive;
use std::ops;
#[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)]
#[repr(u32)]
pub enum ChannelPermission {
View = 0b00000000000000000000000000000001, // 1
SendMessage = 0b00000000000000000000000000000010, // 2
ManageMessages = 0b00000000000000000000000000000100, // 4
ManageChannel = 0b00000000000000000000000000001000, // 8
VoiceCall = 0b00000000000000000000000000010000, // 16
InviteOthers = 0b00000000000000000000000000100000, // 32
EmbedLinks = 0b00000000000000000000000001000000, // 64
UploadFiles = 0b00000000000000000000000010000000, // 128
}
lazy_static! {
pub static ref DEFAULT_PERMISSION_DM: u32 =
ChannelPermission::View
+ ChannelPermission::SendMessage
+ ChannelPermission::ManageChannel
+ ChannelPermission::VoiceCall
+ ChannelPermission::InviteOthers
+ ChannelPermission::EmbedLinks
+ ChannelPermission::UploadFiles;
pub static ref DEFAULT_PERMISSION_SERVER: u32 =
ChannelPermission::View
+ ChannelPermission::SendMessage
+ ChannelPermission::VoiceCall
+ ChannelPermission::InviteOthers
+ ChannelPermission::EmbedLinks
+ ChannelPermission::UploadFiles;
}
impl_op_ex!(+ |a: &ChannelPermission, b: &ChannelPermission| -> u32 { *a as u32 | *b as u32 });
impl_op_ex_commutative!(+ |a: &u32, b: &ChannelPermission| -> u32 { *a | *b as u32 });
bitfield! {
pub struct ChannelPermissions(MSB0 [u32]);
u32;
pub get_view, _: 31;
pub get_send_message, _: 30;
pub get_manage_messages, _: 29;
pub get_manage_channel, _: 28;
pub get_voice_call, _: 27;
pub get_invite_others, _: 26;
pub get_embed_links, _: 25;
pub get_upload_files, _: 24;
}
impl<'a> PermissionCalculator<'a> {
pub async fn calculate_channel(self) -> Result<u32> {
let channel = if let Some(channel) = self.channel {
channel
} else {
unreachable!()
};
match channel {
Channel::SavedMessages { user: owner, .. } => {
if &self.perspective.id == owner {
Ok(u32::MAX)
} else {
Ok(0)
}
}
Channel::DirectMessage { recipients, .. } => {
if recipients
.iter()
.find(|x| *x == &self.perspective.id)
.is_some()
{
if let Some(recipient) = recipients.iter().find(|x| *x != &self.perspective.id)
{
let perms = self.for_user(recipient).await?;
if perms.get_send_message() {
return Ok(*DEFAULT_PERMISSION_DM);
}
return Ok(ChannelPermission::View as u32);
}
}
Ok(0)
}
Channel::Group { recipients, permissions, owner, .. } => {
if &self.perspective.id == owner {
return Ok(*DEFAULT_PERMISSION_DM)
}
if recipients
.iter()
.find(|x| *x == &self.perspective.id)
.is_some()
{
if let Some(permissions) = permissions {
Ok(permissions.clone() as u32)
} else {
Ok(*DEFAULT_PERMISSION_DM)
}
} else {
Ok(0)
}
}
Channel::TextChannel { server, default_permissions, role_permissions, .. }
| Channel::VoiceChannel { server, default_permissions, role_permissions, .. } => {
let server = Ref::from_unchecked(server.clone()).fetch_server().await?;
if self.perspective.id == server.owner {
Ok(u32::MAX)
} else {
match Ref::from_unchecked(self.perspective.id.clone()).fetch_member(&server.id).await {
Ok(member) => {
let mut perm = if let Some(permission) = default_permissions {
*permission as u32
} else {
server.default_permissions.1 as u32
};
if let Some(roles) = member.roles {
for role in roles {
if let Some(permission) = role_permissions.get(&role) {
perm |= *permission as u32;
}
if let Some(server_role) = server.roles.get(&role) {
perm |= server_role.permissions.1 as u32;
}
}
}
Ok(perm)
}
Err(error) => {
match &error {
Error::NotFound => Ok(0),
_ => Err(error)
}
}
}
}
}
}
}
pub async fn for_channel(self) -> Result<ChannelPermissions<[u32; 1]>> {
Ok(ChannelPermissions([self.calculate_channel().await?]))
}
}
pub use crate::database::*;
pub mod channel;
pub mod server;
pub mod user;
pub use user::get_relationship;
pub struct PermissionCalculator<'a> {
perspective: &'a User,
user: Option<&'a User>,
relationship: Option<&'a RelationshipStatus>,
channel: Option<&'a Channel>,
server: Option<&'a Server>,
// member: Option<&'a Member>,
has_mutual_connection: bool,
}
impl<'a> PermissionCalculator<'a> {
pub fn new(perspective: &'a User) -> PermissionCalculator {
PermissionCalculator {
perspective,
user: None,
relationship: None,
channel: None,
server: None,
// member: None,
has_mutual_connection: false,
}
}
pub fn with_user(self, user: &'a User) -> PermissionCalculator {
PermissionCalculator {
user: Some(&user),
..self
}
}
pub fn with_relationship(self, relationship: &'a RelationshipStatus) -> PermissionCalculator {
PermissionCalculator {
relationship: Some(&relationship),
..self
}
}
pub fn with_channel(self, channel: &'a Channel) -> PermissionCalculator {
PermissionCalculator {
channel: Some(&channel),
..self
}
}
pub fn with_server(self, server: &'a Server) -> PermissionCalculator {
PermissionCalculator {
server: Some(&server),
..self
}
}
/* pub fn with_member(self, member: &'a Member) -> PermissionCalculator {
PermissionCalculator {
member: Some(&member),
..self
}
} */
pub fn with_mutual_connection(self) -> PermissionCalculator<'a> {
PermissionCalculator {
has_mutual_connection: true,
..self
}
}
}
use crate::util::result::{Error, Result};
use super::PermissionCalculator;
use super::Ref;
use num_enum::TryFromPrimitive;
use std::ops;
#[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)]
#[repr(u32)]
pub enum ServerPermission {
View = 0b00000000000000000000000000000001, // 1
ManageRoles = 0b00000000000000000000000000000010, // 2
ManageChannels = 0b00000000000000000000000000000100, // 4
ManageServer = 0b00000000000000000000000000001000, // 8
KickMembers = 0b00000000000000000000000000010000, // 16
BanMembers = 0b00000000000000000000000000100000, // 32
// 6 bits of space
ChangeNickname = 0b00000000000000000001000000000000, // 4096
ManageNicknames = 0b00000000000000000010000000000000, // 8192
ChangeAvatar = 0b00000000000000000100000000000000, // 16382
RemoveAvatars = 0b00000000000000001000000000000000, // 32768
// 16 bits of space
}
lazy_static! {
pub static ref DEFAULT_PERMISSION: u32 =
ServerPermission::View
+ ServerPermission::ChangeNickname
+ ServerPermission::ChangeAvatar;
}
impl_op_ex!(+ |a: &ServerPermission, b: &ServerPermission| -> u32 { *a as u32 | *b as u32 });
impl_op_ex_commutative!(+ |a: &u32, b: &ServerPermission| -> u32 { *a | *b as u32 });
bitfield! {
pub struct ServerPermissions(MSB0 [u32]);
u32;
pub get_view, _: 31;
pub get_manage_roles, _: 30;
pub get_manage_channels, _: 29;
pub get_manage_server, _: 28;
pub get_kick_members, _: 27;
pub get_ban_members, _: 26;
pub get_change_nickname, _: 19;
pub get_manage_nicknames, _: 18;
pub get_change_avatar, _: 17;
pub get_remove_avatars, _: 16;
}
impl<'a> PermissionCalculator<'a> {
pub async fn calculate_server(self) -> Result<u32> {
let server = if let Some(server) = self.server {
server
} else {
unreachable!()
};
if self.perspective.id == server.owner {
Ok(u32::MAX)
} else {
match Ref::from_unchecked(self.perspective.id.clone()).fetch_member(&server.id).await {
Ok(member) => {
let mut perm = server.default_permissions.0 as u32;
if let Some(roles) = member.roles {
for role in roles {
if let Some(server_role) = server.roles.get(&role) {
perm |= server_role.permissions.0 as u32;
}
}
}
Ok(perm)
}
Err(error) => {
match &error {
Error::NotFound => Ok(0),
_ => Err(error)
}
}
}
}
}
pub async fn for_server(self) -> Result<ServerPermissions<[u32; 1]>> {
Ok(ServerPermissions([self.calculate_server().await?]))
}
}
use crate::database::*;
use crate::util::result::{Error, Result};
use super::PermissionCalculator;
use mongodb::bson::doc;
use num_enum::TryFromPrimitive;
use std::ops;
#[derive(Debug, PartialEq, Eq, TryFromPrimitive, Copy, Clone)]
#[repr(u32)]
pub enum UserPermission {
Access = 0b00000000000000000000000000000001, // 1
ViewProfile = 0b00000000000000000000000000000010, // 2
SendMessage = 0b00000000000000000000000000000100, // 4
Invite = 0b00000000000000000000000000001000, // 8
}
bitfield! {
pub struct UserPermissions(MSB0 [u32]);
u32;
pub get_access, _: 31;
pub get_view_profile, _: 30;
pub get_send_message, _: 29;
pub get_invite, _: 28;
}
impl_op_ex!(+ |a: &UserPermission, b: &UserPermission| -> u32 { *a as u32 | *b as u32 });
impl_op_ex_commutative!(+ |a: &u32, b: &UserPermission| -> u32 { *a | *b as u32 });
pub fn get_relationship(a: &User, b: &str) -> RelationshipStatus {
if a.id == b {
return RelationshipStatus::User;
}
if let Some(relations) = &a.relations {
if let Some(relationship) = relations.iter().find(|x| x.id == b) {
return relationship.status.clone();
}
}
RelationshipStatus::None
}
impl<'a> PermissionCalculator<'a> {
pub async fn calculate_user(self, target: &str) -> Result<u32> {
if &self.perspective.id == target {
return Ok(u32::MAX);
}
let mut permissions: u32 = 0;
match self
.relationship
.clone()
.map(|v| v.to_owned())
.unwrap_or_else(|| get_relationship(&self.perspective, &target))
{
RelationshipStatus::Friend | RelationshipStatus::User => return Ok(u32::MAX),
RelationshipStatus::Blocked | RelationshipStatus::BlockedOther => {
return Ok(UserPermission::Access as u32)
}
RelationshipStatus::Incoming | RelationshipStatus::Outgoing => {
permissions = UserPermission::Access as u32;
// ! INFO: if we add boolean switch for permission to
// ! message people who have mutual, we need to get
// ! rid of this return statement.
// return Ok(permissions);
}
_ => {}
}
let check_server_overlap = async || {
let server_ids = User::fetch_server_ids(&self.perspective.id).await?;
Ok(
get_collection("server_members")
.find_one(
doc! {
"_id.user": &target,
"_id.server": {
"$in": server_ids
}
},
None
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "server_members",
})?
.is_some()
)
};
if self.has_mutual_connection
|| check_server_overlap().await?
|| get_collection("channels")
.find_one(
doc! {
"channel_type": {
"$in": ["Group", "DirectMessage"]
},
"recipients": {
"$all": [ &self.perspective.id, target ]
}
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find_one",
with: "channels",
})?
.is_some()
{
// ! FIXME: add privacy settings
return Ok(UserPermission::Access + UserPermission::ViewProfile);
}
Ok(permissions)
}
pub async fn for_user(self, target: &str) -> Result<UserPermissions<[u32; 1]>> {
Ok(UserPermissions([self.calculate_user(&target).await?]))
}
pub async fn for_user_given(self) -> Result<UserPermissions<[u32; 1]>> {
let id = &self.user.unwrap().id;
Ok(UserPermissions([self.calculate_user(&id).await?]))
}
}
use serde::{ Deserialize, Serialize };
use bson::UtcDateTime;
#[derive(Serialize, Deserialize, Debug)]
pub struct UserEmailVerification {
pub verified: bool,
pub target: Option<String>,
pub expiry: Option<UtcDateTime>,
pub rate_limit: Option<UtcDateTime>,
pub code: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct UserRelationship {
pub id: String,
pub status: u8,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct User {
#[serde(rename = "_id")]
pub id: String,
pub email: String,
pub username: String,
pub password: String,
pub access_token: Option<String>,
pub email_verification: UserEmailVerification,
pub relations: Option<Vec<UserRelationship>>,
}
use reqwest::blocking::Client;
use std::collections::HashMap;
use std::env;
pub fn send_email(target: String, subject: String, body: String, html: String) -> Result<(), ()> {
let mut map = HashMap::new();
map.insert("target", target.clone());
map.insert("subject", subject);
map.insert("body", body);
map.insert("html", html);
let client = Client::new();
match client.post("http://192.168.0.26:3838/send")
.json(&map)
.send() {
Ok(_) => Ok(()),
Err(_) => Err(())
}
}
fn public_uri() -> String {
env::var("PUBLIC_URI").expect("PUBLIC_URI not in environment variables!")
}
pub fn send_verification_email(email: String, code: String) -> bool {
let url = format!("{}/api/account/verify/{}", public_uri(), code);
match send_email(
email,
"Verify your email!".to_string(),
format!("Verify your email here: {}", url).to_string(),
format!("<a href=\"{}\">Click to verify your email!</a>", url).to_string()
) {
Ok(_) => true,
Err(_) => false,
}
}
pub fn send_welcome_email(email: String, username: String) -> bool {
match send_email(
email,
"Welcome to REVOLT!".to_string(),
format!("Welcome, {}! You can now use REVOLT.", username.clone()).to_string(),
format!("<b>Welcome, {}!</b><br/>You can now use REVOLT.<br/><a href=\"{}\">Go to REVOLT</a>", username.clone(), public_uri()).to_string()
) {
Ok(_) => true,
Err(_) => false,
}
}
use rocket::Outcome;
use rocket::http::{ Status, RawStr };
use rocket::request::{ self, Request, FromRequest, FromParam };
use bson::{ bson, doc, from_bson };
use crate::database;
use database::user::User;
#[derive(Debug)]
pub enum AuthError {
BadCount,
Missing,
Invalid,
}
impl<'a, 'r> FromRequest<'a, 'r> for User {
type Error = AuthError;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
let keys: Vec<_> = request.headers().get("x-auth-token").collect();
match keys.len() {
0 => Outcome::Failure((Status::Forbidden, AuthError::Missing)),
1 => {
let key = keys[0];
let col = database::get_db().collection("users");
let result = col.find_one(doc! { "access_token": key }, None).unwrap();
if let Some(user) = result {
Outcome::Success(from_bson(bson::Bson::Document(user)).expect("Failed to unwrap user."))
} else {
Outcome::Failure((Status::Forbidden, AuthError::Invalid))
}
},
_ => Outcome::Failure((Status::BadRequest, AuthError::BadCount)),
}
}
}
impl<'r> FromParam<'r> for User {
type Error = &'r RawStr;
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
let col = database::get_db().collection("users");
let result = col.find_one(doc! { "_id": param.to_string() }, None).unwrap();
if let Some(user) = result {
Ok(from_bson(bson::Bson::Document(user)).expect("Failed to unwrap user."))
} else {
Err(param)
}
}
}
use rocket::http::{ RawStr };
use rocket::request::{ FromParam };
use bson::{ bson, doc, from_bson };
use crate::database::{ self, user::User };
use database::channel::Channel;
use database::message::Message;
impl<'r> FromParam<'r> for Channel {
type Error = &'r RawStr;
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
let col = database::get_db().collection("channels");
let result = col.find_one(doc! { "_id": param.to_string() }, None).unwrap();
if let Some(channel) = result {
Ok(from_bson(bson::Bson::Document(channel)).expect("Failed to unwrap channel."))
} else {
Err(param)
}
}
}
impl<'r> FromParam<'r> for Message {
type Error = &'r RawStr;
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
let col = database::get_db().collection("messages");
let result = col.find_one(doc! { "_id": param.to_string() }, None).unwrap();
if let Some(message) = result {
Ok(from_bson(bson::Bson::Document(message)).expect("Failed to unwrap message."))
} else {
Err(param)
}
}
}
pub mod auth;
pub mod channel;
#![feature(proc_macro_hygiene, decl_macro)]
#[macro_use] extern crate rocket;
#[macro_use] extern crate rocket_contrib;
#![feature(async_closure)]
#[macro_use]
extern crate rocket;
#[macro_use]
extern crate rocket_contrib;
#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate impl_ops;
#[macro_use]
extern crate bitfield;
extern crate ctrlc;
pub mod database;
pub mod guards;
pub mod notifications;
pub mod routes;
pub mod email;
pub mod util;
pub mod version;
use async_std::task;
use chrono::Duration;
use futures::join;
use log::info;
use rauth::options::{EmailVerification, Options, SMTP};
use rauth::{
auth::Auth,
options::{Template, Templates},
};
use rocket_cors::AllowedOrigins;
use rocket_prometheus::PrometheusMetrics;
use util::variables::{
APP_URL, HCAPTCHA_KEY, INVITE_ONLY, PUBLIC_URL, SMTP_FROM, SMTP_HOST, SMTP_PASSWORD,
SMTP_USERNAME, USE_EMAIL, USE_HCAPTCHA, USE_PROMETHEUS,
};
#[async_std::main]
async fn main() {
dotenv::dotenv().ok();
env_logger::init_from_env(env_logger::Env::default().filter_or("RUST_LOG", "info"));
info!(
"Starting REVOLT server [version {}].",
crate::version::VERSION
);
util::variables::preflight_checks();
database::connect().await;
notifications::hive::init_hive().await;
ctrlc::set_handler(move || {
// Force ungraceful exit to avoid hang.
std::process::exit(0);
})
.expect("Error setting Ctrl-C handler");
let web_task = task::spawn(launch_web());
let hive_task = task::spawn(notifications::hive::listen());
join!(
web_task,
hive_task,
notifications::websocket::launch_server()
);
}
async fn launch_web() {
let cors = rocket_cors::CorsOptions {
allowed_origins: AllowedOrigins::All,
..Default::default()
}
.to_cors()
.expect("Failed to create CORS.");
let mut options = Options::new()
.base_url(format!("{}/auth", *PUBLIC_URL))
.email_verification(if *USE_EMAIL {
EmailVerification::Enabled {
success_redirect_uri: format!("{}/login", *APP_URL),
welcome_redirect_uri: format!("{}/welcome", *APP_URL),
password_reset_url: Some(format!("{}/login/reset", *APP_URL)),
verification_expiry: Duration::days(1),
password_reset_expiry: Duration::hours(1),
templates: Templates {
verify_email: Template {
title: "Verify your Revolt account.",
text: "You're almost there!
If you did not perform this action you can safely ignore this email.
Please verify your account here: {{url}}",
html: None,
},
reset_password: Template {
title: "Reset your Revolt password.",
text: "You requested a password reset, if you did not perform this action you can safely ignore this email.
Reset your password here: {{url}}",
html: None,
},
welcome: None,
},
smtp: SMTP {
from: (*SMTP_FROM).to_string(),
host: (*SMTP_HOST).to_string(),
username: (*SMTP_USERNAME).to_string(),
password: (*SMTP_PASSWORD).to_string(),
},
}
} else {
EmailVerification::Disabled
});
if *INVITE_ONLY {
options = options.invite_only_collection(database::get_collection("invites"))
}
if *USE_HCAPTCHA {
options = options.hcaptcha_secret(HCAPTCHA_KEY.clone());
}
let auth = Auth::new(database::get_collection("accounts"), options);
use dotenv;
let mut rocket = rocket::ignite();
fn main() {
dotenv::dotenv().ok();
database::connect();
if *USE_PROMETHEUS {
info!("Enabled Prometheus metrics!");
let prometheus = PrometheusMetrics::new();
rocket = rocket
.attach(prometheus.clone())
.mount("/metrics", prometheus);
}
routes::mount(rocket::ignite()).launch();
routes::mount(rocket)
.mount("/", rocket_cors::catch_all_options_routes())
.mount("/auth", rauth::routes::routes())
.manage(auth)
.manage(cors.clone())
.attach(cors)
.launch()
.await
.unwrap();
}
use hive_pubsub::PubSub;
use mongodb::bson::doc;
use rauth::auth::Session;
use rocket_contrib::json::JsonValue;
use serde::{Deserialize, Serialize};
use super::hive::{get_hive, subscribe_if_exists};
use crate::{database::*, util::result::Result};
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "error")]
pub enum WebSocketError {
LabelMe,
InternalError { at: String },
InvalidSession,
OnboardingNotFinished,
AlreadyAuthenticated,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
pub enum ServerboundNotification {
Authenticate(Session),
BeginTyping { channel: String },
EndTyping { channel: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveUserField {
ProfileContent,
ProfileBackground,
StatusText,
Avatar,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveChannelField {
Icon,
Description
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveServerField {
Icon,
Banner,
Description,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveRoleField {
Colour,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum RemoveMemberField {
Nickname,
Avatar,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum ClientboundNotification {
Error(WebSocketError),
Authenticated,
Ready {
users: Vec<User>,
servers: Vec<Server>,
channels: Vec<Channel>,
members: Vec<Member>
},
Message(Message),
MessageUpdate {
id: String,
channel: String,
data: JsonValue,
},
MessageDelete {
id: String,
channel: String,
},
ChannelCreate(Channel),
ChannelUpdate {
id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveChannelField>,
},
ChannelDelete {
id: String,
},
ChannelGroupJoin {
id: String,
user: String,
},
ChannelGroupLeave {
id: String,
user: String,
},
ChannelStartTyping {
id: String,
user: String,
},
ChannelStopTyping {
id: String,
user: String,
},
ChannelAck {
id: String,
user: String,
message_id: String,
},
ServerUpdate {
id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveServerField>,
},
ServerDelete {
id: String,
},
ServerMemberUpdate {
id: MemberCompositeKey,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveMemberField>,
},
ServerMemberJoin {
id: String,
user: String,
},
ServerMemberLeave {
id: String,
user: String,
},
ServerRoleUpdate {
id: String,
role_id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveRoleField>
},
ServerRoleDelete {
id: String,
role_id: String
},
UserUpdate {
id: String,
data: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
clear: Option<RemoveUserField>,
},
UserRelationship {
id: String,
user: User,
status: RelationshipStatus,
},
UserSettingsUpdate {
id: String,
update: JsonValue,
},
}
impl ClientboundNotification {
pub fn publish(self, topic: String) {
async_std::task::spawn(async move {
prehandle_hook(&self).await.ok(); // ! FIXME: this should be moved to pubsub
hive_pubsub::backend::mongo::publish(get_hive(), &topic, self)
.await
.ok();
});
}
pub fn publish_as_user(self, user: String) {
self.clone().publish(user.clone());
async_std::task::spawn(async move {
if let Ok(server_ids) = User::fetch_server_ids(&user).await {
for server in server_ids {
self.clone().publish(server.clone());
}
}
});
}
}
pub async fn prehandle_hook(notification: &ClientboundNotification) -> Result<()> {
match &notification {
ClientboundNotification::ChannelGroupJoin { id, user } => {
subscribe_if_exists(user.clone(), id.clone()).ok();
}
ClientboundNotification::ChannelCreate(channel) => {
let channel_id = channel.id();
match &channel {
Channel::SavedMessages { user, .. } => {
subscribe_if_exists(user.clone(), channel_id.to_string()).ok();
}
Channel::DirectMessage { recipients, .. } | Channel::Group { recipients, .. } => {
for recipient in recipients {
subscribe_if_exists(recipient.clone(), channel_id.to_string()).ok();
}
}
Channel::TextChannel { server, .. }
| Channel::VoiceChannel { server, .. } => {
// ! FIXME: write a better algorithm?
let members = Server::fetch_member_ids(server).await?;
for member in members {
subscribe_if_exists(member.clone(), channel_id.to_string()).ok();
}
}
}
}
ClientboundNotification::ServerMemberJoin { id, user } => {
let server = Ref::from_unchecked(id.clone()).fetch_server().await?;
subscribe_if_exists(user.clone(), id.clone()).ok();
for channel in server.channels {
subscribe_if_exists(user.clone(), channel).ok();
}
}
ClientboundNotification::UserRelationship { id, user, status } => {
if status != &RelationshipStatus::None {
subscribe_if_exists(id.clone(), user.id.clone()).ok();
}
}
_ => {}
}
Ok(())
}
pub async fn posthandle_hook(notification: &ClientboundNotification) {
match &notification {
ClientboundNotification::ChannelDelete { id } => {
get_hive().hive.drop_topic(&id).ok();
}
ClientboundNotification::ChannelGroupLeave { id, user } => {
get_hive().hive.unsubscribe(user, id).ok();
}
ClientboundNotification::ServerDelete { id } => {
get_hive().hive.drop_topic(&id).ok();
}
ClientboundNotification::ServerMemberLeave { id, user } => {
get_hive().hive.unsubscribe(user, id).ok();
if let Ok(server) = Ref::from_unchecked(id.clone()).fetch_server().await {
for channel in server.channels {
get_hive().hive.unsubscribe(user, &channel).ok();
}
}
}
ClientboundNotification::UserRelationship { id, user, status } => {
if status == &RelationshipStatus::None {
get_hive().hive.unsubscribe(id, &user.id).ok();
}
}
_ => {}
}
}
use super::{events::ClientboundNotification, websocket};
use crate::database::*;
use futures::FutureExt;
use hive_pubsub::backend::mongo::MongodbPubSub;
use hive_pubsub::PubSub;
use log::{debug, error};
use once_cell::sync::OnceCell;
use serde_json::to_string;
type Hive = MongodbPubSub<String, String, ClientboundNotification>;
static HIVE: OnceCell<Hive> = OnceCell::new();
pub async fn init_hive() {
let hive = MongodbPubSub::new(
|ids, notification: ClientboundNotification| {
let notif = notification.clone();
async_std::task::spawn(async move {
super::events::posthandle_hook(&notif).await;
});
if let Ok(data) = to_string(&notification) {
debug!("Pushing out notification. {}", data);
websocket::publish(ids, notification);
} else {
error!("Failed to serialise notification.");
}
},
get_collection("pubsub"),
);
if HIVE.set(hive).is_err() {
panic!("Failed to set global pubsub instance.");
}
}
pub async fn listen() {
HIVE.get()
.unwrap()
.listen()
.fuse()
.await
.expect("Hive hit an error");
}
pub fn subscribe_multiple(user: String, topics: Vec<String>) -> Result<(), String> {
let hive = HIVE.get().unwrap();
for topic in topics {
hive.subscribe(user.clone(), topic)?;
}
Ok(())
}
pub fn subscribe_if_exists(user: String, topic: String) -> Result<(), String> {
let hive = HIVE.get().unwrap();
if hive.hive.map.lock().unwrap().get_left(&user).is_some() {
hive.subscribe(user, topic)?;
}
Ok(())
}
pub fn get_hive() -> &'static Hive {
HIVE.get().unwrap()
}
pub mod events;
pub mod hive;
pub mod payload;
pub mod subscriptions;
pub mod websocket;
use std::collections::HashSet;
use crate::{database::*, notifications::events::ClientboundNotification};
use crate::{
database::{entities::User, get_collection},
util::result::{Error, Result},
};
use futures::StreamExt;
use mongodb::bson::{doc, from_document};
pub async fn generate_ready(mut user: User) -> Result<ClientboundNotification> {
let mut user_ids: HashSet<String> = HashSet::new();
if let Some(relationships) = &user.relations {
user_ids.extend(
relationships
.iter()
.map(|relationship| relationship.id.clone()),
);
}
let members = User::fetch_memberships(&user.id).await?;
let server_ids: Vec<String> = members.iter()
.map(|x| x.id.server.clone())
.collect();
let mut cursor = get_collection("servers")
.find(
doc! {
"_id": {
"$in": server_ids
}
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "servers",
})?;
let mut servers = vec![];
let mut channel_ids = vec![];
while let Some(result) = cursor.next().await {
if let Ok(doc) = result {
let server: Server = from_document(doc).map_err(|_| Error::DatabaseError {
operation: "from_document",
with: "server",
})?;
channel_ids.extend(server.channels.iter().cloned());
servers.push(server);
}
}
let mut cursor = get_collection("channels")
.find(
doc! {
"$or": [
{
"_id": {
"$in": channel_ids
}
},
{
"channel_type": "SavedMessages",
"user": &user.id
},
{
"channel_type": "DirectMessage",
"recipients": &user.id
},
{
"channel_type": "Group",
"recipients": &user.id
}
]
},
None,
)
.await
.map_err(|_| Error::DatabaseError {
operation: "find",
with: "channels",
})?;
let mut channels = vec![];
while let Some(result) = cursor.next().await {
if let Ok(doc) = result {
let channel = from_document(doc).map_err(|_| Error::DatabaseError {
operation: "from_document",
with: "channel",
})?;
if let Channel::Group { recipients, .. } = &channel {
user_ids.extend(recipients.iter().cloned());
} else if let Channel::DirectMessage { recipients, .. } = &channel {
user_ids.extend(recipients.iter().cloned());
}
channels.push(channel);
}
}
user_ids.remove(&user.id);
let mut users = if user_ids.len() > 0 {
user.fetch_multiple_users(user_ids.into_iter().collect::<Vec<String>>())
.await?
} else {
vec![]
};
user.relationship = Some(RelationshipStatus::User);
user.online = Some(true);
users.push(user);
Ok(ClientboundNotification::Ready {
users,
servers,
channels,
members
})
}
use crate::database::*;
use super::hive::get_hive;
use futures::StreamExt;
use hive_pubsub::PubSub;
use mongodb::bson::doc;
use mongodb::bson::Document;
use mongodb::options::FindOptions;
pub async fn generate_subscriptions(user: &User) -> Result<(), String> {
let hive = get_hive();
hive.subscribe(user.id.clone(), user.id.clone())?;
if let Some(relations) = &user.relations {
for relation in relations {
hive.subscribe(user.id.clone(), relation.id.clone())?;
}
}
let server_ids = User::fetch_server_ids(&user.id)
.await
.map_err(|_| "Failed to fetch memberships.".to_string())?;
let channel_ids = get_collection("servers")
.find(
doc! {
"_id": {
"$in": &server_ids
}
},
None,
)
.await
.map_err(|_| "Failed to fetch servers.".to_string())?
.filter_map(async move |s| s.ok())
.collect::<Vec<Document>>()
.await
.into_iter()
.filter_map(|x| {
x.get_array("channels").ok().map(|v| {
v.into_iter()
.filter_map(|x| x.as_str().map(|x| x.to_string()))
.collect::<Vec<String>>()
})
})
.flatten()
.collect::<Vec<String>>();
for id in server_ids {
hive.subscribe(user.id.clone(), id)?;
}
for id in channel_ids {
hive.subscribe(user.id.clone(), id)?;
}
let mut cursor = get_collection("channels")
.find(
doc! {
"$or": [
{
"channel_type": "SavedMessages",
"user": &user.id
},
{
"channel_type": "DirectMessage",
"recipients": &user.id
},
{
"channel_type": "Group",
"recipients": &user.id
}
]
},
FindOptions::builder().projection(doc! { "_id": 1 }).build(),
)
.await
.map_err(|_| "Failed to fetch channels.".to_string())?;
while let Some(result) = cursor.next().await {
if let Ok(doc) = result {
hive.subscribe(user.id.clone(), doc.get_str("_id").unwrap().to_string())?;
}
}
Ok(())
}
use crate::database::*;
use crate::util::variables::WS_HOST;
use super::subscriptions;
use async_std::net::{TcpListener, TcpStream};
use async_std::task;
use async_tungstenite::tungstenite::Message;
use futures::channel::mpsc::{unbounded, UnboundedSender};
use futures::stream::TryStreamExt;
use futures::{pin_mut, prelude::*};
use hive_pubsub::PubSub;
use log::{debug, info};
use many_to_many::ManyToMany;
use rauth::{
auth::{Auth, Session},
options::Options,
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, RwLock};
use super::{
events::{ClientboundNotification, ServerboundNotification, WebSocketError},
hive::get_hive,
};
type Tx = UnboundedSender<Message>;
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
lazy_static! {
static ref CONNECTIONS: PeerMap = Arc::new(Mutex::new(HashMap::new()));
static ref USERS: Arc<RwLock<ManyToMany<String, SocketAddr>>> =
Arc::new(RwLock::new(ManyToMany::new()));
}
pub async fn launch_server() {
let try_socket = TcpListener::bind(WS_HOST.to_string()).await;
let listener = try_socket.expect("Failed to bind");
info!("Listening on: {}", *WS_HOST);
while let Ok((stream, _)) = listener.accept().await {
task::spawn(accept(stream));
}
}
async fn accept(stream: TcpStream) {
let addr = stream
.peer_addr()
.expect("Connected streams should have a peer address.");
let ws_stream = async_tungstenite::accept_async(stream)
.await
.expect("Error during websocket handshake.");
info!("User established WebSocket connection from {}.", &addr);
let (write, read) = ws_stream.split();
let (tx, rx) = unbounded();
CONNECTIONS.lock().unwrap().insert(addr, tx.clone());
let send = |notification: ClientboundNotification| {
if let Ok(response) = serde_json::to_string(&notification) {
if let Err(_) = tx.unbounded_send(Message::Text(response)) {
debug!("Failed unbounded_send to websocket stream.");
}
}
};
let session: Arc<Mutex<Option<Session>>> = Arc::new(Mutex::new(None));
let mutex_generator = || session.clone();
let fwd = rx.map(Ok).forward(write);
let incoming = read.try_for_each(async move |msg| {
let mutex = mutex_generator();
if let Message::Text(text) = msg {
if let Ok(notification) = serde_json::from_str::<ServerboundNotification>(&text) {
match notification {
ServerboundNotification::Authenticate(new_session) => {
{
if mutex.lock().unwrap().is_some() {
send(ClientboundNotification::Error(
WebSocketError::AlreadyAuthenticated,
));
return Ok(());
}
}
if let Ok(validated_session) =
Auth::new(get_collection("accounts"), Options::new())
.verify_session(new_session)
.await
{
let id = validated_session.user_id.clone();
if let Ok(user) = (Ref { id: id.clone() }).fetch_user().await {
let was_online = is_online(&id);
{
match USERS.write() {
Ok(mut map) => {
map.insert(id.clone(), addr);
}
Err(_) => {
send(ClientboundNotification::Error(
WebSocketError::InternalError {
at: "Writing users map.".to_string(),
},
));
return Ok(());
}
}
}
*mutex.lock().unwrap() = Some(validated_session);
if let Err(_) = subscriptions::generate_subscriptions(&user).await {
send(ClientboundNotification::Error(
WebSocketError::InternalError {
at: "Generating subscriptions.".to_string(),
},
));
return Ok(());
}
send(ClientboundNotification::Authenticated);
match super::payload::generate_ready(user).await {
Ok(payload) => {
send(payload);
if !was_online {
ClientboundNotification::UserUpdate {
id: id.clone(),
data: json!({
"online": true
}),
clear: None
}
.publish_as_user(id);
}
}
Err(_) => {
send(ClientboundNotification::Error(
WebSocketError::InternalError {
at: "Generating payload.".to_string(),
},
));
return Ok(());
}
}
} else {
send(ClientboundNotification::Error(
WebSocketError::OnboardingNotFinished,
));
}
} else {
send(ClientboundNotification::Error(
WebSocketError::InvalidSession,
));
}
}
// ! TEMP: verify user part of channel
// ! Could just run permission check here.
ServerboundNotification::BeginTyping { channel } => {
if mutex.lock().unwrap().is_some() {
let user = {
let mutex = mutex.lock().unwrap();
let session = mutex.as_ref().unwrap();
session.user_id.clone()
};
ClientboundNotification::ChannelStartTyping {
id: channel.clone(),
user,
}
.publish(channel);
} else {
send(ClientboundNotification::Error(
WebSocketError::AlreadyAuthenticated,
));
return Ok(());
}
}
ServerboundNotification::EndTyping { channel } => {
if mutex.lock().unwrap().is_some() {
let user = {
let mutex = mutex.lock().unwrap();
let session = mutex.as_ref().unwrap();
session.user_id.clone()
};
ClientboundNotification::ChannelStopTyping {
id: channel.clone(),
user,
}
.publish(channel);
} else {
send(ClientboundNotification::Error(
WebSocketError::AlreadyAuthenticated,
));
return Ok(());
}
}
}
}
}
Ok(())
});
pin_mut!(fwd, incoming);
future::select(fwd, incoming).await;
info!("User {} disconnected.", &addr);
CONNECTIONS.lock().unwrap().remove(&addr);
let mut offline = None;
{
let session = session.lock().unwrap();
if let Some(session) = session.as_ref() {
let mut users = USERS.write().unwrap();
users.remove(&session.user_id, &addr);
if users.get_left(&session.user_id).is_none() {
get_hive().drop_client(&session.user_id).unwrap();
offline = Some(session.user_id.clone());
}
}
}
if let Some(id) = offline {
ClientboundNotification::UserUpdate {
id: id.clone(),
data: json!({
"online": false
}),
clear: None
}
.publish_as_user(id);
}
}
pub fn publish(ids: Vec<String>, notification: ClientboundNotification) {
let mut targets = vec![];
{
let users = USERS.read().unwrap();
for id in ids {
// Block certain notifications from reaching users that aren't meant to see them.
match &notification {
ClientboundNotification::UserRelationship { id: user_id, .. }
| ClientboundNotification::UserSettingsUpdate { id: user_id, .. }
| ClientboundNotification::ChannelAck { user: user_id, .. } => {
if &id != user_id {
continue;
}
}
_ => {}
}
if let Some(mut arr) = users.get_left(&id) {
targets.append(&mut arr);
}
}
}
let msg = Message::Text(serde_json::to_string(&notification).unwrap());
let connections = CONNECTIONS.lock().unwrap();
for target in targets {
if let Some(conn) = connections.get(&target) {
if let Err(_) = conn.unbounded_send(msg.clone()) {
debug!("Failed unbounded_send.");
}
}
}
}
pub fn is_online(user: &String) -> bool {
USERS.read().unwrap().get_left(&user).is_some()
}
use crate::database;
use crate::email;
use bson::{ bson, doc, Bson::UtcDatetime, from_bson};
use rand::{ Rng, distributions::Alphanumeric };
use rocket_contrib::json::{ Json, JsonValue };
use serde::{ Serialize, Deserialize };
use validator::validate_email;
use bcrypt::{ hash, verify };
use database::user::User;
use chrono::prelude::*;
use ulid::Ulid;
fn gen_token(l: usize) -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(l)
.collect::<String>()
}
#[derive(Serialize, Deserialize)]
pub struct Create {
username: String,
password: String,
email: String,
}
/// create a new Revolt account
/// (1) validate input
/// [username] 2 to 32 characters
/// [password] 8 to 72 characters
/// [email] validate against RFC
/// (2) check email existence
/// (3) add user and send email verification
#[post("/create", data = "<info>")]
pub fn create(info: Json<Create>) -> JsonValue {
let col = database::get_collection("users");
if info.username.len() < 2 || info.username.len() > 32 {
return json!({
"success": false,
"error": "Username requirements not met! Must be between 2 and 32 characters.",
})
}
if info.password.len() < 8 || info.password.len() > 72 {
return json!({
"success": false,
"error": "Password requirements not met! Must be between 8 and 72 characters.",
})
}
if !validate_email(info.email.clone()) {
return json!({
"success": false,
"error": "Invalid email provided!",
})
}
if let Some(_) = col.find_one(doc! { "email": info.email.clone() }, None).expect("Failed user lookup") {
return json!({
"success": false,
"error": "Email already in use!",
})
}
if let Ok(hashed) = hash(info.password.clone(), 10) {
let access_token = gen_token(92);
let code = gen_token(48);
match col.insert_one(doc! {
"_id": Ulid::new().to_string(),
"email": info.email.clone(),
"username": info.username.clone(),
"password": hashed,
"access_token": access_token,
"email_verification": {
"verified": false,
"target": info.email.clone(),
"expiry": UtcDatetime(Utc::now() + chrono::Duration::days(1)),
"rate_limit": UtcDatetime(Utc::now() + chrono::Duration::minutes(1)),
"code": code.clone(),
}
}, None) {
Ok(_) => {
let sent = email::send_verification_email(info.email.clone(), code);
json!({
"success": true,
"email_sent": sent,
})
},
Err(_) => json!({
"success": false,
"error": "Failed to create account!",
})
}
} else {
json!({
"success": false,
"error": "Failed to hash password!",
})
}
}
/// verify an email for a Revolt account
/// (1) check if code is valid
/// (2) check if it expired yet
/// (3) set account as verified
#[get("/verify/<code>")]
pub fn verify_email(code: String) -> JsonValue {
let col = database::get_collection("users");
if let Some(u) =
col.find_one(doc! { "email_verification.code": code.clone() }, None).expect("Failed user lookup") {
let user: User = from_bson(bson::Bson::Document(u)).expect("Failed to unwrap user.");
let ev = user.email_verification;
if Utc::now() > *ev.expiry.unwrap() {
json!({
"success": false,
"error": "Token has expired!",
})
} else {
let target = ev.target.unwrap();
col.update_one(
doc! { "_id": user.id },
doc! {
"$unset": {
"email_verification.code": "",
"email_verification.expiry": "",
"email_verification.target": "",
"email_verification.rate_limit": "",
},
"$set": {
"email_verification.verified": true,
"email": target.clone(),
},
},
None,
).expect("Failed to update user!");
email::send_welcome_email(
target.to_string(),
user.username
);
json!({
"success": true
})
}
} else {
json!({
"success": false,
"error": "Invalid code!",
})
}
}
#[derive(Serialize, Deserialize)]
pub struct Resend {
email: String,
}
/// resend a verification email
/// (1) check if verification is pending for x email
/// (2) check for rate limit
/// (3) resend the email
#[post("/resend", data = "<info>")]
pub fn resend_email(info: Json<Resend>) -> JsonValue {
let col = database::get_collection("users");
if let Some(u) =
col.find_one(doc! { "email_verification.target": info.email.clone() }, None).expect("Failed user lookup") {
let user: User = from_bson(bson::Bson::Document(u)).expect("Failed to unwrap user.");
let ev = user.email_verification;
let expiry = ev.expiry.unwrap();
let rate_limit = ev.rate_limit.unwrap();
if Utc::now() < *rate_limit {
json!({
"success": false,
"error": "Hit rate limit! Please try again in a minute or so.",
})
} else {
let mut new_expiry = UtcDatetime(Utc::now() + chrono::Duration::days(1));
if info.email.clone() != user.email {
if Utc::now() > *expiry {
return json!({
"success": "false",
"error": "For security reasons, please login and change your email again.",
})
}
new_expiry = UtcDatetime(*expiry);
}
let code = gen_token(48);
col.update_one(
doc! { "_id": user.id },
doc! {
"$set": {
"email_verification.code": code.clone(),
"email_verification.expiry": new_expiry,
"email_verification.rate_limit": UtcDatetime(Utc::now() + chrono::Duration::minutes(1)),
},
},
None,
).expect("Failed to update user!");
match email::send_verification_email(
info.email.to_string(),
code,
) {
true => json!({
"success": true,
}),
false => json!({
"success": false,
"error": "Failed to send email! Likely an issue with the backend API.",
})
}
}
} else {
json!({
"success": false,
"error": "Email not pending verification!",
})
}
}
#[derive(Serialize, Deserialize)]
pub struct Login {
email: String,
password: String,
}
/// login to a Revolt account
/// (1) find user by email
/// (2) verify password
/// (3) return access token
#[post("/login", data = "<info>")]
pub fn login(info: Json<Login>) -> JsonValue {
let col = database::get_collection("users");
if let Some(u) =
col.find_one(doc! { "email": info.email.clone() }, None).expect("Failed user lookup") {
let user: User = from_bson(bson::Bson::Document(u)).expect("Failed to unwrap user.");
match verify(info.password.clone(), &user.password)
.expect("Failed to check hash of password.") {
true => {
let token =
match user.access_token {
Some(t) => t.to_string(),
None => {
let token = gen_token(92);
col.update_one(
doc! { "_id": &user.id },
doc! { "$set": { "access_token": token.clone() } },
None
).expect("Failed to update user object");
token
}
};
json!({
"success": true,
"access_token": token
})
},
false => json!({
"success": false,
"error": "Invalid password."
})
}
} else {
json!({
"success": false,
"error": "Email is not registered.",
})
}
}
use crate::database::{ self, user::User, channel::Channel, message::Message };
use bson::{ bson, doc, from_bson, Bson::UtcDatetime };
use rocket_contrib::json::{ JsonValue, Json };
use serde::{ Serialize, Deserialize };
use num_enum::TryFromPrimitive;
use chrono::prelude::*;
use ulid::Ulid;
#[derive(Debug, TryFromPrimitive)]
#[repr(usize)]
pub enum ChannelType {
DM = 0,
GROUP_DM = 1,
GUILD_CHANNEL = 2,
}
fn has_permission(user: &User, target: &Channel) -> bool {
match target.channel_type {
0..=1 => {
if let Some(arr) = &target.recipients {
for item in arr {
if item == &user.id {
return true;
}
}
}
false
},
2 =>
false,
_ =>
false
}
}
/// fetch channel information
#[get("/<target>")]
pub fn channel(user: User, target: Channel) -> Option<JsonValue> {
if !has_permission(&user, &target) {
return None
}
Some(
json!({
"id": target.id,
"type": target.channel_type
}
))
}
/// delete channel
/// or leave group DM
/// or close DM conversation
#[delete("/<target>")]
pub fn delete(user: User, target: Channel) -> Option<JsonValue> {
if !has_permission(&user, &target) {
return None
}
let col = database::get_collection("channels");
Some(match target.channel_type {
0 => {
col.update_one(
doc! { "_id": target.id },
doc! { "$set": { "active": false } },
None
).expect("Failed to update channel.");
json!({
"success": true
})
},
1 => {
// ? TODO: group dm
json!({
"success": true
})
},
2 => {
// ? TODO: guild
json!({
"success": true
})
},
_ =>
json!({
"success": false
})
})
}
/// fetch channel messages
#[get("/<target>/messages")]
pub fn messages(user: User, target: Channel) -> Option<JsonValue> {
if !has_permission(&user, &target) {
return None
}
let col = database::get_collection("messages");
let result = col.find(
doc! { "channel": target.id },
None
).unwrap();
let mut messages = Vec::new();
for item in result {
let message: Message = from_bson(bson::Bson::Document(item.unwrap())).expect("Failed to unwrap message.");
messages.push(
json!({
"id": message.id,
"author": message.author,
"content": message.content,
"edited": if let Some(t) = message.edited { Some(t.timestamp()) } else { None }
})
);
}
Some(json!(messages))
}
#[derive(Serialize, Deserialize)]
pub struct SendMessage {
content: String,
}
/// send a message to a channel
#[post("/<target>/messages", data = "<message>")]
pub fn send_message(user: User, target: Channel, message: Json<SendMessage>) -> Option<JsonValue> {
if !has_permission(&user, &target) {
return None
}
let col = database::get_collection("messages");
let id = Ulid::new().to_string();
Some(match col.insert_one(
doc! {
"_id": id.clone(),
"channel": target.id,
"author": user.id,
"content": message.content.clone(),
},
None
) {
Ok(_) =>
json!({
"success": true,
"id": id
}),
Err(_) =>
json!({
"success": false,
"error": "Failed database query."
})
})
}
#[derive(Serialize, Deserialize)]
pub struct EditMessage {
content: String,
}
/// edit a message
#[patch("/<target>/messages/<message>", data = "<edit>")]
pub fn edit_message(user: User, target: Channel, message: Message, edit: Json<SendMessage>) -> Option<JsonValue> {
if !has_permission(&user, &target) {
return None
}
Some(
if message.author != user.id {
json!({
"success": false,
"error": "You did not send this message."
})
} else {
let col = database::get_collection("messages");
match col.update_one(
doc! { "_id": message.id },
doc! {
"$set": {
"content": edit.content.clone(),
"edited": UtcDatetime(Utc::now())
}
},
None
) {
Ok(_) =>
json!({
"success": true
}),
Err(_) =>
json!({
"success": false,
"error": "Failed to update message."
})
}
}
)
}
/// delete a message
#[delete("/<target>/messages/<message>")]
pub fn delete_message(user: User, target: Channel, message: Message) -> Option<JsonValue> {
if !has_permission(&user, &target) {
return None
}
Some(
if message.author != user.id {
json!({
"success": false,
"error": "You did not send this message."
})
} else {
let col = database::get_collection("messages");
match col.delete_one(
doc! { "_id": message.id },
None
) {
Ok(_) =>
json!({
"success": true
}),
Err(_) =>
json!({
"success": false,
"error": "Failed to delete message."
})
}
}
)
}