diff --git a/Cargo.lock b/Cargo.lock index a9fb7fd12f0fc1c77c355a9920a1a50fa60192c8..977d4055776a1becf27d6878d179bfdf2149de42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,7 +120,7 @@ checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db" [[package]] name = "cat-disruptor-6500" -version = "0.1.2" +version = "0.2.0" dependencies = [ "anyhow", "bigdecimal", @@ -202,6 +202,20 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if 1.0.0", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", + "serde", +] + [[package]] name = "deflate" version = "0.8.6" @@ -1511,10 +1525,12 @@ dependencies = [ "bitflags", "bytes 1.2.1", "cfg-if 1.0.0", + "dashmap", "flate2", "futures", "mime", "mime_guess", + "parking_lot", "percent-encoding", "reqwest 0.11.12", "serde", diff --git a/Cargo.toml b/Cargo.toml index 08e87216452f4c804a1889abe33bf82b4abc770d..56b54cc6b98629cada06293c5cea4782ef92ddb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "cat-disruptor-6500" -version = "0.1.2" +version = "0.2.0" authors = ["Stephen <stephen@stephendownward.ca>"] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -serenity = {version = "0.11", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "static_assertions"] } +serenity = {version = "0.11", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "static_assertions", "cache"] } tokio = {version = "1.21", features = ["full", "time"] } phf = { version = "0.8", features = ["macros"] } toml = "0.5" diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 7df01df2c768f9e9c2de25e919eb4c4da861adbe..e39b04fac492e3805211c3e9ba374fb310ced7c5 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -35,7 +35,8 @@ pub(crate) trait LineHandler: Send + Sync { #[async_trait] pub(crate) trait ReactionHandler: Send + Sync { - async fn reaction(&self, _ctx: &Context, reaction: &Reaction); + async fn reaction_add(&self, ctx: &Context, reaction: &Reaction); + async fn reaction_del(&self, ctx: &Context, reaction: &Reaction); } pub(crate) struct Dispatcher { @@ -53,7 +54,13 @@ impl EventHandler for Dispatcher { async fn reaction_add(&self, ctx: Context, reaction: Reaction) { for r in &self.reacts { - r.reaction(&ctx, &reaction).await; + r.reaction_add(&ctx, &reaction).await; + } + } + + async fn reaction_remove(&self, ctx: Context, reaction: Reaction) { + for r in &self.reacts { + r.reaction_del(&ctx, &reaction).await; } } diff --git a/src/handlers/starboard.rs b/src/handlers/starboard.rs index b119f310a3bfd8dd046d476373cf26365f81d724..4fdc3b96d9978889baf3e75bc5aba2856d2c6fb2 100644 --- a/src/handlers/starboard.rs +++ b/src/handlers/starboard.rs @@ -7,35 +7,45 @@ use anyhow::Context as AnyhowContext; use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive}; use diesel::{ExpressionMethods, PgConnection, QueryDsl, RunQueryDsl}; use serenity::async_trait; -use serenity::model::prelude::{ChannelId, EmojiId, GuildId, Message, Reaction, ReactionType}; +use serenity::builder::CreateEmbed; +use serenity::model::prelude::{ + ChannelId, Emoji, EmojiId, Message, MessageReaction, Reaction, ReactionType, +}; use serenity::prelude::*; use std::sync::Arc; use tokio::sync::Mutex; -pub struct StarboardHandler { +#[derive(Clone)] +pub struct SingleMessageHandler { conn: Arc<Mutex<PgConnection>>, + server_settings: ServerSetting, + reaction_count: u64, + msg: Message, + emoji: Emoji, + name: String, + image: Option<String>, } -impl StarboardHandler { - pub fn new(conn: Arc<Mutex<PgConnection>>) -> Self { - Self { conn } - } - - async fn handle_reaction(&self, ctx: &Context, reaction: &Reaction) -> anyhow::Result<()> { +impl SingleMessageHandler { + pub async fn handle_reaction( + conn: Arc<Mutex<PgConnection>>, + ctx: &Context, + reaction: &Reaction, + ) -> anyhow::Result<()> { let guild = match reaction.channel_id.to_channel(ctx).await?.guild() { Some(x) => x, None => return Ok(()), }; // get corresponding guild settings - let gs: Vec<ServerSetting> = server_settings::dsl::server_settings + let mut gs: Vec<ServerSetting> = server_settings::dsl::server_settings .filter( schema::server_settings::columns::guild_id .eq(BigDecimal::from_u64(guild.guild_id.0).unwrap()), ) .limit(1) - .get_results(&*self.conn.lock().await)?; - let gs = match gs.first() { + .get_results(&*conn.lock().await)?; + let gs = match gs.pop() { Some(x) => x, None => return Ok(()), }; @@ -45,13 +55,92 @@ impl StarboardHandler { .context("Could not convert emoji id to u64")?, ); - if Self::emoji_match(&reaction.emoji, emoji) { - let msg = reaction.message(ctx).await?; - for mr in &msg.reactions { - if Self::emoji_match(&mr.reaction_type, emoji) { - self.handle_matching_reaction(ctx, gs, mr.count, &msg, &guild.guild_id) - .await?; - break; + let msg = reaction.message(ctx).await?; + + // reaction from event handler must match + // otherwise we'll update the repost (and add an "edited" + // to the message) + // whenever someone reacts with any reaction + if !Self::emoji_match(&reaction.emoji, emoji) { + return Ok(()); + } + + let reaction_count = Self::find_emoji_match(msg.reactions.iter(), emoji); + let guild_id = guild.guild_id; + + let emoji = guild_id + .emoji(ctx, emoji) + .await + .context("Could not get emoji from guild")?; + + let name = msg + .author + .nick_in(ctx, guild_id) + .await + .unwrap_or_else(|| msg.author.tag()); + + let image = msg + .attachments + .iter() + .filter(|a| a.width.is_some()) + .map(|a| &a.url) + .next() + .cloned(); + + let handler = Self { + conn, + server_settings: gs, + reaction_count, + msg, + emoji, + name, + image, + }; + + handler.process_match(ctx).await?; + + Ok(()) + } + + async fn process_match(&self, ctx: &Context) -> anyhow::Result<()> { + let original_id = BigDecimal::from(self.msg.id.0); + diesel::insert_into(starboard_mappings::dsl::starboard_mappings) + .values(starboard_mappings::columns::original_id.eq(&original_id)) + .returning(starboard_mappings::columns::repost_id) + .on_conflict_do_nothing() + .execute(&*self.conn.lock().await)?; + + let repost_id = starboard_mappings::dsl::starboard_mappings + .filter(starboard_mappings::columns::original_id.eq(&original_id)) + .select(starboard_mappings::columns::repost_id) + .limit(1) + .get_results(&*self.conn.lock().await)?; + + let repost_id: &Option<BigDecimal> = + repost_id.get(0).context("Insert of mapping failed")?; + + match repost_id { + Some(id) => { + self.edit_existing_starboard_message( + ctx, + id.to_u64() + .context("Could not convert repost message id to a u64")?, + ) + .await?; + } + None => { + if self.reaction_count >= self.server_settings.starboard_threshold as u64 { + // post to repost + let repost = self.post_new_starboard_message(ctx).await?; + + // update the DB + let repost_id = BigDecimal::from_u64(repost); + diesel::update( + starboard_mappings::dsl::starboard_mappings + .filter(starboard_mappings::columns::original_id.eq(original_id)), + ) + .set(starboard_mappings::columns::repost_id.eq(repost_id)) + .execute(&*self.conn.lock().await)?; } } } @@ -59,83 +148,64 @@ impl StarboardHandler { Ok(()) } - async fn handle_matching_reaction( + async fn post_new_starboard_message(&self, ctx: &Context) -> anyhow::Result<u64> { + let repost = ChannelId( + self.server_settings + .starboard_channel + .to_u64() + .context("Could not convert starboard channel to a u64")?, + ) + .send_message(ctx, |m| m.embed(|e| self.clone().starboard_message(e))) + .await?; + + Ok(repost.id.0) + } + + async fn edit_existing_starboard_message( &self, ctx: &Context, - gs: &ServerSetting, - count: u64, - msg: &Message, - guild: &GuildId, + message_id: u64, ) -> anyhow::Result<()> { - if count >= gs.starboard_threshold as u64 { - let original_id = BigDecimal::from(msg.id.0); - diesel::insert_into(starboard_mappings::dsl::starboard_mappings) - .values(starboard_mappings::columns::original_id.eq(&original_id)) - .returning(starboard_mappings::columns::repost_id) - .on_conflict_do_nothing() - .execute(&*self.conn.lock().await)?; - - let repost_id = starboard_mappings::dsl::starboard_mappings - .filter(starboard_mappings::columns::original_id.eq(&original_id)) - .select(starboard_mappings::columns::repost_id) - .limit(1) - .get_results(&*self.conn.lock().await)?; - - let repost_id: &Option<BigDecimal> = - repost_id.get(0).context("Insert of mapping failed")?; - - if repost_id.is_none() { - // post to repost channel - let name = msg - .author - .nick_in(&ctx, guild) - .await - .unwrap_or_else(|| msg.author.tag()); - - let image = msg - .attachments - .iter() - .filter(|a| a.width.is_some()) - .map(|a| &a.url) - .next(); - - let repost = ChannelId( - gs.starboard_channel - .to_u64() - .context("Could not convert starboard channel to a u64")?, - ) - .send_message(ctx, |m| { - m.embed(|e| { - let mut e = e - .description(format!( - "[Jump to source]({})\n{}", - msg.link(), - msg.content - )) - .author(|a| a.name(&name).icon_url(msg.author.face())) - .timestamp(&msg.timestamp); - - if let Some(url) = image { - e = e.image(url); - } - - e - }) - }) - .await?; + let channel_id = ChannelId( + self.server_settings + .starboard_channel + .to_u64() + .context("Could not convert starboard channel to a u64")?, + ); - // update the DB - let repost_id = BigDecimal::from_u64(repost.id.0); - diesel::update( - starboard_mappings::dsl::starboard_mappings - .filter(starboard_mappings::columns::original_id.eq(original_id)), - ) - .set(starboard_mappings::columns::repost_id.eq(repost_id)) - .execute(&*self.conn.lock().await)?; + let mut msg = channel_id.message(ctx, message_id).await?; + msg.edit(ctx, |m| m.embed(|e| self.clone().starboard_message(e))) + .await?; + + Ok(()) + } + + fn starboard_message(self, e: &mut CreateEmbed) -> &mut CreateEmbed { + let mut e = e + .description(format!( + "[Jump to source]({})\n{}", + self.msg.link(), + self.msg.content + )) + .title(format!("{} {}", self.reaction_count, self.emoji)) + .author(|a| a.name(&self.name).icon_url(self.msg.author.face())) + .timestamp(self.msg.timestamp); + + if let Some(url) = self.image { + e = e.image(url); + } + + e + } + + fn find_emoji_match<'a, I: Iterator<Item = &'a MessageReaction>>(iter: I, em: EmojiId) -> u64 { + for mr in iter { + if Self::emoji_match(&mr.reaction_type, em) { + return mr.count; } } - Ok(()) + 0 } fn emoji_match(rt: &ReactionType, em: EmojiId) -> bool { @@ -143,9 +213,29 @@ impl StarboardHandler { } } +pub struct StarboardHandler { + conn: Arc<Mutex<PgConnection>>, +} + +impl StarboardHandler { + pub fn new(conn: Arc<Mutex<PgConnection>>) -> Self { + Self { conn } + } + + async fn handle_reaction(&self, ctx: &Context, reaction: &Reaction) -> anyhow::Result<()> { + SingleMessageHandler::handle_reaction(self.conn.clone(), ctx, reaction).await + } +} + #[async_trait] impl ReactionHandler for StarboardHandler { - async fn reaction(&self, ctx: &Context, reaction: &Reaction) { + async fn reaction_add(&self, ctx: &Context, reaction: &Reaction) { + if let Err(e) = self.handle_reaction(ctx, reaction).await { + eprintln!("Error in starboard: {:?}", e); + } + } + + async fn reaction_del(&self, ctx: &Context, reaction: &Reaction) { if let Err(e) = self.handle_reaction(ctx, reaction).await { eprintln!("Error in starboard: {:?}", e); } diff --git a/src/models.rs b/src/models.rs index 10d32fea3696c52717ba4eb92d647342a6c39ac4..bae06f4d4f272ece41ec8dd572ff8443ca543da8 100644 --- a/src/models.rs +++ b/src/models.rs @@ -17,7 +17,7 @@ pub struct NewUserProgram<'a> { pub code: &'a str, } -#[derive(Queryable)] +#[derive(Clone, Queryable)] pub struct ServerSetting { pub id: i32, pub guild_id: BigDecimal, diff --git a/src/program.rs b/src/program.rs index 689e8e81499b51aa5cb5d44a7b7a2ee39878ba70..03c1547768a1453d70f181e9e19cfc217e998840 100644 --- a/src/program.rs +++ b/src/program.rs @@ -67,7 +67,7 @@ impl Program { .and(columns::name.eq(name)), ), ) - .set(columns::published.eq(if published { 1 } else { 0 })) + .set(columns::published.eq(i32::from(published))) .execute(conn) .ok()? == 1 {