From 798047625ad4ab834c645b47911db04c0a3a8994 Mon Sep 17 00:00:00 2001
From: Paul Makles <paulmakles@gmail.com>
Date: Mon, 28 Dec 2020 21:47:32 +0000
Subject: [PATCH] Add onboarding and FromRequest for User.

---
 src/database/entities/mod.rs       | 13 ++++--
 src/database/entities/user.rs      | 64 +++++++++++++++++++++++++++-
 src/database/guards/mod.rs         |  5 ++-
 src/database/guards/reference.rs   |  0
 src/database/guards/user.rs        |  0
 src/database/migrations/init.rs    | 68 ++++++++++++++++++++----------
 src/database/migrations/scripts.rs | 30 ++++++++++++-
 src/database/mod.rs                |  1 +
 src/routes/mod.rs                  | 52 -----------------------
 src/routes/onboard/complete.rs     | 52 +++++++++++++++++++++++
 src/routes/onboard/hello.rs        |  9 ++--
 src/routes/onboard/mod.rs          |  4 +-
 src/util/result.rs                 | 26 +++++++++++-
 13 files changed, 236 insertions(+), 88 deletions(-)
 create mode 100644 src/database/guards/reference.rs
 create mode 100644 src/database/guards/user.rs
 create mode 100644 src/routes/onboard/complete.rs

diff --git a/src/database/entities/mod.rs b/src/database/entities/mod.rs
index 4ad9863..02c0f5e 100644
--- a/src/database/entities/mod.rs
+++ b/src/database/entities/mod.rs
@@ -1,4 +1,9 @@
-pub mod channel;
-pub mod message;
-pub mod guild;
-pub mod user;
\ No newline at end of file
+mod channel;
+mod message;
+mod guild;
+mod user;
+
+pub use channel::*;
+pub use message::*;
+pub use guild::*;
+pub use user::*;
diff --git a/src/database/entities/user.rs b/src/database/entities/user.rs
index 2e44ed6..d3cdf87 100644
--- a/src/database/entities/user.rs
+++ b/src/database/entities/user.rs
@@ -1,4 +1,9 @@
+use mongodb::bson::{doc, from_bson, Bson};
+use rauth::auth::Session;
+use rocket::http::Status;
 use serde::{Deserialize, Serialize};
