1use std::collections::HashMap;
6use std::fmt;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::AtomicU64;
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::utils::Time;
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use dashmap::DashMap;
17use serde::{Deserialize, Serialize};
18use tokio::sync::{broadcast, RwLock, Semaphore};
19use tokio::time::{timeout, Instant};
20use uuid::Uuid;
21
22use crate::config::TaskConfig;
23use crate::error::{Error, Result};
24use crate::event::{Event, EventBusManager};
25use crate::manager::{ManagedState, Manager, ManagerStatus};
26use crate::types::{CorrelationId, Metadata};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum TaskStatus {
30 Pending,
31 Running,
32 Completed,
33 Failed,
34 Cancelled,
35 TimedOut,
36 Paused,
37}
38
39impl fmt::Display for TaskStatus {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::Pending => write!(f, "PENDING"),
43 Self::Running => write!(f, "RUNNING"),
44 Self::Completed => write!(f, "COMPLETED"),
45 Self::Failed => write!(f, "FAILED"),
46 Self::Cancelled => write!(f, "CANCELLED"),
47 Self::TimedOut => write!(f, "TIMED_OUT"),
48 Self::Paused => write!(f, "PAUSED"),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
54pub enum TaskPriority {
55 Low = 0,
56 Normal = 50,
57 High = 100,
58 Critical = 200,
59}
60
61impl Default for TaskPriority {
62 fn default() -> Self {
63 Self::Normal
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum TaskCategory {
69 Core,
70 Plugin,
71 Ui,
72 Io,
73 Background,
74 User,
75 Maintenance,
76 Custom(String),
77}
78
79impl fmt::Display for TaskCategory {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 match self {
82 Self::Core => write!(f, "core"),
83 Self::Plugin => write!(f, "plugin"),
84 Self::Ui => write!(f, "ui"),
85 Self::Io => write!(f, "io"),
86 Self::Background => write!(f, "background"),
87 Self::User => write!(f, "user"),
88 Self::Maintenance => write!(f, "maintenance"),
89 Self::Custom(name) => write!(f, "{}", name),
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct TaskProgress {
96 pub percent: u8,
97 pub message: String,
98 pub current_step: Option<u32>,
99 pub total_steps: Option<u32>,
100 pub updated_at: DateTime<Utc>,
101 pub metadata: Metadata,
102}
103
104impl Default for TaskProgress {
105 fn default() -> Self {
106 Self {
107 percent: 0,
108 message: String::new(),
109 current_step: None,
110 total_steps: None,
111 updated_at: Time::now(),
112 metadata: HashMap::new(),
113 }
114 }
115}
116
117impl TaskProgress {
118 pub fn new(percent: u8, message: impl Into<String>) -> Self {
119 Self {
120 percent: percent.min(100),
121 message: message.into(),
122 updated_at: Time::now(),
123 ..Default::default()
124 }
125 }
126
127 pub fn with_steps(current: u32, total: u32, message: impl Into<String>) -> Self {
128 let percent = if total > 0 {
129 ((current as f64 / total as f64) * 100.0) as u8
130 } else {
131 0
132 };
133
134 Self {
135 percent,
136 message: message.into(),
137 current_step: Some(current),
138 total_steps: Some(total),
139 updated_at: Time::now(),
140 ..Default::default()
141 }
142 }
143
144 pub fn set_percent(&mut self, percent: u8) {
145 self.percent = percent.min(100);
146 self.updated_at = Time::now();
147 }
148
149 pub fn set_message(&mut self, message: impl Into<String>) {
150 self.message = message.into();
151 self.updated_at = Time::now();
152 }
153
154 pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
155 self.metadata.insert(key.into(), value);
156 self.updated_at = Time::now();
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct TaskResult {
162 pub success: bool,
163 pub data: Option<serde_json::Value>,
164 pub error: Option<String>,
165 pub duration: Duration,
166 pub resource_usage: ResourceUsage,
167 pub metadata: Metadata,
168}
169
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
171pub struct ResourceUsage {
172 pub peak_memory_bytes: u64,
173 pub cpu_time_ms: u64,
174 pub file_operations: u32,
175 pub network_bytes: u64,
176}
177
178pub trait ProgressReporter: Send + Sync + fmt::Debug {
179 fn report(&self, progress: TaskProgress);
180 fn report_percent(&self, percent: u8, message: String) {
181 self.report(TaskProgress::new(percent, message));
182 }
183 fn report_step(&self, current: u32, total: u32, message: String) {
184 self.report(TaskProgress::with_steps(current, total, message));
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct CancellationToken {
191 #[cfg(not(target_arch = "wasm32"))]
192 inner: tokio_util::sync::CancellationToken,
193 #[cfg(target_arch = "wasm32")]
194 cancelled: Arc<std::sync::atomic::AtomicBool>,
195}
196
197impl Default for CancellationToken {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl CancellationToken {
204 pub fn new() -> Self {
205 Self {
206 #[cfg(not(target_arch = "wasm32"))]
207 inner: tokio_util::sync::CancellationToken::new(),
208 #[cfg(target_arch = "wasm32")]
209 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
210 }
211 }
212
213 #[allow(dead_code)]
214 fn default() -> Self {
215 Self::new()
216 }
217
218 pub fn cancel(&self) {
219 #[cfg(not(target_arch = "wasm32"))]
220 self.inner.cancel();
221 #[cfg(target_arch = "wasm32")]
222 self.cancelled.store(true, Ordering::SeqCst);
223 }
224
225 pub fn is_cancelled(&self) -> bool {
226 #[cfg(not(target_arch = "wasm32"))]
227 return self.inner.is_cancelled();
228 #[cfg(target_arch = "wasm32")]
229 return self.cancelled.load(Ordering::SeqCst);
230 }
231
232 pub async fn cancelled(&self) {
233 #[cfg(not(target_arch = "wasm32"))]
234 self.inner.cancelled().await;
235 #[cfg(target_arch = "wasm32")]
236 {
237 while !self.is_cancelled() {
239 tokio::time::sleep(Duration::from_millis(10)).await;
240 }
241 }
242 }
243}
244
245#[derive(Debug)]
246pub struct TaskContext {
247 pub task_id: Uuid,
248 pub name: String,
249 pub category: TaskCategory,
250 pub plugin_id: Option<String>,
251 pub correlation_id: Option<CorrelationId>,
252 pub progress: Arc<dyn ProgressReporter>,
253 pub cancellation_token: CancellationToken,
254 pub metadata: Metadata,
255}
256
257impl TaskContext {
258 pub fn is_cancelled(&self) -> bool {
259 self.cancellation_token.is_cancelled()
260 }
261
262 pub async fn cancelled(&self) {
263 self.cancellation_token.cancelled().await;
264 }
265
266 pub fn report_progress(&self, progress: TaskProgress) {
267 self.progress.report(progress);
268 }
269
270 pub fn report_percent(&self, percent: u8, message: impl Into<String>) {
271 self.progress.report_percent(percent, message.into());
272 }
273
274 pub fn report_step(&self, current: u32, total: u32, message: impl Into<String>) {
275 self.progress.report_step(current, total, message.into());
276 }
277}
278
279pub type TaskFunction = Arc<
280 dyn Fn(TaskContext) -> Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send>>
281 + Send
282 + Sync,
283>;
284
285pub struct TaskDefinition {
286 pub id: Uuid,
287 pub name: String,
288 pub category: TaskCategory,
289 pub priority: TaskPriority,
290 pub plugin_id: Option<String>,
291 pub dependencies: Vec<Uuid>,
292 pub timeout: Duration,
293 pub max_retries: u32,
294 pub cancellable: bool,
295 pub metadata: Metadata,
296 pub correlation_id: Option<CorrelationId>,
297 pub function: TaskFunction,
298}
299
300impl fmt::Debug for TaskDefinition {
301 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302 f.debug_struct("TaskDefinition")
303 .field("id", &self.id)
304 .field("name", &self.name)
305 .field("category", &self.category)
306 .field("priority", &self.priority)
307 .field("plugin_id", &self.plugin_id)
308 .field("dependencies", &self.dependencies)
309 .field("timeout", &self.timeout)
310 .field("max_retries", &self.max_retries)
311 .field("cancellable", &self.cancellable)
312 .field("metadata", &self.metadata)
313 .field("correlation_id", &self.correlation_id)
314 .field("function", &"<function>")
315 .finish()
316 }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct TaskInfo {
321 pub id: Uuid,
322 pub name: String,
323 pub category: TaskCategory,
324 pub priority: TaskPriority,
325 pub status: TaskStatus,
326 pub plugin_id: Option<String>,
327 pub dependencies: Vec<Uuid>,
328 pub created_at: DateTime<Utc>,
329 pub started_at: Option<DateTime<Utc>>,
330 pub completed_at: Option<DateTime<Utc>>,
331 pub progress: TaskProgress,
332 pub result: Option<TaskResult>,
333 pub retry_count: u32,
334 pub max_retries: u32,
335 pub timeout: Duration,
336 pub cancellable: bool,
337 pub correlation_id: Option<CorrelationId>,
338 pub metadata: Metadata,
339}
340
341impl TaskInfo {
342 pub fn from_definition(definition: &TaskDefinition) -> Self {
343 Self {
344 id: definition.id,
345 name: definition.name.clone(),
346 category: definition.category.clone(),
347 priority: definition.priority,
348 status: TaskStatus::Pending,
349 plugin_id: definition.plugin_id.clone(),
350 dependencies: definition.dependencies.clone(),
351 created_at: Time::now(),
352 started_at: None,
353 completed_at: None,
354 progress: TaskProgress::default(),
355 result: None,
356 retry_count: 0,
357 max_retries: definition.max_retries,
358 timeout: definition.timeout,
359 cancellable: definition.cancellable,
360 correlation_id: definition.correlation_id,
361 metadata: definition.metadata.clone(),
362 }
363 }
364
365 pub fn duration(&self) -> Option<Duration> {
366 if let (Some(started), Some(completed)) = (self.started_at, self.completed_at) {
367 Some((completed - started).to_std().ok()?)
368 } else {
369 None
370 }
371 }
372
373 pub fn is_terminal(&self) -> bool {
374 matches!(
375 self.status,
376 TaskStatus::Completed
377 | TaskStatus::Failed
378 | TaskStatus::Cancelled
379 | TaskStatus::TimedOut
380 )
381 }
382
383 pub fn can_retry(&self) -> bool {
384 matches!(self.status, TaskStatus::Failed | TaskStatus::TimedOut)
385 && self.retry_count < self.max_retries
386 }
387}
388
389#[derive(Debug)]
390struct TaskExecution {
391 info: TaskInfo,
392 definition: TaskDefinition,
393 cancellation_token: CancellationToken,
394 progress_sender: broadcast::Sender<TaskProgress>,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct TaskManagerStats {
399 pub total_created: u64,
400 pub total_completed: u64,
401 pub total_failed: u64,
402 pub total_cancelled: u64,
403 pub currently_running: u32,
404 pub currently_pending: u32,
405 pub by_category: HashMap<String, u64>,
406 pub by_priority: HashMap<TaskPriority, u64>,
407 pub avg_execution_time_ms: f64,
408 pub total_resource_usage: ResourceUsage,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct TaskCreatedEvent {
413 pub task_id: Uuid,
414 pub name: String,
415 pub category: TaskCategory,
416 pub priority: TaskPriority,
417 pub timestamp: DateTime<Utc>,
418 pub source: String,
419 pub metadata: Metadata,
420}
421
422impl Event for TaskCreatedEvent {
423 fn event_type(&self) -> &'static str {
424 "task.created"
425 }
426
427 fn source(&self) -> &str {
428 &self.source
429 }
430
431 fn metadata(&self) -> &Metadata {
432 &self.metadata
433 }
434
435 fn as_any(&self) -> &dyn std::any::Any {
436 self
437 }
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct TaskStatusChangedEvent {
442 pub task_id: Uuid,
443 pub name: String,
444 pub old_status: TaskStatus,
445 pub new_status: TaskStatus,
446 pub timestamp: DateTime<Utc>,
447 pub source: String,
448 pub metadata: Metadata,
449}
450
451impl Event for TaskStatusChangedEvent {
452 fn event_type(&self) -> &'static str {
453 "task.status_changed"
454 }
455
456 fn source(&self) -> &str {
457 &self.source
458 }
459
460 fn metadata(&self) -> &Metadata {
461 &self.metadata
462 }
463
464 fn as_any(&self) -> &dyn std::any::Any {
465 self
466 }
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
470pub struct TaskProgressEvent {
471 pub task_id: Uuid,
472 pub name: String,
473 pub progress: TaskProgress,
474 pub timestamp: DateTime<Utc>,
475 pub source: String,
476 pub metadata: Metadata,
477}
478
479impl Event for TaskProgressEvent {
480 fn event_type(&self) -> &'static str {
481 "task.progress"
482 }
483
484 fn source(&self) -> &str {
485 &self.source
486 }
487
488 fn metadata(&self) -> &Metadata {
489 &self.metadata
490 }
491
492 fn as_any(&self) -> &dyn std::any::Any {
493 self
494 }
495}
496
497#[derive(Debug)]
498struct TaskProgressReporter {
499 task_id: Uuid,
500 progress_sender: broadcast::Sender<TaskProgress>,
501}
502
503impl ProgressReporter for TaskProgressReporter {
504 fn report(&self, progress: TaskProgress) {
505 tracing::debug!(
506 "Task {} progress: {}% - {}",
507 self.task_id,
508 progress.percent,
509 progress.message
510 );
511 let _ = self.progress_sender.send(progress);
512 }
513}
514
515#[derive(Debug)]
516pub struct TaskManager {
517 state: ManagedState,
518 #[allow(dead_code)]
519 config: TaskConfig,
520 tasks: Arc<DashMap<Uuid, TaskExecution>>,
521 stats: Arc<RwLock<TaskManagerStats>>,
522 #[allow(dead_code)]
523 task_counter: Arc<AtomicU64>,
524 concurrency_semaphore: Arc<Semaphore>,
525 event_bus: Option<Arc<EventBusManager>>,
526 worker_handles: Vec<tokio::task::JoinHandle<()>>,
527 shutdown_flag: Arc<tokio::sync::RwLock<bool>>,
528}
529
530impl TaskManager {
531 pub fn new(config: TaskConfig) -> Self {
532 let max_concurrent = config.max_concurrent;
533 Self {
534 state: ManagedState::new(Uuid::new_v4(), "task_manager"),
535 config,
536 tasks: Arc::new(DashMap::new()),
537 stats: Arc::new(RwLock::new(TaskManagerStats {
538 total_created: 0,
539 total_completed: 0,
540 total_failed: 0,
541 total_cancelled: 0,
542 currently_running: 0,
543 currently_pending: 0,
544 by_category: HashMap::new(),
545 by_priority: HashMap::new(),
546 avg_execution_time_ms: 0.0,
547 total_resource_usage: ResourceUsage::default(),
548 })),
549 task_counter: Arc::new(AtomicU64::new(0)),
550 concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
551 event_bus: None,
552 worker_handles: Vec::new(),
553 shutdown_flag: Arc::new(tokio::sync::RwLock::new(false)),
554 }
555 }
556
557 pub fn set_event_bus(&mut self, event_bus: Arc<EventBusManager>) {
558 self.event_bus = Some(event_bus);
559 }
560
561 pub async fn submit_task(&self, definition: TaskDefinition) -> Result<Uuid> {
562 let task_id = definition.id;
563 let task_info = TaskInfo::from_definition(&definition);
564
565 tracing::info!("Submitting task {} ({})", task_info.name, task_id);
566
567 for dep_id in &task_info.dependencies {
569 if let Some(dep_task) = self.tasks.get(dep_id) {
570 if !dep_task.info.is_terminal() {
571 return Err(Error::task(
572 Some(task_id),
573 Some(format!("dep_{}", dep_id)),
574 format!("Dependency task {} is not completed", dep_id),
575 ));
576 }
577 if dep_task.info.status != TaskStatus::Completed {
578 return Err(Error::task(
579 Some(task_id),
580 Some(format!("dep_{}", dep_id)),
581 format!("Dependency task {} failed", dep_id),
582 ));
583 }
584 } else {
585 return Err(Error::task(
586 Some(task_id),
587 Some(format!("dep_{}", dep_id)),
588 format!("Dependency task {} not found", dep_id),
589 ));
590 }
591 }
592
593 let (progress_sender, _) = broadcast::channel(100);
594 let cancellation_token = CancellationToken::new();
595
596 let execution = TaskExecution {
597 info: task_info.clone(),
598 definition,
599 cancellation_token,
600 progress_sender,
601 };
602
603 self.tasks.insert(task_id, execution);
605 tracing::debug!(
606 "Task {} added to collection, total tasks: {}",
607 task_id,
608 self.tasks.len()
609 );
610
611 {
613 let mut stats = self.stats.write().await;
614 stats.total_created += 1;
615 stats.currently_pending += 1;
616 *stats
617 .by_category
618 .entry(task_info.category.to_string())
619 .or_insert(0) += 1;
620 *stats.by_priority.entry(task_info.priority).or_insert(0) += 1;
621 tracing::debug!("Stats updated: {} pending tasks", stats.currently_pending);
622 }
623
624 if let Some(event_bus) = &self.event_bus {
626 let event = TaskCreatedEvent {
627 task_id,
628 name: task_info.name.clone(),
629 category: task_info.category,
630 priority: task_info.priority,
631 timestamp: Time::now(),
632 source: "task_manager".to_string(),
633 metadata: task_info.metadata.clone(),
634 };
635 let _ = event_bus.publish(event).await;
636 }
637
638 tracing::info!("Task {} submitted successfully", task_id);
639 Ok(task_id)
640 }
641
642 pub async fn cancel_task(&self, task_id: Uuid) -> Result<bool> {
643 if let Some(mut task) = self.tasks.get_mut(&task_id) {
644 if !task.info.cancellable {
645 return Err(Error::task(Some(task_id), None, "Task is not cancellable"));
646 }
647
648 if task.info.is_terminal() {
649 return Ok(false);
650 }
651
652 task.cancellation_token.cancel();
654 task.info.status = TaskStatus::Cancelled;
655 task.info.completed_at = Some(Time::now());
656
657 {
659 let mut stats = self.stats.write().await;
660 stats.total_cancelled += 1;
661 if task.info.status == TaskStatus::Running {
662 stats.currently_running = stats.currently_running.saturating_sub(1);
663 } else {
664 stats.currently_pending = stats.currently_pending.saturating_sub(1);
665 }
666 }
667
668 self.publish_status_change_event(
670 &task.info,
671 TaskStatus::Running,
672 TaskStatus::Cancelled,
673 )
674 .await;
675
676 Ok(true)
677 } else {
678 Err(Error::task(Some(task_id), None, "Task not found"))
679 }
680 }
681
682 pub async fn get_task_info(&self, task_id: Uuid) -> Option<TaskInfo> {
683 self.tasks.get(&task_id).map(|task| task.info.clone())
684 }
685
686 pub async fn list_tasks(
687 &self,
688 status_filter: Option<TaskStatus>,
689 category_filter: Option<TaskCategory>,
690 limit: Option<usize>,
691 ) -> Vec<TaskInfo> {
692 let tasks: Vec<TaskInfo> = self
693 .tasks
694 .iter()
695 .filter_map(|entry| {
696 let task_info = &entry.value().info;
697
698 if let Some(status) = status_filter {
699 if task_info.status != status {
700 return None;
701 }
702 }
703
704 if let Some(category) = &category_filter {
705 if task_info.category != *category {
706 return None;
707 }
708 }
709
710 Some(task_info.clone())
711 })
712 .collect();
713
714 if let Some(limit) = limit {
715 tasks.into_iter().take(limit).collect()
716 } else {
717 tasks
718 }
719 }
720
721 pub async fn wait_for_task(
722 &self,
723 task_id: Uuid,
724 timeout_duration: Option<Duration>,
725 ) -> Result<TaskInfo> {
726 tracing::info!(
727 "Waiting for task {} with timeout {:?}",
728 task_id,
729 timeout_duration
730 );
731
732 if let Some(task) = self.tasks.get(&task_id) {
733 if task.info.is_terminal() {
734 tracing::info!(
735 "Task {} already completed with status: {:?}",
736 task_id,
737 task.info.status
738 );
739 return Ok(task.info.clone());
740 }
741
742 let progress_receiver = task.progress_sender.subscribe();
743 drop(task); let wait_future = self.wait_for_completion(task_id, progress_receiver);
746
747 if let Some(timeout_duration) = timeout_duration {
748 match timeout(timeout_duration, wait_future).await {
749 Ok(result) => result,
750 Err(_) => {
751 tracing::error!(
752 "Task {} wait timed out after {:?}",
753 task_id,
754 timeout_duration
755 );
756 Err(Error::timeout("Task wait timeout"))
757 }
758 }
759 } else {
760 wait_future.await
761 }
762 } else {
763 Err(Error::task(Some(task_id), None, "Task not found"))
764 }
765 }
766
767 async fn wait_for_completion(
768 &self,
769 task_id: Uuid,
770 mut progress_receiver: broadcast::Receiver<TaskProgress>,
771 ) -> Result<TaskInfo> {
772 loop {
773 if let Some(updated_task) = self.tasks.get(&task_id) {
775 if updated_task.info.is_terminal() {
776 tracing::info!(
777 "Task {} completed with status: {:?}",
778 task_id,
779 updated_task.info.status
780 );
781 return Ok(updated_task.info.clone());
782 }
783 }
784
785 match tokio::time::timeout(Duration::from_millis(500), progress_receiver.recv()).await {
787 Ok(Ok(progress)) => {
788 tracing::debug!(
789 "Task {} progress: {}% - {}",
790 task_id,
791 progress.percent,
792 progress.message
793 );
794 continue;
795 }
796 Ok(Err(_)) => {
797 if let Some(updated_task) = self.tasks.get(&task_id) {
799 if updated_task.info.is_terminal() {
800 return Ok(updated_task.info.clone());
801 }
802 }
803 tracing::warn!("Progress channel closed for task {}", task_id);
804 break;
805 }
806 Err(_) => {
807 if let Some(updated_task) = self.tasks.get(&task_id) {
809 tracing::debug!(
810 "Task {} current status: {:?}",
811 task_id,
812 updated_task.info.status
813 );
814 if updated_task.info.is_terminal() {
815 return Ok(updated_task.info.clone());
816 }
817 }
818 }
819 }
820 }
821
822 Err(Error::task(
823 Some(task_id),
824 None,
825 "Task completion wait failed",
826 ))
827 }
828
829 pub async fn get_stats(&self) -> TaskManagerStats {
830 self.stats.read().await.clone()
831 }
832
833 pub async fn cleanup_old_tasks(&self, max_age: Duration) -> u64 {
834 let cutoff_time = Time::now() - chrono::Duration::from_std(max_age).unwrap_or_default();
835 let mut removed_count = 0u64;
836
837 let task_ids_to_remove: Vec<Uuid> = self
838 .tasks
839 .iter()
840 .filter_map(|entry| {
841 let task_info = &entry.value().info;
842 if task_info.is_terminal() {
843 if let Some(completed_at) = task_info.completed_at {
844 if completed_at < cutoff_time {
845 return Some(task_info.id);
846 }
847 }
848 }
849 None
850 })
851 .collect();
852
853 for task_id in task_ids_to_remove {
854 if self.tasks.remove(&task_id).is_some() {
855 removed_count += 1;
856 }
857 }
858
859 removed_count
860 }
861
862 async fn start_workers(&mut self) -> Result<()> {
863 let worker_count = 4;
864 tracing::info!("Starting {} task workers", worker_count);
865
866 for worker_id in 0..worker_count {
867 let tasks = Arc::clone(&self.tasks);
868 let stats = Arc::clone(&self.stats);
869 let semaphore = Arc::clone(&self.concurrency_semaphore);
870 let event_bus = self.event_bus.clone();
871 let shutdown_flag = Arc::clone(&self.shutdown_flag);
872
873 let handle = tokio::spawn(async move {
874 Self::task_worker(worker_id, tasks, stats, semaphore, event_bus, shutdown_flag)
875 .await;
876 });
877
878 self.worker_handles.push(handle);
879 }
880
881 tracing::info!("Started {} task workers", worker_count);
882 Ok(())
883 }
884
885 async fn task_worker(
886 worker_id: usize,
887 tasks: Arc<DashMap<Uuid, TaskExecution>>,
888 stats: Arc<RwLock<TaskManagerStats>>,
889 semaphore: Arc<Semaphore>,
890 event_bus: Option<Arc<EventBusManager>>,
891 shutdown_flag: Arc<tokio::sync::RwLock<bool>>,
892 ) {
893 tracing::info!("Task worker {} started", worker_id);
894
895 loop {
896 if *shutdown_flag.read().await {
898 tracing::info!("Task worker {} shutting down", worker_id);
899 break;
900 }
901
902 let claimed_task_id = {
904 let mut claimed = None;
905 for mut entry in tasks.iter_mut() {
906 let task_id = *entry.key();
907 let task_ref = entry.value_mut(); if task_ref.info.status == TaskStatus::Pending {
910 task_ref.info.status = TaskStatus::Running;
911 task_ref.info.started_at = Some(Time::now());
912 claimed = Some(task_id);
913 break;
914 }
915 }
916 claimed
917 };
918
919 if let Some(task_id) = claimed_task_id {
920 tracing::info!("Worker {} claimed task {}", worker_id, task_id);
921
922 if let Ok(permit) = semaphore.acquire().await {
924 let task_execution_data = {
926 if let Some(task_entry) = tasks.get(&task_id) {
927 let task = task_entry.value();
928
929 {
931 let mut stats_guard = stats.write().await;
932 stats_guard.currently_pending =
933 stats_guard.currently_pending.saturating_sub(1);
934 stats_guard.currently_running += 1;
935 tracing::debug!(
936 "Stats: {} pending, {} running",
937 stats_guard.currently_pending,
938 stats_guard.currently_running
939 );
940 }
941
942 if let Some(event_bus) = &event_bus {
944 let event = TaskStatusChangedEvent {
945 task_id,
946 name: task.info.name.clone(),
947 old_status: TaskStatus::Pending,
948 new_status: TaskStatus::Running,
949 timestamp: Time::now(),
950 source: "task_manager".to_string(),
951 metadata: task.info.metadata.clone(),
952 };
953 let _ = event_bus.publish(event).await;
954 }
955
956 let context = TaskContext {
958 task_id,
959 name: task.info.name.clone(),
960 category: task.info.category.clone(),
961 plugin_id: task.info.plugin_id.clone(),
962 correlation_id: task.info.correlation_id,
963 progress: Arc::new(TaskProgressReporter {
964 task_id,
965 progress_sender: task.progress_sender.clone(),
966 }),
967 cancellation_token: task.cancellation_token.clone(),
968 metadata: task.info.metadata.clone(),
969 };
970
971 Some((
972 Arc::clone(&task.definition.function),
973 context,
974 task.info.timeout,
975 task.progress_sender.clone(),
976 ))
977 } else {
978 None
979 }
980 };
981
982 if let Some((function, context, task_timeout, progress_sender)) =
983 task_execution_data
984 {
985 let start_time = Instant::now();
987 tracing::info!(
988 "Worker {} executing task {} with timeout {:?}",
989 worker_id,
990 task_id,
991 task_timeout
992 );
993
994 let future = function(context);
996
997 let execution_result = timeout(task_timeout, future).await;
999 let execution_duration = start_time.elapsed();
1000
1001 tracing::info!(
1002 "Task {} execution completed in {:?}",
1003 task_id,
1004 execution_duration
1005 );
1006
1007 let (new_status, result) = match execution_result {
1009 Ok(Ok(data)) => {
1010 tracing::info!("Task {} completed successfully", task_id);
1011 let result = TaskResult {
1012 success: true,
1013 data: Some(data),
1014 error: None,
1015 duration: execution_duration,
1016 resource_usage: ResourceUsage::default(),
1017 metadata: HashMap::new(),
1018 };
1019 (TaskStatus::Completed, Some(result))
1020 }
1021 Ok(Err(error)) => {
1022 tracing::error!("Task {} failed: {}", task_id, error);
1023 let result = TaskResult {
1024 success: false,
1025 data: None,
1026 error: Some(error.to_string()),
1027 duration: execution_duration,
1028 resource_usage: ResourceUsage::default(),
1029 metadata: HashMap::new(),
1030 };
1031 (TaskStatus::Failed, Some(result))
1032 }
1033 Err(_) => {
1034 tracing::error!("Task {} timed out", task_id);
1035 let result = TaskResult {
1036 success: false,
1037 data: None,
1038 error: Some("Task timed out".to_string()),
1039 duration: execution_duration,
1040 resource_usage: ResourceUsage::default(),
1041 metadata: HashMap::new(),
1042 };
1043 (TaskStatus::TimedOut, Some(result))
1044 }
1045 };
1046
1047 if let Some(mut task_entry) = tasks.get_mut(&task_id) {
1049 let task = task_entry.value_mut();
1050 task.info.status = new_status;
1051 task.info.completed_at = Some(Time::now());
1052 task.info.result = result;
1053
1054 let final_progress = TaskProgress::new(100, "Task completed");
1056 let _ = progress_sender.send(final_progress);
1057
1058 tracing::info!("Task {} final status: {:?}", task_id, new_status);
1059 }
1060
1061 {
1063 let mut stats_guard = stats.write().await;
1064 stats_guard.currently_running =
1065 stats_guard.currently_running.saturating_sub(1);
1066 match new_status {
1067 TaskStatus::Completed => stats_guard.total_completed += 1,
1068 TaskStatus::Failed => stats_guard.total_failed += 1,
1069 TaskStatus::TimedOut => stats_guard.total_failed += 1,
1070 _ => {}
1071 }
1072
1073 let total_completed =
1075 stats_guard.total_completed + stats_guard.total_failed;
1076 if total_completed > 0 {
1077 stats_guard.avg_execution_time_ms = (stats_guard
1078 .avg_execution_time_ms
1079 * (total_completed - 1) as f64
1080 + execution_duration.as_millis() as f64)
1081 / total_completed as f64;
1082 }
1083
1084 tracing::debug!(
1085 "Updated stats: {} completed, {} failed, {} running",
1086 stats_guard.total_completed,
1087 stats_guard.total_failed,
1088 stats_guard.currently_running
1089 );
1090 }
1091
1092 if let Some(event_bus) = &event_bus {
1094 let event = TaskStatusChangedEvent {
1095 task_id,
1096 name: tasks
1097 .get(&task_id)
1098 .map(|t| t.info.name.clone())
1099 .unwrap_or_default(),
1100 old_status: TaskStatus::Running,
1101 new_status,
1102 timestamp: Time::now(),
1103 source: "task_manager".to_string(),
1104 metadata: HashMap::new(),
1105 };
1106 let _ = event_bus.publish(event).await;
1107 }
1108 }
1109
1110 drop(permit); }
1112 } else {
1113 tokio::time::sleep(Duration::from_millis(100)).await;
1115 }
1116 }
1117
1118 tracing::info!("Task worker {} stopped", worker_id);
1119 }
1120
1121 async fn publish_status_change_event(
1122 &self,
1123 task_info: &TaskInfo,
1124 old_status: TaskStatus,
1125 new_status: TaskStatus,
1126 ) {
1127 if let Some(event_bus) = &self.event_bus {
1128 let event = TaskStatusChangedEvent {
1129 task_id: task_info.id,
1130 name: task_info.name.clone(),
1131 old_status,
1132 new_status,
1133 timestamp: Time::now(),
1134 source: "task_manager".to_string(),
1135 metadata: task_info.metadata.clone(),
1136 };
1137 let _ = event_bus.publish(event).await;
1138 }
1139 }
1140
1141 async fn stop_workers(&mut self) {
1142 tracing::info!("Stopping task workers");
1143
1144 *self.shutdown_flag.write().await = true;
1146
1147 for handle in self.worker_handles.drain(..) {
1149 let _ = tokio::time::timeout(Duration::from_secs(5), handle).await;
1150 }
1151
1152 tracing::info!("All task workers stopped");
1153 }
1154}
1155
1156#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1157#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1158impl Manager for TaskManager {
1159 fn name(&self) -> &str {
1160 "task_manager"
1161 }
1162
1163 fn id(&self) -> Uuid {
1164 Uuid::new_v4() }
1166
1167 async fn initialize(&mut self) -> Result<()> {
1168 self.state
1169 .set_state(crate::manager::ManagerState::Initializing)
1170 .await;
1171
1172 tracing::info!("Initializing task manager");
1173
1174 self.start_workers().await?;
1176
1177 self.state
1178 .set_state(crate::manager::ManagerState::Running)
1179 .await;
1180 tracing::info!("Task manager initialized successfully");
1181 Ok(())
1182 }
1183
1184 async fn shutdown(&mut self) -> Result<()> {
1185 self.state
1186 .set_state(crate::manager::ManagerState::ShuttingDown)
1187 .await;
1188
1189 tracing::info!("Shutting down task manager");
1190
1191 let running_tasks: Vec<Uuid> = self
1193 .tasks
1194 .iter()
1195 .filter_map(|entry| {
1196 let task = entry.value();
1197 if task.info.status == TaskStatus::Running && task.info.cancellable {
1198 Some(task.info.id)
1199 } else {
1200 None
1201 }
1202 })
1203 .collect();
1204
1205 for task_id in running_tasks {
1206 let _ = self.cancel_task(task_id).await;
1207 }
1208
1209 self.stop_workers().await;
1211
1212 self.cleanup_old_tasks(Duration::from_secs(0)).await;
1214
1215 self.state
1216 .set_state(crate::manager::ManagerState::Shutdown)
1217 .await;
1218 tracing::info!("Task manager shutdown complete");
1219 Ok(())
1220 }
1221
1222 async fn status(&self) -> ManagerStatus {
1223 let mut status = self.state.status().await;
1224 let stats = self.get_stats().await;
1225
1226 status.add_metadata("total_tasks", serde_json::Value::from(stats.total_created));
1227 status.add_metadata(
1228 "completed_tasks",
1229 serde_json::Value::from(stats.total_completed),
1230 );
1231 status.add_metadata("failed_tasks", serde_json::Value::from(stats.total_failed));
1232 status.add_metadata(
1233 "running_tasks",
1234 serde_json::Value::from(stats.currently_running),
1235 );
1236 status.add_metadata(
1237 "pending_tasks",
1238 serde_json::Value::from(stats.currently_pending),
1239 );
1240 status.add_metadata(
1241 "avg_execution_time_ms",
1242 serde_json::Value::from(stats.avg_execution_time_ms),
1243 );
1244
1245 status
1246 }
1247}
1248
1249pub struct TaskBuilder {
1250 name: String,
1251 category: TaskCategory,
1252 priority: TaskPriority,
1253 plugin_id: Option<String>,
1254 dependencies: Vec<Uuid>,
1255 timeout: Duration,
1256 max_retries: u32,
1257 cancellable: bool,
1258 metadata: Metadata,
1259 correlation_id: Option<CorrelationId>,
1260}
1261
1262impl TaskBuilder {
1263 pub fn new(name: impl Into<String>) -> Self {
1264 Self {
1265 name: name.into(),
1266 category: TaskCategory::Core,
1267 priority: TaskPriority::Normal,
1268 plugin_id: None,
1269 dependencies: Vec::new(),
1270 timeout: Duration::from_secs(300), max_retries: 0,
1272 cancellable: true,
1273 metadata: HashMap::new(),
1274 correlation_id: None,
1275 }
1276 }
1277
1278 pub fn category(mut self, category: TaskCategory) -> Self {
1279 self.category = category;
1280 self
1281 }
1282
1283 pub fn priority(mut self, priority: TaskPriority) -> Self {
1284 self.priority = priority;
1285 self
1286 }
1287
1288 pub fn plugin_id(mut self, plugin_id: impl Into<String>) -> Self {
1289 self.plugin_id = Some(plugin_id.into());
1290 self
1291 }
1292
1293 pub fn dependency(mut self, task_id: Uuid) -> Self {
1294 self.dependencies.push(task_id);
1295 self
1296 }
1297
1298 pub fn timeout(mut self, timeout: Duration) -> Self {
1299 self.timeout = timeout;
1300 self
1301 }
1302
1303 pub fn max_retries(mut self, max_retries: u32) -> Self {
1304 self.max_retries = max_retries;
1305 self
1306 }
1307
1308 pub fn cancellable(mut self, cancellable: bool) -> Self {
1309 self.cancellable = cancellable;
1310 self
1311 }
1312
1313 pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
1314 self.metadata.insert(key.into(), value);
1315 self
1316 }
1317
1318 pub fn correlation_id(mut self, correlation_id: CorrelationId) -> Self {
1319 self.correlation_id = Some(correlation_id);
1320 self
1321 }
1322
1323 pub fn build<F, Fut>(self, function: F) -> TaskDefinition
1324 where
1325 F: Fn(TaskContext) -> Fut + Send + Sync + 'static,
1326 Fut: Future<Output = Result<serde_json::Value>> + Send + 'static,
1327 {
1328 let task_function: TaskFunction = Arc::new(move |ctx| Box::pin(function(ctx)));
1329
1330 TaskDefinition {
1331 id: Uuid::new_v4(),
1332 name: self.name,
1333 category: self.category,
1334 priority: self.priority,
1335 plugin_id: self.plugin_id,
1336 dependencies: self.dependencies,
1337 timeout: self.timeout,
1338 max_retries: self.max_retries,
1339 cancellable: self.cancellable,
1340 metadata: self.metadata,
1341 correlation_id: self.correlation_id,
1342 function: task_function,
1343 }
1344 }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::*;
1350
1351 #[tokio::test]
1352 async fn test_task_manager_initialization() {
1353 let config = TaskConfig::default();
1354 let mut manager = TaskManager::new(config);
1355
1356 manager.initialize().await.unwrap();
1357
1358 let status = manager.status().await;
1359 assert_eq!(status.state, crate::manager::ManagerState::Running);
1360
1361 manager.shutdown().await.unwrap();
1362 }
1363
1364 #[tokio::test]
1365 async fn test_task_builder() {
1366 let task = TaskBuilder::new("test_task")
1367 .category(TaskCategory::User)
1368 .priority(TaskPriority::High)
1369 .timeout(Duration::from_secs(60))
1370 .cancellable(true)
1371 .metadata(
1372 "key".to_string(),
1373 serde_json::Value::String("value".to_string()),
1374 )
1375 .build(|_ctx| async { Ok(serde_json::Value::String("completed".to_string())) });
1376
1377 assert_eq!(task.name, "test_task");
1378 assert_eq!(task.category, TaskCategory::User);
1379 assert_eq!(task.priority, TaskPriority::High);
1380 assert_eq!(task.timeout, Duration::from_secs(60));
1381 assert!(task.cancellable);
1382 assert!(task.metadata.contains_key("key"));
1383 }
1384
1385 #[tokio::test]
1386 async fn test_task_submission_and_execution() {
1387 let config = TaskConfig::default();
1388 let mut manager = TaskManager::new(config);
1389 manager.initialize().await.unwrap();
1390
1391 let task = TaskBuilder::new("test_task")
1392 .timeout(Duration::from_secs(10))
1393 .build(|ctx| async move {
1394 ctx.report_percent(50, "Half way done");
1395 tokio::time::sleep(Duration::from_millis(100)).await;
1396 ctx.report_percent(100, "Complete");
1397 Ok(serde_json::Value::String("completed".to_string()))
1398 });
1399
1400 let task_id = manager.submit_task(task).await.unwrap();
1401
1402 let task_info = manager
1403 .wait_for_task(task_id, Some(Duration::from_secs(5)))
1404 .await
1405 .unwrap();
1406 assert_eq!(task_info.name, "test_task");
1407 assert_eq!(task_info.status, TaskStatus::Completed);
1408
1409 manager.shutdown().await.unwrap();
1410 }
1411
1412 #[tokio::test]
1413 async fn test_task_progress() {
1414 let mut progress = TaskProgress::default();
1415 assert_eq!(progress.percent, 0);
1416
1417 progress.set_percent(50);
1418 assert_eq!(progress.percent, 50);
1419
1420 progress.set_message("Half complete");
1421 assert_eq!(progress.message, "Half complete");
1422
1423 let step_progress = TaskProgress::with_steps(5, 10, "Processing step 5");
1424 assert_eq!(step_progress.percent, 50);
1425 assert_eq!(step_progress.current_step, Some(5));
1426 assert_eq!(step_progress.total_steps, Some(10));
1427 }
1428
1429 #[test]
1430 fn test_task_category_display() {
1431 assert_eq!(TaskCategory::Core.to_string(), "core");
1432 assert_eq!(TaskCategory::Plugin.to_string(), "plugin");
1433 assert_eq!(
1434 TaskCategory::Custom("custom".to_string()).to_string(),
1435 "custom"
1436 );
1437 }
1438}