From 5a583b039b465b44bcd58be84404033c1a21126e Mon Sep 17 00:00:00 2001 From: Stephen D <webmaster@scd31.com> Date: Fri, 18 Aug 2023 12:58:02 -0300 Subject: [PATCH] llama --- Cargo.lock | 49 ++++++++++++++++ Cargo.toml | 1 + src/config.rs | 10 +++- src/handlers/llama.rs | 129 ++++++++++++++++++++++++++++++++++++++++++ src/handlers/mod.rs | 28 ++++++--- src/main.rs | 2 +- 6 files changed, 208 insertions(+), 11 deletions(-) create mode 100644 src/handlers/llama.rs diff --git a/Cargo.lock b/Cargo.lock index b5fd1c4..183bcc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,12 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db" +[[package]] +name = "cargo-husky" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad" + [[package]] name = "cat-disruptor-6500" version = "0.2.0" @@ -125,6 +131,7 @@ dependencies = [ "png", "rand 0.8.5", "reqwest", + "reqwest-streams", "rusttype", "serde", "serde_json", @@ -329,6 +336,7 @@ checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -351,12 +359,34 @@ version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" +[[package]] +name = "futures-executor" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" +[[package]] +name = "futures-macro" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" +dependencies = [ + "proc-macro2 1.0.46", + "quote 1.0.21", + "syn 1.0.102", +] + [[package]] name = "futures-sink" version = "0.3.24" @@ -378,6 +408,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1144,6 +1175,24 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-streams" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c79d1481ca294b5b0ba33f41df2ed704686531bc8892a5e30b96fd0a165c642f" +dependencies = [ + "async-trait", + "bytes", + "cargo-husky", + "futures", + "futures-util", + "reqwest", + "serde", + "serde_json", + "tokio", + "tokio-util", +] + [[package]] name = "ring" version = "0.16.20" diff --git a/Cargo.toml b/Cargo.toml index 42af7a6..7bf8742 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ rusttype = "0.4.3" rand = "0.8" itertools = "0.10" anyhow = "1.0" +reqwest-streams = { version = "0.3.0", features = ["json"] } diff --git a/src/config.rs b/src/config.rs index fb1c36c..30cdd5b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,9 +1,17 @@ use serde::Deserialize; -use std::fs; +use std::{collections::HashMap, fs}; + +#[derive(Deserialize)] +pub struct LlamaConfig { + pub(crate) address: String, + pub(crate) models: HashMap<String, String>, + pub(crate) channel: Option<u64>, +} #[derive(Deserialize)] pub struct Config { pub(crate) token: String, + pub(crate) llama: Option<LlamaConfig>, } pub fn get_conf() -> Config { diff --git a/src/handlers/llama.rs b/src/handlers/llama.rs new file mode 100644 index 0000000..38be5d9 --- /dev/null +++ b/src/handlers/llama.rs @@ -0,0 +1,129 @@ +use std::collections::HashMap; + +use crate::config::LlamaConfig; + +use super::LineHandler; + +use reqwest_streams::*; +use serde::{Deserialize, Serialize}; +use serenity::{ + async_trait, futures::StreamExt, http::Typing, model::prelude::Message, prelude::*, +}; + +#[derive(Serialize)] +pub struct GenerateReq<'a, 'b> { + model: &'a str, + prompt: &'b str, +} + +#[derive(Deserialize, Debug)] +pub struct GenerateResp { + #[serde(default = "String::new")] + response: String, +} + +pub struct LlamaHandler { + address: String, + models: HashMap<String, String>, + channel_id: Option<u64>, +} + +impl LlamaHandler { + pub fn new(lc: LlamaConfig) -> Self { + Self { + address: lc.address, + models: lc.models, + channel_id: lc.channel, + } + } + + async fn call_llama(&self, model: &str, prompt: &str) -> anyhow::Result<String> { + let client = reqwest::Client::new(); + let mut stream = client + .post(format!("{}/api/generate", self.address)) + .json(&GenerateReq { model, prompt }) + .send() + .await? + .json_array_stream::<GenerateResp>(65536); + + let mut resp = String::new(); + while let Some(r) = stream.next().await { + resp.push_str(&r?.response); + } + + Ok(resp) + } +} + +#[async_trait] +impl LineHandler for LlamaHandler { + async fn message(&self, ctx: &Context, msg: &Message) { + if let Some(cid) = self.channel_id { + if msg.channel_id.0 != cid { + return; + } + } + + let txt = &msg.content; + + if txt.starts_with("!people") { + let people = self + .models + .keys() + .map(|x| format!("- {x}")) + .collect::<Vec<_>>() + .join("\n"); + + if let Err(e) = msg + .reply(ctx, format!("Available models:\n{}", people)) + .await + { + eprintln!("{:?}", e); + } + + return; + } + + for (name, model) in &self.models { + if let Some(txt) = txt.strip_prefix(&format!("!{name} ")) { + if txt.is_empty() { + return; + } + + let _typing = try_or_log(|| Typing::start(ctx.http.clone(), msg.channel_id.0)); + + let resp = self.call_llama(model, txt).await; + let resp = match resp.as_ref() { + Ok(x) => x, + Err(e) => { + eprintln!("{e:?}"); + + "Could not communicate with Llama. Check the server logs for more details." + } + }; + + let resp = if resp.is_empty() { + "[No response]" + } else { + resp + }; + + if let Err(e) = msg.reply(ctx, resp).await { + eprintln!("{e:?}"); + } + + return; + } + } + } +} + +fn try_or_log<T, E: std::fmt::Debug, F: Fn() -> Result<T, E>>(f: F) -> Result<T, E> { + let res = f(); + + if let Err(e) = res.as_ref() { + eprintln!("{e:?}"); + } + + res +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index ceca6d7..265483f 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,10 +1,12 @@ mod horse; mod joke; +mod llama; mod react; mod starboard; mod sus; mod xbasic; +use crate::config::LlamaConfig; use crate::handlers::horse::HorseHandler; use crate::handlers::joke::*; use crate::handlers::react::*; @@ -22,6 +24,8 @@ use std::env; use std::sync::Arc; use tokio::sync::Mutex; +use self::llama::LlamaHandler; + #[async_trait] pub(crate) trait LineHandler: Send + Sync { async fn message(&self, ctx: &Context, msg: &Message) { @@ -69,18 +73,24 @@ impl EventHandler for Dispatcher { } } -impl Default for Dispatcher { - fn default() -> Self { +impl Dispatcher { + pub fn new(llama_config: Option<LlamaConfig>) -> Self { let conn = Arc::new(Mutex::new(establish_connection())); + let mut handlers: Vec<Box<dyn LineHandler>> = vec![ + Box::new(XbasicHandler::new(conn.clone())), + Box::<JokeHandler>::default(), + Box::<ReactHandler>::default(), + Box::<SusHandler>::default(), + Box::<HorseHandler>::default(), + ]; + + if let Some(lc) = llama_config { + handlers.push(Box::new(LlamaHandler::new(lc))); + } + Self { - handlers: vec![ - Box::new(XbasicHandler::new(conn.clone())), - Box::<JokeHandler>::default(), - Box::<ReactHandler>::default(), - Box::<SusHandler>::default(), - Box::<HorseHandler>::default(), - ], + handlers, reacts: vec![Box::new(StarboardHandler::new(conn))], } } diff --git a/src/main.rs b/src/main.rs index f0997ff..9b1f9aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ async fn main() { .union(GatewayIntents::GUILD_MESSAGE_REACTIONS) .union(GatewayIntents::MESSAGE_CONTENT), ) - .event_handler(Dispatcher::default()) + .event_handler(Dispatcher::new(config.llama)) .await .expect("Error creating client"); if let Err(e) = client.start().await { -- GitLab