diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 7df01df2c768f9e9c2de25e919eb4c4da861adbe..e39b04fac492e3805211c3e9ba374fb310ced7c5 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -35,7 +35,8 @@ pub(crate) trait LineHandler: Send + Sync { #[async_trait] pub(crate) trait ReactionHandler: Send + Sync { - async fn reaction(&self, _ctx: &Context, reaction: &Reaction); + async fn reaction_add(&self, ctx: &Context, reaction: &Reaction); + async fn reaction_del(&self, ctx: &Context, reaction: &Reaction); } pub(crate) struct Dispatcher { @@ -53,7 +54,13 @@ impl EventHandler for Dispatcher { async fn reaction_add(&self, ctx: Context, reaction: Reaction) { for r in &self.reacts { - r.reaction(&ctx, &reaction).await; + r.reaction_add(&ctx, &reaction).await; + } + } + + async fn reaction_remove(&self, ctx: Context, reaction: Reaction) { + for r in &self.reacts { + r.reaction_del(&ctx, &reaction).await; } } diff --git a/src/handlers/starboard.rs b/src/handlers/starboard.rs index 25bb92b4bc35ac237002181aa7593472cc241b52..82393ebdd99912f50aae6a617921186b534d9731 100644 --- a/src/handlers/starboard.rs +++ b/src/handlers/starboard.rs @@ -8,7 +8,9 @@ use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive}; use diesel::{ExpressionMethods, PgConnection, QueryDsl, RunQueryDsl}; use serenity::async_trait; use serenity::builder::CreateEmbed; -use serenity::model::prelude::{ChannelId, Emoji, EmojiId, Message, Reaction, ReactionType}; +use serenity::model::prelude::{ + ChannelId, Emoji, EmojiId, Message, MessageReaction, Reaction, ReactionType, +}; use serenity::prelude::*; use std::sync::Arc; use tokio::sync::Mutex; @@ -53,80 +55,93 @@ impl SingleMessageHandler { .context("Could not convert emoji id to u64")?, ); - if Self::emoji_match(&reaction.emoji, emoji) { - let msg = reaction.message(ctx).await?; - for mr in &msg.reactions { - if Self::emoji_match(&mr.reaction_type, emoji) { - let guild_id = guild.guild_id; - - let emoji = guild_id - .emoji(ctx, emoji) - .await - .context("Could not get emoji from guild")?; - - let name = msg - .author - .nick_in(ctx, guild_id) - .await - .unwrap_or_else(|| msg.author.tag()); - - let image = msg - .attachments - .iter() - .filter(|a| a.width.is_some()) - .map(|a| &a.url) - .next() - .cloned(); - - let handler = Self { - conn, - server_settings: gs, - reaction_count: mr.count, - msg, - emoji, - name, - image, - }; - - handler.process_match(ctx).await?; - break; - } - } + let msg = reaction.message(ctx).await?; + + // reaction from event handler must match + // otherwise we'll update the repost (and add an "edited" + // to the message) + // whenever someone reacts with any reaction + if !Self::emoji_match(&reaction.emoji, emoji) { + return Ok(()); } + let reaction_count = Self::find_emoji_match(msg.reactions.iter(), emoji); + let guild_id = guild.guild_id; + + let emoji = guild_id + .emoji(ctx, emoji) + .await + .context("Could not get emoji from guild")?; + + let name = msg + .author + .nick_in(ctx, guild_id) + .await + .unwrap_or_else(|| msg.author.tag()); + + let image = msg + .attachments + .iter() + .filter(|a| a.width.is_some()) + .map(|a| &a.url) + .next() + .cloned(); + + let handler = Self { + conn, + server_settings: gs, + reaction_count, + msg, + emoji, + name, + image, + }; + + handler.process_match(ctx).await?; + Ok(()) } async fn process_match(&self, ctx: &Context) -> anyhow::Result<()> { - if self.reaction_count >= self.server_settings.starboard_threshold as u64 { - let original_id = BigDecimal::from(self.msg.id.0); - diesel::insert_into(starboard_mappings::dsl::starboard_mappings) - .values(starboard_mappings::columns::original_id.eq(&original_id)) - .returning(starboard_mappings::columns::repost_id) - .on_conflict_do_nothing() - .execute(&*self.conn.lock().await)?; - - let repost_id = starboard_mappings::dsl::starboard_mappings - .filter(starboard_mappings::columns::original_id.eq(&original_id)) - .select(starboard_mappings::columns::repost_id) - .limit(1) - .get_results(&*self.conn.lock().await)?; - - let repost_id: &Option<BigDecimal> = - repost_id.get(0).context("Insert of mapping failed")?; - - if repost_id.is_none() { - // post to repost - let repost = self.post_new_starboard_message(ctx).await?; - - // update the DB - let repost_id = BigDecimal::from_u64(repost); - diesel::update( - starboard_mappings::dsl::starboard_mappings - .filter(starboard_mappings::columns::original_id.eq(original_id)), + let original_id = BigDecimal::from(self.msg.id.0); + diesel::insert_into(starboard_mappings::dsl::starboard_mappings) + .values(starboard_mappings::columns::original_id.eq(&original_id)) + .returning(starboard_mappings::columns::repost_id) + .on_conflict_do_nothing() + .execute(&*self.conn.lock().await)?; + + let repost_id = starboard_mappings::dsl::starboard_mappings + .filter(starboard_mappings::columns::original_id.eq(&original_id)) + .select(starboard_mappings::columns::repost_id) + .limit(1) + .get_results(&*self.conn.lock().await)?; + + let repost_id: &Option<BigDecimal> = + repost_id.get(0).context("Insert of mapping failed")?; + + match repost_id { + Some(id) => { + self.edit_existing_starboard_message( + ctx, + id.to_u64() + .context("Could not convert repost message id to a u64")?, ) - .set(starboard_mappings::columns::repost_id.eq(repost_id)) - .execute(&*self.conn.lock().await)?; + .await?; + } + None => { + if self.reaction_count >= self.server_settings.starboard_threshold as u64 { + // post to repost + let repost = self.post_new_starboard_message(ctx).await?; + + // update the DB + let repost_id = BigDecimal::from_u64(repost); + diesel::update( + starboard_mappings::dsl::starboard_mappings + .filter(starboard_mappings::columns::original_id.eq(original_id)), + ) + .set(starboard_mappings::columns::repost_id.eq(repost_id)) + .execute(&*self.conn.lock().await)?; + } } } @@ -146,6 +161,25 @@ impl SingleMessageHandler { Ok(repost.id.0) } + async fn edit_existing_starboard_message( + &self, + ctx: &Context, + message_id: u64, + ) -> anyhow::Result<()> { + let channel_id = ChannelId( + self.server_settings + .starboard_channel + .to_u64() + .context("Could not convert starboard channel to a u64")?, + ); + + let mut msg = channel_id.message(ctx, message_id).await?; + msg.edit(ctx, |m| m.embed(|e| self.clone().starboard_message(e))) + .await?; + + Ok(()) + } + fn starboard_message(self, e: &mut CreateEmbed) -> &mut CreateEmbed { let mut e = e .description(format!( @@ -164,6 +198,16 @@ impl SingleMessageHandler { e } + fn find_emoji_match<'a, I: Iterator<Item = &'a MessageReaction>>(iter: I, em: EmojiId) -> u64 { + for mr in iter { + if Self::emoji_match(&mr.reaction_type, em) { + return mr.count; + } + } + + 0 + } + fn emoji_match(rt: &ReactionType, em: EmojiId) -> bool { matches!(rt, ReactionType::Custom { id, .. } if *id == em) } @@ -185,7 +229,13 @@ impl StarboardHandler { #[async_trait] impl ReactionHandler for StarboardHandler { - async fn reaction(&self, ctx: &Context, reaction: &Reaction) { + async fn reaction_add(&self, ctx: &Context, reaction: &Reaction) { + if let Err(e) = self.handle_reaction(ctx, reaction).await { + eprintln!("Error in starboard: {:?}", e); + } + } + + async fn reaction_del(&self, ctx: &Context, reaction: &Reaction) { if let Err(e) = self.handle_reaction(ctx, reaction).await { eprintln!("Error in starboard: {:?}", e); }