diff --git a/src/config.rs b/src/config.rs index eff6f5312299d1268268cff010f679e19a831469..e077fffb71cc4238bfebf99448e903313c62b69a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,12 +1,16 @@ use serde::Deserialize; -use std::{collections::HashMap, fs}; +use std::{ + collections::{HashMap, HashSet}, + fs, +}; #[derive(Deserialize)] pub struct LlamaConfig { pub(crate) address: String, pub(crate) port: u16, pub(crate) models: HashMap<String, String>, - pub(crate) channel: Option<u64>, + #[serde(default)] + pub(crate) channels: HashSet<u64>, } #[derive(Deserialize)] diff --git a/src/handlers/llama.rs b/src/handlers/llama.rs index 23c84ca5b7f2d715fd607f105e063eb307b542de..7bd0dd7794b15af92bfa7ecab1613b90590e7c0e 100644 --- a/src/handlers/llama.rs +++ b/src/handlers/llama.rs @@ -13,14 +13,17 @@ use serenity::{ model::prelude::{Message, MessageId}, prelude::*, }; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; pub struct LlamaHandler { ollama: Ollama, contexts: Arc<Mutex<HashMap<MessageId, (String, GenerationContext)>>>, models: HashMap<String, String>, - channel_id: Option<u64>, + channel_ids: HashSet<u64>, } impl LlamaHandler { @@ -30,7 +33,7 @@ impl LlamaHandler { contexts: Arc::new(Mutex::new(HashMap::new())), models: lc.models, - channel_id: lc.channel, + channel_ids: lc.channels, } } @@ -146,10 +149,8 @@ impl LlamaHandler { #[async_trait] impl LineHandler for LlamaHandler { async fn message(&self, ctx: &Context, msg: &Message) { - if let Some(cid) = self.channel_id { - if msg.channel_id.0 != cid { - return; - } + if !self.channel_ids.contains(&msg.channel_id.0) { + return; } let txt = &msg.content;