qorzen_oxide/
task.rs

1// src/task.rs
2
3//! Async task management system with progress tracking and lifecycle management
4
5use std::collections::HashMap;
6use std::fmt;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::AtomicU64;
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::utils::Time;
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use dashmap::DashMap;
17use serde::{Deserialize, Serialize};
18use tokio::sync::{broadcast, RwLock, Semaphore};
19use tokio::time::{timeout, Instant};
20use uuid::Uuid;
21
22use crate::config::TaskConfig;
23use crate::error::{Error, Result};
24use crate::event::{Event, EventBusManager};
25use crate::manager::{ManagedState, Manager, ManagerStatus};
26use crate::types::{CorrelationId, Metadata};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum TaskStatus {
30    Pending,
31    Running,
32    Completed,
33    Failed,
34    Cancelled,
35    TimedOut,
36    Paused,
37}
38
39impl fmt::Display for TaskStatus {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            Self::Pending => write!(f, "PENDING"),
43            Self::Running => write!(f, "RUNNING"),
44            Self::Completed => write!(f, "COMPLETED"),
45            Self::Failed => write!(f, "FAILED"),
46            Self::Cancelled => write!(f, "CANCELLED"),
47            Self::TimedOut => write!(f, "TIMED_OUT"),
48            Self::Paused => write!(f, "PAUSED"),
49        }
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
54pub enum TaskPriority {
55    Low = 0,
56    Normal = 50,
57    High = 100,
58    Critical = 200,
59}
60
61impl Default for TaskPriority {
62    fn default() -> Self {
63        Self::Normal
64    }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum TaskCategory {
69    Core,
70    Plugin,
71    Ui,
72    Io,
73    Background,
74    User,
75    Maintenance,
76    Custom(String),
77}
78
79impl fmt::Display for TaskCategory {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        match self {
82            Self::Core => write!(f, "core"),
83            Self::Plugin => write!(f, "plugin"),
84            Self::Ui => write!(f, "ui"),
85            Self::Io => write!(f, "io"),
86            Self::Background => write!(f, "background"),
87            Self::User => write!(f, "user"),
88            Self::Maintenance => write!(f, "maintenance"),
89            Self::Custom(name) => write!(f, "{}", name),
90        }
91    }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct TaskProgress {
96    pub percent: u8,
97    pub message: String,
98    pub current_step: Option<u32>,
99    pub total_steps: Option<u32>,
100    pub updated_at: DateTime<Utc>,
101    pub metadata: Metadata,
102}
103
104impl Default for TaskProgress {
105    fn default() -> Self {
106        Self {
107            percent: 0,
108            message: String::new(),
109            current_step: None,
110            total_steps: None,
111            updated_at: Time::now(),
112            metadata: HashMap::new(),
113        }
114    }
115}
116
117impl TaskProgress {
118    pub fn new(percent: u8, message: impl Into<String>) -> Self {
119        Self {
120            percent: percent.min(100),
121            message: message.into(),
122            updated_at: Time::now(),
123            ..Default::default()
124        }
125    }
126
127    pub fn with_steps(current: u32, total: u32, message: impl Into<String>) -> Self {
128        let percent = if total > 0 {
129            ((current as f64 / total as f64) * 100.0) as u8
130        } else {
131            0
132        };
133
134        Self {
135            percent,
136            message: message.into(),
137            current_step: Some(current),
138            total_steps: Some(total),
139            updated_at: Time::now(),
140            ..Default::default()
141        }
142    }
143
144    pub fn set_percent(&mut self, percent: u8) {
145        self.percent = percent.min(100);
146        self.updated_at = Time::now();
147    }
148
149    pub fn set_message(&mut self, message: impl Into<String>) {
150        self.message = message.into();
151        self.updated_at = Time::now();
152    }
153
154    pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
155        self.metadata.insert(key.into(), value);
156        self.updated_at = Time::now();
157    }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct TaskResult {
162    pub success: bool,
163    pub data: Option<serde_json::Value>,
164    pub error: Option<String>,
165    pub duration: Duration,
166    pub resource_usage: ResourceUsage,
167    pub metadata: Metadata,
168}
169
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
171pub struct ResourceUsage {
172    pub peak_memory_bytes: u64,
173    pub cpu_time_ms: u64,
174    pub file_operations: u32,
175    pub network_bytes: u64,
176}
177
178pub trait ProgressReporter: Send + Sync + fmt::Debug {
179    fn report(&self, progress: TaskProgress);
180    fn report_percent(&self, percent: u8, message: String) {
181        self.report(TaskProgress::new(percent, message));
182    }
183    fn report_step(&self, current: u32, total: u32, message: String) {
184        self.report(TaskProgress::with_steps(current, total, message));
185    }
186}
187
188// Platform-specific cancellation token implementation
189#[derive(Debug, Clone)]
190pub struct CancellationToken {
191    #[cfg(not(target_arch = "wasm32"))]
192    inner: tokio_util::sync::CancellationToken,
193    #[cfg(target_arch = "wasm32")]
194    cancelled: Arc<std::sync::atomic::AtomicBool>,
195}
196
197impl Default for CancellationToken {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl CancellationToken {
204    pub fn new() -> Self {
205        Self {
206            #[cfg(not(target_arch = "wasm32"))]
207            inner: tokio_util::sync::CancellationToken::new(),
208            #[cfg(target_arch = "wasm32")]
209            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
210        }
211    }
212
213    #[allow(dead_code)]
214    fn default() -> Self {
215        Self::new()
216    }
217
218    pub fn cancel(&self) {
219        #[cfg(not(target_arch = "wasm32"))]
220        self.inner.cancel();
221        #[cfg(target_arch = "wasm32")]
222        self.cancelled.store(true, Ordering::SeqCst);
223    }
224
225    pub fn is_cancelled(&self) -> bool {
226        #[cfg(not(target_arch = "wasm32"))]
227        return self.inner.is_cancelled();
228        #[cfg(target_arch = "wasm32")]
229        return self.cancelled.load(Ordering::SeqCst);
230    }
231
232    pub async fn cancelled(&self) {
233        #[cfg(not(target_arch = "wasm32"))]
234        self.inner.cancelled().await;
235        #[cfg(target_arch = "wasm32")]
236        {
237            // Simple polling implementation for WASM
238            while !self.is_cancelled() {
239                tokio::time::sleep(Duration::from_millis(10)).await;
240            }
241        }
242    }
243}
244
245#[derive(Debug)]
246pub struct TaskContext {
247    pub task_id: Uuid,
248    pub name: String,
249    pub category: TaskCategory,
250    pub plugin_id: Option<String>,
251    pub correlation_id: Option<CorrelationId>,
252    pub progress: Arc<dyn ProgressReporter>,
253    pub cancellation_token: CancellationToken,
254    pub metadata: Metadata,
255}
256
257impl TaskContext {
258    pub fn is_cancelled(&self) -> bool {
259        self.cancellation_token.is_cancelled()
260    }
261
262    pub async fn cancelled(&self) {
263        self.cancellation_token.cancelled().await;
264    }
265
266    pub fn report_progress(&self, progress: TaskProgress) {
267        self.progress.report(progress);
268    }
269
270    pub fn report_percent(&self, percent: u8, message: impl Into<String>) {
271        self.progress.report_percent(percent, message.into());
272    }
273
274    pub fn report_step(&self, current: u32, total: u32, message: impl Into<String>) {
275        self.progress.report_step(current, total, message.into());
276    }
277}
278
279pub type TaskFunction = Arc<
280    dyn Fn(TaskContext) -> Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send>>
281        + Send
282        + Sync,
283>;
284
285pub struct TaskDefinition {
286    pub id: Uuid,
287    pub name: String,
288    pub category: TaskCategory,
289    pub priority: TaskPriority,
290    pub plugin_id: Option<String>,
291    pub dependencies: Vec<Uuid>,
292    pub timeout: Duration,
293    pub max_retries: u32,
294    pub cancellable: bool,
295    pub metadata: Metadata,
296    pub correlation_id: Option<CorrelationId>,
297    pub function: TaskFunction,
298}
299
300impl fmt::Debug for TaskDefinition {
301    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302        f.debug_struct("TaskDefinition")
303            .field("id", &self.id)
304            .field("name", &self.name)
305            .field("category", &self.category)
306            .field("priority", &self.priority)
307            .field("plugin_id", &self.plugin_id)
308            .field("dependencies", &self.dependencies)
309            .field("timeout", &self.timeout)
310            .field("max_retries", &self.max_retries)
311            .field("cancellable", &self.cancellable)
312            .field("metadata", &self.metadata)
313            .field("correlation_id", &self.correlation_id)
314            .field("function", &"<function>")
315            .finish()
316    }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct TaskInfo {
321    pub id: Uuid,
322    pub name: String,
323    pub category: TaskCategory,
324    pub priority: TaskPriority,
325    pub status: TaskStatus,
326    pub plugin_id: Option<String>,
327    pub dependencies: Vec<Uuid>,
328    pub created_at: DateTime<Utc>,
329    pub started_at: Option<DateTime<Utc>>,
330    pub completed_at: Option<DateTime<Utc>>,
331    pub progress: TaskProgress,
332    pub result: Option<TaskResult>,
333    pub retry_count: u32,
334    pub max_retries: u32,
335    pub timeout: Duration,
336    pub cancellable: bool,
337    pub correlation_id: Option<CorrelationId>,
338    pub metadata: Metadata,
339}
340
341impl TaskInfo {
342    pub fn from_definition(definition: &TaskDefinition) -> Self {
343        Self {
344            id: definition.id,
345            name: definition.name.clone(),
346            category: definition.category.clone(),
347            priority: definition.priority,
348            status: TaskStatus::Pending,
349            plugin_id: definition.plugin_id.clone(),
350            dependencies: definition.dependencies.clone(),
351            created_at: Time::now(),
352            started_at: None,
353            completed_at: None,
354            progress: TaskProgress::default(),
355            result: None,
356            retry_count: 0,
357            max_retries: definition.max_retries,
358            timeout: definition.timeout,
359            cancellable: definition.cancellable,
360            correlation_id: definition.correlation_id,
361            metadata: definition.metadata.clone(),
362        }
363    }
364
365    pub fn duration(&self) -> Option<Duration> {
366        if let (Some(started), Some(completed)) = (self.started_at, self.completed_at) {
367            Some((completed - started).to_std().ok()?)
368        } else {
369            None
370        }
371    }
372
373    pub fn is_terminal(&self) -> bool {
374        matches!(
375            self.status,
376            TaskStatus::Completed
377                | TaskStatus::Failed
378                | TaskStatus::Cancelled
379                | TaskStatus::TimedOut
380        )
381    }
382
383    pub fn can_retry(&self) -> bool {
384        matches!(self.status, TaskStatus::Failed | TaskStatus::TimedOut)
385            && self.retry_count < self.max_retries
386    }
387}
388
389#[derive(Debug)]
390struct TaskExecution {
391    info: TaskInfo,
392    definition: TaskDefinition,
393    cancellation_token: CancellationToken,
394    progress_sender: broadcast::Sender<TaskProgress>,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct TaskManagerStats {
399    pub total_created: u64,
400    pub total_completed: u64,
401    pub total_failed: u64,
402    pub total_cancelled: u64,
403    pub currently_running: u32,
404    pub currently_pending: u32,
405    pub by_category: HashMap<String, u64>,
406    pub by_priority: HashMap<TaskPriority, u64>,
407    pub avg_execution_time_ms: f64,
408    pub total_resource_usage: ResourceUsage,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct TaskCreatedEvent {
413    pub task_id: Uuid,
414    pub name: String,
415    pub category: TaskCategory,
416    pub priority: TaskPriority,
417    pub timestamp: DateTime<Utc>,
418    pub source: String,
419    pub metadata: Metadata,
420}
421
422impl Event for TaskCreatedEvent {
423    fn event_type(&self) -> &'static str {
424        "task.created"
425    }
426
427    fn source(&self) -> &str {
428        &self.source
429    }
430
431    fn metadata(&self) -> &Metadata {
432        &self.metadata
433    }
434
435    fn as_any(&self) -> &dyn std::any::Any {
436        self
437    }
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct TaskStatusChangedEvent {
442    pub task_id: Uuid,
443    pub name: String,
444    pub old_status: TaskStatus,
445    pub new_status: TaskStatus,
446    pub timestamp: DateTime<Utc>,
447    pub source: String,
448    pub metadata: Metadata,
449}
450
451impl Event for TaskStatusChangedEvent {
452    fn event_type(&self) -> &'static str {
453        "task.status_changed"
454    }
455
456    fn source(&self) -> &str {
457        &self.source
458    }
459
460    fn metadata(&self) -> &Metadata {
461        &self.metadata
462    }
463
464    fn as_any(&self) -> &dyn std::any::Any {
465        self
466    }
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
470pub struct TaskProgressEvent {
471    pub task_id: Uuid,
472    pub name: String,
473    pub progress: TaskProgress,
474    pub timestamp: DateTime<Utc>,
475    pub source: String,
476    pub metadata: Metadata,
477}
478
479impl Event for TaskProgressEvent {
480    fn event_type(&self) -> &'static str {
481        "task.progress"
482    }
483
484    fn source(&self) -> &str {
485        &self.source
486    }
487
488    fn metadata(&self) -> &Metadata {
489        &self.metadata
490    }
491
492    fn as_any(&self) -> &dyn std::any::Any {
493        self
494    }
495}
496
497#[derive(Debug)]
498struct TaskProgressReporter {
499    task_id: Uuid,
500    progress_sender: broadcast::Sender<TaskProgress>,
501}
502
503impl ProgressReporter for TaskProgressReporter {
504    fn report(&self, progress: TaskProgress) {
505        tracing::debug!(
506            "Task {} progress: {}% - {}",
507            self.task_id,
508            progress.percent,
509            progress.message
510        );
511        let _ = self.progress_sender.send(progress);
512    }
513}
514
515#[derive(Debug)]
516pub struct TaskManager {
517    state: ManagedState,
518    #[allow(dead_code)]
519    config: TaskConfig,
520    tasks: Arc<DashMap<Uuid, TaskExecution>>,
521    stats: Arc<RwLock<TaskManagerStats>>,
522    #[allow(dead_code)]
523    task_counter: Arc<AtomicU64>,
524    concurrency_semaphore: Arc<Semaphore>,
525    event_bus: Option<Arc<EventBusManager>>,
526    worker_handles: Vec<tokio::task::JoinHandle<()>>,
527    shutdown_flag: Arc<tokio::sync::RwLock<bool>>,
528}
529
530impl TaskManager {
531    pub fn new(config: TaskConfig) -> Self {
532        let max_concurrent = config.max_concurrent;
533        Self {
534            state: ManagedState::new(Uuid::new_v4(), "task_manager"),
535            config,
536            tasks: Arc::new(DashMap::new()),
537            stats: Arc::new(RwLock::new(TaskManagerStats {
538                total_created: 0,
539                total_completed: 0,
540                total_failed: 0,
541                total_cancelled: 0,
542                currently_running: 0,
543                currently_pending: 0,
544                by_category: HashMap::new(),
545                by_priority: HashMap::new(),
546                avg_execution_time_ms: 0.0,
547                total_resource_usage: ResourceUsage::default(),
548            })),
549            task_counter: Arc::new(AtomicU64::new(0)),
550            concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
551            event_bus: None,
552            worker_handles: Vec::new(),
553            shutdown_flag: Arc::new(tokio::sync::RwLock::new(false)),
554        }
555    }
556
557    pub fn set_event_bus(&mut self, event_bus: Arc<EventBusManager>) {
558        self.event_bus = Some(event_bus);
559    }
560
561    pub async fn submit_task(&self, definition: TaskDefinition) -> Result<Uuid> {
562        let task_id = definition.id;
563        let task_info = TaskInfo::from_definition(&definition);
564
565        tracing::info!("Submitting task {} ({})", task_info.name, task_id);
566
567        // Check dependencies
568        for dep_id in &task_info.dependencies {
569            if let Some(dep_task) = self.tasks.get(dep_id) {
570                if !dep_task.info.is_terminal() {
571                    return Err(Error::task(
572                        Some(task_id),
573                        Some(format!("dep_{}", dep_id)),
574                        format!("Dependency task {} is not completed", dep_id),
575                    ));
576                }
577                if dep_task.info.status != TaskStatus::Completed {
578                    return Err(Error::task(
579                        Some(task_id),
580                        Some(format!("dep_{}", dep_id)),
581                        format!("Dependency task {} failed", dep_id),
582                    ));
583                }
584            } else {
585                return Err(Error::task(
586                    Some(task_id),
587                    Some(format!("dep_{}", dep_id)),
588                    format!("Dependency task {} not found", dep_id),
589                ));
590            }
591        }
592
593        let (progress_sender, _) = broadcast::channel(100);
594        let cancellation_token = CancellationToken::new();
595
596        let execution = TaskExecution {
597            info: task_info.clone(),
598            definition,
599            cancellation_token,
600            progress_sender,
601        };
602
603        // Add to tasks collection
604        self.tasks.insert(task_id, execution);
605        tracing::debug!(
606            "Task {} added to collection, total tasks: {}",
607            task_id,
608            self.tasks.len()
609        );
610
611        // Update statistics
612        {
613            let mut stats = self.stats.write().await;
614            stats.total_created += 1;
615            stats.currently_pending += 1;
616            *stats
617                .by_category
618                .entry(task_info.category.to_string())
619                .or_insert(0) += 1;
620            *stats.by_priority.entry(task_info.priority).or_insert(0) += 1;
621            tracing::debug!("Stats updated: {} pending tasks", stats.currently_pending);
622        }
623
624        // Publish task created event
625        if let Some(event_bus) = &self.event_bus {
626            let event = TaskCreatedEvent {
627                task_id,
628                name: task_info.name.clone(),
629                category: task_info.category,
630                priority: task_info.priority,
631                timestamp: Time::now(),
632                source: "task_manager".to_string(),
633                metadata: task_info.metadata.clone(),
634            };
635            let _ = event_bus.publish(event).await;
636        }
637
638        tracing::info!("Task {} submitted successfully", task_id);
639        Ok(task_id)
640    }
641
642    pub async fn cancel_task(&self, task_id: Uuid) -> Result<bool> {
643        if let Some(mut task) = self.tasks.get_mut(&task_id) {
644            if !task.info.cancellable {
645                return Err(Error::task(Some(task_id), None, "Task is not cancellable"));
646            }
647
648            if task.info.is_terminal() {
649                return Ok(false);
650            }
651
652            // Cancel the task
653            task.cancellation_token.cancel();
654            task.info.status = TaskStatus::Cancelled;
655            task.info.completed_at = Some(Time::now());
656
657            // Update statistics
658            {
659                let mut stats = self.stats.write().await;
660                stats.total_cancelled += 1;
661                if task.info.status == TaskStatus::Running {
662                    stats.currently_running = stats.currently_running.saturating_sub(1);
663                } else {
664                    stats.currently_pending = stats.currently_pending.saturating_sub(1);
665                }
666            }
667
668            // Publish status change event
669            self.publish_status_change_event(
670                &task.info,
671                TaskStatus::Running,
672                TaskStatus::Cancelled,
673            )
674            .await;
675
676            Ok(true)
677        } else {
678            Err(Error::task(Some(task_id), None, "Task not found"))
679        }
680    }
681
682    pub async fn get_task_info(&self, task_id: Uuid) -> Option<TaskInfo> {
683        self.tasks.get(&task_id).map(|task| task.info.clone())
684    }
685
686    pub async fn list_tasks(
687        &self,
688        status_filter: Option<TaskStatus>,
689        category_filter: Option<TaskCategory>,
690        limit: Option<usize>,
691    ) -> Vec<TaskInfo> {
692        let tasks: Vec<TaskInfo> = self
693            .tasks
694            .iter()
695            .filter_map(|entry| {
696                let task_info = &entry.value().info;
697
698                if let Some(status) = status_filter {
699                    if task_info.status != status {
700                        return None;
701                    }
702                }
703
704                if let Some(category) = &category_filter {
705                    if task_info.category != *category {
706                        return None;
707                    }
708                }
709
710                Some(task_info.clone())
711            })
712            .collect();
713
714        if let Some(limit) = limit {
715            tasks.into_iter().take(limit).collect()
716        } else {
717            tasks
718        }
719    }
720
721    pub async fn wait_for_task(
722        &self,
723        task_id: Uuid,
724        timeout_duration: Option<Duration>,
725    ) -> Result<TaskInfo> {
726        tracing::info!(
727            "Waiting for task {} with timeout {:?}",
728            task_id,
729            timeout_duration
730        );
731
732        if let Some(task) = self.tasks.get(&task_id) {
733            if task.info.is_terminal() {
734                tracing::info!(
735                    "Task {} already completed with status: {:?}",
736                    task_id,
737                    task.info.status
738                );
739                return Ok(task.info.clone());
740            }
741
742            let progress_receiver = task.progress_sender.subscribe();
743            drop(task); // Release the DashMap reference
744
745            let wait_future = self.wait_for_completion(task_id, progress_receiver);
746
747            if let Some(timeout_duration) = timeout_duration {
748                match timeout(timeout_duration, wait_future).await {
749                    Ok(result) => result,
750                    Err(_) => {
751                        tracing::error!(
752                            "Task {} wait timed out after {:?}",
753                            task_id,
754                            timeout_duration
755                        );
756                        Err(Error::timeout("Task wait timeout"))
757                    }
758                }
759            } else {
760                wait_future.await
761            }
762        } else {
763            Err(Error::task(Some(task_id), None, "Task not found"))
764        }
765    }
766
767    async fn wait_for_completion(
768        &self,
769        task_id: Uuid,
770        mut progress_receiver: broadcast::Receiver<TaskProgress>,
771    ) -> Result<TaskInfo> {
772        loop {
773            // Check if task is completed
774            if let Some(updated_task) = self.tasks.get(&task_id) {
775                if updated_task.info.is_terminal() {
776                    tracing::info!(
777                        "Task {} completed with status: {:?}",
778                        task_id,
779                        updated_task.info.status
780                    );
781                    return Ok(updated_task.info.clone());
782                }
783            }
784
785            // Wait for progress update or timeout
786            match tokio::time::timeout(Duration::from_millis(500), progress_receiver.recv()).await {
787                Ok(Ok(progress)) => {
788                    tracing::debug!(
789                        "Task {} progress: {}% - {}",
790                        task_id,
791                        progress.percent,
792                        progress.message
793                    );
794                    continue;
795                }
796                Ok(Err(_)) => {
797                    // Channel closed, check final status
798                    if let Some(updated_task) = self.tasks.get(&task_id) {
799                        if updated_task.info.is_terminal() {
800                            return Ok(updated_task.info.clone());
801                        }
802                    }
803                    tracing::warn!("Progress channel closed for task {}", task_id);
804                    break;
805                }
806                Err(_) => {
807                    // Timeout on progress, check task status anyway
808                    if let Some(updated_task) = self.tasks.get(&task_id) {
809                        tracing::debug!(
810                            "Task {} current status: {:?}",
811                            task_id,
812                            updated_task.info.status
813                        );
814                        if updated_task.info.is_terminal() {
815                            return Ok(updated_task.info.clone());
816                        }
817                    }
818                }
819            }
820        }
821
822        Err(Error::task(
823            Some(task_id),
824            None,
825            "Task completion wait failed",
826        ))
827    }
828
829    pub async fn get_stats(&self) -> TaskManagerStats {
830        self.stats.read().await.clone()
831    }
832
833    pub async fn cleanup_old_tasks(&self, max_age: Duration) -> u64 {
834        let cutoff_time = Time::now() - chrono::Duration::from_std(max_age).unwrap_or_default();
835        let mut removed_count = 0u64;
836
837        let task_ids_to_remove: Vec<Uuid> = self
838            .tasks
839            .iter()
840            .filter_map(|entry| {
841                let task_info = &entry.value().info;
842                if task_info.is_terminal() {
843                    if let Some(completed_at) = task_info.completed_at {
844                        if completed_at < cutoff_time {
845                            return Some(task_info.id);
846                        }
847                    }
848                }
849                None
850            })
851            .collect();
852
853        for task_id in task_ids_to_remove {
854            if self.tasks.remove(&task_id).is_some() {
855                removed_count += 1;
856            }
857        }
858
859        removed_count
860    }
861
862    async fn start_workers(&mut self) -> Result<()> {
863        let worker_count = 4;
864        tracing::info!("Starting {} task workers", worker_count);
865
866        for worker_id in 0..worker_count {
867            let tasks = Arc::clone(&self.tasks);
868            let stats = Arc::clone(&self.stats);
869            let semaphore = Arc::clone(&self.concurrency_semaphore);
870            let event_bus = self.event_bus.clone();
871            let shutdown_flag = Arc::clone(&self.shutdown_flag);
872
873            let handle = tokio::spawn(async move {
874                Self::task_worker(worker_id, tasks, stats, semaphore, event_bus, shutdown_flag)
875                    .await;
876            });
877
878            self.worker_handles.push(handle);
879        }
880
881        tracing::info!("Started {} task workers", worker_count);
882        Ok(())
883    }
884
885    async fn task_worker(
886        worker_id: usize,
887        tasks: Arc<DashMap<Uuid, TaskExecution>>,
888        stats: Arc<RwLock<TaskManagerStats>>,
889        semaphore: Arc<Semaphore>,
890        event_bus: Option<Arc<EventBusManager>>,
891        shutdown_flag: Arc<tokio::sync::RwLock<bool>>,
892    ) {
893        tracing::info!("Task worker {} started", worker_id);
894
895        loop {
896            // Check shutdown flag
897            if *shutdown_flag.read().await {
898                tracing::info!("Task worker {} shutting down", worker_id);
899                break;
900            }
901
902            // Try to find and claim a pending task atomically
903            let claimed_task_id = {
904                let mut claimed = None;
905                for mut entry in tasks.iter_mut() {
906                    let task_id = *entry.key();
907                    let task_ref = entry.value_mut(); // ✅ mutable access
908
909                    if task_ref.info.status == TaskStatus::Pending {
910                        task_ref.info.status = TaskStatus::Running;
911                        task_ref.info.started_at = Some(Time::now());
912                        claimed = Some(task_id);
913                        break;
914                    }
915                }
916                claimed
917            };
918
919            if let Some(task_id) = claimed_task_id {
920                tracing::info!("Worker {} claimed task {}", worker_id, task_id);
921
922                // Acquire semaphore permit
923                if let Ok(permit) = semaphore.acquire().await {
924                    // Extract the task data we need for execution
925                    let task_execution_data = {
926                        if let Some(task_entry) = tasks.get(&task_id) {
927                            let task = task_entry.value();
928
929                            // Update statistics
930                            {
931                                let mut stats_guard = stats.write().await;
932                                stats_guard.currently_pending =
933                                    stats_guard.currently_pending.saturating_sub(1);
934                                stats_guard.currently_running += 1;
935                                tracing::debug!(
936                                    "Stats: {} pending, {} running",
937                                    stats_guard.currently_pending,
938                                    stats_guard.currently_running
939                                );
940                            }
941
942                            // Publish status change event
943                            if let Some(event_bus) = &event_bus {
944                                let event = TaskStatusChangedEvent {
945                                    task_id,
946                                    name: task.info.name.clone(),
947                                    old_status: TaskStatus::Pending,
948                                    new_status: TaskStatus::Running,
949                                    timestamp: Time::now(),
950                                    source: "task_manager".to_string(),
951                                    metadata: task.info.metadata.clone(),
952                                };
953                                let _ = event_bus.publish(event).await;
954                            }
955
956                            // Create context for task execution
957                            let context = TaskContext {
958                                task_id,
959                                name: task.info.name.clone(),
960                                category: task.info.category.clone(),
961                                plugin_id: task.info.plugin_id.clone(),
962                                correlation_id: task.info.correlation_id,
963                                progress: Arc::new(TaskProgressReporter {
964                                    task_id,
965                                    progress_sender: task.progress_sender.clone(),
966                                }),
967                                cancellation_token: task.cancellation_token.clone(),
968                                metadata: task.info.metadata.clone(),
969                            };
970
971                            Some((
972                                Arc::clone(&task.definition.function),
973                                context,
974                                task.info.timeout,
975                                task.progress_sender.clone(),
976                            ))
977                        } else {
978                            None
979                        }
980                    };
981
982                    if let Some((function, context, task_timeout, progress_sender)) =
983                        task_execution_data
984                    {
985                        // Execute the task function
986                        let start_time = Instant::now();
987                        tracing::info!(
988                            "Worker {} executing task {} with timeout {:?}",
989                            worker_id,
990                            task_id,
991                            task_timeout
992                        );
993
994                        // Call the function to get the future
995                        let future = function(context);
996
997                        // Execute with timeout
998                        let execution_result = timeout(task_timeout, future).await;
999                        let execution_duration = start_time.elapsed();
1000
1001                        tracing::info!(
1002                            "Task {} execution completed in {:?}",
1003                            task_id,
1004                            execution_duration
1005                        );
1006
1007                        // Update task with result
1008                        let (new_status, result) = match execution_result {
1009                            Ok(Ok(data)) => {
1010                                tracing::info!("Task {} completed successfully", task_id);
1011                                let result = TaskResult {
1012                                    success: true,
1013                                    data: Some(data),
1014                                    error: None,
1015                                    duration: execution_duration,
1016                                    resource_usage: ResourceUsage::default(),
1017                                    metadata: HashMap::new(),
1018                                };
1019                                (TaskStatus::Completed, Some(result))
1020                            }
1021                            Ok(Err(error)) => {
1022                                tracing::error!("Task {} failed: {}", task_id, error);
1023                                let result = TaskResult {
1024                                    success: false,
1025                                    data: None,
1026                                    error: Some(error.to_string()),
1027                                    duration: execution_duration,
1028                                    resource_usage: ResourceUsage::default(),
1029                                    metadata: HashMap::new(),
1030                                };
1031                                (TaskStatus::Failed, Some(result))
1032                            }
1033                            Err(_) => {
1034                                tracing::error!("Task {} timed out", task_id);
1035                                let result = TaskResult {
1036                                    success: false,
1037                                    data: None,
1038                                    error: Some("Task timed out".to_string()),
1039                                    duration: execution_duration,
1040                                    resource_usage: ResourceUsage::default(),
1041                                    metadata: HashMap::new(),
1042                                };
1043                                (TaskStatus::TimedOut, Some(result))
1044                            }
1045                        };
1046
1047                        // Update task info
1048                        if let Some(mut task_entry) = tasks.get_mut(&task_id) {
1049                            let task = task_entry.value_mut();
1050                            task.info.status = new_status;
1051                            task.info.completed_at = Some(Time::now());
1052                            task.info.result = result;
1053
1054                            // Send final progress update
1055                            let final_progress = TaskProgress::new(100, "Task completed");
1056                            let _ = progress_sender.send(final_progress);
1057
1058                            tracing::info!("Task {} final status: {:?}", task_id, new_status);
1059                        }
1060
1061                        // Update statistics
1062                        {
1063                            let mut stats_guard = stats.write().await;
1064                            stats_guard.currently_running =
1065                                stats_guard.currently_running.saturating_sub(1);
1066                            match new_status {
1067                                TaskStatus::Completed => stats_guard.total_completed += 1,
1068                                TaskStatus::Failed => stats_guard.total_failed += 1,
1069                                TaskStatus::TimedOut => stats_guard.total_failed += 1,
1070                                _ => {}
1071                            }
1072
1073                            // Update average execution time
1074                            let total_completed =
1075                                stats_guard.total_completed + stats_guard.total_failed;
1076                            if total_completed > 0 {
1077                                stats_guard.avg_execution_time_ms = (stats_guard
1078                                    .avg_execution_time_ms
1079                                    * (total_completed - 1) as f64
1080                                    + execution_duration.as_millis() as f64)
1081                                    / total_completed as f64;
1082                            }
1083
1084                            tracing::debug!(
1085                                "Updated stats: {} completed, {} failed, {} running",
1086                                stats_guard.total_completed,
1087                                stats_guard.total_failed,
1088                                stats_guard.currently_running
1089                            );
1090                        }
1091
1092                        // Publish status change event
1093                        if let Some(event_bus) = &event_bus {
1094                            let event = TaskStatusChangedEvent {
1095                                task_id,
1096                                name: tasks
1097                                    .get(&task_id)
1098                                    .map(|t| t.info.name.clone())
1099                                    .unwrap_or_default(),
1100                                old_status: TaskStatus::Running,
1101                                new_status,
1102                                timestamp: Time::now(),
1103                                source: "task_manager".to_string(),
1104                                metadata: HashMap::new(),
1105                            };
1106                            let _ = event_bus.publish(event).await;
1107                        }
1108                    }
1109
1110                    drop(permit); // Release semaphore
1111                }
1112            } else {
1113                // No pending tasks, wait a bit
1114                tokio::time::sleep(Duration::from_millis(100)).await;
1115            }
1116        }
1117
1118        tracing::info!("Task worker {} stopped", worker_id);
1119    }
1120
1121    async fn publish_status_change_event(
1122        &self,
1123        task_info: &TaskInfo,
1124        old_status: TaskStatus,
1125        new_status: TaskStatus,
1126    ) {
1127        if let Some(event_bus) = &self.event_bus {
1128            let event = TaskStatusChangedEvent {
1129                task_id: task_info.id,
1130                name: task_info.name.clone(),
1131                old_status,
1132                new_status,
1133                timestamp: Time::now(),
1134                source: "task_manager".to_string(),
1135                metadata: task_info.metadata.clone(),
1136            };
1137            let _ = event_bus.publish(event).await;
1138        }
1139    }
1140
1141    async fn stop_workers(&mut self) {
1142        tracing::info!("Stopping task workers");
1143
1144        // Set shutdown flag
1145        *self.shutdown_flag.write().await = true;
1146
1147        // Wait for workers to finish
1148        for handle in self.worker_handles.drain(..) {
1149            let _ = tokio::time::timeout(Duration::from_secs(5), handle).await;
1150        }
1151
1152        tracing::info!("All task workers stopped");
1153    }
1154}
1155
1156#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1157#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1158impl Manager for TaskManager {
1159    fn name(&self) -> &str {
1160        "task_manager"
1161    }
1162
1163    fn id(&self) -> Uuid {
1164        Uuid::new_v4() // Simplified
1165    }
1166
1167    async fn initialize(&mut self) -> Result<()> {
1168        self.state
1169            .set_state(crate::manager::ManagerState::Initializing)
1170            .await;
1171
1172        tracing::info!("Initializing task manager");
1173
1174        // Start task execution workers
1175        self.start_workers().await?;
1176
1177        self.state
1178            .set_state(crate::manager::ManagerState::Running)
1179            .await;
1180        tracing::info!("Task manager initialized successfully");
1181        Ok(())
1182    }
1183
1184    async fn shutdown(&mut self) -> Result<()> {
1185        self.state
1186            .set_state(crate::manager::ManagerState::ShuttingDown)
1187            .await;
1188
1189        tracing::info!("Shutting down task manager");
1190
1191        // Cancel all running tasks
1192        let running_tasks: Vec<Uuid> = self
1193            .tasks
1194            .iter()
1195            .filter_map(|entry| {
1196                let task = entry.value();
1197                if task.info.status == TaskStatus::Running && task.info.cancellable {
1198                    Some(task.info.id)
1199                } else {
1200                    None
1201                }
1202            })
1203            .collect();
1204
1205        for task_id in running_tasks {
1206            let _ = self.cancel_task(task_id).await;
1207        }
1208
1209        // Stop workers
1210        self.stop_workers().await;
1211
1212        // Clean up old tasks
1213        self.cleanup_old_tasks(Duration::from_secs(0)).await;
1214
1215        self.state
1216            .set_state(crate::manager::ManagerState::Shutdown)
1217            .await;
1218        tracing::info!("Task manager shutdown complete");
1219        Ok(())
1220    }
1221
1222    async fn status(&self) -> ManagerStatus {
1223        let mut status = self.state.status().await;
1224        let stats = self.get_stats().await;
1225
1226        status.add_metadata("total_tasks", serde_json::Value::from(stats.total_created));
1227        status.add_metadata(
1228            "completed_tasks",
1229            serde_json::Value::from(stats.total_completed),
1230        );
1231        status.add_metadata("failed_tasks", serde_json::Value::from(stats.total_failed));
1232        status.add_metadata(
1233            "running_tasks",
1234            serde_json::Value::from(stats.currently_running),
1235        );
1236        status.add_metadata(
1237            "pending_tasks",
1238            serde_json::Value::from(stats.currently_pending),
1239        );
1240        status.add_metadata(
1241            "avg_execution_time_ms",
1242            serde_json::Value::from(stats.avg_execution_time_ms),
1243        );
1244
1245        status
1246    }
1247}
1248
1249pub struct TaskBuilder {
1250    name: String,
1251    category: TaskCategory,
1252    priority: TaskPriority,
1253    plugin_id: Option<String>,
1254    dependencies: Vec<Uuid>,
1255    timeout: Duration,
1256    max_retries: u32,
1257    cancellable: bool,
1258    metadata: Metadata,
1259    correlation_id: Option<CorrelationId>,
1260}
1261
1262impl TaskBuilder {
1263    pub fn new(name: impl Into<String>) -> Self {
1264        Self {
1265            name: name.into(),
1266            category: TaskCategory::Core,
1267            priority: TaskPriority::Normal,
1268            plugin_id: None,
1269            dependencies: Vec::new(),
1270            timeout: Duration::from_secs(300), // 5 minutes default
1271            max_retries: 0,
1272            cancellable: true,
1273            metadata: HashMap::new(),
1274            correlation_id: None,
1275        }
1276    }
1277
1278    pub fn category(mut self, category: TaskCategory) -> Self {
1279        self.category = category;
1280        self
1281    }
1282
1283    pub fn priority(mut self, priority: TaskPriority) -> Self {
1284        self.priority = priority;
1285        self
1286    }
1287
1288    pub fn plugin_id(mut self, plugin_id: impl Into<String>) -> Self {
1289        self.plugin_id = Some(plugin_id.into());
1290        self
1291    }
1292
1293    pub fn dependency(mut self, task_id: Uuid) -> Self {
1294        self.dependencies.push(task_id);
1295        self
1296    }
1297
1298    pub fn timeout(mut self, timeout: Duration) -> Self {
1299        self.timeout = timeout;
1300        self
1301    }
1302
1303    pub fn max_retries(mut self, max_retries: u32) -> Self {
1304        self.max_retries = max_retries;
1305        self
1306    }
1307
1308    pub fn cancellable(mut self, cancellable: bool) -> Self {
1309        self.cancellable = cancellable;
1310        self
1311    }
1312
1313    pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
1314        self.metadata.insert(key.into(), value);
1315        self
1316    }
1317
1318    pub fn correlation_id(mut self, correlation_id: CorrelationId) -> Self {
1319        self.correlation_id = Some(correlation_id);
1320        self
1321    }
1322
1323    pub fn build<F, Fut>(self, function: F) -> TaskDefinition
1324    where
1325        F: Fn(TaskContext) -> Fut + Send + Sync + 'static,
1326        Fut: Future<Output = Result<serde_json::Value>> + Send + 'static,
1327    {
1328        let task_function: TaskFunction = Arc::new(move |ctx| Box::pin(function(ctx)));
1329
1330        TaskDefinition {
1331            id: Uuid::new_v4(),
1332            name: self.name,
1333            category: self.category,
1334            priority: self.priority,
1335            plugin_id: self.plugin_id,
1336            dependencies: self.dependencies,
1337            timeout: self.timeout,
1338            max_retries: self.max_retries,
1339            cancellable: self.cancellable,
1340            metadata: self.metadata,
1341            correlation_id: self.correlation_id,
1342            function: task_function,
1343        }
1344    }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::*;
1350
1351    #[tokio::test]
1352    async fn test_task_manager_initialization() {
1353        let config = TaskConfig::default();
1354        let mut manager = TaskManager::new(config);
1355
1356        manager.initialize().await.unwrap();
1357
1358        let status = manager.status().await;
1359        assert_eq!(status.state, crate::manager::ManagerState::Running);
1360
1361        manager.shutdown().await.unwrap();
1362    }
1363
1364    #[tokio::test]
1365    async fn test_task_builder() {
1366        let task = TaskBuilder::new("test_task")
1367            .category(TaskCategory::User)
1368            .priority(TaskPriority::High)
1369            .timeout(Duration::from_secs(60))
1370            .cancellable(true)
1371            .metadata(
1372                "key".to_string(),
1373                serde_json::Value::String("value".to_string()),
1374            )
1375            .build(|_ctx| async { Ok(serde_json::Value::String("completed".to_string())) });
1376
1377        assert_eq!(task.name, "test_task");
1378        assert_eq!(task.category, TaskCategory::User);
1379        assert_eq!(task.priority, TaskPriority::High);
1380        assert_eq!(task.timeout, Duration::from_secs(60));
1381        assert!(task.cancellable);
1382        assert!(task.metadata.contains_key("key"));
1383    }
1384
1385    #[tokio::test]
1386    async fn test_task_submission_and_execution() {
1387        let config = TaskConfig::default();
1388        let mut manager = TaskManager::new(config);
1389        manager.initialize().await.unwrap();
1390
1391        let task = TaskBuilder::new("test_task")
1392            .timeout(Duration::from_secs(10))
1393            .build(|ctx| async move {
1394                ctx.report_percent(50, "Half way done");
1395                tokio::time::sleep(Duration::from_millis(100)).await;
1396                ctx.report_percent(100, "Complete");
1397                Ok(serde_json::Value::String("completed".to_string()))
1398            });
1399
1400        let task_id = manager.submit_task(task).await.unwrap();
1401
1402        let task_info = manager
1403            .wait_for_task(task_id, Some(Duration::from_secs(5)))
1404            .await
1405            .unwrap();
1406        assert_eq!(task_info.name, "test_task");
1407        assert_eq!(task_info.status, TaskStatus::Completed);
1408
1409        manager.shutdown().await.unwrap();
1410    }
1411
1412    #[tokio::test]
1413    async fn test_task_progress() {
1414        let mut progress = TaskProgress::default();
1415        assert_eq!(progress.percent, 0);
1416
1417        progress.set_percent(50);
1418        assert_eq!(progress.percent, 50);
1419
1420        progress.set_message("Half complete");
1421        assert_eq!(progress.message, "Half complete");
1422
1423        let step_progress = TaskProgress::with_steps(5, 10, "Processing step 5");
1424        assert_eq!(step_progress.percent, 50);
1425        assert_eq!(step_progress.current_step, Some(5));
1426        assert_eq!(step_progress.total_steps, Some(10));
1427    }
1428
1429    #[test]
1430    fn test_task_category_display() {
1431        assert_eq!(TaskCategory::Core.to_string(), "core");
1432        assert_eq!(TaskCategory::Plugin.to_string(), "plugin");
1433        assert_eq!(
1434            TaskCategory::Custom("custom".to_string()).to_string(),
1435            "custom"
1436        );
1437    }
1438}