command.rs 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. use proc_macro2::TokenStream;
  2. use quote::{format_ident, quote};
  3. use syn::{
  4. parse::Parser, punctuated::Punctuated, FnArg, Ident, ItemFn, Meta, NestedMeta, Pat, Path,
  5. ReturnType, Token, Type,
  6. };
  7. pub fn generate_command(attrs: Vec<NestedMeta>, function: ItemFn) -> TokenStream {
  8. // Check if "with_manager" attr was passed to macro
  9. let uses_manager = attrs.iter().any(|a| {
  10. if let NestedMeta::Meta(Meta::Path(path)) = a {
  11. path
  12. .get_ident()
  13. .map(|i| *i == "with_manager")
  14. .unwrap_or(false)
  15. } else {
  16. false
  17. }
  18. });
  19. let fn_name = function.sig.ident.clone();
  20. let fn_name_str = fn_name.to_string();
  21. let fn_wrapper = format_ident!("{}_wrapper", fn_name);
  22. let returns_result = match function.sig.output {
  23. ReturnType::Type(_, ref ty) => match &**ty {
  24. Type::Path(type_path) => {
  25. type_path
  26. .path
  27. .segments
  28. .first()
  29. .map(|seg| seg.ident.to_string())
  30. == Some("Result".to_string())
  31. }
  32. _ => false,
  33. },
  34. ReturnType::Default => false,
  35. };
  36. // Split function args into names and types
  37. let (mut names, mut types): (Vec<Ident>, Vec<Path>) = function
  38. .sig
  39. .inputs
  40. .iter()
  41. .map(|param| {
  42. let mut arg_name = None;
  43. let mut arg_type = None;
  44. if let FnArg::Typed(arg) = param {
  45. if let Pat::Ident(ident) = arg.pat.as_ref() {
  46. arg_name = Some(ident.ident.clone());
  47. }
  48. if let Type::Path(path) = arg.ty.as_ref() {
  49. arg_type = Some(path.path.clone());
  50. }
  51. }
  52. (
  53. arg_name.clone().unwrap(),
  54. arg_type.unwrap_or_else(|| panic!("Invalid type for arg \"{}\"", arg_name.unwrap())),
  55. )
  56. })
  57. .unzip();
  58. // If function doesn't take the webview manager, wrapper just takes webview manager generically and ignores it
  59. // Otherwise the wrapper uses the specific type from the original function declaration
  60. let mut manager_arg_type = quote!(::tauri::WebviewManager<A>);
  61. let mut application_ext_generic = quote!(<A: ::tauri::ApplicationExt>);
  62. let manager_arg_maybe = match types.first() {
  63. Some(first_type) if uses_manager => {
  64. // Give wrapper specific type
  65. manager_arg_type = quote!(#first_type);
  66. // Generic is no longer needed
  67. application_ext_generic = quote!();
  68. // Remove webview manager arg from list so it isn't expected as arg from JS
  69. types.drain(0..1);
  70. names.drain(0..1);
  71. // Tell wrapper to pass webview manager to original function
  72. quote!(_manager,)
  73. }
  74. // Tell wrapper not to pass webview manager to original function
  75. _ => quote!(),
  76. };
  77. let await_maybe = if function.sig.asyncness.is_some() {
  78. quote!(.await)
  79. } else {
  80. quote!()
  81. };
  82. // if the command handler returns a Result,
  83. // we just map the values to the ones expected by Tauri
  84. // otherwise we wrap it with an `Ok()`, converting the return value to tauri::InvokeResponse
  85. // note that all types must implement `serde::Serialize`.
  86. let return_value = if returns_result {
  87. quote! {
  88. match #fn_name(#manager_arg_maybe #(parsed_args.#names),*)#await_maybe {
  89. Ok(value) => ::core::result::Result::Ok(value.into()),
  90. Err(e) => ::core::result::Result::Err(tauri::Error::Command(::serde_json::to_value(e)?)),
  91. }
  92. }
  93. } else {
  94. quote! { ::core::result::Result::Ok(#fn_name(#manager_arg_maybe #(parsed_args.#names),*)#await_maybe.into()) }
  95. };
  96. quote! {
  97. #function
  98. pub async fn #fn_wrapper #application_ext_generic(_manager: #manager_arg_type, arg: ::serde_json::Value) -> ::tauri::Result<::tauri::InvokeResponse> {
  99. #[derive(::serde::Deserialize)]
  100. #[serde(rename_all = "camelCase")]
  101. struct ParsedArgs {
  102. #(#names: #types),*
  103. }
  104. let parsed_args: ParsedArgs = ::serde_json::from_value(arg).map_err(|e| ::tauri::Error::InvalidArgs(#fn_name_str, e))?;
  105. #return_value
  106. }
  107. }
  108. }
  109. pub fn generate_handler(item: proc_macro::TokenStream) -> TokenStream {
  110. // Get paths of functions passed to macro
  111. let paths = <Punctuated<Path, Token![,]>>::parse_terminated
  112. .parse(item)
  113. .expect("generate_handler!: Failed to parse list of command functions");
  114. // Get names of functions, used for match statement
  115. let fn_names = paths
  116. .iter()
  117. .map(|p| p.segments.last().unwrap().ident.clone());
  118. // Get paths to wrapper functions
  119. let fn_wrappers = paths.iter().map(|func| {
  120. let mut func = func.clone();
  121. let mut last_segment = func.segments.last_mut().unwrap();
  122. last_segment.ident = format_ident!("{}_wrapper", last_segment.ident);
  123. func
  124. });
  125. quote! {
  126. |webview_manager, arg| async move {
  127. let dispatch: ::std::result::Result<::tauri::DispatchInstructions, ::serde_json::Error> =
  128. ::serde_json::from_str(&arg);
  129. match dispatch {
  130. Err(e) => Err(e.into()),
  131. Ok(dispatch) => {
  132. match dispatch.cmd.as_str() {
  133. #(stringify!(#fn_names) => #fn_wrappers(webview_manager, dispatch.args).await,)*
  134. _ => Err(tauri::Error::UnknownApi(None)),
  135. }
  136. }
  137. }
  138. }
  139. }
  140. }