From 79076b3bac8ff5b1e28b7f0e895ac0dd483c7423 Mon Sep 17 00:00:00 2001
From: Stephen D <webmaster@scd31.com>
Date: Thu, 13 Jun 2024 19:23:20 -0300
Subject: [PATCH] llama multi-channel support

---
 src/config.rs         |  8 ++++++--
 src/handlers/llama.rs | 15 ++++++++-------
 2 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/src/config.rs b/src/config.rs
index eff6f53..e077fff 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,12 +1,16 @@
 use serde::Deserialize;
-use std::{collections::HashMap, fs};
+use std::{
+	collections::{HashMap, HashSet},
+	fs,
+};
 
 #[derive(Deserialize)]
 pub struct LlamaConfig {
 	pub(crate) address: String,
 	pub(crate) port: u16,
 	pub(crate) models: HashMap<String, String>,
-	pub(crate) channel: Option<u64>,
+	#[serde(default)]
+	pub(crate) channels: HashSet<u64>,
 }
 
 #[derive(Deserialize)]
diff --git a/src/handlers/llama.rs b/src/handlers/llama.rs
index 23c84ca..7bd0dd7 100644
--- a/src/handlers/llama.rs
+++ b/src/handlers/llama.rs
@@ -13,14 +13,17 @@ use serenity::{
 	model::prelude::{Message, MessageId},
 	prelude::*,
 };
-use std::{collections::HashMap, sync::Arc};
+use std::{
+	collections::{HashMap, HashSet},
+	sync::Arc,
+};
 
 pub struct LlamaHandler {
 	ollama: Ollama,
 	contexts: Arc<Mutex<HashMap<MessageId, (String, GenerationContext)>>>,
 
 	models: HashMap<String, String>,
-	channel_id: Option<u64>,
+	channel_ids: HashSet<u64>,
 }
 
 impl LlamaHandler {
@@ -30,7 +33,7 @@ impl LlamaHandler {
 			contexts: Arc::new(Mutex::new(HashMap::new())),
 
 			models: lc.models,
-			channel_id: lc.channel,
+			channel_ids: lc.channels,
 		}
 	}
 
@@ -146,10 +149,8 @@ impl LlamaHandler {
 #[async_trait]
 impl LineHandler for LlamaHandler {
 	async fn message(&self, ctx: &Context, msg: &Message) {
-		if let Some(cid) = self.channel_id {
-			if msg.channel_id.0 != cid {
-				return;
-			}
+		if !self.channel_ids.contains(&msg.channel_id.0) {
+			return;
 		}
 
 		let txt = &msg.content;
-- 
GitLab