From 79076b3bac8ff5b1e28b7f0e895ac0dd483c7423 Mon Sep 17 00:00:00 2001 From: Stephen D <webmaster@scd31.com> Date: Thu, 13 Jun 2024 19:23:20 -0300 Subject: [PATCH] llama multi-channel support --- src/config.rs | 8 ++++++-- src/handlers/llama.rs | 15 ++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/config.rs b/src/config.rs index eff6f53..e077fff 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,12 +1,16 @@ use serde::Deserialize; -use std::{collections::HashMap, fs}; +use std::{ + collections::{HashMap, HashSet}, + fs, +}; #[derive(Deserialize)] pub struct LlamaConfig { pub(crate) address: String, pub(crate) port: u16, pub(crate) models: HashMap<String, String>, - pub(crate) channel: Option<u64>, + #[serde(default)] + pub(crate) channels: HashSet<u64>, } #[derive(Deserialize)] diff --git a/src/handlers/llama.rs b/src/handlers/llama.rs index 23c84ca..7bd0dd7 100644 --- a/src/handlers/llama.rs +++ b/src/handlers/llama.rs @@ -13,14 +13,17 @@ use serenity::{ model::prelude::{Message, MessageId}, prelude::*, }; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; pub struct LlamaHandler { ollama: Ollama, contexts: Arc<Mutex<HashMap<MessageId, (String, GenerationContext)>>>, models: HashMap<String, String>, - channel_id: Option<u64>, + channel_ids: HashSet<u64>, } impl LlamaHandler { @@ -30,7 +33,7 @@ impl LlamaHandler { contexts: Arc::new(Mutex::new(HashMap::new())), models: lc.models, - channel_id: lc.channel, + channel_ids: lc.channels, } } @@ -146,10 +149,8 @@ impl LlamaHandler { #[async_trait] impl LineHandler for LlamaHandler { async fn message(&self, ctx: &Context, msg: &Message) { - if let Some(cid) = self.channel_id { - if msg.channel_id.0 != cid { - return; - } + if !self.channel_ids.contains(&msg.channel_id.0) { + return; } let txt = &msg.content; -- GitLab