From 798047625ad4ab834c645b47911db04c0a3a8994 Mon Sep 17 00:00:00 2001 From: Paul Makles <paulmakles@gmail.com> Date: Mon, 28 Dec 2020 21:47:32 +0000 Subject: [PATCH] Add onboarding and FromRequest for User. --- src/database/entities/mod.rs | 13 ++++-- src/database/entities/user.rs | 64 +++++++++++++++++++++++++++- src/database/guards/mod.rs | 5 ++- src/database/guards/reference.rs | 0 src/database/guards/user.rs | 0 src/database/migrations/init.rs | 68 ++++++++++++++++++++---------- src/database/migrations/scripts.rs | 30 ++++++++++++- src/database/mod.rs | 1 + src/routes/mod.rs | 52 ----------------------- src/routes/onboard/complete.rs | 52 +++++++++++++++++++++++ src/routes/onboard/hello.rs | 9 ++-- src/routes/onboard/mod.rs | 4 +- src/util/result.rs | 26 +++++++++++- 13 files changed, 236 insertions(+), 88 deletions(-) create mode 100644 src/database/guards/reference.rs create mode 100644 src/database/guards/user.rs create mode 100644 src/routes/onboard/complete.rs diff --git a/src/database/entities/mod.rs b/src/database/entities/mod.rs index 4ad9863..02c0f5e 100644 --- a/src/database/entities/mod.rs +++ b/src/database/entities/mod.rs @@ -1,4 +1,9 @@ -pub mod channel; -pub mod message; -pub mod guild; -pub mod user; \ No newline at end of file +mod channel; +mod message; +mod guild; +mod user; + +pub use channel::*; +pub use message::*; +pub use guild::*; +pub use user::*; diff --git a/src/database/entities/user.rs b/src/database/entities/user.rs index 2e44ed6..d3cdf87 100644 --- a/src/database/entities/user.rs +++ b/src/database/entities/user.rs @@ -1,4 +1,9 @@ +use mongodb::bson::{doc, from_bson, Bson}; +use rauth::auth::Session; +use rocket::http::Status; use serde::{Deserialize, Serialize}; +use crate::database::get_collection; +use rocket::request::{self, FromRequest, Outcome, Request}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Relationship { @@ -10,6 +15,63 @@ pub struct Relationship { pub struct User { #[serde(rename = "_id")] pub id: String, - pub username: Option<String>, + pub username: String, pub relations: Option<Vec<Relationship>>, } + +#[rocket::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for User { + type Error = rauth::util::Error; + + async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> { + let session: Session = try_outcome!(request.guard::<Session>().await); + + if let Ok(result) = get_collection("users") + .find_one( + doc! { + "_id": &session.user_id + }, None + ) + .await { + if let Some(doc) = result { + Outcome::Success( + from_bson(Bson::Document(doc)).unwrap() + ) + } else { + Outcome::Failure((Status::Forbidden, rauth::util::Error::InvalidSession)) + } + } else { + Outcome::Failure((Status::InternalServerError, rauth::util::Error::DatabaseError)) + } + + /*Outcome::Success( + User { + id: "gaming".to_string(), + username: None, + relations: None + } + )*/ + + /*match ( + request.managed_state::<Auth>(), + header_user_id, + header_session_token, + ) { + (Some(auth), Some(user_id), Some(session_token)) => { + let session = Session { + id: None, + user_id, + session_token, + }; + + if let Ok(session) = auth.verify_session(session).await { + Outcome::Success(session) + } else { + Outcome::Failure((Status::Forbidden, Error::InvalidSession)) + } + } + (None, _, _) => Outcome::Failure((Status::InternalServerError, Error::InternalError)), + (_, _, _) => Outcome::Failure((Status::Forbidden, Error::MissingHeaders)), + }*/ + } +} diff --git a/src/database/guards/mod.rs b/src/database/guards/mod.rs index e308776..728ede5 100644 --- a/src/database/guards/mod.rs +++ b/src/database/guards/mod.rs @@ -1,4 +1,7 @@ -/** +pub mod user; +pub mod reference; + +/* // ! FIXME impl<'r> FromParam<'r> for User { type Error = &'r RawStr; diff --git a/src/database/guards/reference.rs b/src/database/guards/reference.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/database/guards/user.rs b/src/database/guards/user.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/database/migrations/init.rs b/src/database/migrations/init.rs index 8d561a6..4edb04f 100644 --- a/src/database/migrations/init.rs +++ b/src/database/migrations/init.rs @@ -10,28 +10,28 @@ pub async fn create_database() { let db = get_db(); db.create_collection("users", None) - .await - .expect("Failed to create users collection."); + .await + .expect("Failed to create users collection."); db.create_collection("channels", None) - .await - .expect("Failed to create channels collection."); + .await + .expect("Failed to create channels collection."); db.create_collection("guilds", None) - .await - .expect("Failed to create guilds collection."); + .await + .expect("Failed to create guilds collection."); db.create_collection("members", None) - .await - .expect("Failed to create members collection."); + .await + .expect("Failed to create members collection."); db.create_collection("messages", None) - .await - .expect("Failed to create messages collection."); + .await + .expect("Failed to create messages collection."); db.create_collection("migrations", None) - .await - .expect("Failed to create migrations collection."); + .await + .expect("Failed to create migrations collection."); db.create_collection( "pubsub", @@ -40,19 +40,41 @@ pub async fn create_database() { .size(1_000_000) .build(), ) - .await - .expect("Failed to create pubsub collection."); + .await + .expect("Failed to create pubsub collection."); + + db.run_command( + doc! { + "createIndexes": "users", + "indexes": [ + { + "key": { + "username": 1 + }, + "name": "username", + "unique": true, + "collation": { + "locale": "en", + "strength": 2 + } + } + ] + }, + None + ) + .await + .expect("Failed to create username index."); db.collection("migrations") - .insert_one( - doc! { - "_id": 0, - "revision": LATEST_REVISION - }, - None, - ) - .await - .expect("Failed to save migration info."); + .insert_one( + doc! { + "_id": 0, + "revision": LATEST_REVISION + }, + None, + ) + .await + .expect("Failed to save migration info."); info!("Created database."); } diff --git a/src/database/migrations/scripts.rs b/src/database/migrations/scripts.rs index a90b83f..5d469ff 100644 --- a/src/database/migrations/scripts.rs +++ b/src/database/migrations/scripts.rs @@ -1,4 +1,4 @@ -use super::super::get_collection; +use super::super::{get_db, get_collection}; use log::info; use mongodb::options::FindOptions; @@ -12,7 +12,7 @@ struct MigrationInfo { revision: i32, } -pub const LATEST_REVISION: i32 = 2; +pub const LATEST_REVISION: i32 = 3; pub async fn migrate_database() { let migrations = get_collection("migrations"); @@ -121,6 +121,32 @@ pub async fn run_migrations(revision: i32) -> i32 { } } + if revision <= 2 { + info!("Running migration [revision 2]: Add username index to users."); + + get_db().run_command( + doc! { + "createIndexes": "users", + "indexes": [ + { + "key": { + "username": 1 + }, + "name": "username", + "unique": true, + "collation": { + "locale": "en", + "strength": 2 + } + } + ] + }, + None + ) + .await + .expect("Failed to create username index."); + } + // Reminder to update LATEST_REVISION when adding new migrations. LATEST_REVISION } diff --git a/src/database/mod.rs b/src/database/mod.rs index 5f5810a..b287c2c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -28,3 +28,4 @@ pub fn get_collection(collection: &str) -> Collection { pub mod migrations; pub mod entities; +pub mod guards; diff --git a/src/routes/mod.rs b/src/routes/mod.rs index c5bc606..35f354f 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -15,56 +15,4 @@ pub fn mount(rocket: Rocket) -> Rocket { .mount("/users", users::routes()) .mount("/channels", channels::routes()) .mount("/guild", guild::routes()) - - /*.mount( - "/users", - routes![ - user::me, - user::user, - user::query, - user::dms, - user::dm, - user::get_friends, - user::get_friend, - user::add_friend, - user::remove_friend, - user::block_user, - user::unblock_user, - ], - ) - .mount( - "/channels", - routes![ - channel::create_group, - channel::channel, - channel::add_member, - channel::remove_member, - channel::delete, - channel::messages, - channel::get_message, - channel::send_message, - channel::edit_message, - channel::delete_message, - ], - ) - .mount( - "/guild", - routes![ - guild::my_guilds, - guild::guild, - guild::remove_guild, - guild::create_channel, - guild::create_invite, - guild::remove_invite, - guild::fetch_invites, - guild::fetch_invite, - guild::use_invite, - guild::create_guild, - guild::fetch_members, - guild::fetch_member, - guild::kick_member, - guild::ban_member, - guild::unban_member, - ], - )*/ } diff --git a/src/routes/onboard/complete.rs b/src/routes/onboard/complete.rs new file mode 100644 index 0000000..93a2daf --- /dev/null +++ b/src/routes/onboard/complete.rs @@ -0,0 +1,52 @@ +use mongodb::options::{Collation, FindOneOptions}; +use crate::util::result::{Error, Result}; +use serde::{Deserialize, Serialize}; +use crate::database::entities::User; +use crate::database::get_collection; +use rocket_contrib::json::Json; +use rauth::auth::Session; +use validator::Validate; +use mongodb::bson::doc; + +#[derive(Validate, Serialize, Deserialize)] +pub struct Data { + #[validate(length(min = 2, max = 32))] + username: String +} + +#[post("/complete", data = "<data>")] +pub async fn req(session: Session, user: Option<User>, data: Json<Data>) -> Result<()> { + if user.is_some() { + Err(Error::AlreadyOnboarded)? + } + + data.validate() + .map_err(|error| Error::FailedValidation { error })?; + + let col = get_collection("users"); + if col.find_one( + doc! { + "username": &data.username + }, + FindOneOptions::builder() + .collation(Collation::builder().locale("en").strength(2).build()) + .build() + ) + .await + .map_err(|_| Error::DatabaseError { operation: "find_one", with: "user" })? + .is_some() { + Err(Error::UsernameTaken)? + } + + col.insert_one( + doc! { + "_id": session.user_id, + "username": &data.username + }, + None + ) + .await + .map_err(|_| Error::DatabaseError { operation: "insert_one", with: "user" })?; + + Ok(()) +} diff --git a/src/routes/onboard/hello.rs b/src/routes/onboard/hello.rs index ecf5847..8f8f77d 100644 --- a/src/routes/onboard/hello.rs +++ b/src/routes/onboard/hello.rs @@ -1,7 +1,10 @@ -use crate::util::result::Result; +use rocket_contrib::json::JsonValue; +use crate::database::entities::User; use rauth::auth::Session; #[get("/hello")] -pub async fn req(session: Session) -> Result<String> { - Ok("try onboard user".to_string()) +pub async fn req(_session: Session, user: Option<User>) -> JsonValue { + json!({ + "onboarding": user.is_none() + }) } diff --git a/src/routes/onboard/mod.rs b/src/routes/onboard/mod.rs index 74216ab..c7e9971 100644 --- a/src/routes/onboard/mod.rs +++ b/src/routes/onboard/mod.rs @@ -1,9 +1,11 @@ use rocket::Route; mod hello; +mod complete; pub fn routes() -> Vec<Route> { routes! [ - hello::req + hello::req, + complete::req ] } diff --git a/src/util/result.rs b/src/util/result.rs index e467f8a..4866756 100644 --- a/src/util/result.rs +++ b/src/util/result.rs @@ -1,5 +1,6 @@ use rocket::response::{self, Responder, Response}; use rocket::http::{ContentType, Status}; +use validator::ValidationErrors; use rocket::request::Request; use serde::Serialize; use std::io::Cursor; @@ -12,6 +13,25 @@ pub enum Error { #[snafu(display("This error has not been labelled."))] #[serde(rename = "unlabelled_error")] LabelMe, + + // ? Onboarding related errors. + #[snafu(display("Already finished onboarding."))] + #[serde(rename = "already_onboarded")] + AlreadyOnboarded, + + // ? User related errors. + #[snafu(display("Username has already been taken."))] + #[serde(rename = "username_taken")] + UsernameTaken, + + // ? General errors. + #[snafu(display("Failed to validate fields."))] + #[serde(rename = "failed_validation")] + FailedValidation { error: ValidationErrors }, + #[snafu(display("Encountered a database error."))] + #[serde(rename = "database_error")] + DatabaseError { operation: &'static str, with: &'static str }, + /* #[snafu(display("Failed to validate fields."))] #[serde(rename = "failed_validation")] FailedValidation { error: ValidationErrors }, @@ -50,7 +70,11 @@ pub type Result<T, E = Error> = std::result::Result<T, E>; impl<'r> Responder<'r, 'static> for Error { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { let status = match self { - Error::LabelMe => Status::InternalServerError + Error::AlreadyOnboarded => Status::Forbidden, + Error::DatabaseError { .. } => Status::InternalServerError, + Error::FailedValidation { .. } => Status::UnprocessableEntity, + Error::LabelMe => Status::InternalServerError, + Error::UsernameTaken => Status::Conflict, }; // Serialize the error data structure into JSON. -- GitLab