command_module.rs 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. // Copyright 2019-2021 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. use heck::SnakeCase;
  5. use proc_macro::TokenStream;
  6. use proc_macro2::{Span, TokenStream as TokenStream2, TokenTree};
  7. use quote::{format_ident, quote, quote_spanned};
  8. use syn::{
  9. parse::{Parse, ParseStream},
  10. spanned::Spanned,
  11. Data, DeriveInput, Error, Fields, FnArg, Ident, ItemFn, LitStr, Pat, Token,
  12. };
  13. pub fn generate_run_fn(input: DeriveInput) -> TokenStream {
  14. let name = &input.ident;
  15. let data = &input.data;
  16. let mut is_async = false;
  17. let attrs = input.attrs;
  18. for attr in attrs {
  19. if attr.path.is_ident("cmd") {
  20. let _ = attr.parse_args_with(|input: ParseStream| {
  21. while let Some(token) = input.parse()? {
  22. if let TokenTree::Ident(ident) = token {
  23. is_async |= ident == "async";
  24. }
  25. }
  26. Ok(())
  27. });
  28. }
  29. }
  30. let maybe_await = if is_async { quote!(.await) } else { quote!() };
  31. let maybe_async = if is_async { quote!(async) } else { quote!() };
  32. let mut matcher;
  33. match data {
  34. Data::Enum(data_enum) => {
  35. matcher = TokenStream2::new();
  36. for variant in &data_enum.variants {
  37. let variant_name = &variant.ident;
  38. let (fields_in_variant, variables) = match &variant.fields {
  39. Fields::Unit => (quote_spanned! { variant.span() => }, quote!()),
  40. Fields::Unnamed(fields) => {
  41. let mut variables = TokenStream2::new();
  42. for i in 0..fields.unnamed.len() {
  43. let variable_name = format_ident!("value{}", i);
  44. variables.extend(quote!(#variable_name,));
  45. }
  46. (quote_spanned! { variant.span() => (#variables) }, variables)
  47. }
  48. Fields::Named(fields) => {
  49. let mut variables = TokenStream2::new();
  50. for field in &fields.named {
  51. let ident = field.ident.as_ref().unwrap();
  52. variables.extend(quote!(#ident,));
  53. }
  54. (
  55. quote_spanned! { variant.span() => { #variables } },
  56. variables,
  57. )
  58. }
  59. };
  60. let mut variant_execute_function_name = format_ident!(
  61. "{}",
  62. variant_name.to_string().to_snake_case().to_lowercase()
  63. );
  64. variant_execute_function_name.set_span(variant_name.span());
  65. matcher.extend(quote_spanned! {
  66. variant.span() => #name::#variant_name #fields_in_variant => #name::#variant_execute_function_name(context, #variables)#maybe_await.map(Into::into),
  67. });
  68. }
  69. }
  70. _ => {
  71. return Error::new(
  72. Span::call_site(),
  73. "CommandModule is only implemented for enums",
  74. )
  75. .to_compile_error()
  76. .into()
  77. }
  78. };
  79. let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
  80. let expanded = quote! {
  81. impl #impl_generics #name #ty_generics #where_clause {
  82. pub #maybe_async fn run<R: crate::Runtime>(self, context: crate::endpoints::InvokeContext<R>) -> crate::Result<crate::endpoints::InvokeResponse> {
  83. match self {
  84. #matcher
  85. }
  86. }
  87. }
  88. };
  89. TokenStream::from(expanded)
  90. }
  91. /// Attributes for the module enum variant handler.
  92. pub struct HandlerAttributes {
  93. allowlist: Ident,
  94. error_message: String,
  95. }
  96. impl Parse for HandlerAttributes {
  97. fn parse(input: ParseStream) -> syn::Result<Self> {
  98. let allowlist = input.parse()?;
  99. input.parse::<Token![,]>()?;
  100. let raw: LitStr = input.parse()?;
  101. let error_message = raw.value();
  102. Ok(Self {
  103. allowlist,
  104. error_message,
  105. })
  106. }
  107. }
  108. pub struct HandlerTestAttributes {
  109. allowlist: Ident,
  110. error_message: String,
  111. is_async: bool,
  112. }
  113. impl Parse for HandlerTestAttributes {
  114. fn parse(input: ParseStream) -> syn::Result<Self> {
  115. let allowlist = input.parse()?;
  116. input.parse::<Token![,]>()?;
  117. let error_message_raw: LitStr = input.parse()?;
  118. let error_message = error_message_raw.value();
  119. let _ = input.parse::<Token![,]>();
  120. let is_async = input
  121. .parse::<Ident>()
  122. .map(|i| i == "async")
  123. .unwrap_or_default();
  124. Ok(Self {
  125. allowlist,
  126. error_message,
  127. is_async,
  128. })
  129. }
  130. }
  131. pub fn command_handler(attributes: HandlerAttributes, function: ItemFn) -> TokenStream2 {
  132. let allowlist = attributes.allowlist;
  133. let error_message = attributes.error_message.as_str();
  134. let signature = function.sig.clone();
  135. quote!(
  136. #[cfg(#allowlist)]
  137. #function
  138. #[cfg(not(#allowlist))]
  139. #[allow(unused_variables)]
  140. #[allow(unused_mut)]
  141. #signature {
  142. Err(crate::Error::ApiNotAllowlisted(
  143. #error_message.to_string(),
  144. ))
  145. }
  146. )
  147. }
  148. pub fn command_test(attributes: HandlerTestAttributes, function: ItemFn) -> TokenStream2 {
  149. let allowlist = attributes.allowlist;
  150. let is_async = attributes.is_async;
  151. let error_message = attributes.error_message.as_str();
  152. let signature = function.sig.clone();
  153. let test_name = function.sig.ident.clone();
  154. let mut args = quote!();
  155. for arg in &function.sig.inputs {
  156. if let FnArg::Typed(t) = arg {
  157. if let Pat::Ident(i) = &*t.pat {
  158. let ident = &i.ident;
  159. args.extend(quote!(#ident,))
  160. }
  161. }
  162. }
  163. let response = if is_async {
  164. quote!(crate::async_runtime::block_on(
  165. super::Cmd::#test_name(crate::test::mock_invoke_context(), #args)
  166. ))
  167. } else {
  168. quote!(super::Cmd::#test_name(crate::test::mock_invoke_context(), #args))
  169. };
  170. quote!(
  171. #[cfg(#allowlist)]
  172. #function
  173. #[cfg(not(#allowlist))]
  174. #[quickcheck_macros::quickcheck]
  175. #signature {
  176. if let Err(crate::Error::ApiNotAllowlisted(e)) = #response {
  177. assert_eq!(e, #error_message);
  178. } else {
  179. panic!("unexpected response");
  180. }
  181. }
  182. )
  183. }