async_runtime.rs 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. // Copyright 2019-2024 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. //! The singleton async runtime used by Tauri and exposed to users.
  5. //!
  6. //! Tauri uses [`tokio`] Runtime to initialize code, such as
  7. //! [`Plugin::initialize`](../plugin/trait.Plugin.html#method.initialize) and [`crate::Builder::setup`] hooks.
  8. //! This module also re-export some common items most developers need from [`tokio`]. If there's
  9. //! one you need isn't here, you could use types in [`tokio`] directly.
  10. //! For custom command handlers, it's recommended to use a plain `async fn` command.
  11. pub use tokio::{
  12. runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
  13. sync::{
  14. mpsc::{channel, Receiver, Sender},
  15. Mutex, RwLock,
  16. },
  17. task::JoinHandle as TokioJoinHandle,
  18. };
  19. use std::{
  20. future::Future,
  21. pin::Pin,
  22. sync::OnceLock,
  23. task::{Context, Poll},
  24. };
  25. static RUNTIME: OnceLock<GlobalRuntime> = OnceLock::new();
  26. struct GlobalRuntime {
  27. runtime: Option<Runtime>,
  28. handle: RuntimeHandle,
  29. }
  30. impl GlobalRuntime {
  31. fn handle(&self) -> RuntimeHandle {
  32. if let Some(r) = &self.runtime {
  33. r.handle()
  34. } else {
  35. self.handle.clone()
  36. }
  37. }
  38. fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
  39. where
  40. F: Future + Send + 'static,
  41. F::Output: Send + 'static,
  42. {
  43. if let Some(r) = &self.runtime {
  44. r.spawn(task)
  45. } else {
  46. self.handle.spawn(task)
  47. }
  48. }
  49. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  50. where
  51. F: FnOnce() -> R + Send + 'static,
  52. R: Send + 'static,
  53. {
  54. if let Some(r) = &self.runtime {
  55. r.spawn_blocking(func)
  56. } else {
  57. self.handle.spawn_blocking(func)
  58. }
  59. }
  60. fn block_on<F: Future>(&self, task: F) -> F::Output {
  61. if let Some(r) = &self.runtime {
  62. r.block_on(task)
  63. } else {
  64. self.handle.block_on(task)
  65. }
  66. }
  67. }
  68. /// A runtime used to execute asynchronous tasks.
  69. pub enum Runtime {
  70. /// The tokio runtime.
  71. Tokio(TokioRuntime),
  72. }
  73. impl Runtime {
  74. /// Gets a reference to the [`TokioRuntime`].
  75. pub fn inner(&self) -> &TokioRuntime {
  76. let Self::Tokio(r) = self;
  77. r
  78. }
  79. /// Returns a handle of the async runtime.
  80. pub fn handle(&self) -> RuntimeHandle {
  81. match self {
  82. Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone()),
  83. }
  84. }
  85. /// Spawns a future onto the runtime.
  86. pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
  87. where
  88. F: Future + Send + 'static,
  89. F::Output: Send + 'static,
  90. {
  91. match self {
  92. Self::Tokio(r) => {
  93. let _guard = r.enter();
  94. JoinHandle::Tokio(tokio::spawn(task))
  95. }
  96. }
  97. }
  98. /// Runs the provided function on an executor dedicated to blocking operations.
  99. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  100. where
  101. F: FnOnce() -> R + Send + 'static,
  102. R: Send + 'static,
  103. {
  104. match self {
  105. Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func)),
  106. }
  107. }
  108. /// Runs a future to completion on runtime.
  109. pub fn block_on<F: Future>(&self, task: F) -> F::Output {
  110. match self {
  111. Self::Tokio(r) => r.block_on(task),
  112. }
  113. }
  114. }
  115. /// An owned permission to join on a task (await its termination).
  116. #[derive(Debug)]
  117. pub enum JoinHandle<T> {
  118. /// The tokio JoinHandle.
  119. Tokio(TokioJoinHandle<T>),
  120. }
  121. impl<T> JoinHandle<T> {
  122. /// Gets a reference to the [`TokioJoinHandle`].
  123. pub fn inner(&self) -> &TokioJoinHandle<T> {
  124. let Self::Tokio(t) = self;
  125. t
  126. }
  127. /// Abort the task associated with the handle.
  128. ///
  129. /// Awaiting a cancelled task might complete as usual if the task was
  130. /// already completed at the time it was cancelled, but most likely it
  131. /// will fail with a cancelled `JoinError`.
  132. pub fn abort(&self) {
  133. match self {
  134. Self::Tokio(t) => t.abort(),
  135. }
  136. }
  137. }
  138. impl<T> Future for JoinHandle<T> {
  139. type Output = crate::Result<T>;
  140. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  141. match self.get_mut() {
  142. Self::Tokio(t) => Pin::new(t).poll(cx).map_err(Into::into),
  143. }
  144. }
  145. }
  146. /// A handle to the async runtime
  147. #[derive(Clone)]
  148. pub enum RuntimeHandle {
  149. /// The tokio handle.
  150. Tokio(TokioHandle),
  151. }
  152. impl RuntimeHandle {
  153. /// Gets a reference to the [`TokioHandle`].
  154. pub fn inner(&self) -> &TokioHandle {
  155. let Self::Tokio(h) = self;
  156. h
  157. }
  158. /// Runs the provided function on an executor dedicated to blocking operations.
  159. pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
  160. where
  161. F: FnOnce() -> R + Send + 'static,
  162. R: Send + 'static,
  163. {
  164. match self {
  165. Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func)),
  166. }
  167. }
  168. /// Spawns a future onto the runtime.
  169. pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
  170. where
  171. F: Future + Send + 'static,
  172. F::Output: Send + 'static,
  173. {
  174. match self {
  175. Self::Tokio(h) => {
  176. let _guard = h.enter();
  177. JoinHandle::Tokio(tokio::spawn(task))
  178. }
  179. }
  180. }
  181. /// Runs a future to completion on runtime.
  182. pub fn block_on<F: Future>(&self, task: F) -> F::Output {
  183. match self {
  184. Self::Tokio(h) => h.block_on(task),
  185. }
  186. }
  187. }
  188. fn default_runtime() -> GlobalRuntime {
  189. let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
  190. let handle = runtime.handle();
  191. GlobalRuntime {
  192. runtime: Some(runtime),
  193. handle,
  194. }
  195. }
  196. /// Sets the runtime to use to execute asynchronous tasks.
  197. /// For convenience, this method takes a [`TokioHandle`].
  198. /// Note that you cannot drop the underlying [`TokioRuntime`].
  199. ///
  200. /// # Examples
  201. ///
  202. /// ```rust
  203. /// #[tokio::main]
  204. /// async fn main() {
  205. /// // perform some async task before initializing the app
  206. /// do_something().await;
  207. /// // share the current runtime with Tauri
  208. /// tauri::async_runtime::set(tokio::runtime::Handle::current());
  209. ///
  210. /// // bootstrap the tauri app...
  211. /// // tauri::Builder::default().run().unwrap();
  212. /// }
  213. ///
  214. /// async fn do_something() {}
  215. /// ```
  216. ///
  217. /// # Panics
  218. ///
  219. /// Panics if the runtime is already set.
  220. pub fn set(handle: TokioHandle) {
  221. RUNTIME
  222. .set(GlobalRuntime {
  223. runtime: None,
  224. handle: RuntimeHandle::Tokio(handle),
  225. })
  226. .unwrap_or_else(|_| panic!("runtime already initialized"))
  227. }
  228. /// Returns a handle of the async runtime.
  229. pub fn handle() -> RuntimeHandle {
  230. let runtime = RUNTIME.get_or_init(default_runtime);
  231. runtime.handle()
  232. }
  233. /// Runs a future to completion on runtime.
  234. pub fn block_on<F: Future>(task: F) -> F::Output {
  235. let runtime = RUNTIME.get_or_init(default_runtime);
  236. runtime.block_on(task)
  237. }
  238. /// Spawns a future onto the runtime.
  239. pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
  240. where
  241. F: Future + Send + 'static,
  242. F::Output: Send + 'static,
  243. {
  244. let runtime = RUNTIME.get_or_init(default_runtime);
  245. runtime.spawn(task)
  246. }
  247. /// Runs the provided function on an executor dedicated to blocking operations.
  248. pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
  249. where
  250. F: FnOnce() -> R + Send + 'static,
  251. R: Send + 'static,
  252. {
  253. let runtime = RUNTIME.get_or_init(default_runtime);
  254. runtime.spawn_blocking(func)
  255. }
  256. #[allow(dead_code)]
  257. pub(crate) fn safe_block_on<F>(task: F) -> F::Output
  258. where
  259. F: Future + Send + 'static,
  260. F::Output: Send + 'static,
  261. {
  262. if let Ok(handle) = tokio::runtime::Handle::try_current() {
  263. let (tx, rx) = std::sync::mpsc::sync_channel(1);
  264. let handle_ = handle.clone();
  265. handle.spawn_blocking(move || {
  266. tx.send(handle_.block_on(task)).unwrap();
  267. });
  268. rx.recv().unwrap()
  269. } else {
  270. block_on(task)
  271. }
  272. }
  273. #[cfg(test)]
  274. mod tests {
  275. use super::*;
  276. #[tokio::test]
  277. async fn runtime_spawn() {
  278. let join = spawn(async { 5 });
  279. assert_eq!(join.await.unwrap(), 5);
  280. }
  281. #[test]
  282. fn runtime_block_on() {
  283. assert_eq!(block_on(async { 0 }), 0);
  284. }
  285. #[tokio::test]
  286. async fn handle_spawn() {
  287. let handle = handle();
  288. let join = handle.spawn(async { 5 });
  289. assert_eq!(join.await.unwrap(), 5);
  290. }
  291. #[test]
  292. fn handle_block_on() {
  293. let handle = handle();
  294. assert_eq!(handle.block_on(async { 0 }), 0);
  295. }
  296. #[tokio::test]
  297. async fn handle_abort() {
  298. let handle = handle();
  299. let join = handle.spawn(async {
  300. // Here we sleep 1 second to ensure this task to be uncompleted when abort() invoked.
  301. tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
  302. 5
  303. });
  304. join.abort();
  305. if let crate::Error::JoinError(raw_error) = join.await.unwrap_err() {
  306. assert!(raw_error.is_cancelled());
  307. } else {
  308. panic!("Abort did not result in the expected `JoinError`");
  309. }
  310. }
  311. }