+use crate::database::get_collection;
+use rocket::request::{self, FromRequest, Outcome, Request};
 
 #[derive(Serialize, Deserialize, Debug, Clone)]
 pub struct Relationship {
@@ -10,6 +15,63 @@ pub struct Relationship {
 pub struct User {
     #[serde(rename = "_id")]
     pub id: String,
-    pub username: Option<String>,
+    pub username: String,
     pub relations: Option<Vec<Relationship>>,
 }
+
+#[rocket::async_trait]
+impl<'a, 'r> FromRequest<'a, 'r> for User {
+    type Error = rauth::util::Error;
+
+    async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
+        let session: Session = try_outcome!(request.guard::<Session>().await);
+
+        if let Ok(result) = get_collection("users")
+            .find_one(
+                doc! {
+                    "_id": &session.user_id
+                }, None
+            )
+            .await {
+            if let Some(doc) = result {
+                Outcome::Success(
+                    from_bson(Bson::Document(doc)).unwrap()
+                )
+            } else {
+                Outcome::Failure((Status::Forbidden, rauth::util::Error::InvalidSession))
+            }
+        } else {
+            Outcome::Failure((Status::InternalServerError, rauth::util::Error::DatabaseError))
+        }
+
+        /*Outcome::Success(
+            User {
+                id: "gaming".to_string(),
+                username: None,
+                relations: None
+            }
+        )*/
+
+        /*match (
+            request.managed_state::<Auth>(),
+            header_user_id,
+            header_session_token,
+        ) {
+            (Some(auth), Some(user_id), Some(session_token)) => {
+                let session = Session {
+                    id: None,
+                    user_id,
+                    session_token,
+                };
+
+                if let Ok(session) = auth.verify_session(session).await {
+                    Outcome::Success(session)
+                } else {
+                    Outcome::Failure((Status::Forbidden, Error::InvalidSession))
+                }
+            }
+            (None, _, _) => Outcome::Failure((Status::InternalServerError, Error::InternalError)),
+            (_, _, _) => Outcome::Failure((Status::Forbidden, Error::MissingHeaders)),
+        }*/
+    }
+}
diff --git a/src/database/guards/mod.rs b/src/database/guards/mod.rs
index e308776..728ede5 100644
--- a/src/database/guards/mod.rs
+++ b/src/database/guards/mod.rs
@@ -1,4 +1,7 @@
-/**
+pub mod user;
+pub mod reference;
+
+/*
 // ! FIXME
 impl<'r> FromParam<'r> for User {
     type Error = &'r RawStr;
diff --git a/src/database/guards/reference.rs b/src/database/guards/reference.rs
new file mode 100644
index 0000000..e69de29
diff --git a/src/database/guards/user.rs b/src/database/guards/user.rs
new file mode 100644
index 0000000..e69de29
diff --git a/src/database/migrations/init.rs b/src/database/migrations/init.rs
index 8d561a6..4edb04f 100644
--- a/src/database/migrations/init.rs
+++ b/src/database/migrations/init.rs
@@ -10,28 +10,28 @@ pub async fn create_database() {
     let db = get_db();
 
     db.create_collection("users", None)
-        .await
-        .expect("Failed to create users collection.");
+    .await
+    .expect("Failed to create users collection.");
     
     db.create_collection("channels", None)
-        .await
-        .expect("Failed to create channels collection.");
+    .await
+    .expect("Failed to create channels collection.");
     
     db.create_collection("guilds", None)
-        .await
-        .expect("Failed to create guilds collection.");
+    .await
+    .expect("Failed to create guilds collection.");
     
     db.create_collection("members", None)
-        .await
-        .expect("Failed to create members collection.");
+    .await
+    .expect("Failed to create members collection.");
     
     db.create_collection("messages", None)
-        .await
-        .expect("Failed to create messages collection.");
+    .await
+    .expect("Failed to create messages collection.");
     
     db.create_collection("migrations", None)
-        .await
-        .expect("Failed to create migrations collection.");
+    .await
+    .expect("Failed to create migrations collection.");
 
     db.create_collection(
         "pubsub",
@@ -40,19 +40,41 @@ pub async fn create_database() {
             .size(1_000_000)
             .build(),
     )
-        .await
-        .expect("Failed to create pubsub collection.");
+    .await
+    .expect("Failed to create pubsub collection.");
+
+    db.run_command(
+        doc! {
+            "createIndexes": "users",
+            "indexes": [
+                {
+                    "key": {
+                        "username": 1
+                    },
+                    "name": "username",
+                    "unique": true,
+                    "collation": {
+                        "locale": "en",
+                        "strength": 2
+                    }
+                }
+            ]
+        },
+        None
+    )
+    .await
+    .expect("Failed to create username index.");
 
     db.collection("migrations")
-        .insert_one(
-            doc! {
-                "_id": 0,
-                "revision": LATEST_REVISION
-            },
-            None,
-        )
-        .await
-        .expect("Failed to save migration info.");
+    .insert_one(
+        doc! {
+            "_id": 0,
+            "revision": LATEST_REVISION
+        },
+        None,
+    )
+    .await
+    .expect("Failed to save migration info.");
 
     info!("Created database.");
 }
diff --git a/src/database/migrations/scripts.rs b/src/database/migrations/scripts.rs
index a90b83f..5d469ff 100644
--- a/src/database/migrations/scripts.rs
+++ b/src/database/migrations/scripts.rs
@@ -1,4 +1,4 @@
-use super::super::get_collection;
+use super::super::{get_db, get_collection};
 
 use log::info;
 use mongodb::options::FindOptions;
@@ -12,7 +12,7 @@ struct MigrationInfo {
     revision: i32,
 }
 
-pub const LATEST_REVISION: i32 = 2;
+pub const LATEST_REVISION: i32 = 3;
 
 pub async fn migrate_database() {
     let migrations = get_collection("migrations");
@@ -121,6 +121,32 @@ pub async fn run_migrations(revision: i32) -> i32 {
         }
     }
 
+    if revision <= 2 {
+        info!("Running migration [revision 2]: Add username index to users.");
+
+        get_db().run_command(
+            doc! {
+                "createIndexes": "users",
+                "indexes": [
+                    {
+                        "key": {
+                            "username": 1
+                        },
+                        "name": "username",
+                        "unique": true,
+                        "collation": {
+                            "locale": "en",
+                            "strength": 2
+                        }
+                    }
+                ]
+            },
+            None
+        )
+        .await
+        .expect("Failed to create username index.");
+    }
+
     // Reminder to update LATEST_REVISION when adding new migrations.
     LATEST_REVISION
 }
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 5f5810a..b287c2c 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -28,3 +28,4 @@ pub fn get_collection(collection: &str) -> Collection {
 
 pub mod migrations;
 pub mod entities;
+pub mod guards;
diff --git a/src/routes/mod.rs b/src/routes/mod.rs
index c5bc606..35f354f 100644
--- a/src/routes/mod.rs
+++ b/src/routes/mod.rs
@@ -15,56 +15,4 @@ pub fn mount(rocket: Rocket) -> Rocket {
         .mount("/users", users::routes())
         .mount("/channels", channels::routes())
         .mount("/guild", guild::routes())
-
-        /*.mount(
-            "/users",
-            routes![
-                user::me,
-                user::user,
-                user::query,
-                user::dms,
-                user::dm,
-                user::get_friends,
-                user::get_friend,
-                user::add_friend,
-                user::remove_friend,
-                user::block_user,
-                user::unblock_user,
-            ],
-        )
-        .mount(
-            "/channels",
-            routes![
-                channel::create_group,
-                channel::channel,
-                channel::add_member,
-                channel::remove_member,
-                channel::delete,
-                channel::messages,
-                channel::get_message,
-                channel::send_message,
-                channel::edit_message,
-                channel::delete_message,
-            ],
-        )
-        .mount(
-            "/guild",
-            routes![
-                guild::my_guilds,
-                guild::guild,
-                guild::remove_guild,
-                guild::create_channel,
-                guild::create_invite,
-                guild::remove_invite,
-                guild::fetch_invites,
-                guild::fetch_invite,
-                guild::use_invite,
-                guild::create_guild,
-                guild::fetch_members,
-                guild::fetch_member,
-                guild::kick_member,
-                guild::ban_member,
-                guild::unban_member,
-            ],
-        )*/
 }
