renegade_sdk/renegade_wallet_client/websocket/
task_waiter.rs

1//! A task waiter is a structure that waits for a task to complete then
2//! transforms the status into a result
3
4use std::{
5    collections::HashMap,
6    future::Future,
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use futures_util::{FutureExt, Stream, future::BoxFuture};
14use renegade_external_api::types::{ApiTask, TaskUpdateMessage};
15use tokio::sync::{
16    RwLock,
17    oneshot::{self, Receiver as OneshotReceiver, Sender as OneshotSender},
18};
19use tokio_stream::StreamExt;
20use tracing::error;
21use uuid::Uuid;
22
23use crate::RenegadeClientError;
24
25// -------------
26// | Constants |
27// -------------
28
29/// The timeout for a task to complete
30pub const DEFAULT_TASK_TIMEOUT: Duration = Duration::from_secs(60);
31
32// ----------------
33// | Type Aliases |
34// ----------------
35
36/// A oneshot channel on which to send task status notifications
37type TaskNotificationTx = OneshotSender<TaskStatusNotification>;
38/// A oneshot channel on which to receive task status notifications
39type TaskNotificationRx = OneshotReceiver<TaskStatusNotification>;
40
41/// A map of task IDs to their corresponding notification channels
42type NotificationMap = Arc<RwLock<HashMap<Uuid, TaskNotificationTx>>>;
43
44/// The future type for a task waiter
45type TaskWaiterFuture = BoxFuture<'static, Result<(), RenegadeClientError>>;
46
47// -------------------
48// | Channel Helpers |
49// -------------------
50
51/// Create a new notification channel
52pub fn create_notification_channel() -> (TaskNotificationTx, TaskNotificationRx) {
53    oneshot::channel()
54}
55
56// ---------
57// | Types |
58// ---------
59
60/// A task status notification
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum TaskStatusNotification {
63    /// A task has been completed
64    Success,
65    /// A task has failed
66    Failed {
67        /// The error message
68        error: String,
69    },
70}
71
72impl TaskStatusNotification {
73    /// Convert the task status into a Result<(), RenegadeClientError>
74    pub fn into_result(self, task_id: Uuid) -> Result<(), RenegadeClientError> {
75        match self {
76            Self::Success => Ok(()),
77            Self::Failed { error } => Err(RenegadeClientError::task(task_id, error)),
78        }
79    }
80}
81
82// -----------------------
83// | Task Waiter Manager |
84// -----------------------
85
86/// Manages sending notifications to task waiters
87#[derive(Clone)]
88pub struct TaskWaiterManager {
89    /// The notification map
90    notifications: NotificationMap,
91}
92
93impl TaskWaiterManager {
94    /// Create a new task waiter manager
95    pub fn new<S>(tasks_topic: S) -> Self
96    where
97        S: Stream<Item = TaskUpdateMessage> + Unpin + Send + 'static,
98    {
99        let this = Self { notifications: Arc::new(RwLock::new(HashMap::new())) };
100
101        let this_clone = this.clone();
102        tokio::spawn(async move { this_clone.watch_task_updates(tasks_topic).await });
103
104        this
105    }
106
107    /// Create a task waiter which can be awaited until the given task completes
108    pub async fn create_task_waiter(&self, task_id: Uuid, timeout: Duration) -> TaskWaiter {
109        let (tx, rx) = create_notification_channel();
110        self.notifications.write().await.insert(task_id, tx);
111        TaskWaiter::new(task_id, rx, timeout)
112    }
113
114    /// A persistent loop which watches for task updates and forward the task
115    /// status notification to the appropriate receiver if the task's status
116    /// is being awaited
117    async fn watch_task_updates<S>(&self, mut tasks_topic: S)
118    where
119        S: Stream<Item = TaskUpdateMessage> + Unpin,
120    {
121        while let Some(message) = tasks_topic.next().await {
122            self.handle_task_update(message.task).await;
123        }
124
125        error!("Task update stream closed");
126    }
127
128    /// Handle a task update, forwarding the task status notification to the
129    /// appropriate receiver if the task's status is being awaited
130    async fn handle_task_update(&self, task: ApiTask) {
131        let ApiTask { id, state, .. } = task;
132        let state = state.to_lowercase();
133        if state.contains("completed") {
134            self.handle_completed_task(id).await;
135        } else if state.contains("failed") {
136            self.handle_failed_task(id, state).await;
137        }
138    }
139
140    /// Handle a completed task, forwarding a success notification
141    async fn handle_completed_task(&self, task_id: Uuid) {
142        let mut notifications = self.notifications.write().await;
143
144        let tx = match notifications.remove(&task_id) {
145            Some(tx) => tx,
146            None => return,
147        };
148
149        // We explicitly ignore errors here in case the receiver is dropped
150        let _ = tx.send(TaskStatusNotification::Success);
151    }
152
153    /// Handle a failed task, forwarding a failure notification
154    async fn handle_failed_task(&self, task_id: Uuid, error: String) {
155        let mut notifications = self.notifications.write().await;
156
157        let tx = match notifications.remove(&task_id) {
158            Some(tx) => tx,
159            None => return,
160        };
161
162        // We explicitly ignore errors here in case the receiver is dropped
163        let _ = tx.send(TaskStatusNotification::Failed { error });
164    }
165}
166
167// ---------------
168// | Task Waiter |
169// ---------------
170
171/// A thin wrapper around a notification channel that waits for a task to
172/// complete then transforms the status into a result
173pub struct TaskWaiter {
174    /// The task ID
175    task_id: Uuid,
176    /// The task status notification receiver.
177    /// This will be `taken` once the task waiter is first polled.
178    notification_rx: Option<TaskNotificationRx>,
179    /// The duration to wait for the task to complete before timing out
180    timeout: Duration,
181    /// The underlying future that waits for the task to complete
182    fut: Option<TaskWaiterFuture>,
183}
184
185impl TaskWaiter {
186    /// Create a new task waiter
187    pub fn new(task_id: Uuid, notification_rx: TaskNotificationRx, timeout: Duration) -> Self {
188        Self { task_id, notification_rx: Some(notification_rx), timeout, fut: None }
189    }
190
191    /// Watch a task until it terminates as a success or failure
192    async fn watch_task(
193        task_id: Uuid,
194        notification_rx: TaskNotificationRx,
195        timeout: Duration,
196    ) -> Result<(), RenegadeClientError> {
197        // Register a notification channel for the task and await
198        let timeout = tokio::time::timeout(timeout, notification_rx);
199        let notification = timeout
200            .await
201            .map_err(|_| RenegadeClientError::task(task_id, "Task timed out"))?
202            .map_err(|_| RenegadeClientError::task(task_id, "Task waiter closed"))?;
203
204        notification.into_result(task_id)
205    }
206}
207
208impl Future for TaskWaiter {
209    type Output = Result<(), RenegadeClientError>;
210
211    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
212        let this = self.get_mut();
213        if this.fut.is_none() {
214            let notification_rx = this.notification_rx.take().unwrap();
215            let fut = Self::watch_task(this.task_id, notification_rx, this.timeout).boxed();
216            this.fut = Some(fut);
217        }
218
219        this.fut.as_mut().unwrap().as_mut().poll(cx)
220    }
221}