1use std::collections::HashMap;
15use std::future::Future;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use std::{fmt, thread};
20
21use async_trait::async_trait;
22use crossbeam::queue::SegQueue;
23use parking_lot::RwLock;
24use serde::{Deserialize, Serialize};
25use tokio::sync::{mpsc, oneshot, Semaphore};
26use uuid::Uuid;
27
28use crate::config::ConcurrencyConfig;
29use crate::error::{ConcurrencyOperation, Error, ErrorKind, Result, ResultExt};
30use crate::manager::{ManagedState, Manager, ManagerStatus};
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub enum ThreadPoolType {
34 Compute,
35 Io,
36 Blocking,
37 Background,
38 Custom(u8),
39}
40
41impl ThreadPoolType {
42 pub fn default_thread_count(self) -> usize {
43 match self {
44 Self::Compute => num_cpus::get(),
45 Self::Io => num_cpus::get() * 2,
46 Self::Blocking => num_cpus::get().max(4),
47 Self::Background => 2,
48 Self::Custom(_) => 4,
49 }
50 }
51
52 pub fn default_queue_capacity(self) -> usize {
53 match self {
54 Self::Compute => 1000,
55 Self::Io => 5000,
56 Self::Blocking => 2000,
57 Self::Background => 500,
58 Self::Custom(_) => 1000,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ThreadPoolConfig {
65 pub thread_count: usize,
66 pub queue_capacity: usize,
67 pub stack_size: Option<usize>,
68 pub priority: Option<i32>,
69 pub name_prefix: String,
70 pub daemon: bool,
71 pub keep_alive: Duration,
72 pub work_stealing: bool,
73}
74
75impl Default for ThreadPoolConfig {
76 fn default() -> Self {
77 Self {
78 thread_count: num_cpus::get(),
79 queue_capacity: 1000,
80 stack_size: None,
81 priority: None,
82 name_prefix: "worker".to_string(),
83 daemon: false,
84 keep_alive: Duration::from_secs(60),
85 work_stealing: true,
86 }
87 }
88}
89
90type WorkItem = Box<dyn FnOnce() + Send + 'static>;
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ThreadPoolStats {
94 pub pool_type: ThreadPoolType,
95 pub active_threads: usize,
96 pub idle_threads: usize,
97 pub queue_size: usize,
98 pub total_executed: u64,
99 pub total_rejected: u64,
100 pub avg_execution_time_ms: f64,
101 pub peak_queue_size: usize,
102 pub utilization_percent: f64,
103}
104
105#[derive(Debug)]
106struct ThreadWorker {
107 id: usize,
108 thread_handle: Option<thread::JoinHandle<()>>,
109 #[allow(dead_code)]
110 work_queue: Arc<SegQueue<WorkItem>>,
111 stats: Arc<ThreadWorkerStats>,
112 #[allow(dead_code)]
113 shutdown_signal: Arc<parking_lot::Mutex<bool>>,
114}
115
116#[derive(Debug)]
117struct ThreadWorkerStats {
118 tasks_executed: AtomicU64,
119 total_execution_time_ms: AtomicU64,
120 last_activity: parking_lot::Mutex<Instant>,
121}
122
123impl ThreadWorkerStats {
124 fn new() -> Self {
125 Self {
126 tasks_executed: AtomicU64::new(0),
127 total_execution_time_ms: AtomicU64::new(0),
128 last_activity: parking_lot::Mutex::new(Instant::now()),
129 }
130 }
131
132 fn record_task_execution(&self, duration: Duration) {
133 self.tasks_executed.fetch_add(1, Ordering::Relaxed);
134 self.total_execution_time_ms
135 .fetch_add(duration.as_millis() as u64, Ordering::Relaxed);
136 *self.last_activity.lock() = Instant::now();
137 }
138
139 #[allow(dead_code)]
140 fn get_average_execution_time(&self) -> f64 {
141 let total_tasks = self.tasks_executed.load(Ordering::Relaxed);
142 if total_tasks == 0 {
143 0.0
144 } else {
145 let total_time = self.total_execution_time_ms.load(Ordering::Relaxed);
146 total_time as f64 / total_tasks as f64
147 }
148 }
149
150 fn is_idle(&self, threshold: Duration) -> bool {
151 self.last_activity.lock().elapsed() > threshold
152 }
153}
154
155#[derive(Debug)]
156pub struct ThreadPool {
157 #[allow(dead_code)]
158 pool_type: ThreadPoolType,
159 config: ThreadPoolConfig,
160 workers: Vec<ThreadWorker>,
161 global_queue: Arc<SegQueue<WorkItem>>,
162 stats: Arc<RwLock<ThreadPoolStats>>,
163 task_counter: Arc<AtomicU64>,
164 rejection_counter: Arc<AtomicU64>,
165 shutdown_signal: Arc<parking_lot::Mutex<bool>>,
166}
167
168impl ThreadPool {
169 pub fn new(pool_type: ThreadPoolType, config: ThreadPoolConfig) -> Result<Self> {
170 let global_queue = Arc::new(SegQueue::new());
171 let shutdown_signal = Arc::new(parking_lot::Mutex::new(false));
172 let task_counter = Arc::new(AtomicU64::new(0));
173 let rejection_counter = Arc::new(AtomicU64::new(0));
174
175 let stats = Arc::new(RwLock::new(ThreadPoolStats {
176 pool_type,
177 active_threads: 0,
178 idle_threads: 0,
179 queue_size: 0,
180 total_executed: 0,
181 total_rejected: 0,
182 avg_execution_time_ms: 0.0,
183 peak_queue_size: 0,
184 utilization_percent: 0.0,
185 }));
186
187 let mut workers = Vec::with_capacity(config.thread_count);
188
189 for worker_id in 0..config.thread_count {
191 let worker_queue = Arc::new(SegQueue::new());
192 let worker_stats = Arc::new(ThreadWorkerStats::new());
193 let worker_shutdown = Arc::clone(&shutdown_signal);
194 let worker_global_queue = Arc::clone(&global_queue);
195 let worker_task_counter = Arc::clone(&task_counter);
196 let worker_stats_clone = Arc::clone(&worker_stats);
197 let worker_queue_clone = Arc::clone(&worker_queue);
198 let thread_name = format!("{}-{}", config.name_prefix, worker_id);
199
200 let mut thread_builder = thread::Builder::new().name(thread_name);
201
202 if let Some(stack_size) = config.stack_size {
203 thread_builder = thread_builder.stack_size(stack_size);
204 }
205
206 let thread_handle = thread_builder
207 .spawn(move || {
208 Self::worker_thread(
209 worker_id,
210 worker_queue_clone,
211 worker_global_queue,
212 worker_stats_clone,
213 worker_shutdown,
214 worker_task_counter,
215 config.work_stealing,
216 );
217 })
218 .with_context(|| format!("Failed to spawn worker thread {}", worker_id))?;
219
220 let worker = ThreadWorker {
221 id: worker_id,
222 thread_handle: Some(thread_handle),
223 work_queue: worker_queue,
224 stats: worker_stats,
225 shutdown_signal: Arc::clone(&shutdown_signal),
226 };
227
228 workers.push(worker);
229 }
230
231 {
233 let mut stats_guard = stats.write();
234 stats_guard.active_threads = config.thread_count;
235 }
236
237 Ok(Self {
238 pool_type,
239 config,
240 workers,
241 global_queue,
242 stats,
243 task_counter,
244 rejection_counter,
245 shutdown_signal,
246 })
247 }
248
249 pub fn submit<F>(&self, task: F) -> Result<()>
250 where
251 F: FnOnce() + Send + 'static,
252 {
253 if *self.shutdown_signal.lock() {
255 return Err(Error::new(
256 ErrorKind::Concurrency {
257 thread_id: None,
258 operation: ConcurrencyOperation::ThreadPool,
259 },
260 "Thread pool is shutting down",
261 ));
262 }
263
264 let current_queue_size = self.global_queue.len();
266 if current_queue_size >= self.config.queue_capacity {
267 self.rejection_counter.fetch_add(1, Ordering::Relaxed);
268 return Err(Error::new(
269 ErrorKind::Concurrency {
270 thread_id: None,
271 operation: ConcurrencyOperation::ThreadPool,
272 },
273 "Thread pool queue is full",
274 ));
275 }
276
277 let work_item: WorkItem = Box::new(task);
279 self.global_queue.push(work_item);
280
281 self.update_stats();
283
284 Ok(())
285 }
286
287 pub async fn submit_async<F, R>(&self, task: F) -> Result<R>
288 where
289 F: FnOnce() -> R + Send + 'static,
290 R: Send + 'static,
291 {
292 let (tx, rx) = oneshot::channel();
293
294 let work_item = move || {
295 let result = task();
296 let _ = tx.send(result);
297 };
298
299 self.submit(work_item)?;
300
301 rx.await.map_err(|_| {
302 Error::new(
303 ErrorKind::Concurrency {
304 thread_id: None,
305 operation: ConcurrencyOperation::ThreadPool,
306 },
307 "Task execution was cancelled",
308 )
309 })
310 }
311
312 pub fn stats(&self) -> ThreadPoolStats {
313 self.stats.read().clone()
314 }
315
316 pub fn shutdown(mut self, timeout: Duration) -> Result<()> {
317 *self.shutdown_signal.lock() = true;
319
320 let start_time = Instant::now();
322 for mut worker in self.workers.drain(..) {
323 let remaining_time = timeout.saturating_sub(start_time.elapsed());
324
325 if let Some(handle) = worker.thread_handle.take() {
326 let join_result = if remaining_time.is_zero() {
328 Err("Thread join timeout")
329 } else {
330 match handle.join() {
331 Ok(()) => Ok(()),
332 Err(_) => Err("Thread join failed"),
333 }
334 };
335
336 if join_result.is_err() {
337 eprintln!("Worker thread {} did not shut down gracefully", worker.id);
338 }
339 }
340 }
341
342 Ok(())
343 }
344
345 fn worker_thread(
346 worker_id: usize,
347 local_queue: Arc<SegQueue<WorkItem>>,
348 global_queue: Arc<SegQueue<WorkItem>>,
349 stats: Arc<ThreadWorkerStats>,
350 shutdown_signal: Arc<parking_lot::Mutex<bool>>,
351 task_counter: Arc<AtomicU64>,
352 work_stealing: bool,
353 ) {
354 eprintln!("Worker thread {} started", worker_id);
355
356 while !*shutdown_signal.lock() {
357 let work_item = local_queue
359 .pop()
360 .or_else(|| global_queue.pop())
361 .or_else(|| {
362 if work_stealing {
363 None
365 } else {
366 None
367 }
368 });
369
370 if let Some(task) = work_item {
371 let start_time = Instant::now();
372
373 let execution_result =
375 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
376 task();
377 }));
378
379 let execution_time = start_time.elapsed();
380 stats.record_task_execution(execution_time);
381 task_counter.fetch_add(1, Ordering::Relaxed);
382
383 if execution_result.is_err() {
384 eprintln!("Task panicked in worker thread {}", worker_id);
385 }
386 } else {
387 thread::sleep(Duration::from_millis(1));
389 }
390 }
391
392 eprintln!("Worker thread {} shutting down", worker_id);
393 }
394
395 fn update_stats(&self) {
396 let mut stats = self.stats.write();
397
398 stats.queue_size = self.global_queue.len();
399 stats.total_executed = self.task_counter.load(Ordering::Relaxed);
400 stats.total_rejected = self.rejection_counter.load(Ordering::Relaxed);
401
402 if stats.queue_size > stats.peak_queue_size {
403 stats.peak_queue_size = stats.queue_size;
404 }
405
406 let mut total_execution_time = 0u64;
408 let mut total_tasks = 0u64;
409 let mut active_threads = 0;
410 let mut idle_threads = 0;
411
412 for worker in &self.workers {
413 let worker_tasks = worker.stats.tasks_executed.load(Ordering::Relaxed);
414 let worker_time = worker.stats.total_execution_time_ms.load(Ordering::Relaxed);
415
416 total_tasks += worker_tasks;
417 total_execution_time += worker_time;
418
419 if worker.stats.is_idle(Duration::from_secs(5)) {
420 idle_threads += 1;
421 } else {
422 active_threads += 1;
423 }
424 }
425
426 stats.active_threads = active_threads;
427 stats.idle_threads = idle_threads;
428
429 if total_tasks > 0 {
430 stats.avg_execution_time_ms = total_execution_time as f64 / total_tasks as f64;
431 }
432
433 let total_threads = active_threads + idle_threads;
435 if total_threads > 0 {
436 stats.utilization_percent = (active_threads as f64 / total_threads as f64) * 100.0;
437 }
438 }
439}
440
441#[derive(Debug)]
442pub struct AsyncWorkCoordinator {
443 semaphore: Arc<Semaphore>,
444 work_sender: mpsc::UnboundedSender<AsyncWorkItem>,
445 stats: Arc<RwLock<AsyncCoordinatorStats>>,
446}
447
448struct AsyncWorkItem {
449 task: Box<dyn FnOnce() -> Result<serde_json::Value> + Send>,
450 result_sender: oneshot::Sender<Result<serde_json::Value>>,
451}
452
453impl fmt::Debug for AsyncWorkItem {
454 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 f.debug_struct("AsyncWorkItem")
456 .field("task", &"FnOnce(..)")
457 .field("result_sender", &"oneshot::Sender")
458 .finish()
459 }
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct AsyncCoordinatorStats {
464 pub total_coordinated: u64,
465 pub active_permits: usize,
466 pub max_concurrent: usize,
467 pub avg_coordination_time_ms: f64,
468}
469
470impl AsyncWorkCoordinator {
471 pub fn new(max_concurrent: usize) -> Self {
472 let (work_sender, mut work_receiver): (
473 mpsc::UnboundedSender<AsyncWorkItem>,
474 mpsc::UnboundedReceiver<AsyncWorkItem>,
475 ) = mpsc::unbounded_channel();
476 let semaphore = Arc::new(Semaphore::new(max_concurrent));
477
478 tokio::spawn(async move {
480 while let Some(work_item) = work_receiver.recv().await {
481 let result = (work_item.task)();
482 let _ = work_item.result_sender.send(result);
483 }
484 });
485
486 Self {
487 semaphore,
488 work_sender,
489 stats: Arc::new(RwLock::new(AsyncCoordinatorStats {
490 total_coordinated: 0,
491 active_permits: 0,
492 max_concurrent,
493 avg_coordination_time_ms: 0.0,
494 })),
495 }
496 }
497
498 pub async fn coordinate<F, R>(&self, task: F) -> Result<R>
499 where
500 F: FnOnce() -> R + Send + 'static,
501 R: serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
502 {
503 let _permit = self.semaphore.acquire().await.map_err(|_| {
505 Error::new(
506 ErrorKind::Concurrency {
507 thread_id: None,
508 operation: ConcurrencyOperation::Sync,
509 },
510 "Failed to acquire coordination permit",
511 )
512 })?;
513
514 let start_time = Instant::now();
515
516 let (result_sender, result_receiver) = oneshot::channel();
517
518 let work_item = AsyncWorkItem {
519 task: Box::new(move || {
520 let val = task();
521 serde_json::to_value(val).map_err(|e| {
522 Error::new(
523 ErrorKind::Serialization,
524 format!("Failed to serialize result: {}", e),
525 )
526 })
527 }),
528 result_sender,
529 };
530
531 self.work_sender.send(work_item).map_err(|_| {
532 Error::new(
533 ErrorKind::Concurrency {
534 thread_id: None,
535 operation: ConcurrencyOperation::Channel,
536 },
537 "Failed to send work item",
538 )
539 })?;
540
541 let result = result_receiver.await.map_err(|_| {
543 Error::new(
544 ErrorKind::Concurrency {
545 thread_id: None,
546 operation: ConcurrencyOperation::Sync,
547 },
548 "Task execution was cancelled",
549 )
550 })??;
551
552 let coordination_time = start_time.elapsed();
554 self.update_stats(coordination_time).await;
555
556 serde_json::from_value(result).map_err(|e| {
558 Error::new(
559 ErrorKind::Serialization,
560 format!("Failed to deserialize result: {}", e),
561 )
562 })
563 }
564
565 async fn update_stats(&self, coordination_time: Duration) {
566 let mut stats = self.stats.write();
567 stats.total_coordinated += 1;
568 stats.active_permits = self.semaphore.available_permits();
569
570 let total_time = stats.avg_coordination_time_ms * (stats.total_coordinated - 1) as f64;
571 stats.avg_coordination_time_ms =
572 (total_time + coordination_time.as_millis() as f64) / stats.total_coordinated as f64;
573 }
574
575 pub async fn stats(&self) -> AsyncCoordinatorStats {
576 self.stats.read().clone()
577 }
578}
579
580#[derive(Debug)]
581pub struct ConcurrencyManager {
582 state: ManagedState,
583 #[allow(dead_code)]
584 config: ConcurrencyConfig,
585 thread_pools: HashMap<ThreadPoolType, ThreadPool>,
586 async_coordinator: AsyncWorkCoordinator,
587}
588
589impl ConcurrencyManager {
590 pub fn new(config: ConcurrencyConfig) -> Result<Self> {
591 let async_coordinator = AsyncWorkCoordinator::new(config.thread_pool_size * 2);
592
593 let mut thread_pools = HashMap::new();
594
595 let compute_config = ThreadPoolConfig {
597 thread_count: config.thread_pool_size,
598 name_prefix: "compute".to_string(),
599 ..Default::default()
600 };
601 let compute_pool = ThreadPool::new(ThreadPoolType::Compute, compute_config)?;
602 thread_pools.insert(ThreadPoolType::Compute, compute_pool);
603
604 let io_config = ThreadPoolConfig {
606 thread_count: config.io_thread_pool_size,
607 name_prefix: "io".to_string(),
608 ..Default::default()
609 };
610 let io_pool = ThreadPool::new(ThreadPoolType::Io, io_config)?;
611 thread_pools.insert(ThreadPoolType::Io, io_pool);
612
613 let blocking_config = ThreadPoolConfig {
615 thread_count: config.blocking_thread_pool_size,
616 name_prefix: "blocking".to_string(),
617 ..Default::default()
618 };
619 let blocking_pool = ThreadPool::new(ThreadPoolType::Blocking, blocking_config)?;
620 thread_pools.insert(ThreadPoolType::Blocking, blocking_pool);
621
622 Ok(Self {
623 state: ManagedState::new(Uuid::new_v4(), "concurrency_manager"),
624 config,
625 thread_pools,
626 async_coordinator,
627 })
628 }
629
630 pub async fn execute_compute<F, R>(&self, task: F) -> Result<R>
631 where
632 F: FnOnce() -> R + Send + 'static,
633 R: Send + 'static,
634 {
635 let compute_pool = self
636 .thread_pools
637 .get(&ThreadPoolType::Compute)
638 .ok_or_else(|| {
639 Error::new(
640 ErrorKind::Concurrency {
641 thread_id: None,
642 operation: ConcurrencyOperation::ThreadPool,
643 },
644 "Compute thread pool not available",
645 )
646 })?;
647
648 compute_pool.submit_async(task).await
649 }
650
651 pub async fn execute_io<F, R>(&self, task: F) -> Result<R>
652 where
653 F: FnOnce() -> R + Send + 'static,
654 R: Send + 'static,
655 {
656 let io_pool = self.thread_pools.get(&ThreadPoolType::Io).ok_or_else(|| {
657 Error::new(
658 ErrorKind::Concurrency {
659 thread_id: None,
660 operation: ConcurrencyOperation::ThreadPool,
661 },
662 "I/O thread pool not available",
663 )
664 })?;
665
666 io_pool.submit_async(task).await
667 }
668
669 pub async fn execute_blocking<F, R>(&self, task: F) -> Result<R>
670 where
671 F: FnOnce() -> R + Send + 'static,
672 R: Send + 'static,
673 {
674 let blocking_pool = self
675 .thread_pools
676 .get(&ThreadPoolType::Blocking)
677 .ok_or_else(|| {
678 Error::new(
679 ErrorKind::Concurrency {
680 thread_id: None,
681 operation: ConcurrencyOperation::ThreadPool,
682 },
683 "Blocking thread pool not available",
684 )
685 })?;
686
687 blocking_pool.submit_async(task).await
688 }
689
690 pub fn get_thread_pool_stats(&self, pool_type: ThreadPoolType) -> Option<ThreadPoolStats> {
691 self.thread_pools.get(&pool_type).map(|pool| pool.stats())
692 }
693
694 pub fn get_all_thread_pool_stats(&self) -> HashMap<ThreadPoolType, ThreadPoolStats> {
695 self.thread_pools
696 .iter()
697 .map(|(pool_type, pool)| (*pool_type, pool.stats()))
698 .collect()
699 }
700
701 pub async fn get_async_coordinator_stats(&self) -> AsyncCoordinatorStats {
702 self.async_coordinator.stats().await
703 }
704
705 pub fn create_custom_pool(&mut self, pool_id: u8, config: ThreadPoolConfig) -> Result<()> {
706 let pool_type = ThreadPoolType::Custom(pool_id);
707 let thread_pool = ThreadPool::new(pool_type, config)?;
708 self.thread_pools.insert(pool_type, thread_pool);
709 Ok(())
710 }
711
712 pub async fn execute_custom<F, R>(&self, pool_id: u8, task: F) -> Result<R>
713 where
714 F: FnOnce() -> R + Send + 'static,
715 R: Send + 'static,
716 {
717 let pool_type = ThreadPoolType::Custom(pool_id);
718 let custom_pool = self.thread_pools.get(&pool_type).ok_or_else(|| {
719 Error::new(
720 ErrorKind::Concurrency {
721 thread_id: None,
722 operation: ConcurrencyOperation::ThreadPool,
723 },
724 format!("Custom thread pool {} not available", pool_id),
725 )
726 })?;
727
728 custom_pool.submit_async(task).await
729 }
730}
731
732#[async_trait]
733impl Manager for ConcurrencyManager {
734 fn name(&self) -> &str {
735 "concurrency_manager"
736 }
737
738 fn id(&self) -> Uuid {
739 Uuid::new_v4()
740 }
741
742 async fn initialize(&mut self) -> Result<()> {
743 self.state
744 .set_state(crate::manager::ManagerState::Initializing)
745 .await;
746
747 self.state
751 .set_state(crate::manager::ManagerState::Running)
752 .await;
753 Ok(())
754 }
755
756 async fn shutdown(&mut self) -> Result<()> {
757 self.state
758 .set_state(crate::manager::ManagerState::ShuttingDown)
759 .await;
760
761 let shutdown_timeout = Duration::from_secs(30);
763 let mut pools_to_shutdown = Vec::new();
764
765 for (pool_type, pool) in self.thread_pools.drain() {
767 pools_to_shutdown.push((pool_type, pool));
768 }
769
770 for (pool_type, pool) in pools_to_shutdown {
771 if let Err(e) = pool.shutdown(shutdown_timeout) {
772 eprintln!("Failed to shutdown {:?} thread pool: {}", pool_type, e);
773 }
774 }
775
776 self.state
777 .set_state(crate::manager::ManagerState::Shutdown)
778 .await;
779 Ok(())
780 }
781
782 async fn status(&self) -> ManagerStatus {
783 let mut status = self.state.status().await;
784
785 let mut total_active_threads = 0;
787 let mut total_queued_tasks = 0;
788 let mut total_executed_tasks = 0u64;
789
790 for (pool_type, stats) in self.get_all_thread_pool_stats() {
791 total_active_threads += stats.active_threads;
792 total_queued_tasks += stats.queue_size;
793 total_executed_tasks += stats.total_executed;
794
795 status.add_metadata(
796 format!("{:?}_threads", pool_type).to_lowercase(),
797 serde_json::Value::from(stats.active_threads),
798 );
799 status.add_metadata(
800 format!("{:?}_queue_size", pool_type).to_lowercase(),
801 serde_json::Value::from(stats.queue_size),
802 );
803 }
804
805 status.add_metadata(
806 "total_active_threads",
807 serde_json::Value::from(total_active_threads),
808 );
809 status.add_metadata(
810 "total_queued_tasks",
811 serde_json::Value::from(total_queued_tasks),
812 );
813 status.add_metadata(
814 "total_executed_tasks",
815 serde_json::Value::from(total_executed_tasks),
816 );
817
818 let coordinator_stats = self.get_async_coordinator_stats().await;
820 status.add_metadata(
821 "async_coordinated_tasks",
822 serde_json::Value::from(coordinator_stats.total_coordinated),
823 );
824 status.add_metadata(
825 "async_active_permits",
826 serde_json::Value::from(coordinator_stats.active_permits),
827 );
828
829 status
830 }
831}
832
833pub mod utils {
834 use super::*;
835 use std::pin::Pin;
836 use std::sync::Arc;
837 use tokio::sync::Barrier;
838
839 pub async fn join_all<F, R>(tasks: Vec<F>) -> Vec<Result<R>>
840 where
841 F: Future<Output = Result<R>> + Send + 'static,
842 R: Send + 'static,
843 {
844 let handles: Vec<_> = tasks.into_iter().map(tokio::spawn).collect();
845
846 let mut results = Vec::new();
847 for handle in handles {
848 match handle.await {
849 Ok(result) => results.push(result),
850 Err(e) => results.push(Err(Error::new(
851 ErrorKind::Concurrency {
852 thread_id: None,
853 operation: ConcurrencyOperation::Spawn,
854 },
855 format!("Task join error: {}", e),
856 ))),
857 }
858 }
859
860 results
861 }
862
863 pub async fn execute_with_limit<F, R>(tasks: Vec<F>, limit: usize) -> Vec<Result<R>>
864 where
865 F: Future<Output = Result<R>> + Send + 'static,
866 R: Send + 'static,
867 {
868 let semaphore = Arc::new(Semaphore::new(limit));
869 let handles: Vec<_> = tasks
870 .into_iter()
871 .map(|task| {
872 let sem = Arc::clone(&semaphore);
873 tokio::spawn(async move {
874 let _permit = sem.acquire().await.map_err(|_| {
875 Error::new(
876 ErrorKind::Concurrency {
877 thread_id: None,
878 operation: ConcurrencyOperation::Sync,
879 },
880 "Failed to acquire semaphore permit",
881 )
882 })?;
883 task.await
884 })
885 })
886 .collect();
887
888 let mut results = Vec::new();
889 for handle in handles {
890 match handle.await {
891 Ok(result) => results.push(result),
892 Err(e) => results.push(Err(Error::new(
893 ErrorKind::Concurrency {
894 thread_id: None,
895 operation: ConcurrencyOperation::Spawn,
896 },
897 format!("Task join error: {}", e),
898 ))),
899 }
900 }
901
902 results
903 }
904
905 pub async fn synchronize_at_barrier(
906 tasks: Vec<
907 Box<
908 dyn FnOnce(Arc<Barrier>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send,
909 >,
910 >,
911 ) -> Result<()> {
912 let barrier = Arc::new(Barrier::new(tasks.len()));
913 let handles: Vec<_> = tasks
914 .into_iter()
915 .map(|task| {
916 let barrier_clone = Arc::clone(&barrier);
917 tokio::spawn(task(barrier_clone))
918 })
919 .collect();
920
921 for handle in handles {
922 handle.await.map_err(|e| {
923 Error::new(
924 ErrorKind::Concurrency {
925 thread_id: None,
926 operation: ConcurrencyOperation::Spawn,
927 },
928 format!("Barrier synchronization error: {}", e),
929 )
930 })??;
931 }
932
933 Ok(())
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use super::*;
940 use std::pin::Pin;
941 use std::sync::atomic::{AtomicU32, Ordering};
942 use std::time::Duration;
943
944 #[tokio::test]
945 async fn test_thread_pool_creation() {
946 let config = ThreadPoolConfig::default();
947 let pool = ThreadPool::new(ThreadPoolType::Compute, config).unwrap();
948
949 let stats = pool.stats();
950 assert_eq!(stats.pool_type, ThreadPoolType::Compute);
951 assert!(stats.active_threads > 0);
952 }
953
954 #[tokio::test]
955 async fn test_thread_pool_task_execution() {
956 let config = ThreadPoolConfig {
957 thread_count: 2,
958 ..Default::default()
959 };
960 let pool = ThreadPool::new(ThreadPoolType::Compute, config).unwrap();
961
962 let counter = Arc::new(AtomicU32::new(0));
963 let counter_clone = Arc::clone(&counter);
964
965 let result = pool
966 .submit_async(move || {
967 counter_clone.fetch_add(1, Ordering::SeqCst);
968 42i32
969 })
970 .await
971 .unwrap();
972
973 assert_eq!(result, 42);
974 assert_eq!(counter.load(Ordering::SeqCst), 1);
975 }
976
977 #[tokio::test]
978 async fn test_concurrency_manager_initialization() {
979 let config = ConcurrencyConfig::default();
980 let mut manager = ConcurrencyManager::new(config).unwrap();
981
982 manager.initialize().await.unwrap();
983
984 let status = manager.status().await;
985 assert_eq!(status.state, crate::manager::ManagerState::Running);
986
987 manager.shutdown().await.unwrap();
988 }
989
990 #[tokio::test]
991 async fn test_compute_task_execution() {
992 let config = ConcurrencyConfig::default();
993 let manager = ConcurrencyManager::new(config).unwrap();
994
995 let result = manager
996 .execute_compute(|| {
997 let mut sum = 0i32;
999 for i in 0..1000i32 {
1000 sum += i;
1001 }
1002 sum
1003 })
1004 .await
1005 .unwrap();
1006
1007 assert_eq!(result, 499500);
1008 }
1009
1010 #[tokio::test]
1011 async fn test_thread_pool_stats() {
1012 let config = ThreadPoolConfig {
1013 thread_count: 2,
1014 ..Default::default()
1015 };
1016 let pool = ThreadPool::new(ThreadPoolType::Io, config).unwrap();
1017
1018 for i in 0..5i32 {
1020 let _ = pool
1021 .submit_async(move || {
1022 thread::sleep(Duration::from_millis(10));
1023 i * 2
1024 })
1025 .await;
1026 }
1027
1028 let stats = pool.stats();
1029 assert_eq!(stats.pool_type, ThreadPoolType::Io);
1030 assert!(stats.total_executed >= 5);
1031 }
1032
1033 #[tokio::test]
1034 async fn test_utils_join_all() {
1035 let tasks = vec![
1036 Box::pin(async { Ok(1i32) }) as Pin<Box<dyn Future<Output = Result<i32>> + Send>>,
1037 Box::pin(async { Ok(2i32) }) as Pin<Box<dyn Future<Output = Result<i32>> + Send>>,
1038 Box::pin(async { Ok(3i32) }) as Pin<Box<dyn Future<Output = Result<i32>> + Send>>,
1039 ];
1040
1041 let results = utils::join_all(tasks).await;
1042 assert_eq!(results.len(), 3);
1043
1044 for (i, result) in results.into_iter().enumerate() {
1045 assert!(result.is_ok());
1046 assert_eq!(result.unwrap(), (i + 1) as i32);
1047 }
1048 }
1049
1050 #[tokio::test]
1051 async fn test_utils_execute_with_limit() {
1052 let counter = Arc::new(AtomicU32::new(0));
1053 let tasks: Vec<_> = (0..10i32)
1054 .map(|i| {
1055 let counter = Arc::clone(&counter);
1056 async move {
1057 counter.fetch_add(1, Ordering::SeqCst);
1058 tokio::time::sleep(Duration::from_millis(50)).await;
1059 Ok(i)
1060 }
1061 })
1062 .collect();
1063
1064 let results = utils::execute_with_limit(tasks, 3).await;
1065 assert_eq!(results.len(), 10);
1066 assert_eq!(counter.load(Ordering::SeqCst), 10);
1067 }
1068}