Browse Source

feat(core): allow running along another tokio runtime, closes #2838 (#2973)

Lucas Fernandes Nogueira 3 years ago
parent
commit
a3537078dd

+ 5 - 0
.changes/async-runtime-refactor.md

@@ -0,0 +1,5 @@
+---
+"tauri": patch
+---
+
+**Breaking change:** Refactored the types returned from the `async_runtime` module.

+ 5 - 0
.changes/async-runtime-set.md

@@ -0,0 +1,5 @@
+---
+"tauri": patch
+---
+
+Added `tauri::async_runtime::set` method, allowing to share your tokio runtime with Tauri.

+ 5 - 0
.changes/async-runtime-spawn-blocking.md

@@ -0,0 +1,5 @@
+---
+"tauri": patch
+---
+
+Added `tauri::async_runtime::spawn_blocking` API.

+ 5 - 0
.changes/fix-block-on-runtime.md

@@ -0,0 +1,5 @@
+---
+"tauri": patch
+---
+
+Avoid `async_runtime::block_on` panics when used along another tokio runtime.

+ 2 - 2
core/tauri/src/api/process/command.rs

@@ -312,7 +312,7 @@ impl Command {
   /// Stdin, stdout and stderr are ignored.
   pub fn status(self) -> crate::api::Result<ExitStatus> {
     let (mut rx, _child) = self.spawn()?;
-    let code = crate::async_runtime::block_on(async move {
+    let code = crate::async_runtime::safe_block_on(async move {
       let mut code = None;
       #[allow(clippy::collapsible_match)]
       while let Some(event) = rx.recv().await {
@@ -330,7 +330,7 @@ impl Command {
   pub fn output(self) -> crate::api::Result<Output> {
     let (mut rx, _child) = self.spawn()?;
 
-    let output = crate::async_runtime::block_on(async move {
+    let output = crate::async_runtime::safe_block_on(async move {
       let mut code = None;
       let mut stdout = String::new();
       let mut stderr = String::new();

+ 1 - 3
core/tauri/src/app.rs

@@ -1057,9 +1057,7 @@ impl<R: Runtime> Builder<R> {
               }
             };
             let listener = listener.clone();
-            crate::async_runtime::spawn(async move {
-              listener.lock().unwrap()(&app_handle, event);
-            });
+            listener.lock().unwrap()(&app_handle, event);
           });
       }
     }

+ 235 - 32
core/tauri/src/async_runtime.rs

@@ -12,9 +12,8 @@
 
 use futures_lite::future::FutureExt;
 use once_cell::sync::OnceCell;
-use tokio::runtime::Runtime;
 pub use tokio::{
-  runtime::Handle,
+  runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
   sync::{
     mpsc::{channel, Receiver, Sender},
     Mutex, RwLock,
@@ -23,74 +22,240 @@ pub use tokio::{
 };
 
 use std::{
-  fmt,
   future::Future,
   pin::Pin,
   task::{Context, Poll},
 };
 
-static RUNTIME: OnceCell<Runtime> = OnceCell::new();
+static RUNTIME: OnceCell<GlobalRuntime> = OnceCell::new();
+
+struct GlobalRuntime {
+  runtime: Option<Runtime>,
+  handle: RuntimeHandle,
+}
+
+impl GlobalRuntime {
+  fn handle(&self) -> RuntimeHandle {
+    if let Some(r) = &self.runtime {
+      r.handle()
+    } else {
+      self.handle.clone()
+    }
+  }
+
+  fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
+  where
+    F: Future + Send + 'static,
+    F::Output: Send + 'static,
+  {
+    if let Some(r) = &self.runtime {
+      r.spawn(task)
+    } else {
+      self.handle.spawn(task)
+    }
+  }
+
+  pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
+  where
+    F: FnOnce() -> R + Send + 'static,
+    R: Send + 'static,
+  {
+    if let Some(r) = &self.runtime {
+      r.spawn_blocking(func)
+    } else {
+      self.handle.spawn_blocking(func)
+    }
+  }
+
+  fn block_on<F: Future>(&self, task: F) -> F::Output {
+    if let Some(r) = &self.runtime {
+      r.block_on(task)
+    } else {
+      self.handle.block_on(task)
+    }
+  }
+}
+
+/// A runtime used to execute asynchronous tasks.
+pub enum Runtime {
+  /// The tokio runtime.
+  Tokio(TokioRuntime),
+}
+
+impl Runtime {
+  /// Gets a reference to the [`TokioRuntime`].
+  pub fn inner(&self) -> &TokioRuntime {
+    let Self::Tokio(r) = self;
+    r
+  }
+
+  /// Returns a handle of the async runtime.
+  pub fn handle(&self) -> RuntimeHandle {
+    match self {
+      Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone()),
+    }
+  }
+
+  /// Spawns a future onto the runtime.
+  pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
+  where
+    F: Future + Send + 'static,
+    F::Output: Send + 'static,
+  {
+    match self {
+      Self::Tokio(r) => JoinHandle::Tokio(r.spawn(task)),
+    }
+  }
+
+  /// Runs the provided function on an executor dedicated to blocking operations.
+  pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
+  where
+    F: FnOnce() -> R + Send + 'static,
+    R: Send + 'static,
+  {
+    match self {
+      Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func)),
+    }
+  }
+
+  /// Runs a future to completion on runtime.
+  pub fn block_on<F: Future>(&self, task: F) -> F::Output {
+    match self {
+      Self::Tokio(r) => r.block_on(task),
+    }
+  }
+}
 
 /// An owned permission to join on a task (await its termination).
 #[derive(Debug)]
-pub struct JoinHandle<T>(TokioJoinHandle<T>);
+pub enum JoinHandle<T> {
+  /// The tokio JoinHandle.
+  Tokio(TokioJoinHandle<T>),
+}
 
 impl<T> JoinHandle<T> {
+  /// Gets a reference to the [`TokioJoinHandle`].
+  pub fn inner(&self) -> &TokioJoinHandle<T> {
+    let Self::Tokio(t) = self;
+    t
+  }
+
   /// Abort the task associated with the handle.
   ///
   /// Awaiting a cancelled task might complete as usual if the task was
   /// already completed at the time it was cancelled, but most likely it
   /// will fail with a cancelled `JoinError`.
   pub fn abort(&self) {
-    self.0.abort();
+    match self {
+      Self::Tokio(t) => t.abort(),
+    }
   }
 }
 
 impl<T> Future for JoinHandle<T> {
   type Output = crate::Result<T>;
-  fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-    self
-      .0
-      .poll(cx)
-      .map_err(|e| crate::Error::JoinError(Box::new(e)))
+  fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+    match self.get_mut() {
+      Self::Tokio(t) => t.poll(cx).map_err(|e| crate::Error::JoinError(Box::new(e))),
+    }
   }
 }
 
-/// Runtime handle definition.
-pub trait RuntimeHandle: fmt::Debug + Clone + Sync + Sync {
-  /// Spawns a future onto the runtime.
-  fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
-  where
-    F: Future + Send + 'static,
-    F::Output: Send + 'static;
-
-  /// Runs a future to completion on runtime.
-  fn block_on<F: Future>(&self, task: F) -> F::Output;
+/// A handle to the async runtime
+#[derive(Clone)]
+pub enum RuntimeHandle {
+  /// The tokio handle.
+  Tokio(TokioHandle),
 }
 
-impl RuntimeHandle for Handle {
-  fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
+impl RuntimeHandle {
+  /// Gets a reference to the [`TokioHandle`].
+  pub fn inner(&self) -> &TokioHandle {
+    let Self::Tokio(h) = self;
+    h
+  }
+
+  /// Runs the provided function on an executor dedicated to blocking operations.
+  pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
+  where
+    F: FnOnce() -> R + Send + 'static,
+    R: Send + 'static,
+  {
+    match self {
+      Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func)),
+    }
+  }
+
+  /// Spawns a future onto the runtime.
+  pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
   where
     F: Future + Send + 'static,
     F::Output: Send + 'static,
   {
-    JoinHandle(self.spawn(task))
+    match self {
+      Self::Tokio(h) => JoinHandle::Tokio(h.spawn(task)),
+    }
   }
 
-  fn block_on<F: Future>(&self, task: F) -> F::Output {
-    self.block_on(task)
+  /// Runs a future to completion on runtime.
+  pub fn block_on<F: Future>(&self, task: F) -> F::Output {
+    match self {
+      Self::Tokio(h) => h.block_on(task),
+    }
+  }
+}
+
+fn default_runtime() -> GlobalRuntime {
+  let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
+  let handle = runtime.handle();
+  GlobalRuntime {
+    runtime: Some(runtime),
+    handle,
   }
 }
 
+/// Sets the runtime to use to execute asynchronous tasks.
+/// For convinience, this method takes a [`TokioHandle`].
+/// Note that you cannot drop the underlying [`TokioRuntime`].
+///
+/// # Example
+///
+/// ```rust
+/// #[tokio::main]
+/// async fn main() {
+///   // perform some async task before initializing the app
+///   do_something().await;
+///   // share the current runtime with Tauri
+///   tauri::async_runtime::set(tokio::runtime::Handle::current());
+///
+///   // bootstrap the tauri app...
+///   // tauri::Builder::default().run().unwrap();
+/// }
+///
+/// async fn do_something() {}
+/// ```
+///
+/// # Panics
+///
+/// Panics if the runtime is already set.
+pub fn set(handle: TokioHandle) {
+  RUNTIME
+    .set(GlobalRuntime {
+      runtime: None,
+      handle: RuntimeHandle::Tokio(handle),
+    })
+    .unwrap_or_else(|_| panic!("runtime already initialized"))
+}
+
 /// Returns a handle of the async runtime.
-pub fn handle() -> impl RuntimeHandle {
-  let runtime = RUNTIME.get_or_init(|| Runtime::new().unwrap());
-  runtime.handle().clone()
+pub fn handle() -> RuntimeHandle {
+  let runtime = RUNTIME.get_or_init(default_runtime);
+  runtime.handle()
 }
 
 /// Runs a future to completion on runtime.
 pub fn block_on<F: Future>(task: F) -> F::Output {
-  let runtime = RUNTIME.get_or_init(|| Runtime::new().unwrap());
+  let runtime = RUNTIME.get_or_init(default_runtime);
   runtime.block_on(task)
 }
 
@@ -100,13 +265,51 @@ where
   F: Future + Send + 'static,
   F::Output: Send + 'static,
 {
-  let runtime = RUNTIME.get_or_init(|| Runtime::new().unwrap());
-  JoinHandle(runtime.spawn(task))
+  let runtime = RUNTIME.get_or_init(default_runtime);
+  runtime.spawn(task)
+}
+
+/// Runs the provided function on an executor dedicated to blocking operations.
+pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
+where
+  F: FnOnce() -> R + Send + 'static,
+  R: Send + 'static,
+{
+  let runtime = RUNTIME.get_or_init(default_runtime);
+  runtime.spawn_blocking(func)
+}
+
+pub(crate) fn safe_block_on<F>(task: F) -> F::Output
+where
+  F: Future + Send + 'static,
+  F::Output: Send + 'static,
+{
+  if tokio::runtime::Handle::try_current().is_ok() {
+    let (tx, rx) = std::sync::mpsc::sync_channel(1);
+    spawn(async move {
+      tx.send(task.await).unwrap();
+    });
+    rx.recv().unwrap()
+  } else {
+    block_on(task)
+  }
 }
 
 #[cfg(test)]
 mod tests {
   use super::*;
+
+  #[tokio::test]
+  async fn runtime_spawn() {
+    let join = spawn(async { 5 });
+    assert_eq!(join.await.unwrap(), 5);
+  }
+
+  #[test]
+  fn runtime_block_on() {
+    assert_eq!(block_on(async { 0 }), 0);
+  }
+
   #[tokio::test]
   async fn handle_spawn() {
     let handle = handle();

+ 23 - 25
core/tauri/src/manager.rs

@@ -314,11 +314,12 @@ impl<R: Runtime> WindowManager<R> {
         let path_for_data = path.clone();
 
         // handle 206 (partial range) http request
-        if let Some(range) = request.headers().get("range") {
+        if let Some(range) = request.headers().get("range").cloned() {
           let mut status_code = 200;
           let path_for_data = path_for_data.clone();
           let mut response = HttpResponseBuilder::new();
-          let (response, status_code, data) = crate::async_runtime::block_on(async move {
+          let (headers, status_code, data) = crate::async_runtime::safe_block_on(async move {
+            let mut headers = HashMap::new();
             let mut buf = Vec::new();
             let mut file = tokio::fs::File::open(path_for_data.clone()).await.unwrap();
             // Get the file size
@@ -345,22 +346,25 @@ impl<R: Runtime> WindowManager<R> {
               // partial content
               status_code = 206;
 
-              response = response
-                .header("Connection", "Keep-Alive")
-                .header("Accept-Ranges", "bytes")
-                .header("Content-Length", real_length)
-                .header(
-                  "Content-Range",
-                  format!("bytes {}-{}/{}", range.start, last_byte, file_size),
-                );
+              headers.insert("Connection", "Keep-Alive".into());
+              headers.insert("Accept-Ranges", "bytes".into());
+              headers.insert("Content-Length", real_length.to_string());
+              headers.insert(
+                "Content-Range",
+                format!("bytes {}-{}/{}", range.start, last_byte, file_size),
+              );
 
               file.seek(SeekFrom::Start(range.start)).await.unwrap();
               file.take(real_length).read_to_end(&mut buf).await.unwrap();
             }
 
-            (response, status_code, buf)
+            (headers, status_code, buf)
           });
 
+          for (k, v) in headers {
+            response = response.header(k, v);
+          }
+
           if !data.is_empty() {
             let mime_type = MimeType::parse(&data, &path);
             return response.mimetype(&mime_type).status(status_code).body(data);
@@ -368,7 +372,7 @@ impl<R: Runtime> WindowManager<R> {
         }
 
         let data =
-          crate::async_runtime::block_on(async move { tokio::fs::read(path_for_data).await })?;
+          crate::async_runtime::safe_block_on(async move { tokio::fs::read(path_for_data).await })?;
         let mime_type = MimeType::parse(&data, &path);
         HttpResponseBuilder::new().mimetype(&mime_type).body(data)
       });
@@ -488,19 +492,13 @@ impl<R: Runtime> WindowManager<R> {
   fn prepare_file_drop(&self, app_handle: AppHandle<R>) -> FileDropHandler<R> {
     let manager = self.clone();
     Box::new(move |event, window| {
-      let manager = manager.clone();
-      let app_handle = app_handle.clone();
-      crate::async_runtime::block_on(async move {
-        let window = Window::new(manager.clone(), window, app_handle);
-        let _ = match event {
-          FileDropEvent::Hovered(paths) => {
-            window.emit_and_trigger("tauri://file-drop-hover", paths)
-          }
-          FileDropEvent::Dropped(paths) => window.emit_and_trigger("tauri://file-drop", paths),
-          FileDropEvent::Cancelled => window.emit_and_trigger("tauri://file-drop-cancelled", ()),
-          _ => unimplemented!(),
-        };
-      });
+      let window = Window::new(manager.clone(), window, app_handle.clone());
+      let _ = match event {
+        FileDropEvent::Hovered(paths) => window.emit_and_trigger("tauri://file-drop-hover", paths),
+        FileDropEvent::Dropped(paths) => window.emit_and_trigger("tauri://file-drop", paths),
+        FileDropEvent::Cancelled => window.emit_and_trigger("tauri://file-drop-cancelled", ()),
+        _ => unimplemented!(),
+      };
       true
     })
   }