From 5a583b039b465b44bcd58be84404033c1a21126e Mon Sep 17 00:00:00 2001
From: Stephen D <webmaster@scd31.com>
Date: Fri, 18 Aug 2023 12:58:02 -0300
Subject: [PATCH] llama

---
 Cargo.lock            |  49 ++++++++++++++++
 Cargo.toml            |   1 +
 src/config.rs         |  10 +++-
 src/handlers/llama.rs | 129 ++++++++++++++++++++++++++++++++++++++++++
 src/handlers/mod.rs   |  28 ++++++---
 src/main.rs           |   2 +-
 6 files changed, 208 insertions(+), 11 deletions(-)
 create mode 100644 src/handlers/llama.rs

diff --git a/Cargo.lock b/Cargo.lock
index b5fd1c4..183bcc1 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -112,6 +112,12 @@ version = "1.2.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db"
 
+[[package]]
+name = "cargo-husky"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad"
+
 [[package]]
 name = "cat-disruptor-6500"
 version = "0.2.0"
@@ -125,6 +131,7 @@ dependencies = [
  "png",
  "rand 0.8.5",
  "reqwest",
+ "reqwest-streams",
  "rusttype",
  "serde",
  "serde_json",
@@ -329,6 +336,7 @@ checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c"
 dependencies = [
  "futures-channel",
  "futures-core",
+ "futures-executor",
  "futures-io",
  "futures-sink",
  "futures-task",
@@ -351,12 +359,34 @@ version = "0.3.24"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
 
+[[package]]
+name = "futures-executor"
+version = "0.3.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab"
+dependencies = [
+ "futures-core",
+ "futures-task",
+ "futures-util",
+]
+
 [[package]]
 name = "futures-io"
 version = "0.3.24"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68"
 
+[[package]]
+name = "futures-macro"
+version = "0.3.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17"
+dependencies = [
+ "proc-macro2 1.0.46",
+ "quote 1.0.21",
+ "syn 1.0.102",
+]
+
 [[package]]
 name = "futures-sink"
 version = "0.3.24"
@@ -378,6 +408,7 @@ dependencies = [
  "futures-channel",
  "futures-core",
  "futures-io",
+ "futures-macro",
  "futures-sink",
  "futures-task",
  "memchr",
@@ -1144,6 +1175,24 @@ dependencies = [
  "winreg",
 ]
 
+[[package]]
+name = "reqwest-streams"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c79d1481ca294b5b0ba33f41df2ed704686531bc8892a5e30b96fd0a165c642f"
+dependencies = [
+ "async-trait",
+ "bytes",
+ "cargo-husky",
+ "futures",
+ "futures-util",
+ "reqwest",
+ "serde",
+ "serde_json",
+ "tokio",
+ "tokio-util",
+]
+
 [[package]]
 name = "ring"
 version = "0.16.20"
diff --git a/Cargo.toml b/Cargo.toml
index 42af7a6..7bf8742 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -23,3 +23,4 @@ rusttype = "0.4.3"
 rand = "0.8"
 itertools = "0.10"
 anyhow = "1.0"
+reqwest-streams = { version = "0.3.0", features = ["json"] }
diff --git a/src/config.rs b/src/config.rs
index fb1c36c..30cdd5b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,9 +1,17 @@
 use serde::Deserialize;
-use std::fs;
+use std::{collections::HashMap, fs};
+
+#[derive(Deserialize)]
+pub struct LlamaConfig {
+	pub(crate) address: String,
+	pub(crate) models: HashMap<String, String>,
+	pub(crate) channel: Option<u64>,
+}
 
 #[derive(Deserialize)]
 pub struct Config {
 	pub(crate) token: String,
+	pub(crate) llama: Option<LlamaConfig>,
 }
 
 pub fn get_conf() -> Config {
diff --git a/src/handlers/llama.rs b/src/handlers/llama.rs
new file mode 100644
index 0000000..38be5d9
--- /dev/null
+++ b/src/handlers/llama.rs
@@ -0,0 +1,129 @@
+use std::collections::HashMap;
+
+use crate::config::LlamaConfig;
+
+use super::LineHandler;
+
+use reqwest_streams::*;
+use serde::{Deserialize, Serialize};
+use serenity::{
+	async_trait, futures::StreamExt, http::Typing, model::prelude::Message, prelude::*,
+};
+
+#[derive(Serialize)]
+pub struct GenerateReq<'a, 'b> {
+	model: &'a str,
+	prompt: &'b str,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct GenerateResp {
+	#[serde(default = "String::new")]
+	response: String,
+}
+
+pub struct LlamaHandler {
+	address: String,
+	models: HashMap<String, String>,
+	channel_id: Option<u64>,
+}
+
+impl LlamaHandler {
+	pub fn new(lc: LlamaConfig) -> Self {
+		Self {
+			address: lc.address,
+			models: lc.models,
+			channel_id: lc.channel,
+		}
+	}
+
+	async fn call_llama(&self, model: &str, prompt: &str) -> anyhow::Result<String> {
+		let client = reqwest::Client::new();
+		let mut stream = client
+			.post(format!("{}/api/generate", self.address))
+			.json(&GenerateReq { model, prompt })
+			.send()
+			.await?
+			.json_array_stream::<GenerateResp>(65536);
+
+		let mut resp = String::new();
+		while let Some(r) = stream.next().await {
+			resp.push_str(&r?.response);
+		}
+
+		Ok(resp)
+	}
+}
+
+#[async_trait]
+impl LineHandler for LlamaHandler {
+	async fn message(&self, ctx: &Context, msg: &Message) {
+		if let Some(cid) = self.channel_id {
+			if msg.channel_id.0 != cid {
+				return;
+			}
+		}
+
+		let txt = &msg.content;
+
+		if txt.starts_with("!people") {
+			let people = self
+				.models
+				.keys()
+				.map(|x| format!("- {x}"))
+				.collect::<Vec<_>>()
+				.join("\n");
+
+			if let Err(e) = msg
+				.reply(ctx, format!("Available models:\n{}", people))
+				.await
+			{
+				eprintln!("{:?}", e);
+			}
+
+			return;
+		}
+
+		for (name, model) in &self.models {
+			if let Some(txt) = txt.strip_prefix(&format!("!{name} ")) {
+				if txt.is_empty() {
+					return;
+				}
+
+				let _typing = try_or_log(|| Typing::start(ctx.http.clone(), msg.channel_id.0));
+
+				let resp = self.call_llama(model, txt).await;
+				let resp = match resp.as_ref() {
+					Ok(x) => x,
+					Err(e) => {
+						eprintln!("{e:?}");
+
+						"Could not communicate with Llama. Check the server logs for more details."
+					}
+				};
+
+				let resp = if resp.is_empty() {
+					"[No response]"
+				} else {
+					resp
+				};
+
+				if let Err(e) = msg.reply(ctx, resp).await {
+					eprintln!("{e:?}");
+				}
+
+				return;
+			}
+		}
+	}
+}
+
+fn try_or_log<T, E: std::fmt::Debug, F: Fn() -> Result<T, E>>(f: F) -> Result<T, E> {
+	let res = f();
+
+	if let Err(e) = res.as_ref() {
+		eprintln!("{e:?}");
+	}
+
+	res
+}
diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs
index ceca6d7..265483f 100644
--- a/src/handlers/mod.rs
+++ b/src/handlers/mod.rs
@@ -1,10 +1,12 @@
 mod horse;
 mod joke;
+mod llama;
 mod react;
 mod starboard;
 mod sus;
 mod xbasic;
 
+use crate::config::LlamaConfig;
 use crate::handlers::horse::HorseHandler;
 use crate::handlers::joke::*;
 use crate::handlers::react::*;
@@ -22,6 +24,8 @@ use std::env;
 use std::sync::Arc;
 use tokio::sync::Mutex;
 
+use self::llama::LlamaHandler;
+
 #[async_trait]
 pub(crate) trait LineHandler: Send + Sync {
 	async fn message(&self, ctx: &Context, msg: &Message) {
@@ -69,18 +73,24 @@ impl EventHandler for Dispatcher {
 	}
 }
 
-impl Default for Dispatcher {
-	fn default() -> Self {
+impl Dispatcher {
+	pub fn new(llama_config: Option<LlamaConfig>) -> Self {
 		let conn = Arc::new(Mutex::new(establish_connection()));
 
+		let mut handlers: Vec<Box<dyn LineHandler>> = vec![
+			Box::new(XbasicHandler::new(conn.clone())),
+			Box::<JokeHandler>::default(),
+			Box::<ReactHandler>::default(),
+			Box::<SusHandler>::default(),
+			Box::<HorseHandler>::default(),
+		];
+
+		if let Some(lc) = llama_config {
+			handlers.push(Box::new(LlamaHandler::new(lc)));
+		}
+
 		Self {
-			handlers: vec![
-				Box::new(XbasicHandler::new(conn.clone())),
-				Box::<JokeHandler>::default(),
-				Box::<ReactHandler>::default(),
-				Box::<SusHandler>::default(),
-				Box::<HorseHandler>::default(),
-			],
+			handlers,
 			reacts: vec![Box::new(StarboardHandler::new(conn))],
 		}
 	}
diff --git a/src/main.rs b/src/main.rs
index f0997ff..9b1f9aa 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,7 +25,7 @@ async fn main() {
 			.union(GatewayIntents::GUILD_MESSAGE_REACTIONS)
 			.union(GatewayIntents::MESSAGE_CONTENT),
 	)
-	.event_handler(Dispatcher::default())
+	.event_handler(Dispatcher::new(config.llama))
 	.await
 	.expect("Error creating client");
 	if let Err(e) = client.start().await {
-- 
GitLab