From 0167465e74f7f2e42101b3b9b4078a68c64e6ccb Mon Sep 17 00:00:00 2001 From: Stephen D <webmaster@scd31.com> Date: Wed, 5 Oct 2022 17:23:09 -0300 Subject: [PATCH] start using tokio mutex --- src/handlers/mod.rs | 3 ++- src/handlers/xbasic.rs | 41 +++++++++++++++++++---------------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 0d86643..d1c71e3 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -17,7 +17,8 @@ use serenity::model::channel::Message; use serenity::model::prelude::Ready; use serenity::prelude::*; use std::env; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::sync::Mutex; #[async_trait] pub(crate) trait LineHandler: Send + Sync { diff --git a/src/handlers/xbasic.rs b/src/handlers/xbasic.rs index 0ae9238..fe3b27d 100644 --- a/src/handlers/xbasic.rs +++ b/src/handlers/xbasic.rs @@ -10,7 +10,8 @@ use serenity::prelude::*; use std::borrow::{Borrow, Cow}; use std::collections::HashMap; use std::str::FromStr; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::sync::Mutex; use tokio::task; use xbasic::basic_io::BasicIO; use xbasic::expr::ExprValue; @@ -42,7 +43,7 @@ impl BasicIO for DiscordIO { macro_rules! get_user_programs { ($self: expr, $user_id: expr) => { - $self.programs.lock().unwrap().get_mut($user_id).unwrap() + $self.programs.lock().await.get_mut($user_id).unwrap() }; } @@ -63,10 +64,10 @@ impl XbasicHandler { // TODO we lock the mutex to check, but unlock before locking again later // allows another thread to screw it up // we should lock once for this entire function - if self.programs.lock().unwrap().contains_key(&msg.author.id) { + if self.programs.lock().await.contains_key(&msg.author.id) { match line { "!STOP" => { - self.interpreter_stop(msg); + self.interpreter_stop(msg).await; } "RUN" => { self.interpreter_run(msg, ctx).await; @@ -164,28 +165,24 @@ impl XbasicHandler { } if line == "!START" { - self.interpreter_start(msg); + self.interpreter_start(msg).await; } } - fn interpreter_stop(&self, msg: &Message) { - self.programs - .lock() - .unwrap() - .remove(&msg.author.id) - .unwrap(); + async fn interpreter_stop(&self, msg: &Message) { + self.programs.lock().await.remove(&msg.author.id).unwrap(); } - fn interpreter_start(&self, msg: &Message) { + async fn interpreter_start(&self, msg: &Message) { self.programs .lock() - .unwrap() + .await .insert(msg.author.id, Program::new()); } async fn list_saved_programs(&self, msg: &Message, ctx: &Context) -> Option<()> { let program_names = - Program::list_programs_by_user(self.conn.lock().ok()?.borrow(), msg.author.id)?; + Program::list_programs_by_user(self.conn.lock().await.borrow(), msg.author.id)?; msg.channel_id .say( &ctx, @@ -203,7 +200,7 @@ impl XbasicHandler { async fn list_published_programs(&self, msg: &Message, ctx: &Context) -> Option<()> { let program_names: Vec<String> = - Program::list_published_programs(self.conn.lock().ok()?.borrow())? + Program::list_published_programs(self.conn.lock().await.borrow())? .iter() .map(|row| format!("{}\t{}", row.0, row.1)) .collect(); @@ -223,7 +220,7 @@ impl XbasicHandler { } async fn publish_program(&self, name: &str, msg: &Message, ctx: &Context) -> Option<()> { - Program::set_program_published(self.conn.lock().ok()?.borrow(), name, msg.author.id, true)?; + Program::set_program_published(self.conn.lock().await.borrow(), name, msg.author.id, true)?; msg.channel_id .say(&ctx, format!("Published {}.", name)) @@ -235,7 +232,7 @@ impl XbasicHandler { async fn unpublish_program(&self, name: &str, msg: &Message, ctx: &Context) -> Option<()> { Program::set_program_published( - self.conn.lock().ok()?.borrow(), + self.conn.lock().await.borrow(), name, msg.author.id, false, @@ -251,7 +248,7 @@ impl XbasicHandler { async fn load_published_program(&self, msg: &Message, ctx: &Context, id: i32) -> Option<()> { let name = get_user_programs!(self, &msg.author.id) - .load_published_program(self.conn.lock().ok()?.borrow(), id)?; + .load_published_program(self.conn.lock().await.borrow(), id)?; msg.channel_id .say(&ctx, format!("Loaded {} (\"{}\") into memory.", id, name)) @@ -267,7 +264,7 @@ impl XbasicHandler { &ctx, format!( "```\n{}\n```", - self.programs.lock().unwrap()[&msg.author.id].stringy_line_nums() + self.programs.lock().await[&msg.author.id].stringy_line_nums() ), ) .await @@ -276,7 +273,7 @@ impl XbasicHandler { async fn interpreter_load(&self, name: &str, msg: &Message, ctx: &Context) { let result = get_user_programs!(self, &msg.author.id).load_program( - &self.conn.lock().unwrap(), + &*self.conn.lock().await, msg.author.id, name, ); @@ -298,7 +295,7 @@ impl XbasicHandler { async fn interpreter_save(&self, name: &str, msg: &Message, ctx: &Context) { let result = get_user_programs!(self, &msg.author.id).save_program( - &self.conn.lock().unwrap(), + &*self.conn.lock().await, msg.author.id, name, ); @@ -324,7 +321,7 @@ impl XbasicHandler { Err(_) => None, }; - let code = self.programs.lock().unwrap()[&msg.author.id].stringify(); + let code = self.programs.lock().await[&msg.author.id].stringify(); let io = DiscordIO::new(); let (output, fb, errors) = task::spawn_blocking(move || { -- GitLab