Skip to content
Snippets Groups Projects
Commit 5a583b03 authored by Stephen D's avatar Stephen D
Browse files

llama

parent 43c0d6c3
No related branches found
No related tags found
1 merge request!16llama
Pipeline #2782 passed
......@@ -112,6 +112,12 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db"
[[package]]
name = "cargo-husky"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad"
[[package]]
name = "cat-disruptor-6500"
version = "0.2.0"
......@@ -125,6 +131,7 @@ dependencies = [
"png",
"rand 0.8.5",
"reqwest",
"reqwest-streams",
"rusttype",
"serde",
"serde_json",
......@@ -329,6 +336,7 @@ checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
......@@ -351,12 +359,34 @@ version = "0.3.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
[[package]]
name = "futures-executor"
version = "0.3.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68"
[[package]]
name = "futures-macro"
version = "0.3.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17"
dependencies = [
"proc-macro2 1.0.46",
"quote 1.0.21",
"syn 1.0.102",
]
[[package]]
name = "futures-sink"
version = "0.3.24"
......@@ -378,6 +408,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
......@@ -1144,6 +1175,24 @@ dependencies = [
"winreg",
]
[[package]]
name = "reqwest-streams"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c79d1481ca294b5b0ba33f41df2ed704686531bc8892a5e30b96fd0a165c642f"
dependencies = [
"async-trait",
"bytes",
"cargo-husky",
"futures",
"futures-util",
"reqwest",
"serde",
"serde_json",
"tokio",
"tokio-util",
]
[[package]]
name = "ring"
version = "0.16.20"
......
......@@ -23,3 +23,4 @@ rusttype = "0.4.3"
rand = "0.8"
itertools = "0.10"
anyhow = "1.0"
reqwest-streams = { version = "0.3.0", features = ["json"] }
use serde::Deserialize;
use std::fs;
use std::{collections::HashMap, fs};
#[derive(Deserialize)]
pub struct LlamaConfig {
pub(crate) address: String,
pub(crate) models: HashMap<String, String>,
pub(crate) channel: Option<u64>,
}
#[derive(Deserialize)]
pub struct Config {
pub(crate) token: String,
pub(crate) llama: Option<LlamaConfig>,
}
pub fn get_conf() -> Config {
......
use std::collections::HashMap;
use crate::config::LlamaConfig;
use super::LineHandler;
use reqwest_streams::*;
use serde::{Deserialize, Serialize};
use serenity::{
async_trait, futures::StreamExt, http::Typing, model::prelude::Message, prelude::*,
};
#[derive(Serialize)]
pub struct GenerateReq<'a, 'b> {
model: &'a str,
prompt: &'b str,
}
#[derive(Deserialize, Debug)]
pub struct GenerateResp {
#[serde(default = "String::new")]
response: String,
}
pub struct LlamaHandler {
address: String,
models: HashMap<String, String>,
channel_id: Option<u64>,
}
impl LlamaHandler {
pub fn new(lc: LlamaConfig) -> Self {
Self {
address: lc.address,
models: lc.models,
channel_id: lc.channel,
}
}
async fn call_llama(&self, model: &str, prompt: &str) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let mut stream = client
.post(format!("{}/api/generate", self.address))
.json(&GenerateReq { model, prompt })
.send()
.await?
.json_array_stream::<GenerateResp>(65536);
let mut resp = String::new();
while let Some(r) = stream.next().await {
resp.push_str(&r?.response);
}
Ok(resp)
}
}
#[async_trait]
impl LineHandler for LlamaHandler {
async fn message(&self, ctx: &Context, msg: &Message) {
if let Some(cid) = self.channel_id {
if msg.channel_id.0 != cid {
return;
}
}
let txt = &msg.content;
if txt.starts_with("!people") {
let people = self
.models
.keys()
.map(|x| format!("- {x}"))
.collect::<Vec<_>>()
.join("\n");
if let Err(e) = msg
.reply(ctx, format!("Available models:\n{}", people))
.await
{
eprintln!("{:?}", e);
}
return;
}
for (name, model) in &self.models {
if let Some(txt) = txt.strip_prefix(&format!("!{name} ")) {
if txt.is_empty() {
return;
}
let _typing = try_or_log(|| Typing::start(ctx.http.clone(), msg.channel_id.0));
let resp = self.call_llama(model, txt).await;
let resp = match resp.as_ref() {
Ok(x) => x,
Err(e) => {
eprintln!("{e:?}");
"Could not communicate with Llama. Check the server logs for more details."
}
};
let resp = if resp.is_empty() {
"[No response]"
} else {
resp
};
if let Err(e) = msg.reply(ctx, resp).await {
eprintln!("{e:?}");
}
return;
}
}
}
}
fn try_or_log<T, E: std::fmt::Debug, F: Fn() -> Result<T, E>>(f: F) -> Result<T, E> {
let res = f();
if let Err(e) = res.as_ref() {
eprintln!("{e:?}");
}
res
}
mod horse;
mod joke;
mod llama;
mod react;
mod starboard;
mod sus;
mod xbasic;
use crate::config::LlamaConfig;
use crate::handlers::horse::HorseHandler;
use crate::handlers::joke::*;
use crate::handlers::react::*;
......@@ -22,6 +24,8 @@ use std::env;
use std::sync::Arc;
use tokio::sync::Mutex;
use self::llama::LlamaHandler;
#[async_trait]
pub(crate) trait LineHandler: Send + Sync {
async fn message(&self, ctx: &Context, msg: &Message) {
......@@ -69,18 +73,24 @@ impl EventHandler for Dispatcher {
}
}
impl Default for Dispatcher {
fn default() -> Self {
impl Dispatcher {
pub fn new(llama_config: Option<LlamaConfig>) -> Self {
let conn = Arc::new(Mutex::new(establish_connection()));
let mut handlers: Vec<Box<dyn LineHandler>> = vec![
Box::new(XbasicHandler::new(conn.clone())),
Box::<JokeHandler>::default(),
Box::<ReactHandler>::default(),
Box::<SusHandler>::default(),
Box::<HorseHandler>::default(),
];
if let Some(lc) = llama_config {
handlers.push(Box::new(LlamaHandler::new(lc)));
}
Self {
handlers: vec![
Box::new(XbasicHandler::new(conn.clone())),
Box::<JokeHandler>::default(),
Box::<ReactHandler>::default(),
Box::<SusHandler>::default(),
Box::<HorseHandler>::default(),
],
handlers,
reacts: vec![Box::new(StarboardHandler::new(conn))],
}
}
......
......@@ -25,7 +25,7 @@ async fn main() {
.union(GatewayIntents::GUILD_MESSAGE_REACTIONS)
.union(GatewayIntents::MESSAGE_CONTENT),
)
.event_handler(Dispatcher::default())
.event_handler(Dispatcher::new(config.llama))
.await
.expect("Error creating client");
if let Err(e) = client.start().await {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment