diff --git a/Cargo.lock b/Cargo.lock index 4324f79279935eaf114b585c86cce7b2416b0662..d0b7d221abf39fd5215ca79862a5807d74b99d8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,6 +243,16 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4521f3e3d031370679b3b140beb36dfe4801b09ac77e30c61941f97df3ef28b" +[[package]] +name = "base64" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5032d51da2741729bfdaeb2664d9b8c6d9fd1e2b90715c660b6def36628499c2" +dependencies = [ + "byteorder", + "safemem", +] + [[package]] name = "base64" version = "0.11.0" @@ -686,6 +696,15 @@ dependencies = [ "termcolor", ] +[[package]] +name = "erased-serde" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0465971a8cc1fa2455c8465aaa377131e1f1cf4983280f474a13e68793aa770c" +dependencies = [ + "serde", +] + [[package]] name = "err-derive" version = "0.2.4" @@ -2454,10 +2473,11 @@ dependencies = [ [[package]] name = "revolt" -version = "0.3.3-alpha.5" +version = "0.3.3-alpha.6" dependencies = [ "async-std", "async-tungstenite", + "base64 0.13.0", "bitfield", "chrono", "ctrlc", @@ -2489,6 +2509,7 @@ dependencies = [ "ulid", "urlencoding", "validator", + "web-push", ] [[package]] @@ -2501,7 +2522,7 @@ dependencies = [ "libc", "once_cell", "spin", - "untrusted", + "untrusted 0.7.1", "web-sys", "winapi 0.3.9", ] @@ -2673,6 +2694,12 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" +[[package]] +name = "safemem" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e27a8b19b835f7aea908818e871f5cc3a5a186550c30773be987e155e8163d8f" + [[package]] name = "schannel" version = "0.1.19" @@ -2705,7 +2732,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3042af939fca8c3453b7af0f1c66e533a15a86169e39de2657310ade8f98d3c" dependencies = [ "ring", - "untrusted", + "untrusted 0.7.1", ] [[package]] @@ -3477,6 +3504,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" +[[package]] +name = "untrusted" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cd1f4b4e96b46aeb8d4855db4a7a9bd96eeeb5c6a1ab54593328761642ce2f" + [[package]] name = "untrusted" version = "0.7.1" @@ -3678,6 +3711,31 @@ version = "0.2.70" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd4945e4943ae02d15c13962b38a5b1e81eadd4b71214eee75af64a4d6a4fd64" +[[package]] +name = "web-push" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8353dd6c7cfb9a02737fd6dc9a66a80dca2a93fb690f6ad264d2a7672e6f1c0" +dependencies = [ + "base64 0.7.0", + "chrono", + "erased-serde", + "futures", + "http", + "hyper", + "hyper-tls", + "lazy_static", + "log", + "native-tls", + "openssl", + "ring", + "serde", + "serde_derive", + "serde_json", + "time 0.1.44", + "untrusted 0.6.2", +] + [[package]] name = "web-sys" version = "0.3.47" @@ -3695,7 +3753,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea" dependencies = [ "ring", - "untrusted", + "untrusted 0.7.1", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a3511ebf45e8f5c549511be9c179918bfe85014a..b02ce5c0cf0cc3ea8a39137809604ad9f9753ec5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,16 @@ [package] name = "revolt" -version = "0.3.3-alpha.5" +version = "0.3.3-alpha.6" authors = ["Paul Makles <paulmakles@gmail.com>"] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +base64 = "0.13.0" futures = "0.3.8" impl_ops = "0.1.1" +web-push = "0.7.2" many-to-many = "0.1.2" ctrlc = { version = "3.0", features = ["termination"] } async-std = { version = "1.8.0", features = ["tokio02", "attributes"] } diff --git a/src/database/entities/autumn.rs b/src/database/entities/autumn.rs index 7ae1102cdab2aa31f20ae0e5ea1d5feb59ca7123..8b42c4397ba2f81a4b634d3b9d4774c938436da3 100644 --- a/src/database/entities/autumn.rs +++ b/src/database/entities/autumn.rs @@ -1,4 +1,4 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type")] @@ -17,5 +17,5 @@ pub struct File { metadata: Metadata, content_type: String, - message_id: Option<String> + message_id: Option<String>, } diff --git a/src/database/entities/channel.rs b/src/database/entities/channel.rs index da8a61ebf16c64720db11e3190bb471c96bde790..13346164c2e152618c0b831906e607276a977bfc 100644 --- a/src/database/entities/channel.rs +++ b/src/database/entities/channel.rs @@ -46,9 +46,9 @@ pub enum Channel { impl Channel { pub fn id(&self) -> &str { match self { - Channel::SavedMessages { id, .. } => id, - Channel::DirectMessage { id, .. } => id, - Channel::Group { id, .. } => id, + Channel::SavedMessages { id, .. } + | Channel::DirectMessage { id, .. } + | Channel::Group { id, .. } => id, } } diff --git a/src/database/entities/message.rs b/src/database/entities/message.rs index ee26c40d00c84d0ac79fca5fea9a26f84035bc3b..2fdc70c156df710381bfd51dd88724c724ba4715 100644 --- a/src/database/entities/message.rs +++ b/src/database/entities/message.rs @@ -1,12 +1,19 @@ +use crate::util::variables::VAPID_PRIVATE_KEY; use crate::{ database::*, - notifications::events::ClientboundNotification, + notifications::{events::ClientboundNotification, websocket::is_online}, util::result::{Error, Result}, }; -use mongodb::bson::{doc, to_bson, DateTime}; + +use futures::StreamExt; +use mongodb::{ + bson::{doc, to_bson, DateTime}, + options::FindOptions, +}; use rocket_contrib::json::JsonValue; use serde::{Deserialize, Serialize}; use ulid::Ulid; +use web_push::{ContentEncoding, SubscriptionInfo, VapidSignatureBuilder, WebPushClient, WebPushMessageBuilder}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Message { @@ -49,7 +56,7 @@ impl Message { // ! FIXME: temp code let channels = get_collection("channels"); - match channel { + match &channel { Channel::DirectMessage { id, .. } => { channels .update_one( @@ -96,11 +103,84 @@ impl Message { _ => {} } + let enc = serde_json::to_string(&self).unwrap(); ClientboundNotification::Message(self) .publish(channel.id().to_string()) .await .ok(); + /* + Web Push Test Code + ! FIXME: temp code + */ + + // Find all offline users. + let mut target_ids = vec![]; + match &channel { + Channel::DirectMessage { recipients, .. } | Channel::Group { recipients, .. } => { + for recipient in recipients { + // if !is_online(recipient) { + target_ids.push(recipient.clone()); + // } + } + } + _ => {} + } + + // Fetch their corresponding sessions. + let mut cursor = get_collection("accounts") + .find( + doc! { + "_id": { + "$in": target_ids + }, + "sessions.subscription": { + "$exists": true + } + }, + FindOptions::builder() + .projection(doc! { "sessions": 1 }) + .build(), + ) + .await + .unwrap(); // !FIXME + + let mut subscriptions = vec![]; + while let Some(result) = cursor.next().await { + if let Ok(doc) = result { + if let Ok(sessions) = doc.get_array("sessions") { + for session in sessions { + if let Some(doc) = session.as_document() { + if let Ok(sub) = doc.get_document("subscription") { + let endpoint = sub.get_str("endpoint").unwrap().to_string(); + let p256dh = sub.get_str("p256dh").unwrap().to_string(); + let auth = sub.get_str("auth").unwrap().to_string(); + + subscriptions.push(SubscriptionInfo::new(endpoint, p256dh, auth)); + } + } + } + } + } + } + + if subscriptions.len() > 0 { + let client = WebPushClient::new(); + let key = base64::decode_config(VAPID_PRIVATE_KEY.clone(), base64::URL_SAFE).unwrap(); + + for subscription in subscriptions { + let mut builder = WebPushMessageBuilder::new(&subscription).unwrap(); + let sig_builder = + VapidSignatureBuilder::from_pem(std::io::Cursor::new(&key), &subscription).unwrap(); + let signature = sig_builder.build().unwrap(); + builder.set_vapid_signature(signature); + builder.set_payload(ContentEncoding::AesGcm, enc.as_bytes()); + let m = builder.build().unwrap(); + let response = client.send(m).await.unwrap(); + dbg!(response); + } + } + Ok(()) } diff --git a/src/database/entities/mod.rs b/src/database/entities/mod.rs index 864ee9ca7106616d663a78aa7f4b980316aa7ff2..cdc9a435f953c8348d5a526a50677b7ed2e20493 100644 --- a/src/database/entities/mod.rs +++ b/src/database/entities/mod.rs @@ -1,11 +1,11 @@ +mod autumn; mod channel; mod guild; mod message; mod user; -mod autumn; +pub use autumn::*; pub use channel::*; pub use guild::*; pub use message::*; pub use user::*; -pub use autumn::*; diff --git a/src/routes/channels/message_query_stale.rs b/src/routes/channels/message_query_stale.rs index d2ee19db597ff88333d37a1e31c74b4a7ae5d7b0..300ae76524de9634d62f05dcf77c1028000eece8 100644 --- a/src/routes/channels/message_query_stale.rs +++ b/src/routes/channels/message_query_stale.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize)] pub struct Options { - ids: Vec<String> + ids: Vec<String>, } #[post("/<target>/messages/stale", data = "<data>")] @@ -35,7 +35,7 @@ pub async fn req(user: User, target: Ref, data: Json<Options>) -> Result<JsonVal }, "channel": target.id() }, - None + None, ) .await .map_err(|_| Error::DatabaseError { @@ -47,12 +47,11 @@ pub async fn req(user: User, target: Ref, data: Json<Options>) -> Result<JsonVal let mut found_ids = vec![]; while let Some(result) = cursor.next().await { if let Ok(doc) = result { - let msg = from_document::<Message>(doc) - .map_err(|_| Error::DatabaseError { - operation: "from_document", - with: "message", - })?; - + let msg = from_document::<Message>(doc).map_err(|_| Error::DatabaseError { + operation: "from_document", + with: "message", + })?; + found_ids.push(msg.id.clone()); if msg.edited.is_some() { updated.push(msg); diff --git a/src/routes/channels/message_send.rs b/src/routes/channels/message_send.rs index 66c9e94cafd335b50b5ba5aef800d1ad76803072..aaf275ed3c50ba60dda4758a5bb961cfb962f992 100644 --- a/src/routes/channels/message_send.rs +++ b/src/routes/channels/message_send.rs @@ -1,7 +1,10 @@ use crate::database::*; use crate::util::result::{Error, Result}; -use mongodb::{bson::{doc, from_document}, options::FindOneOptions}; +use mongodb::{ + bson::{doc, from_document}, + options::FindOneOptions, +}; use rocket_contrib::json::{Json, JsonValue}; use serde::{Deserialize, Serialize}; use ulid::Ulid; @@ -41,7 +44,7 @@ pub async fn req(user: User, target: Ref, message: Json<Data>) -> Result<JsonVal }, FindOneOptions::builder() .projection(doc! { "_id": 1 }) - .build() + .build(), ) .await .map_err(|_| Error::DatabaseError { @@ -64,33 +67,40 @@ pub async fn req(user: User, target: Ref, message: Json<Data>) -> Result<JsonVal "$exists": false } }, - None + None, ) .await .map_err(|_| Error::DatabaseError { operation: "find_one", with: "attachment", - })? { - let attachment = from_document::<File>(doc) - .map_err(|_| Error::DatabaseError { operation: "from_document", with: "attachment" })?; + })? + { + let attachment = from_document::<File>(doc).map_err(|_| Error::DatabaseError { + operation: "from_document", + with: "attachment", + })?; - attachments.update_one( - doc! { - "_id": &attachment.id - }, - doc! { - "$set": { - "message_id": &id - } - }, - None - ) - .await - .map_err(|_| Error::DatabaseError { operation: "update_one", with: "attachment" })?; + attachments + .update_one( + doc! { + "_id": &attachment.id + }, + doc! { + "$set": { + "message_id": &id + } + }, + None, + ) + .await + .map_err(|_| Error::DatabaseError { + operation: "update_one", + with: "attachment", + })?; Some(attachment) } else { - return Err(Error::UnknownAttachment) + return Err(Error::UnknownAttachment); } } else { None diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 77ea709d784afdcb901a2fabf2c9021d439d7d19..6d8229b8dee8d0da0455dab590a951f8782dcdc4 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -5,6 +5,7 @@ use rocket::Rocket; mod channels; mod guild; mod onboard; +mod push; mod root; mod users; @@ -15,4 +16,5 @@ pub fn mount(rocket: Rocket) -> Rocket { .mount("/users", users::routes()) .mount("/channels", channels::routes()) .mount("/guild", guild::routes()) + .mount("/push", push::routes()) } diff --git a/src/routes/push/mod.rs b/src/routes/push/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..ff0a14cd53bf5b4b13769d1489313be40c787a96 --- /dev/null +++ b/src/routes/push/mod.rs @@ -0,0 +1,8 @@ +use rocket::Route; + +mod subscribe; +mod unsubscribe; + +pub fn routes() -> Vec<Route> { + routes![subscribe::req, unsubscribe::req] +} diff --git a/src/routes/push/subscribe.rs b/src/routes/push/subscribe.rs new file mode 100644 index 0000000000000000000000000000000000000000..57d30a3541ef55300b703c8b78e0ec3959773cc3 --- /dev/null +++ b/src/routes/push/subscribe.rs @@ -0,0 +1,36 @@ +use crate::database::*; +use crate::util::result::{Error, Result}; + +use mongodb::bson::{doc, to_document}; +use rauth::auth::Session; +use rocket_contrib::json::Json; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct Subscription { + endpoint: String, + p256dh: String, + auth: String, +} + +#[post("/subscribe", data = "<data>")] +pub async fn req(session: Session, data: Json<Subscription>) -> Result<()> { + let data = data.into_inner(); + let col = get_collection("accounts") + .update_one( + doc! { + "_id": session.user_id, + "sessions.id": session.id.unwrap() + }, + doc! { + "$set": { + "sessions.$.subscription": to_document(&data).unwrap() + } + }, + None, + ) + .await + .unwrap(); + + Ok(()) +} diff --git a/src/routes/push/unsubscribe.rs b/src/routes/push/unsubscribe.rs new file mode 100644 index 0000000000000000000000000000000000000000..deb1026d4b3ab60f04816e13888b54b04c208981 --- /dev/null +++ b/src/routes/push/unsubscribe.rs @@ -0,0 +1,26 @@ +use crate::database::*; +use crate::util::result::{Error, Result}; + +use mongodb::bson::doc; +use rauth::auth::Session; + +#[post("/unsubscribe")] +pub async fn req(session: Session) -> Result<()> { + let col = get_collection("accounts") + .update_one( + doc! { + "_id": session.user_id, + "sessions.id": session.id.unwrap() + }, + doc! { + "$unset": { + "sessions.$.subscription": 1 + } + }, + None, + ) + .await + .unwrap(); + + Ok(()) +} diff --git a/src/routes/root.rs b/src/routes/root.rs index b1cc49ebcd414dcae187230b12e6505188884cb0..23f276901ff9c38de5546c103b76bbfbef2ca01b 100644 --- a/src/routes/root.rs +++ b/src/routes/root.rs @@ -1,5 +1,6 @@ use crate::util::variables::{ - DISABLE_REGISTRATION, EXTERNAL_WS_URL, HCAPTCHA_SITEKEY, INVITE_ONLY, USE_EMAIL, USE_HCAPTCHA, USE_AUTUMN, AUTUMN_URL + AUTUMN_URL, DISABLE_REGISTRATION, EXTERNAL_WS_URL, HCAPTCHA_SITEKEY, INVITE_ONLY, USE_AUTUMN, + USE_EMAIL, USE_HCAPTCHA, VAPID_PUBLIC_KEY, }; use mongodb::bson::doc; @@ -8,7 +9,7 @@ use rocket_contrib::json::JsonValue; #[get("/")] pub async fn root() -> JsonValue { json!({ - "revolt": "0.3.3-alpha.5", + "revolt": "0.3.3-alpha.6", "features": { "registration": !*DISABLE_REGISTRATION, "captcha": { @@ -23,5 +24,6 @@ pub async fn root() -> JsonValue { } }, "ws": *EXTERNAL_WS_URL, + "vapid": *VAPID_PUBLIC_KEY }) } diff --git a/src/routes/users/find_mutual.rs b/src/routes/users/find_mutual.rs index e75c07e917d7061ca166b7ece4664062ff69249a..1d52e59353cbf1e74c4b534e3300d8db3e05b8d1 100644 --- a/src/routes/users/find_mutual.rs +++ b/src/routes/users/find_mutual.rs @@ -2,8 +2,8 @@ use crate::database::*; use crate::util::result::{Error, Result}; use futures::StreamExt; +use mongodb::bson::{doc, Document}; use mongodb::options::FindOptions; -use mongodb::bson::{Document, doc}; use rocket_contrib::json::JsonValue; #[get("/<target>/mutual")] @@ -19,9 +19,7 @@ pub async fn req(user: User, target: Ref) -> Result<JsonValue> { { "recipients": &target.id } ] }, - FindOptions::builder() - .projection(doc! { "_id": 1 }) - .build() + FindOptions::builder().projection(doc! { "_id": 1 }).build(), ) .await .map_err(|_| Error::DatabaseError { @@ -35,7 +33,5 @@ pub async fn req(user: User, target: Ref) -> Result<JsonValue> { .filter_map(|x| x.get_str("_id").ok().map(|x| x.to_string())) .collect::<Vec<String>>(); - Ok(json!({ - "channels": channels - })) + Ok(json!({ "channels": channels })) } diff --git a/src/util/variables.rs b/src/util/variables.rs index 7c07adef092612f49e465a5408739d4850d6ca5f..6431673acd6901c2f6aff381c08c10328af7c3d2 100644 --- a/src/util/variables.rs +++ b/src/util/variables.rs @@ -21,6 +21,10 @@ lazy_static! { env::var("REVOLT_HCAPTCHA_KEY").unwrap_or_else(|_| "0x0000000000000000000000000000000000000000".to_string()); pub static ref HCAPTCHA_SITEKEY: String = env::var("REVOLT_HCAPTCHA_SITEKEY").unwrap_or_else(|_| "10000000-ffff-ffff-ffff-000000000001".to_string()); + pub static ref VAPID_PRIVATE_KEY: String = + env::var("REVOLT_VAPID_PRIVATE_KEY").expect("Missing REVOLT_VAPID_PRIVATE_KEY environment variable."); + pub static ref VAPID_PUBLIC_KEY: String = + env::var("REVOLT_VAPID_PUBLIC_KEY").expect("Missing REVOLT_VAPID_PUBLIC_KEY environment variable."); // Application Flags pub static ref DISABLE_REGISTRATION: bool = env::var("REVOLT_DISABLE_REGISTRATION").map_or(false, |v| v == "1"); @@ -48,6 +52,8 @@ lazy_static! { // Application Logic Settings pub static ref MAX_GROUP_SIZE: usize = env::var("REVOLT_MAX_GROUP_SIZE").unwrap_or_else(|_| "50".to_string()).parse().unwrap(); + pub static ref PUSH_LIMIT: usize = + env::var("REVOLT_PUSH_LIMIT").unwrap_or_else(|_| "50".to_string()).parse().unwrap(); } pub fn preflight_checks() {