diff --git a/src/routes/onboard/complete.rs b/src/routes/onboard/complete.rs
new file mode 100644
index 0000000..93a2daf
--- /dev/null
+++ b/src/routes/onboard/complete.rs
@@ -0,0 +1,52 @@
+use mongodb::options::{Collation, FindOneOptions};
+use crate::util::result::{Error, Result};
+use serde::{Deserialize, Serialize};
+use crate::database::entities::User;
+use crate::database::get_collection;
+use rocket_contrib::json::Json;
+use rauth::auth::Session;
+use validator::Validate;
+use mongodb::bson::doc;
+
+#[derive(Validate, Serialize, Deserialize)]
+pub struct Data {
+    #[validate(length(min = 2, max = 32))]
+    username: String
+}
+
+#[post("/complete", data = "<data>")]
+pub async fn req(session: Session, user: Option<User>, data: Json<Data>) -> Result<()> {
+    if user.is_some() {
+        Err(Error::AlreadyOnboarded)?
+    }
+    
+    data.validate()
+        .map_err(|error| Error::FailedValidation { error })?;
+
+    let col = get_collection("users");
+    if col.find_one(
+        doc! {
+            "username": &data.username
+        },
+        FindOneOptions::builder()
+            .collation(Collation::builder().locale("en").strength(2).build())
+            .build()
+    )
+    .await
+    .map_err(|_| Error::DatabaseError { operation: "find_one", with: "user" })?
+    .is_some() {
+        Err(Error::UsernameTaken)?
+    }
+
+    col.insert_one(
+        doc! {
+            "_id": session.user_id,
+            "username": &data.username
+        },
+        None
+    )
+    .await
+    .map_err(|_| Error::DatabaseError { operation: "insert_one", with: "user" })?;
+
+    Ok(())
+}
diff --git a/src/routes/onboard/hello.rs b/src/routes/onboard/hello.rs
index ecf5847..8f8f77d 100644
--- a/src/routes/onboard/hello.rs
+++ b/src/routes/onboard/hello.rs
@@ -1,7 +1,10 @@
-use crate::util::result::Result;
+use rocket_contrib::json::JsonValue;
+use crate::database::entities::User;
 use rauth::auth::Session;
 
 #[get("/hello")]
