Jelajahi Sumber

feat(core): improved command matching with macros, fixes #1157 (#1301)

Co-authored-by: chip <chip@chip.sh>
Co-authored-by: Lucas Nogueira <lucas@tauri.studio>
Co-authored-by: Lucas Fernandes Nogueira <lucasfernandesnog@gmail.com>
Noah Klayman 4 tahun lalu
induk
melakukan
1f2e7a3226

+ 5 - 0
.changes/simple-command-matching.md

@@ -0,0 +1,5 @@
+---
+"tauri-macros": minor
+---
+
+Added new macros to simplify the creation of commands that can be called by the webview.

+ 0 - 10
cli/tauri.js/templates/src-tauri/src/cmd.rs

@@ -1,10 +0,0 @@
-use serde::Deserialize;
-
-#[derive(Deserialize)]
-#[serde(tag = "cmd", rename_all = "camelCase")]
-pub enum Cmd {
-  // your custom commands
-  // multiple arguments are allowed
-  // note that rename_all = "camelCase": you need to use "myCustomCommand" on JS
-  MyCustomCommand { argument: String },
-}

+ 0 - 18
cli/tauri.js/templates/src-tauri/src/main.rs

@@ -3,29 +3,11 @@
   windows_subsystem = "windows"
 )]
 
