diff --git a/Cargo.lock b/Cargo.lock index b5fd1c4d7ceb365965bf00f2801d3f7a0489fa7f..8458ebeac6bd86576c37f2853beb28595f00b26c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,9 +112,15 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db" +[[package]] +name = "cargo-husky" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad" + [[package]] name = "cat-disruptor-6500" -version = "0.2.0" +version = "0.3.0" dependencies = [ "anyhow", "bigdecimal", @@ -125,6 +131,7 @@ dependencies = [ "png", "rand 0.8.5", "reqwest", + "reqwest-streams", "rusttype", "serde", "serde_json", @@ -329,6 +336,7 @@ checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -351,12 +359,34 @@ version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" +[[package]] +name = "futures-executor" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" +[[package]] +name = "futures-macro" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" +dependencies = [ + "proc-macro2 1.0.46", + "quote 1.0.21", + "syn 1.0.102", +] + [[package]] name = "futures-sink" version = "0.3.24" @@ -378,6 +408,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1144,6 +1175,24 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-streams" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c79d1481ca294b5b0ba33f41df2ed704686531bc8892a5e30b96fd0a165c642f" +dependencies = [ + "async-trait", + "bytes", + "cargo-husky", + "futures", + "futures-util", + "reqwest", + "serde", + "serde_json", + "tokio", + "tokio-util", +] + [[package]] name = "ring" version = "0.16.20" diff --git a/Cargo.toml b/Cargo.toml index 42af7a659bea8f15454e6bf68da5597a851e245c..2c186bf8ba350a82b1d8686cb6657b747a74902f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cat-disruptor-6500" -version = "0.2.0" +version = "0.3.0" authors = ["Stephen <wemaster@scd31.com>"] edition = "2018" @@ -23,3 +23,4 @@ rusttype = "0.4.3" rand = "0.8" itertools = "0.10" anyhow = "1.0" +reqwest-streams = { version = "0.3.0", features = ["json"] } diff --git a/src/config.rs b/src/config.rs index fb1c36c0226ce8577a75f9cc10fc250985356c02..30cdd5bc933ed5774477c78a3bb9f6ffeed3f599 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,9 +1,17 @@ use serde::Deserialize; -use std::fs; +use std::{collections::HashMap, fs}; + +#[derive(Deserialize)] +pub struct LlamaConfig { + pub(crate) address: String, + pub(crate) models: HashMap<String, String>, + pub(crate) channel: Option<u64>, +} #[derive(Deserialize)] pub struct Config { pub(crate) token: String, + pub(crate) llama: Option<LlamaConfig>, } pub fn get_conf() -> Config { diff --git a/src/handlers/llama.rs b/src/handlers/llama.rs new file mode 100644 index 0000000000000000000000000000000000000000..8a83b816871a89db89aa6a54d58cfb8063cce575 --- /dev/null +++ b/src/handlers/llama.rs @@ -0,0 +1,148 @@ +use std::collections::HashMap; + +use crate::config::LlamaConfig; + +use super::LineHandler; + +use itertools::Itertools; +use reqwest_streams::*; +use serde::{Deserialize, Serialize}; +use serenity::{ + async_trait, futures::StreamExt, http::Typing, model::prelude::Message, prelude::*, +}; + +#[derive(Serialize)] +pub struct GenerateReq<'a, 'b> { + model: &'a str, + prompt: &'b str, +} + +#[derive(Deserialize, Debug)] +pub struct GenerateResp { + #[serde(default = "String::new")] + response: String, +} + +pub struct LlamaHandler { + address: String, + models: HashMap<String, String>, + channel_id: Option<u64>, +} + +impl LlamaHandler { + pub fn new(lc: LlamaConfig) -> Self { + Self { + address: lc.address, + models: lc.models, + channel_id: lc.channel, + } + } + + async fn call_llama(&self, model: &str, prompt: &str) -> anyhow::Result<String> { + let client = reqwest::Client::new(); + let mut stream = client + .post(format!("{}/api/generate", self.address)) + .json(&GenerateReq { model, prompt }) + .send() + .await? + .json_array_stream::<GenerateResp>(65536); + + let mut resp = String::new(); + while let Some(r) = stream.next().await { + resp.push_str(&r?.response); + } + + Ok(resp) + } + + async fn list_models(&self, ctx: &Context, msg: &Message) { + let people = self + .models + .keys() + .map(|x| format!("- {x}")) + .collect::<Vec<_>>() + .join("\n"); + + if let Err(e) = msg + .reply(ctx, format!("Available models:\n{}", people)) + .await + { + eprintln!("{:?}", e); + } + } + + async fn reply(&self, ctx: &Context, msg: &Message, model: &str, txt: &str) { + if txt.is_empty() { + return; + } + + let _typing = try_or_log(|| Typing::start(ctx.http.clone(), msg.channel_id.0)); + + let resp = self.call_llama(model, txt).await; + let resp = match resp.as_ref() { + Ok(x) => x, + Err(e) => { + eprintln!("{e:?}"); + + "Could not communicate with Llama. Check the server logs for more details." + } + }; + + let resp = if resp.is_empty() { + "[No response]" + } else { + resp + }; + + // discord messages are limited to 2000 codepoints + let chunks: Vec<String> = resp + .chars() + .chunks(2000) + .into_iter() + .map(|chunk| chunk.collect()) + .collect(); + + for chunk in chunks { + if let Err(e) = msg.reply(ctx, chunk).await { + eprintln!("{e:?}"); + } + } + } +} + +#[async_trait] +impl LineHandler for LlamaHandler { + async fn message(&self, ctx: &Context, msg: &Message) { + if let Some(cid) = self.channel_id { + if msg.channel_id.0 != cid { + return; + } + } + + let txt = &msg.content; + + if txt.starts_with("!people") { + self.list_models(ctx, msg).await; + + return; + } + + for (name, model) in &self.models { + if let Some(txt) = txt.strip_prefix(&format!("!{name} ")) { + self.reply(ctx, msg, model, txt).await; + + return; + } + } + } +} + +fn try_or_log<T, E: std::fmt::Debug, F: Fn() -> Result<T, E>>(f: F) -> Result<T, E> { + let res = f(); + + if let Err(e) = res.as_ref() { + eprintln!("{e:?}"); + } + + res +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index ceca6d7bb94c650a3e53a4d315109a1d90d86501..265483f5ee58fb1b1e5261bc23723e2d613ba70a 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,10 +1,12 @@ mod horse; mod joke; +mod llama; mod react; mod starboard; mod sus; mod xbasic; +use crate::config::LlamaConfig; use crate::handlers::horse::HorseHandler; use crate::handlers::joke::*; use crate::handlers::react::*; @@ -22,6 +24,8 @@ use std::env; use std::sync::Arc; use tokio::sync::Mutex; +use self::llama::LlamaHandler; + #[async_trait] pub(crate) trait LineHandler: Send + Sync { async fn message(&self, ctx: &Context, msg: &Message) { @@ -69,18 +73,24 @@ impl EventHandler for Dispatcher { } } -impl Default for Dispatcher { - fn default() -> Self { +impl Dispatcher { + pub fn new(llama_config: Option<LlamaConfig>) -> Self { let conn = Arc::new(Mutex::new(establish_connection())); + let mut handlers: Vec<Box<dyn LineHandler>> = vec![ + Box::new(XbasicHandler::new(conn.clone())), + Box::<JokeHandler>::default(), + Box::<ReactHandler>::default(), + Box::<SusHandler>::default(), + Box::<HorseHandler>::default(), + ]; + + if let Some(lc) = llama_config { + handlers.push(Box::new(LlamaHandler::new(lc))); + } + Self { - handlers: vec![ - Box::new(XbasicHandler::new(conn.clone())), - Box::<JokeHandler>::default(), - Box::<ReactHandler>::default(), - Box::<SusHandler>::default(), - Box::<HorseHandler>::default(), - ], + handlers, reacts: vec![Box::new(StarboardHandler::new(conn))], } } diff --git a/src/main.rs b/src/main.rs index f0997ffc7eaf5316e8781f28e9330b270fc3640c..9b1f9aa4e8b6a46e9382b50a784b5248527eaed7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ async fn main() { .union(GatewayIntents::GUILD_MESSAGE_REACTIONS) .union(GatewayIntents::MESSAGE_CONTENT), ) - .event_handler(Dispatcher::default()) + .event_handler(Dispatcher::new(config.llama)) .await .expect("Error creating client"); if let Err(e) = client.start().await {