-pub async fn req(session: Session) -> Result<String> {
-    Ok("try onboard user".to_string())
+pub async fn req(_session: Session, user: Option<User>) -> JsonValue {
+    json!({
+        "onboarding": user.is_none()
+    })
 }
diff --git a/src/routes/onboard/mod.rs b/src/routes/onboard/mod.rs
index 74216ab..c7e9971 100644
--- a/src/routes/onboard/mod.rs
+++ b/src/routes/onboard/mod.rs
@@ -1,9 +1,11 @@
 use rocket::Route;
 
 mod hello;
+mod complete;
 
 pub fn routes() -> Vec<Route> {
     routes! [
-        hello::req
+        hello::req,
+        complete::req
     ]
 }
diff --git a/src/util/result.rs b/src/util/result.rs
index e467f8a..4866756 100644
--- a/src/util/result.rs
+++ b/src/util/result.rs
@@ -1,5 +1,6 @@
 use rocket::response::{self, Responder, Response};
 use rocket::http::{ContentType, Status};
+use validator::ValidationErrors;
 use rocket::request::Request;
 use serde::Serialize;
 use std::io::Cursor;
@@ -12,6 +13,25 @@ pub enum Error {
     #[snafu(display("This error has not been labelled."))]
     #[serde(rename = "unlabelled_error")]
     LabelMe,
+
+    // ? Onboarding related errors.
+    #[snafu(display("Already finished onboarding."))]
+    #[serde(rename = "already_onboarded")]
+    AlreadyOnboarded,
+    
+    // ? User related errors.
+    #[snafu(display("Username has already been taken."))]
+    #[serde(rename = "username_taken")]
+    UsernameTaken,
+
+    // ? General errors.
+    #[snafu(display("Failed to validate fields."))]
+    #[serde(rename = "failed_validation")]
+    FailedValidation { error: ValidationErrors },
+    #[snafu(display("Encountered a database error."))]
+    #[serde(rename = "database_error")]
+    DatabaseError { operation: &'static str, with: &'static str },
+
     /* #[snafu(display("Failed to validate fields."))]
     #[serde(rename = "failed_validation")]
     FailedValidation { error: ValidationErrors },
@@ -50,7 +70,11 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
 impl<'r> Responder<'r, 'static> for Error {
     fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
         let status = match self {
-            Error::LabelMe => Status::InternalServerError
+            Error::AlreadyOnboarded => Status::Forbidden,
+            Error::DatabaseError { .. } => Status::InternalServerError,
+            Error::FailedValidation { .. } => Status::UnprocessableEntity,
+            Error::LabelMe => Status::InternalServerError,
+            Error::UsernameTaken => Status::Conflict,
         };
 
         // Serialize the error data structure into JSON.
-- 
GitLab