async_runtime.rs 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. // Copyright 2019-2021 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. //! The singleton async runtime used by Tauri and exposed to users.
  5. //!
  6. //! Tauri uses [`tokio`] Runtime to initialize code, such as
  7. //! [`Plugin::initialize`](../plugin/trait.Plugin.html#method.initialize) and [`crate::Builder::setup`] hooks.
  8. //! This module also re-export some common items most developers need from [`tokio`]. If there's
  9. //! one you need isn't here, you could use types in [`tokio`] dierectly.
  10. //! For custom command handlers, it's recommended to use a plain `async fn` command.
  11. use futures_lite::future::FutureExt;
  12. use once_cell::sync::OnceCell;
  13. pub use tokio::{
  14. runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
  15. sync::{
  16. mpsc::{channel, Receiver, Sender},
  17. Mutex, RwLock,
  18. },
  19. task::JoinHandle as TokioJoinHandle,
  20. };
  21. use std::{
  22. future::Future,
  23. pin::Pin,
  24. task::{Context, Poll},
  25. };
  26. static RUNTIME: OnceCell<GlobalRuntime> = OnceCell::new();
  27. struct GlobalRuntime {
  28. runtime: Option<Runtime>,
  29. handle: RuntimeHandle,
  30. }
  31. impl GlobalRuntime {
  32. fn handle(&self) -> RuntimeHandle {
  33. if let Some(r) = &self.runtime {
  34. r.handle()
  35. } else {
  36. self.handle.clone()
  37. }
  38. }
  39. fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
  40. where
  41. F: Future + Send + 'static,
  42. F::Output: Send + 'static,
  43. {
  44. if let Some(r) = &self.runtime {
  45. r.spawn(task)
  46. } else {
  47. self.handle.spawn(task)
  48. }
  49. }
  50. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  51. where
  52. F: FnOnce() -> R + Send + 'static,
  53. R: Send + 'static,
  54. {
  55. if let Some(r) = &self.runtime {
  56. r.spawn_blocking(func)
  57. } else {
  58. self.handle.spawn_blocking(func)
  59. }
  60. }
  61. fn block_on<F: Future>(&self, task: F) -> F::Output {
  62. if let Some(r) = &self.runtime {
  63. r.block_on(task)
  64. } else {
  65. self.handle.block_on(task)
  66. }
  67. }
  68. }
  69. /// A runtime used to execute asynchronous tasks.
  70. pub enum Runtime {
  71. /// The tokio runtime.
  72. Tokio(TokioRuntime),
  73. }
  74. impl Runtime {
  75. /// Gets a reference to the [`TokioRuntime`].
  76. pub fn inner(&self) -> &TokioRuntime {
  77. let Self::Tokio(r) = self;
  78. r
  79. }
  80. /// Returns a handle of the async runtime.
  81. pub fn handle(&self) -> RuntimeHandle {
  82. match self {
  83. Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone()),
  84. }
  85. }
  86. /// Spawns a future onto the runtime.
  87. pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
  88. where
  89. F: Future + Send + 'static,
  90. F::Output: Send + 'static,
  91. {
  92. match self {
  93. Self::Tokio(r) => JoinHandle::Tokio(r.spawn(task)),
  94. }
  95. }
  96. /// Runs the provided function on an executor dedicated to blocking operations.
  97. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  98. where
  99. F: FnOnce() -> R + Send + 'static,
  100. R: Send + 'static,
  101. {
  102. match self {
  103. Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func)),
  104. }
  105. }
  106. /// Runs a future to completion on runtime.
  107. pub fn block_on<F: Future>(&self, task: F) -> F::Output {
  108. match self {
  109. Self::Tokio(r) => r.block_on(task),
  110. }
  111. }
  112. }
  113. /// An owned permission to join on a task (await its termination).
  114. #[derive(Debug)]
  115. pub enum JoinHandle<T> {
  116. /// The tokio JoinHandle.
  117. Tokio(TokioJoinHandle<T>),
  118. }
  119. impl<T> JoinHandle<T> {
  120. /// Gets a reference to the [`TokioJoinHandle`].
  121. pub fn inner(&self) -> &TokioJoinHandle<T> {
  122. let Self::Tokio(t) = self;
  123. t
  124. }
  125. /// Abort the task associated with the handle.
  126. ///
  127. /// Awaiting a cancelled task might complete as usual if the task was
  128. /// already completed at the time it was cancelled, but most likely it
  129. /// will fail with a cancelled `JoinError`.
  130. pub fn abort(&self) {
  131. match self {
  132. Self::Tokio(t) => t.abort(),
  133. }
  134. }
  135. }
  136. impl<T> Future for JoinHandle<T> {
  137. type Output = crate::Result<T>;
  138. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  139. match self.get_mut() {
  140. Self::Tokio(t) => t.poll(cx).map_err(|e| crate::Error::JoinError(Box::new(e))),
  141. }
  142. }
  143. }
  144. /// A handle to the async runtime
  145. #[derive(Clone)]
  146. pub enum RuntimeHandle {
  147. /// The tokio handle.
  148. Tokio(TokioHandle),
  149. }
  150. impl RuntimeHandle {
  151. /// Gets a reference to the [`TokioHandle`].
  152. pub fn inner(&self) -> &TokioHandle {
  153. let Self::Tokio(h) = self;
  154. h
  155. }
  156. /// Runs the provided function on an executor dedicated to blocking operations.
  157. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  158. where
  159. F: FnOnce() -> R + Send + 'static,
  160. R: Send + 'static,
  161. {
  162. match self {
  163. Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func)),
  164. }
  165. }
  166. /// Spawns a future onto the runtime.
  167. pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
  168. where
  169. F: Future + Send + 'static,
  170. F::Output: Send + 'static,
  171. {
  172. match self {
  173. Self::Tokio(h) => JoinHandle::Tokio(h.spawn(task)),
  174. }
  175. }
  176. /// Runs a future to completion on runtime.
  177. pub fn block_on<F: Future>(&self, task: F) -> F::Output {
  178. match self {
  179. Self::Tokio(h) => h.block_on(task),
  180. }
  181. }
  182. }
  183. fn default_runtime() -> GlobalRuntime {
  184. let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
  185. let handle = runtime.handle();
  186. GlobalRuntime {
  187. runtime: Some(runtime),
  188. handle,
  189. }
  190. }
  191. /// Sets the runtime to use to execute asynchronous tasks.
  192. /// For convinience, this method takes a [`TokioHandle`].
  193. /// Note that you cannot drop the underlying [`TokioRuntime`].
  194. ///
  195. /// # Example
  196. ///
  197. /// ```rust
  198. /// #[tokio::main]
  199. /// async fn main() {
  200. /// // perform some async task before initializing the app
  201. /// do_something().await;
  202. /// // share the current runtime with Tauri
  203. /// tauri::async_runtime::set(tokio::runtime::Handle::current());
  204. ///
  205. /// // bootstrap the tauri app...
  206. /// // tauri::Builder::default().run().unwrap();
  207. /// }
  208. ///
  209. /// async fn do_something() {}
  210. /// ```
  211. ///
  212. /// # Panics
  213. ///
  214. /// Panics if the runtime is already set.
  215. pub fn set(handle: TokioHandle) {
  216. RUNTIME
  217. .set(GlobalRuntime {
  218. runtime: None,
  219. handle: RuntimeHandle::Tokio(handle),
  220. })
  221. .unwrap_or_else(|_| panic!("runtime already initialized"))
  222. }
  223. /// Returns a handle of the async runtime.
  224. pub fn handle() -> RuntimeHandle {
  225. let runtime = RUNTIME.get_or_init(default_runtime);
  226. runtime.handle()
  227. }
  228. /// Runs a future to completion on runtime.
  229. pub fn block_on<F: Future>(task: F) -> F::Output {
  230. let runtime = RUNTIME.get_or_init(default_runtime);
  231. runtime.block_on(task)
  232. }
  233. /// Spawns a future onto the runtime.
  234. pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
  235. where
  236. F: Future + Send + 'static,
  237. F::Output: Send + 'static,
  238. {
  239. let runtime = RUNTIME.get_or_init(default_runtime);
  240. runtime.spawn(task)
  241. }
  242. /// Runs the provided function on an executor dedicated to blocking operations.
  243. pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
  244. where
  245. F: FnOnce() -> R + Send + 'static,
  246. R: Send + 'static,
  247. {
  248. let runtime = RUNTIME.get_or_init(default_runtime);
  249. runtime.spawn_blocking(func)
  250. }
  251. pub(crate) fn safe_block_on<F>(task: F) -> F::Output
  252. where
  253. F: Future + Send + 'static,
  254. F::Output: Send + 'static,
  255. {
  256. if tokio::runtime::Handle::try_current().is_ok() {
  257. let (tx, rx) = std::sync::mpsc::sync_channel(1);
  258. spawn(async move {
  259. tx.send(task.await).unwrap();
  260. });
  261. rx.recv().unwrap()
  262. } else {
  263. block_on(task)
  264. }
  265. }
  266. #[cfg(test)]
  267. mod tests {
  268. use super::*;
  269. #[tokio::test]
  270. async fn runtime_spawn() {
  271. let join = spawn(async { 5 });
  272. assert_eq!(join.await.unwrap(), 5);
  273. }
  274. #[test]
  275. fn runtime_block_on() {
  276. assert_eq!(block_on(async { 0 }), 0);
  277. }
  278. #[tokio::test]
  279. async fn handle_spawn() {
  280. let handle = handle();
  281. let join = handle.spawn(async { 5 });
  282. assert_eq!(join.await.unwrap(), 5);
  283. }
  284. #[test]
  285. fn handle_block_on() {
  286. let handle = handle();
  287. assert_eq!(handle.block_on(async { 0 }), 0);
  288. }
  289. #[tokio::test]
  290. async fn handle_abort() {
  291. let handle = handle();
  292. let join = handle.spawn(async { 5 });
  293. join.abort();
  294. if let crate::Error::JoinError(raw_box) = join.await.unwrap_err() {
  295. let raw_error = raw_box.downcast::<tokio::task::JoinError>().unwrap();
  296. assert!(raw_error.is_cancelled());
  297. } else {
  298. panic!("Abort did not result in the expected `JoinError`");
  299. }
  300. }
  301. }