renegade_sdk/renegade_wallet_client/websocket/
task_waiter.rs1use std::{
5 collections::HashMap,
6 future::Future,
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10 time::Duration,
11};
12
13use futures_util::{FutureExt, Stream, future::BoxFuture};
14use renegade_external_api::types::{ApiTask, TaskUpdateMessage};
15use tokio::sync::{
16 RwLock,
17 oneshot::{self, Receiver as OneshotReceiver, Sender as OneshotSender},
18};
19use tokio_stream::StreamExt;
20use tracing::error;
21use uuid::Uuid;
22
23use crate::RenegadeClientError;
24
25pub const DEFAULT_TASK_TIMEOUT: Duration = Duration::from_secs(60);
31
32type TaskNotificationTx = OneshotSender<TaskStatusNotification>;
38type TaskNotificationRx = OneshotReceiver<TaskStatusNotification>;
40
41type NotificationMap = Arc<RwLock<HashMap<Uuid, TaskNotificationTx>>>;
43
44type TaskWaiterFuture = BoxFuture<'static, Result<(), RenegadeClientError>>;
46
47pub fn create_notification_channel() -> (TaskNotificationTx, TaskNotificationRx) {
53 oneshot::channel()
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum TaskStatusNotification {
63 Success,
65 Failed {
67 error: String,
69 },
70}
71
72impl TaskStatusNotification {
73 pub fn into_result(self, task_id: Uuid) -> Result<(), RenegadeClientError> {
75 match self {
76 Self::Success => Ok(()),
77 Self::Failed { error } => Err(RenegadeClientError::task(task_id, error)),
78 }
79 }
80}
81
82#[derive(Clone)]
88pub struct TaskWaiterManager {
89 notifications: NotificationMap,
91}
92
93impl TaskWaiterManager {
94 pub fn new<S>(tasks_topic: S) -> Self
96 where
97 S: Stream<Item = TaskUpdateMessage> + Unpin + Send + 'static,
98 {
99 let this = Self { notifications: Arc::new(RwLock::new(HashMap::new())) };
100
101 let this_clone = this.clone();
102 tokio::spawn(async move { this_clone.watch_task_updates(tasks_topic).await });
103
104 this
105 }
106
107 pub async fn create_task_waiter(&self, task_id: Uuid, timeout: Duration) -> TaskWaiter {
109 let (tx, rx) = create_notification_channel();
110 self.notifications.write().await.insert(task_id, tx);
111 TaskWaiter::new(task_id, rx, timeout)
112 }
113
114 async fn watch_task_updates<S>(&self, mut tasks_topic: S)
118 where
119 S: Stream<Item = TaskUpdateMessage> + Unpin,
120 {
121 while let Some(message) = tasks_topic.next().await {
122 self.handle_task_update(message.task).await;
123 }
124
125 error!("Task update stream closed");
126 }
127
128 async fn handle_task_update(&self, task: ApiTask) {
131 let ApiTask { id, state, .. } = task;
132 let state = state.to_lowercase();
133 if state.contains("completed") {
134 self.handle_completed_task(id).await;
135 } else if state.contains("failed") {
136 self.handle_failed_task(id, state).await;
137 }
138 }
139
140 async fn handle_completed_task(&self, task_id: Uuid) {
142 let mut notifications = self.notifications.write().await;
143
144 let tx = match notifications.remove(&task_id) {
145 Some(tx) => tx,
146 None => return,
147 };
148
149 let _ = tx.send(TaskStatusNotification::Success);
151 }
152
153 async fn handle_failed_task(&self, task_id: Uuid, error: String) {
155 let mut notifications = self.notifications.write().await;
156
157 let tx = match notifications.remove(&task_id) {
158 Some(tx) => tx,
159 None => return,
160 };
161
162 let _ = tx.send(TaskStatusNotification::Failed { error });
164 }
165}
166
167pub struct TaskWaiter {
174 task_id: Uuid,
176 notification_rx: Option<TaskNotificationRx>,
179 timeout: Duration,
181 fut: Option<TaskWaiterFuture>,
183}
184
185impl TaskWaiter {
186 pub fn new(task_id: Uuid, notification_rx: TaskNotificationRx, timeout: Duration) -> Self {
188 Self { task_id, notification_rx: Some(notification_rx), timeout, fut: None }
189 }
190
191 async fn watch_task(
193 task_id: Uuid,
194 notification_rx: TaskNotificationRx,
195 timeout: Duration,
196 ) -> Result<(), RenegadeClientError> {
197 let timeout = tokio::time::timeout(timeout, notification_rx);
199 let notification = timeout
200 .await
201 .map_err(|_| RenegadeClientError::task(task_id, "Task timed out"))?
202 .map_err(|_| RenegadeClientError::task(task_id, "Task waiter closed"))?;
203
204 notification.into_result(task_id)
205 }
206}
207
208impl Future for TaskWaiter {
209 type Output = Result<(), RenegadeClientError>;
210
211 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
212 let this = self.get_mut();
213 if this.fut.is_none() {
214 let notification_rx = this.notification_rx.take().unwrap();
215 let fut = Self::watch_task(this.task_id, notification_rx, this.timeout).boxed();
216 this.fut = Some(fut);
217 }
218
219 this.fut.as_mut().unwrap().as_mut().poll(cx)
220 }
221}