Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • stephen/cat-disruptor-6500
  • roygbyte/cat-disruptor-6500
  • tinyconan/cat-disruptor-6500
3 results
Show changes
Commits on Source (57)
Showing
with 2081 additions and 639 deletions
image: "rust:latest"
before_script:
- rustup component add rustfmt
- rustup component add clippy
- cargo install cargo-deb
test:
script:
- cargo fmt -- --check
- cargo clippy --all-targets --all-features -- -D warnings
- cargo test
build:
script:
- cargo deb
artifacts:
paths:
- target/debian/*.deb
This diff is collapsed.
[package]
name = "cat-disruptor-6500"
version = "0.1.0"
authors = ["Stephen <stephen@stephendownward.ca>"]
version = "0.3.0"
authors = ["Stephen <wemaster@scd31.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serenity = {version = "0.9", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "static_assertions"] }
tokio = {version = "0.2", features = ["full", "time"] }
serenity = {version = "0.11", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "static_assertions", "cache"] }
tokio = {version = "1.21", features = ["full", "time"] }
phf = { version = "0.8", features = ["macros"] }
toml = "0.5"
serde = { version = "1.0", features = ["derive"] }
xbasic = "0.3"
xbasic = "0.3.1"
png = "0.16"
diesel = { version = "1.4", features = ["postgres", "numeric"] }
diesel = { version = "2", features = ["postgres", "numeric"] }
dotenv = "0.15.0"
bigdecimal = "0.1.2"
\ No newline at end of file
bigdecimal = "0.1.2"
reqwest = { version = "0.11", features = ["json"] }
serde_json = "1.0"
rusttype = "0.4.3"
rand = "0.8"
itertools = "0.10"
anyhow = "1.0"
ollama-rs = "0.1.5"
......@@ -2,6 +2,21 @@
A Discord bot that reacts to certain messages with certain emojis. More importantly, it also lets you write code!
## Install
- Setup PostgreSQL
- Install via package manager `apt install postgresql`. Also get `libpq-dev` if you don't want a miserable life.
- `sudo su postgres` to login to default database user
- `createdb cat_disruptor_6500` to create the database
- `psql` and then `ALTER ROLE postgres WITH PASSWORD 'password';`
- Update the database environment variable (`DATABASE_URL=postgres://postgres:password@localhost/cat_disruptor_6500`)
- Create a new [Discord bot](https://discord.com/developers/applications)
- Enable all Privileged Gateway Intents
- Go to OAuth2 menu, URL generator, click "bot" and "administrator"
- Open link to add bot to server
- Create a `config.toml` file in project directory with `token=<your discord bot token>`
- Generate token on your bot page
## Commands:
!START - start the interpreter
......@@ -14,6 +29,14 @@ SAVE filename - save code into database
LOAD filename - load filename into memory
DIR - list programs you have saved
PUB filename - Publish a program
PUBDIR - List published programs
PUBLOAD id - Load public program by ID (IDs are listed in PUBDIR)
Code is specified by writing `line_num code`. For example, `10 print "hello world"`. This makes it easy to insert new code or overwrite lines.
### Example session
......@@ -62,4 +85,4 @@ The interpreter can execute arbitrary code safely - it includes CPU and memory l
It even does graphics!
![Example session](https://git.scd31.com/stephen/cat-disruptor-6500/raw/branch/master/example.png)
\ No newline at end of file
![Example session](example.png)
DROP TABLE server_settings;
CREATE TABLE server_settings (
id SERIAL PRIMARY KEY NOT NULL,
guild_id NUMERIC NOT NULL UNIQUE,
starboard_threshold INTEGER NOT NULL,
starboard_emoji_id NUMERIC NOT NULL,
starboard_channel NUMERIC NOT NULL
);
DROP TABLE starboard_mappings;
CREATE TABLE starboard_mappings (
original_id NUMERIC NOT NULL UNIQUE PRIMARY KEY,
repost_id NUMERIC UNIQUE
);
tab_spaces = 4
hard_tabs = true
\ No newline at end of file
hard_tabs = true
edition = "2018"
File added
use serde::Deserialize;
use std::fs;
use std::{
collections::{HashMap, HashSet},
fs,
};
#[derive(Deserialize)]
pub struct LlamaConfig {
pub(crate) address: String,
pub(crate) port: u16,
pub(crate) models: HashMap<String, String>,
#[serde(default)]
pub(crate) channels: HashSet<u64>,
}
#[derive(Deserialize)]
pub struct Config {
pub(crate) token: String,
pub(crate) llama: Option<LlamaConfig>,
}
pub fn get_conf() -> Config {
......
use crate::handlers::LineHandler;
use serenity::async_trait;
use serenity::model::channel::{Message, ReactionType};
use serenity::prelude::*;
use std::str::FromStr;
#[derive(Default)]
pub struct HorseHandler;
#[async_trait]
impl LineHandler for HorseHandler {
async fn message(&self, ctx: &Context, msg: &Message) {
let reaction = match ReactionType::from_str("🐴") {
Ok(x) => x,
Err(_) => return,
};
if rand::random::<f64>() <= 0.001 {
let _ = msg.react(&ctx, reaction).await;
}
}
}
use crate::handlers::LineHandler;
use crate::joker::tell_joke;
use serenity::async_trait;
use serenity::model::channel::Message;
use serenity::prelude::*;
pub struct JokeHandler;
#[async_trait]
impl LineHandler for JokeHandler {
async fn line(&self, ctx: &Context, msg: &Message, line: &str) {
if line == "!JOKE" {
match tell_joke().await {
Some(s) => msg.channel_id.say(&ctx, s).await.unwrap(),
None => msg
.channel_id
.say(&ctx, "There was an error while fetching a joke.")
.await
.unwrap(),
};
}
}
}
impl Default for JokeHandler {
fn default() -> Self {
Self
}
}
use super::LineHandler;
use crate::config::LlamaConfig;
use anyhow::{anyhow, Context as _};
use itertools::Itertools;
use ollama_rs::{
generation::completion::{request::GenerationRequest, GenerationContext},
Ollama,
};
use serenity::{
async_trait,
http::Typing,
model::prelude::{Message, MessageId},
prelude::*,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
pub struct LlamaHandler {
ollama: Ollama,
contexts: Arc<Mutex<HashMap<MessageId, (String, GenerationContext)>>>,
models: HashMap<String, String>,
channel_ids: HashSet<u64>,
}
impl LlamaHandler {
pub fn new(lc: LlamaConfig) -> Self {
Self {
ollama: Ollama::new(lc.address, lc.port),
contexts: Arc::new(Mutex::new(HashMap::new())),
models: lc.models,
channel_ids: lc.channels,
}
}
async fn call_llama(
&self,
model: &str,
prompt: &str,
context: Option<GenerationContext>,
) -> anyhow::Result<(String, GenerationContext)> {
let mut req = GenerationRequest::new(model.into(), prompt.into());
if let Some(c) = context {
req = req.context(c);
}
let resp = self
.ollama
.generate(req)
.await
.map_err(|x| anyhow!("{x}"))?;
let context = resp.final_data.context("Missing final data")?.context;
Ok((resp.response, context))
}
async fn list_models(&self, ctx: &Context, msg: &Message) {
let people = self
.models
.keys()
.map(|x| format!("- {x}"))
.collect::<Vec<_>>()
.join("\n");
if let Err(e) = msg
.reply(ctx, format!("Available models:\n{}", people))
.await
{
eprintln!("{:?}", e);
}
}
async fn reply(
&self,
ctx: &Context,
msg: &Message,
model: &str,
txt: &str,
context: Option<GenerationContext>,
) {
if txt.is_empty() {
return;
}
let _typing = try_or_log(|| Typing::start(ctx.http.clone(), msg.channel_id.0));
let resp = self.call_llama(model, txt, context).await;
let (resp, context) =
match resp {
Ok(x) => x,
Err(e) => {
eprintln!("{e:?}");
if let Err(e) = msg.reply(
ctx,
"Could not communicate with Llama. Check the server logs for more details.",
).await {
eprintln!("{e:?}");
};
return;
}
};
let resp = if resp.is_empty() {
"[No response]"
} else {
&resp
};
// discord messages are limited to 2000 codepoints
let chunks: Vec<String> = resp
.chars()
.chunks(2000)
.into_iter()
.map(|chunk| chunk.collect())
.collect();
let mut first = true;
for chunk in chunks {
let res = if first {
msg.reply(ctx, chunk).await
} else {
msg.channel_id.send_message(ctx, |m| m.content(chunk)).await
};
first = false;
match res {
Ok(x) => {
self.contexts
.lock()
.await
.insert(x.id, (model.to_string(), context.clone()));
}
Err(e) => {
eprintln!("{e:?}");
break;
}
}
}
}
}
#[async_trait]
impl LineHandler for LlamaHandler {
async fn message(&self, ctx: &Context, msg: &Message) {
if !self.channel_ids.contains(&msg.channel_id.0) {
return;
}
let txt = &msg.content;
if let Some(p) = &msg.referenced_message {
let x = {
let l = self.contexts.lock().await;
l.get(&p.id).cloned()
};
if let Some((model, context)) = x {
self.reply(ctx, msg, &model, txt, Some(context)).await;
return;
}
}
if txt.starts_with("!people") {
self.list_models(ctx, msg).await;
return;
}
for (name, model) in &self.models {
if let Some(txt) = txt.strip_prefix(&format!("!{name} ")) {
self.reply(ctx, msg, model, txt, None).await;
return;
}
}
}
}
fn try_or_log<T, E: std::fmt::Debug, F: Fn() -> Result<T, E>>(f: F) -> Result<T, E> {
let res = f();
if let Err(e) = res.as_ref() {
eprintln!("{e:?}");
}
res
}
mod horse;
mod joke;
mod llama;
mod react;
mod starboard;
mod sus;
mod xbasic;
use crate::config::LlamaConfig;
use crate::handlers::horse::HorseHandler;
use crate::handlers::joke::*;
use crate::handlers::react::*;
use crate::handlers::starboard::StarboardHandler;
use crate::handlers::sus::*;
use crate::handlers::xbasic::*;
use diesel::{Connection, PgConnection};
use dotenv::dotenv;
use serenity::async_trait;
use serenity::model::channel::Message;
use serenity::model::prelude::{Reaction, Ready};
use serenity::prelude::*;
use std::env;
use std::sync::Arc;
use tokio::sync::Mutex;
use self::llama::LlamaHandler;
#[async_trait]
pub(crate) trait LineHandler: Send + Sync {
async fn message(&self, ctx: &Context, msg: &Message) {
for line in msg.content.split('\n') {
self.line(ctx, msg, line).await
}
}
async fn line(&self, _ctx: &Context, _msg: &Message, _line: &str) {}
}
#[async_trait]
pub(crate) trait ReactionHandler: Send + Sync {
async fn reaction_add(&self, ctx: &Context, reaction: &Reaction);
async fn reaction_del(&self, ctx: &Context, reaction: &Reaction);
}
pub(crate) struct Dispatcher {
handlers: Vec<Box<dyn LineHandler>>,
reacts: Vec<Box<dyn ReactionHandler>>,
}
#[async_trait]
impl EventHandler for Dispatcher {
async fn message(&self, ctx: Context, msg: Message) {
for h in &self.handlers {
h.message(&ctx, &msg).await;
}
}
async fn reaction_add(&self, ctx: Context, reaction: Reaction) {
for r in &self.reacts {
r.reaction_add(&ctx, &reaction).await;
}
}
async fn reaction_remove(&self, ctx: Context, reaction: Reaction) {
for r in &self.reacts {
r.reaction_del(&ctx, &reaction).await;
}
}
async fn ready(&self, _: Context, ready: Ready) {
println!("{} is connected", ready.user.name);
}
}
impl Dispatcher {
pub fn new(llama_config: Option<LlamaConfig>) -> Self {
let conn = Arc::new(Mutex::new(establish_connection()));
let mut handlers: Vec<Box<dyn LineHandler>> = vec![
Box::new(XbasicHandler::new(conn.clone())),
Box::<JokeHandler>::default(),
Box::<ReactHandler>::default(),
Box::<SusHandler>::default(),
Box::<HorseHandler>::default(),
];
if let Some(lc) = llama_config {
handlers.push(Box::new(LlamaHandler::new(lc)));
}
Self {
handlers,
reacts: vec![Box::new(StarboardHandler::new(conn))],
}
}
}
fn establish_connection() -> PgConnection {
dotenv().ok();
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
PgConnection::establish(&database_url).expect("Error connecting to database")
}
use crate::handlers::LineHandler;
use itertools::Itertools;
use phf::phf_map;
use serenity::async_trait;
use serenity::model::channel::{Message, ReactionType};
use serenity::prelude::*;
use std::str::FromStr;
static EMOJI_MAP: phf::Map<&'static str, &'static str> = phf_map! {
"cat" => "🐈",
"chicken" => "🐔",
"spaghetti" => "🍝",
"dog" => "🐕",
"bot" => "🤖",
"mango" => "🥭",
"banana" => "🍌",
"bee" => "🐝",
"horse" => "🐎",
"hat" => "🎩"
};
fn map_lookup(msg: &str) -> Option<&'static str> {
// We lose the O(1) benefits of the hashmap
// But whatever. It doesn't need to be fast
for (k, v) in EMOJI_MAP.entries() {
if &msg == k || msg == format!("{k}s") {
return Some(v);
}
}
None
}
#[derive(Default)]
pub struct ReactHandler;
#[async_trait]
impl LineHandler for ReactHandler {
async fn message(&self, ctx: &Context, msg: &Message) {
// Kind of convoluted, but allows us to react in the correct order.
let groups = msg
.content
.chars()
.map(|c| match c {
'!' | '?' | ',' | '.' | '(' | ')' | '[' | ']' | '\n' | '\r' => ' ',
_ => c,
})
.group_by(|c| *c == ' ');
let reacts: Vec<_> = groups
.into_iter()
.filter_map(|(_, c)| map_lookup(&c.collect::<String>().to_lowercase()))
.unique()
.collect();
for r in reacts {
let reaction_type = match ReactionType::from_str(r) {
Ok(x) => x,
Err(x) => {
eprintln!("Could not react: {x}");
return;
}
};
if let Err(e) = msg.react(&ctx, reaction_type).await {
eprintln!("Error reacting: {e}");
}
}
}
}
use crate::handlers::ReactionHandler;
use crate::models::ServerSetting;
use crate::schema;
use crate::schema::{server_settings, starboard_mappings};
use anyhow::Context as AnyhowContext;
use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive};
use diesel::{ExpressionMethods, PgConnection, QueryDsl, RunQueryDsl};
use serenity::async_trait;
use serenity::builder::CreateEmbed;
use serenity::model::prelude::{
ChannelId, Emoji, EmojiId, Message, MessageReaction, Reaction, ReactionType,
};
use serenity::prelude::*;
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct SingleMessageHandler {
conn: Arc<Mutex<PgConnection>>,
server_settings: ServerSetting,
reaction_count: u64,
msg: Message,
emoji: Emoji,
name: String,
image: Option<String>,
}
impl SingleMessageHandler {
pub async fn handle_reaction(
conn: Arc<Mutex<PgConnection>>,
ctx: &Context,
reaction: &Reaction,
) -> anyhow::Result<()> {
let guild = match reaction.channel_id.to_channel(ctx).await?.guild() {
Some(x) => x,
None => return Ok(()),
};
// get corresponding guild settings
let mut gs: Vec<ServerSetting> = server_settings::dsl::server_settings
.filter(
schema::server_settings::columns::guild_id
.eq(BigDecimal::from_u64(guild.guild_id.0).unwrap()),
)
.limit(1)
.get_results(conn.lock().await.deref_mut())?;
let gs = match gs.pop() {
Some(x) => x,
None => return Ok(()),
};
let emoji = EmojiId(
gs.starboard_emoji_id
.to_u64()
.context("Could not convert emoji id to u64")?,
);
let msg = reaction.message(ctx).await?;
// reaction from event handler must match
// otherwise we'll update the repost (and add an "edited"
// to the message)
// whenever someone reacts with any reaction
if !Self::emoji_match(&reaction.emoji, emoji) {
return Ok(());
}
let reaction_count = Self::find_emoji_match(msg.reactions.iter(), emoji);
let guild_id = guild.guild_id;
let emoji = guild_id
.emoji(ctx, emoji)
.await
.context("Could not get emoji from guild")?;
let name = msg
.author
.nick_in(ctx, guild_id)
.await
.unwrap_or_else(|| msg.author.tag());
let image = msg
.attachments
.iter()
.filter(|a| a.width.is_some())
.map(|a| &a.url)
.next()
.cloned();
let handler = Self {
conn,
server_settings: gs,
reaction_count,
msg,
emoji,
name,
image,
};
handler.process_match(ctx).await?;
Ok(())
}
async fn process_match(&self, ctx: &Context) -> anyhow::Result<()> {
let original_id = BigDecimal::from(self.msg.id.0);
diesel::insert_into(starboard_mappings::dsl::starboard_mappings)
.values(starboard_mappings::columns::original_id.eq(&original_id))
.returning(starboard_mappings::columns::repost_id)
.on_conflict_do_nothing()
.execute(self.conn.lock().await.deref_mut())?;
let repost_id = starboard_mappings::dsl::starboard_mappings
.filter(starboard_mappings::columns::original_id.eq(&original_id))
.select(starboard_mappings::columns::repost_id)
.limit(1)
.get_results(self.conn.lock().await.deref_mut())?;
let repost_id: &Option<BigDecimal> =
repost_id.first().context("Insert of mapping failed")?;
match repost_id {
Some(id) => {
self.edit_existing_starboard_message(
ctx,
id.to_u64()
.context("Could not convert repost message id to a u64")?,
)
.await?;
}
None => {
if self.reaction_count >= self.server_settings.starboard_threshold as u64 {
// post to repost
let repost = self.post_new_starboard_message(ctx).await?;
// update the DB
let repost_id = BigDecimal::from_u64(repost);
diesel::update(
starboard_mappings::dsl::starboard_mappings
.filter(starboard_mappings::columns::original_id.eq(original_id)),
)
.set(starboard_mappings::columns::repost_id.eq(repost_id))
.execute(self.conn.lock().await.deref_mut())?;
}
}
}
Ok(())
}
async fn post_new_starboard_message(&self, ctx: &Context) -> anyhow::Result<u64> {
let repost = ChannelId(
self.server_settings
.starboard_channel
.to_u64()
.context("Could not convert starboard channel to a u64")?,
)
.send_message(ctx, |m| m.embed(|e| self.clone().starboard_message(e)))
.await?;
Ok(repost.id.0)
}
async fn edit_existing_starboard_message(
&self,
ctx: &Context,
message_id: u64,
) -> anyhow::Result<()> {
let channel_id = ChannelId(
self.server_settings
.starboard_channel
.to_u64()
.context("Could not convert starboard channel to a u64")?,
);
let mut msg = channel_id.message(ctx, message_id).await?;
msg.edit(ctx, |m| m.embed(|e| self.clone().starboard_message(e)))
.await?;
Ok(())
}
fn starboard_message(self, e: &mut CreateEmbed) -> &mut CreateEmbed {
let mut e = e
.description(format!(
"[Jump to source]({})\n{}",
self.msg.link(),
self.msg.content
))
.title(format!("{} {}", self.reaction_count, self.emoji))
.author(|a| a.name(&self.name).icon_url(self.msg.author.face()))
.timestamp(self.msg.timestamp);
if let Some(url) = self.image {
e = e.image(url);
}
e
}
fn find_emoji_match<'a, I: Iterator<Item = &'a MessageReaction>>(iter: I, em: EmojiId) -> u64 {
for mr in iter {
if Self::emoji_match(&mr.reaction_type, em) {
return mr.count;
}
}
0
}
fn emoji_match(rt: &ReactionType, em: EmojiId) -> bool {
matches!(rt, ReactionType::Custom { id, .. } if *id == em)
}
}
pub struct StarboardHandler {
conn: Arc<Mutex<PgConnection>>,
}
impl StarboardHandler {
pub fn new(conn: Arc<Mutex<PgConnection>>) -> Self {
Self { conn }
}
async fn handle_reaction(&self, ctx: &Context, reaction: &Reaction) -> anyhow::Result<()> {
SingleMessageHandler::handle_reaction(self.conn.clone(), ctx, reaction).await
}
}
#[async_trait]
impl ReactionHandler for StarboardHandler {
async fn reaction_add(&self, ctx: &Context, reaction: &Reaction) {
if let Err(e) = self.handle_reaction(ctx, reaction).await {
eprintln!("Error in starboard: {e:?}");
}
}
async fn reaction_del(&self, ctx: &Context, reaction: &Reaction) {
if let Err(e) = self.handle_reaction(ctx, reaction).await {
eprintln!("Error in starboard: {e:?}");
}
}
}
use crate::framebuffer::FrameBuffer;
use crate::handlers::LineHandler;
use rusttype::LayoutIter;
use rusttype::{point, FontCollection, Scale};
use serenity::async_trait;
use serenity::model::channel::{AttachmentType, Message};
use serenity::prelude::*;
use std::borrow::Cow;
const COLOUR: (u8, u8, u8) = (255, 0, 0);
const SCALE: Scale = Scale { x: 48.0, y: 48.0 };
const PADDING_X: u32 = 5;
const PADDING_Y: u32 = 5;
const MAX_WIDTH: i32 = 800;
const Y_GAP: i32 = 5;
// Wrap text when width exceeds MAX_WIDTH
struct WrappingLayoutIter<'a> {
iter: LayoutIter<'a, 'a>,
offset_x: i32,
offset_y: i32,
cur_x: i32,
cur_y: i32,
}
impl<'a> WrappingLayoutIter<'a> {
fn new(iter: LayoutIter<'a, 'a>, offset_x: i32, offset_y: i32) -> Self {
Self {
iter,
offset_x,
offset_y,
cur_x: 0,
cur_y: 0,
}
}
}
impl Iterator for WrappingLayoutIter<'_> {
type Item = Vec<(i32, i32, f32)>;
fn next(&mut self) -> Option<Self::Item> {
let (glyph, bounding_box) = loop {
let glyph = self.iter.next()?;
if let Some(bb) = glyph.pixel_bounding_box() {
break (glyph, bb);
}
};
if bounding_box.max.x + self.cur_x > MAX_WIDTH {
self.cur_x = -bounding_box.min.x;
self.cur_y += bounding_box.max.y - bounding_box.min.y + Y_GAP;
}
let mut buf = Vec::new();
glyph.draw(|x, y, a| {
let x = x as i32 + bounding_box.min.x + self.offset_x + self.cur_x;
let y = y as i32 + bounding_box.min.y + self.offset_y + self.cur_y;
buf.push((x, y, a));
});
Some(buf)
}
}
#[derive(Default)]
pub struct SusHandler;
impl SusHandler {
async fn render_font(&self, ctx: &Context, msg: &Message, s: &str) {
let font_data = include_bytes!("../amongus.ttf");
let collection = FontCollection::from_bytes(font_data as &[u8]);
let font = collection.into_font().unwrap();
let start = point(0.0, 0.0);
// Find image dimensions
let mut min_x = 0;
let mut max_x = 0;
let mut max_y = 0;
let mut min_y = 0;
for arr in WrappingLayoutIter::new(font.layout(s, SCALE, start), 0, 0) {
for (x, y, _) in arr {
if x < min_x {
min_x = x;
}
if x > max_x {
max_x = x;
}
if y < min_y {
min_y = y;
}
if y > max_y {
max_y = y;
}
}
}
let offset_x = -min_x;
let offset_y = -min_y;
max_x += offset_x;
max_y += offset_y;
let mut framebuffer =
FrameBuffer::new(max_x as u32 + PADDING_X * 2, max_y as u32 + PADDING_Y * 2);
for arr in WrappingLayoutIter::new(
font.layout(s, SCALE, start),
offset_x + PADDING_X as i32,
offset_y + PADDING_Y as i32,
) {
for (x, y, a) in arr {
framebuffer.set_pixel(
x as u32,
y as u32,
COLOUR.0,
COLOUR.1,
COLOUR.2,
(a * 255.0) as u8,
);
}
}
let buf = framebuffer.as_png_vec();
msg.channel_id
.send_message(&ctx, |e| {
e.add_file(AttachmentType::Bytes {
data: Cow::Borrowed(&buf),
filename: "output.png".to_string(),
});
e
})
.await
.unwrap();
}
}
#[async_trait]
impl LineHandler for SusHandler {
async fn line(&self, ctx: &Context, msg: &Message, line: &str) {
if line.starts_with("!SUS") {
let s = line.split(' ').collect::<Vec<_>>()[1..]
.join(" ")
.to_uppercase();
if s.is_empty() {
return;
}
self.render_font(ctx, msg, &s).await;
}
}
}
use crate::framebuffer::FrameBuffer;
use crate::handlers::LineHandler;
use crate::program::Program;
use diesel::PgConnection;
use phf::phf_map;
use serenity::async_trait;
use serenity::http::AttachmentType;
use serenity::model::channel::{Message, ReactionType};
use serenity::model::channel::{AttachmentType, Message, ReactionType};
use serenity::model::id::UserId;
use serenity::model::prelude::Ready;
use serenity::prelude::*;
use std::borrow::{Borrow, Cow};
use std::borrow::Cow;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task;
use xbasic::basic_io::BasicIO;
use xbasic::expr::ExprValue;
use xbasic::xbasic::XBasicBuilder;
static EMOJI_MAP: phf::Map<&'static str, &'static str> = phf_map! {
"cat" => "🐈",
"chicken" => "🐔",
"spaghetti" => "🍝",
"dog" => "🐕",
"bot" => "🤖",
"mango" => "🥭",
"banana" => "🍌",
"bee" => "🐝"
};
struct DiscordIO {
s: String,
frame: Option<FrameBuffer>,
......@@ -54,20 +44,20 @@ impl BasicIO for DiscordIO {
macro_rules! get_user_programs {
($self: expr, $user_id: expr) => {
$self.programs.lock().unwrap().get_mut($user_id).unwrap()
$self.programs.lock().await.get_mut($user_id).unwrap()
};
}
pub(crate) struct Handler {
pub struct XbasicHandler {
programs: Arc<Mutex<HashMap<UserId, Program>>>,
conn: Arc<Mutex<PgConnection>>,
}
impl Handler {
pub fn new(conn: PgConnection) -> Self {
impl XbasicHandler {
pub fn new(conn: Arc<Mutex<PgConnection>>) -> Self {
Self {
programs: Arc::new(Mutex::new(HashMap::new())),
conn: Arc::new(Mutex::new(conn)),
conn,
}
}
......@@ -75,10 +65,10 @@ impl Handler {
// TODO we lock the mutex to check, but unlock before locking again later
// allows another thread to screw it up
// we should lock once for this entire function
if self.programs.lock().unwrap().contains_key(&msg.author.id) {
if self.programs.lock().await.contains_key(&msg.author.id) {
match line {
"!STOP" => {
self.interpreter_stop(msg);
self.interpreter_stop(msg).await;
}
"RUN" => {
self.interpreter_run(msg, ctx).await;
......@@ -117,7 +107,7 @@ impl Handler {
if let Some(name) = line.strip_prefix("PUB ") {
if self.publish_program(name, msg, ctx).await.is_none() {
msg.channel_id
.say(&ctx, format!("Could not publish {}.", name))
.say(&ctx, format!("Could not publish {name}."))
.await
.unwrap();
}
......@@ -126,7 +116,7 @@ impl Handler {
if let Some(name) = line.strip_prefix("UNPUB ") {
if self.unpublish_program(name, msg, ctx).await.is_none() {
msg.channel_id
.say(&ctx, &format!("Could not unpublish {}.", name))
.say(&ctx, &format!("Could not unpublish {name}."))
.await
.unwrap();
}
......@@ -137,14 +127,14 @@ impl Handler {
Ok(id) => {
if self.load_published_program(msg, ctx, id).await.is_none() {
msg.channel_id
.say(&ctx, format!("Could not load {}.", id))
.say(&ctx, format!("Could not load {id}."))
.await
.unwrap();
}
}
Err(_) => {
msg.channel_id
.say(&ctx, "PUBLOAD requires a numerical ID.")
.say(&ctx, "Error: PUBLOAD requires a numerical ID.")
.await
.unwrap();
}
......@@ -158,16 +148,16 @@ impl Handler {
Some(x) => {
if x.is_empty() {
let _ =
get_user_programs!(&self, &msg.author.id).code.remove(&num);
get_user_programs!(self, &msg.author.id).code.remove(&num);
return;
}
get_user_programs!(&self, &msg.author.id)
get_user_programs!(self, &msg.author.id)
.code
.insert(num, x.to_owned());
}
None => {
let _ = get_user_programs!(&self, &msg.author.id).code.remove(&num);
let _ = get_user_programs!(self, &msg.author.id).code.remove(&num);
}
}
}
......@@ -176,28 +166,24 @@ impl Handler {
}
if line == "!START" {
self.interpreter_start(msg);
self.interpreter_start(msg).await;
}
}
fn interpreter_stop(&self, msg: &Message) {
self.programs
.lock()
.unwrap()
.remove(&msg.author.id)
.unwrap();
async fn interpreter_stop(&self, msg: &Message) {
self.programs.lock().await.remove(&msg.author.id).unwrap();
}
fn interpreter_start(&self, msg: &Message) {
async fn interpreter_start(&self, msg: &Message) {
self.programs
.lock()
.unwrap()
.await
.insert(msg.author.id, Program::new());
}
async fn list_saved_programs(&self, msg: &Message, ctx: &Context) -> Option<()> {
let program_names =
Program::list_programs_by_user(self.conn.lock().ok()?.borrow(), msg.author.id)?;
Program::list_programs_by_user(self.conn.lock().await.deref_mut(), msg.author.id)?;
msg.channel_id
.say(
&ctx,
......@@ -215,7 +201,7 @@ impl Handler {
async fn list_published_programs(&self, msg: &Message, ctx: &Context) -> Option<()> {
let program_names: Vec<String> =
Program::list_published_programs(self.conn.lock().ok()?.borrow())?
Program::list_published_programs(self.conn.lock().await.deref_mut())?
.iter()
.map(|row| format!("{}\t{}", row.0, row.1))
.collect();
......@@ -235,10 +221,15 @@ impl Handler {
}
async fn publish_program(&self, name: &str, msg: &Message, ctx: &Context) -> Option<()> {
Program::set_program_published(self.conn.lock().ok()?.borrow(), name, msg.author.id, true)?;
Program::set_program_published(
self.conn.lock().await.deref_mut(),
name,
msg.author.id,
true,
)?;
msg.channel_id
.say(&ctx, format!("Published {}.", name))
.say(&ctx, format!("Published {name}."))
.await
.unwrap();
......@@ -247,14 +238,14 @@ impl Handler {
async fn unpublish_program(&self, name: &str, msg: &Message, ctx: &Context) -> Option<()> {
Program::set_program_published(
self.conn.lock().ok()?.borrow(),
self.conn.lock().await.deref_mut(),
name,
msg.author.id,
false,
)?;
msg.channel_id
.say(&ctx, format!("Unpublished {}.", name))
.say(&ctx, format!("Unpublished {name}."))
.await
.unwrap();
......@@ -262,11 +253,11 @@ impl Handler {
}
async fn load_published_program(&self, msg: &Message, ctx: &Context, id: i32) -> Option<()> {
let name = get_user_programs!(&self, &msg.author.id)
.load_published_program(&self.conn.lock().ok()?.borrow(), id)?;
let name = get_user_programs!(self, &msg.author.id)
.load_published_program(self.conn.lock().await.deref_mut(), id)?;
msg.channel_id
.say(&ctx, format!("Loaded {} (\"{}\") into memory.", id, name))
.say(&ctx, format!("Loaded {id} (\"{name}\") into memory."))
.await
.unwrap();
......@@ -279,7 +270,7 @@ impl Handler {
&ctx,
format!(
"```\n{}\n```",
self.programs.lock().unwrap()[&msg.author.id].stringy_line_nums()
self.programs.lock().await[&msg.author.id].stringy_line_nums()
),
)
.await
......@@ -287,15 +278,15 @@ impl Handler {
}
async fn interpreter_load(&self, name: &str, msg: &Message, ctx: &Context) {
let result = get_user_programs!(&self, &msg.author.id).load_program(
&self.conn.lock().unwrap(),
let result = get_user_programs!(self, &msg.author.id).load_program(
self.conn.lock().await.deref_mut(),
msg.author.id,
name,
);
match result {
Some(_) => {
msg.channel_id
.say(&ctx, format!("Loaded {} into memory.", name))
.say(&ctx, format!("Loaded {name} into memory."))
.await
.unwrap();
}
......@@ -309,15 +300,15 @@ impl Handler {
}
async fn interpreter_save(&self, name: &str, msg: &Message, ctx: &Context) {
let result = get_user_programs!(&self, &msg.author.id).save_program(
&self.conn.lock().unwrap(),
let result = get_user_programs!(self, &msg.author.id).save_program(
self.conn.lock().await.deref_mut(),
msg.author.id,
name,
);
match result {
Some(_) => {
msg.channel_id
.say(&ctx, format!("Saved as {}", name))
.say(&ctx, format!("Saved as {name}"))
.await
.unwrap();
}
......@@ -336,7 +327,7 @@ impl Handler {
Err(_) => None,
};
let code = self.programs.lock().unwrap()[&msg.author.id].stringify();
let code = self.programs.lock().await[&msg.author.id].stringify();
let io = DiscordIO::new();
let (output, fb, errors) = task::spawn_blocking(move || {
......@@ -360,11 +351,8 @@ impl Handler {
let green = args[3].clone().into_decimal() as u8;
let blue = args[4].clone().into_decimal() as u8;
match &mut io.frame {
Some(fb) => {
fb.set_pixel(x, y, red, green, blue, 255);
}
None => {}
if let Some(fb) = &mut io.frame {
fb.set_pixel(x, y, red, green, blue, 255);
}
ExprValue::Decimal(0.0)
......@@ -373,7 +361,7 @@ impl Handler {
let mut xb = xbb.build();
let _ = xb.run(&format!("{}\n", code));
let _ = xb.run(&format!("{code}\n"));
let errors = if xb.error_handler.had_errors || xb.error_handler.had_runtime_error {
Some(xb.error_handler.errors.join("\n"))
......@@ -422,32 +410,8 @@ impl Handler {
}
#[async_trait]
impl EventHandler for Handler {
async fn message(&self, ctx: Context, msg: Message) {
for (key, value) in EMOJI_MAP.entries() {
let msg_lower = format!(" {} ", msg.content.to_lowercase());
if msg_lower.contains(&format!(" {} ", key))
|| msg_lower.contains(&format!(" {}s ", key))
{
let reaction_type = match ReactionType::from_str(value) {
Ok(x) => x,
Err(x) => {
println!("Could not react: {}", x);
return;
}
};
if let Err(e) = msg.react(&ctx, reaction_type).await {
println!("Error reacting: {}", e);
}
}
}
for line in msg.content.split('\n') {
self.interpret_line(&msg, &ctx, line).await;
}
}
async fn ready(&self, _: Context, ready: Ready) {
println!("{} is connected", ready.user.name);
impl LineHandler for XbasicHandler {
async fn line(&self, ctx: &Context, msg: &Message, line: &str) {
self.interpret_line(msg, ctx, line).await;
}
}
use serde_json::Value;
pub async fn tell_joke() -> Option<String> {
let joke: Value = reqwest::get("https://v2.jokeapi.dev/joke/Programming")
.await
.ok()?
.json()
.await
.ok()?;
match joke["type"].as_str()? {
"single" => Some(joke["joke"].as_str()?.to_owned()),
"twopart" => Some(format!(
"{}\r\n\r\n{}",
joke["setup"].as_str()?,
joke["delivery"].as_str()?
)),
_ => None,
}
}