async_runtime.rs 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. // Copyright 2019-2021 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. //! The singleton async runtime used by Tauri and exposed to users.
  5. //!
  6. //! Tauri uses [`tokio`] Runtime to initialize code, such as
  7. //! [`Plugin::initialize`](../plugin/trait.Plugin.html#method.initialize) and [`crate::Builder::setup`] hooks.
  8. //! This module also re-export some common items most developers need from [`tokio`]. If there's
  9. //! one you need isn't here, you could use types in [`tokio`] directly.
  10. //! For custom command handlers, it's recommended to use a plain `async fn` command.
  11. use futures_lite::future::FutureExt;
  12. use once_cell::sync::OnceCell;
  13. pub use tokio::{
  14. runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
  15. sync::{
  16. mpsc::{channel, Receiver, Sender},
  17. Mutex, RwLock,
  18. },
  19. task::JoinHandle as TokioJoinHandle,
  20. };
  21. use std::{
  22. future::Future,
  23. pin::Pin,
  24. task::{Context, Poll},
  25. };
  26. static RUNTIME: OnceCell<GlobalRuntime> = OnceCell::new();
  27. struct GlobalRuntime {
  28. runtime: Option<Runtime>,
  29. handle: RuntimeHandle,
  30. }
  31. impl GlobalRuntime {
  32. fn handle(&self) -> RuntimeHandle {
  33. if let Some(r) = &self.runtime {
  34. r.handle()
  35. } else {
  36. self.handle.clone()
  37. }
  38. }
  39. fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
  40. where
  41. F: Future + Send + 'static,
  42. F::Output: Send + 'static,
  43. {
  44. if let Some(r) = &self.runtime {
  45. r.spawn(task)
  46. } else {
  47. self.handle.spawn(task)
  48. }
  49. }
  50. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  51. where
  52. F: FnOnce() -> R + Send + 'static,
  53. R: Send + 'static,
  54. {
  55. if let Some(r) = &self.runtime {
  56. r.spawn_blocking(func)
  57. } else {
  58. self.handle.spawn_blocking(func)
  59. }
  60. }
  61. fn block_on<F: Future>(&self, task: F) -> F::Output {
  62. if let Some(r) = &self.runtime {
  63. r.block_on(task)
  64. } else {
  65. self.handle.block_on(task)
  66. }
  67. }
  68. }
  69. /// A runtime used to execute asynchronous tasks.
  70. pub enum Runtime {
  71. /// The tokio runtime.
  72. Tokio(TokioRuntime),
  73. }
  74. impl Runtime {
  75. /// Gets a reference to the [`TokioRuntime`].
  76. pub fn inner(&self) -> &TokioRuntime {
  77. let Self::Tokio(r) = self;
  78. r
  79. }
  80. /// Returns a handle of the async runtime.
  81. pub fn handle(&self) -> RuntimeHandle {
  82. match self {
  83. Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone()),
  84. }
  85. }
  86. /// Spawns a future onto the runtime.
  87. pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
  88. where
  89. F: Future + Send + 'static,
  90. F::Output: Send + 'static,
  91. {
  92. match self {
  93. Self::Tokio(r) => {
  94. let _guard = r.enter();
  95. JoinHandle::Tokio(tokio::spawn(task))
  96. }
  97. }
  98. }
  99. /// Runs the provided function on an executor dedicated to blocking operations.
  100. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  101. where
  102. F: FnOnce() -> R + Send + 'static,
  103. R: Send + 'static,
  104. {
  105. match self {
  106. Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func)),
  107. }
  108. }
  109. /// Runs a future to completion on runtime.
  110. pub fn block_on<F: Future>(&self, task: F) -> F::Output {
  111. match self {
  112. Self::Tokio(r) => r.block_on(task),
  113. }
  114. }
  115. }
  116. /// An owned permission to join on a task (await its termination).
  117. #[derive(Debug)]
  118. pub enum JoinHandle<T> {
  119. /// The tokio JoinHandle.
  120. Tokio(TokioJoinHandle<T>),
  121. }
  122. impl<T> JoinHandle<T> {
  123. /// Gets a reference to the [`TokioJoinHandle`].
  124. pub fn inner(&self) -> &TokioJoinHandle<T> {
  125. let Self::Tokio(t) = self;
  126. t
  127. }
  128. /// Abort the task associated with the handle.
  129. ///
  130. /// Awaiting a cancelled task might complete as usual if the task was
  131. /// already completed at the time it was cancelled, but most likely it
  132. /// will fail with a cancelled `JoinError`.
  133. pub fn abort(&self) {
  134. match self {
  135. Self::Tokio(t) => t.abort(),
  136. }
  137. }
  138. }
  139. impl<T> Future for JoinHandle<T> {
  140. type Output = crate::Result<T>;
  141. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  142. match self.get_mut() {
  143. Self::Tokio(t) => t.poll(cx).map_err(Into::into),
  144. }
  145. }
  146. }
  147. /// A handle to the async runtime
  148. #[derive(Clone)]
  149. pub enum RuntimeHandle {
  150. /// The tokio handle.
  151. Tokio(TokioHandle),
  152. }
  153. impl RuntimeHandle {
  154. /// Gets a reference to the [`TokioHandle`].
  155. pub fn inner(&self) -> &TokioHandle {
  156. let Self::Tokio(h) = self;
  157. h
  158. }
  159. /// Runs the provided function on an executor dedicated to blocking operations.
  160. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  161. where
  162. F: FnOnce() -> R + Send + 'static,
  163. R: Send + 'static,
  164. {
  165. match self {
  166. Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func)),
  167. }
  168. }
  169. /// Spawns a future onto the runtime.
  170. pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
  171. where
  172. F: Future + Send + 'static,
  173. F::Output: Send + 'static,
  174. {
  175. match self {
  176. Self::Tokio(h) => {
  177. let _guard = h.enter();
  178. JoinHandle::Tokio(tokio::spawn(task))
  179. }
  180. }
  181. }
  182. /// Runs a future to completion on runtime.
  183. pub fn block_on<F: Future>(&self, task: F) -> F::Output {
  184. match self {
  185. Self::Tokio(h) => h.block_on(task),
  186. }
  187. }
  188. }
  189. fn default_runtime() -> GlobalRuntime {
  190. let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
  191. let handle = runtime.handle();
  192. GlobalRuntime {
  193. runtime: Some(runtime),
  194. handle,
  195. }
  196. }
  197. /// Sets the runtime to use to execute asynchronous tasks.
  198. /// For convenience, this method takes a [`TokioHandle`].
  199. /// Note that you cannot drop the underlying [`TokioRuntime`].
  200. ///
  201. /// # Examples
  202. ///
  203. /// ```rust
  204. /// #[tokio::main]
  205. /// async fn main() {
  206. /// // perform some async task before initializing the app
  207. /// do_something().await;
  208. /// // share the current runtime with Tauri
  209. /// tauri::async_runtime::set(tokio::runtime::Handle::current());
  210. ///
  211. /// // bootstrap the tauri app...
  212. /// // tauri::Builder::default().run().unwrap();
  213. /// }
  214. ///
  215. /// async fn do_something() {}
  216. /// ```
  217. ///
  218. /// # Panics
  219. ///
  220. /// Panics if the runtime is already set.
  221. pub fn set(handle: TokioHandle) {
  222. RUNTIME
  223. .set(GlobalRuntime {
  224. runtime: None,
  225. handle: RuntimeHandle::Tokio(handle),
  226. })
  227. .unwrap_or_else(|_| panic!("runtime already initialized"))
  228. }
  229. /// Returns a handle of the async runtime.
  230. pub fn handle() -> RuntimeHandle {
  231. let runtime = RUNTIME.get_or_init(default_runtime);
  232. runtime.handle()
  233. }
  234. /// Runs a future to completion on runtime.
  235. pub fn block_on<F: Future>(task: F) -> F::Output {
  236. let runtime = RUNTIME.get_or_init(default_runtime);
  237. runtime.block_on(task)
  238. }
  239. /// Spawns a future onto the runtime.
  240. pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
  241. where
  242. F: Future + Send + 'static,
  243. F::Output: Send + 'static,
  244. {
  245. let runtime = RUNTIME.get_or_init(default_runtime);
  246. runtime.spawn(task)
  247. }
  248. /// Runs the provided function on an executor dedicated to blocking operations.
  249. pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
  250. where
  251. F: FnOnce() -> R + Send + 'static,
  252. R: Send + 'static,
  253. {
  254. let runtime = RUNTIME.get_or_init(default_runtime);
  255. runtime.spawn_blocking(func)
  256. }
  257. #[allow(dead_code)]
  258. pub(crate) fn safe_block_on<F>(task: F) -> F::Output
  259. where
  260. F: Future + Send + 'static,
  261. F::Output: Send + 'static,
  262. {
  263. if let Ok(handle) = tokio::runtime::Handle::try_current() {
  264. let (tx, rx) = std::sync::mpsc::sync_channel(1);
  265. let handle_ = handle.clone();
  266. handle.spawn_blocking(move || {
  267. tx.send(handle_.block_on(task)).unwrap();
  268. });
  269. rx.recv().unwrap()
  270. } else {
  271. block_on(task)
  272. }
  273. }
  274. #[cfg(test)]
  275. mod tests {
  276. use super::*;
  277. #[tokio::test]
  278. async fn runtime_spawn() {
  279. let join = spawn(async { 5 });
  280. assert_eq!(join.await.unwrap(), 5);
  281. }
  282. #[test]
  283. fn runtime_block_on() {
  284. assert_eq!(block_on(async { 0 }), 0);
  285. }
  286. #[tokio::test]
  287. async fn handle_spawn() {
  288. let handle = handle();
  289. let join = handle.spawn(async { 5 });
  290. assert_eq!(join.await.unwrap(), 5);
  291. }
  292. #[test]
  293. fn handle_block_on() {
  294. let handle = handle();
  295. assert_eq!(handle.block_on(async { 0 }), 0);
  296. }
  297. #[tokio::test]
  298. async fn handle_abort() {
  299. let handle = handle();
  300. let join = handle.spawn(async {
  301. // Here we sleep 1 second to ensure this task to be uncompleted when abort() invoked.
  302. tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
  303. 5
  304. });
  305. join.abort();
  306. if let crate::Error::JoinError(raw_error) = join.await.unwrap_err() {
  307. assert!(raw_error.is_cancelled());
  308. } else {
  309. panic!("Abort did not result in the expected `JoinError`");
  310. }
  311. }
  312. }