-mod cmd;
-
 #[derive(tauri::FromTauriContext)]
 struct Context;
 
 fn main() {
   tauri::AppBuilder::<Context>::new()
-    .invoke_handler(|_webview, arg| async move {
-      use cmd::Cmd::*;
-      match serde_json::from_str(&arg) {
-        Err(e) => Err(e.into()),
-        Ok(command) => {
-          match command {
-            // definitions for your custom commands from Cmd here
-            MyCustomCommand { argument } => {
-              //  your command code
-              println!("{}", argument);
-            }
-          }
-          Ok(().into())
-        }
-      }
-    })
     .build()
     .unwrap()
     .run();

+ 150 - 0
tauri-macros/src/command.rs

@@ -0,0 +1,150 @@
+use proc_macro2::TokenStream;
+use quote::{format_ident, quote};
+use syn::{
+  parse::Parser, punctuated::Punctuated, FnArg, Ident, ItemFn, Meta, NestedMeta, Pat, Path,
+  ReturnType, Token, Type,
+};
+
+pub fn generate_command(attrs: Vec<NestedMeta>, function: ItemFn) -> TokenStream {
+  // Check if "with_manager" attr was passed to macro
+  let uses_manager = attrs.iter().any(|a| {
+    if let NestedMeta::Meta(Meta::Path(path)) = a {
+      path
+        .get_ident()
+        .map(|i| *i == "with_manager")
+        .unwrap_or(false)
+    } else {
+      false
+    }
+  });
+
+  let fn_name = function.sig.ident.clone();
+  let fn_name_str = fn_name.to_string();
+  let fn_wrapper = format_ident!("{}_wrapper", fn_name);
+  let returns_result = match function.sig.output {
+    ReturnType::Type(_, ref ty) => match &**ty {
+      Type::Path(type_path) => {
+        type_path
+          .path
+          .segments
+          .first()
+          .map(|seg| seg.ident.to_string())
+          == Some("Result".to_string())
+      }
+      _ => false,
+    },
+    ReturnType::Default => false,
+  };
+
+  // Split function args into names and types
+  let (mut names, mut types): (Vec<Ident>, Vec<Path>) = function
+    .sig
+    .inputs
+    .iter()
+    .map(|param| {
+      let mut arg_name = None;
+      let mut arg_type = None;
+      if let FnArg::Typed(arg) = param {
+        if let Pat::Ident(ident) = arg.pat.as_ref() {
+          arg_name = Some(ident.ident.clone());
+        }
+        if let Type::Path(path) = arg.ty.as_ref() {
+          arg_type = Some(path.path.clone());
+        }
+      }
+      (
+        arg_name.clone().unwrap(),
+        arg_type.unwrap_or_else(|| panic!("Invalid type for arg \"{}\"", arg_name.unwrap())),
+      )
+    })
+    .unzip();
+
+  // If function doesn't take the webview manager, wrapper just takes webview manager generically and ignores it
+  // Otherwise the wrapper uses the specific type from the original function declaration
+  let mut manager_arg_type = quote!(::tauri::WebviewManager<A>);
+  let mut application_ext_generic = quote!(<A: ::tauri::ApplicationExt>);
+  let manager_arg_maybe = match types.first() {
+    Some(first_type) if uses_manager => {
+      // Give wrapper specific type
+      manager_arg_type = quote!(#first_type);
+      // Generic is no longer needed
+      application_ext_generic = quote!();
+      // Remove webview manager arg from list so it isn't expected as arg from JS
+      types.drain(0..1);
+      names.drain(0..1);
+      // Tell wrapper to pass webview manager to original function
+      quote!(_manager,)
+    }
+    // Tell wrapper not to pass webview manager to original function
+    _ => quote!(),
+  };
+  let await_maybe = if function.sig.asyncness.is_some() {
+    quote!(.await)
+  } else {
+    quote!()
+  };
+
+  // if the command handler returns a Result,
+  // we just map the values to the ones expected by Tauri
+  // otherwise we wrap it with an `Ok()`, converting the return value to tauri::InvokeResponse
+  // note that all types must implement `serde::Serialize`.
+  let return_value = if returns_result {
+    quote! {
+      match #fn_name(#manager_arg_maybe #(parsed_args.#names),*)#await_maybe {
+        Ok(value) => ::core::result::Result::Ok(value.into()),
+        Err(e) => ::core::result::Result::Err(tauri::Error::Command(::serde_json::to_value(e)?)),
+      }
+    }
+  } else {
+    quote! { ::core::result::Result::Ok(#fn_name(#manager_arg_maybe #(parsed_args.#names),*)#await_maybe.into()) }
+  };
+
+  quote! {
+    #function
+    pub async fn #fn_wrapper #application_ext_generic(_manager: #manager_arg_type, arg: ::serde_json::Value) -> ::tauri::Result<::tauri::InvokeResponse> {
+      #[derive(::serde::Deserialize)]
+      #[serde(rename_all = "camelCase")]
+      struct ParsedArgs {
+        #(#names: #types),*
+      }
+      let parsed_args: ParsedArgs = ::serde_json::from_value(arg).map_err(|e| ::tauri::Error::InvalidArgs(#fn_name_str, e))?;
+      #return_value
+    }
+  }
+}
+
+pub fn generate_handler(item: proc_macro::TokenStream) -> TokenStream {
+  // Get paths of functions passed to macro
+  let paths = <Punctuated<Path, Token![,]>>::parse_terminated
+    .parse(item)
+    .expect("generate_handler!: Failed to parse list of command functions");
+
+  // Get names of functions, used for match statement
+  let fn_names = paths
+    .iter()
+    .map(|p| p.segments.last().unwrap().ident.clone());
+
+  // Get paths to wrapper functions
+  let fn_wrappers = paths.iter().map(|func| {
+    let mut func = func.clone();
+    let mut last_segment = func.segments.last_mut().unwrap();
+    last_segment.ident = format_ident!("{}_wrapper", last_segment.ident);
+    func
+  });
+
+  quote! {
+    |webview_manager, arg| async move {
+      let dispatch: ::std::result::Result<::tauri::DispatchInstructions, ::serde_json::Error> =
+      ::serde_json::from_str(&arg);
+      match dispatch {
+        Err(e) => Err(e.into()),
+        Ok(dispatch) => {
+          match dispatch.cmd.as_str() {
+            #(stringify!(#fn_names) => #fn_wrappers(webview_manager, dispatch.args).await,)*
+            _ => Err(tauri::Error::UnknownApi(None)),
+          }
+        }
+      }
+    }
+  }
+}

+ 16 - 1
tauri-macros/src/lib.rs

@@ -1,7 +1,8 @@
 extern crate proc_macro;
 use proc_macro::TokenStream;
-use syn::{parse_macro_input, DeriveInput};
+use syn::{parse_macro_input, AttributeArgs, DeriveInput, ItemFn};
 
+mod command;
 mod error;
 mod expand;
 mod include_dir;
@@ -17,3 +18,17 @@ pub fn load_context(ast: TokenStream) -> TokenStream {
     .unwrap_or_else(|e| e.into_compile_error(&name))
     .into()
 }
+
+#[proc_macro_attribute]
+pub fn command(attrs: TokenStream, item: TokenStream) -> TokenStream {
+  let function = parse_macro_input!(item as ItemFn);
+  let attrs = parse_macro_input!(attrs as AttributeArgs);
+  let gen = command::generate_command(attrs, function);
+  gen.into()
+}
+
+#[proc_macro]
+pub fn generate_handler(item: TokenStream) -> TokenStream {
+  let gen = command::generate_handler(item);
+  gen.into()
+}

+ 10 - 11
tauri/examples/api/src-tauri/src/cmd.rs

@@ -1,4 +1,5 @@
 use serde::Deserialize;
+use tauri::command;
 
 #[derive(Debug, Deserialize)]
 pub struct RequestBody {
@@ -6,15 +7,13 @@ pub struct RequestBody {
   name: String,
 }
 
-#[derive(Deserialize)]
-#[serde(tag = "cmd", rename_all = "camelCase")]
-pub enum Cmd {
-  LogOperation {
-    event: String,
-    payload: Option<String>,
-  },
-  PerformRequest {
-    endpoint: String,
-    body: RequestBody,
-  },
+#[command]
+pub fn log_operation(event: String, payload: Option<String>) {
+  println!("{} {:?}", event, payload);
+}
+
+#[command]
+pub fn perform_request(endpoint: String, body: RequestBody) -> String {
+  println!("{} {:?}", endpoint, body);
+  "message response".into()
 }

+ 4 - 16
tauri/examples/api/src-tauri/src/main.rs

@@ -31,22 +31,10 @@ fn main() {
           .expect("failed to emit");
       });
     })
-    .invoke_handler(|_webview_manager, arg| async move {
-      use cmd::Cmd::*;
-      match serde_json::from_str(&arg) {
-        Err(e) => Err(e.into()),
-        Ok(command) => match command {
-          LogOperation { event, payload } => {
-            println!("{} {:?}", event, payload);
-            Ok(().into())
-          }
-          PerformRequest { endpoint, body } => {
-            println!("{} {:?}", endpoint, body);
-            Ok("message response".into())
-          }
-        },
-      }
-    })
+    .invoke_handler(tauri::generate_handler![
+      cmd::log_operation,
+      cmd::perform_request
+    ])
     .build()
     .unwrap()
     .run();

+ 2 - 2
tauri/examples/api/src/components/Communication.svelte

@@ -8,7 +8,7 @@
 
   function log() {
     invoke({
-      cmd: "logOperation",
+      cmd: "log_operation",
       event: "tauri-click",
       payload: "this payload is optional because we used Option in Rust"
     });
@@ -16,7 +16,7 @@
 
   function performRequest() {
     invoke({
-      cmd: "performRequest",
+      cmd: "perform_request",
       endpoint: "dummy endpoint arg",
       body: {
         id: 5,

+ 0 - 10
tauri/examples/helloworld/src-tauri/src/cmd.rs

@@ -1,10 +0,0 @@
-use serde::Deserialize;
-
-#[derive(Deserialize)]
-#[serde(tag = "cmd", rename_all = "camelCase")]
-pub enum Cmd {
-  // your custom commands
-  // multiple arguments are allowed
-  // note that rename_all = "camelCase": you need to use "myCustomCommand" on JS
-  MyCustomCommand { argument: String },
-}

+ 6 - 18
tauri/examples/helloworld/src-tauri/src/main.rs

@@ -3,30 +3,18 @@
   windows_subsystem = "windows"
 )]
 
-mod cmd;
-
 #[derive(tauri::FromTauriContext)]
 #[config_path = "examples/helloworld/src-tauri/tauri.conf.json"]
 struct Context;
 
+#[tauri::command]
+fn my_custom_command(argument: String) {
+  println!("{}", argument);
+}
+
 fn main() {
   tauri::AppBuilder::<Context>::new()
-    .invoke_handler(|_webview, arg| async move {
-      use cmd::Cmd::*;
-      match serde_json::from_str(&arg) {
-        Err(e) => Err(e.into()),
-        Ok(command) => {
-          match command {
-            // definitions for your custom commands from Cmd here
-            MyCustomCommand { argument } => {
-              //  your command code
-              println!("{}", argument);
-            }
-          }
-          Ok(().into())
-        }
-      }
-    })
+    .invoke_handler(tauri::generate_handler![my_custom_command])
     .build()
     .unwrap()
     .run();

+ 10 - 1
tauri/src/app.rs

@@ -1,5 +1,5 @@
 use futures::future::BoxFuture;
-use serde::Serialize;
+use serde::{Deserialize, Serialize};
 use serde_json::Value as JsonValue;
 use tauri_api::{config::Config, private::AsTauriContext};
 
@@ -63,6 +63,15 @@ impl<T: Serialize> From<T> for InvokeResponse {
   }
 }
 
+#[derive(Deserialize)]
+#[allow(missing_docs)]
+#[serde(tag = "cmd", rename_all = "camelCase")]
+pub struct DispatchInstructions {
+  pub cmd: String,
+  #[serde(flatten)]
+  pub args: JsonValue,
+}
+
 /// The application runner.
 pub struct App<A: ApplicationExt> {
   /// The JS message handler.

+ 6 - 0
tauri/src/error.rs

@@ -43,6 +43,12 @@ pub enum Error {
   /// API not whitelisted on tauri.conf.json
   #[error("'{0}' not on the allowlist (https://tauri.studio/docs/api/config#tauri.allowlist)")]
   ApiNotAllowlisted(String),
+  /// Command error (userland).
+  #[error("{0}")]
+  Command(serde_json::Value),
+  /// Invalid args when running a command.
+  #[error("invalid args for command `{0}`: {1}")]
+  InvalidArgs(&'static str, serde_json::Error),
 }
 
 impl From<serde_json::Error> for Error {

+ 1 - 1
tauri/src/lib.rs

@@ -31,7 +31,7 @@ pub type SyncTask = Box<dyn FnOnce() + Send>;
 
 pub use app::*;
 pub use tauri_api as api;
-pub use tauri_macros::FromTauriContext;
+pub use tauri_macros::{command, generate_handler, FromTauriContext};
 
 /// The Tauri webview implementations.
 pub mod flavors {