command_module.rs 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. // Copyright 2019-2021 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. use heck::{ToLowerCamelCase, ToSnakeCase};
  5. use proc_macro::TokenStream;
  6. use proc_macro2::{Span, TokenStream as TokenStream2};
  7. use quote::{format_ident, quote, quote_spanned};
  8. use syn::{
  9. parse::{Parse, ParseStream},
  10. parse_quote,
  11. spanned::Spanned,
  12. Data, DeriveInput, Error, Fields, Ident, ItemFn, LitStr, Token,
  13. };
  14. pub(crate) fn generate_command_enum(mut input: DeriveInput) -> TokenStream {
  15. let mut deserialize_functions = TokenStream2::new();
  16. let mut errors = TokenStream2::new();
  17. input.attrs.push(parse_quote!(#[allow(dead_code)]));
  18. match &mut input.data {
  19. Data::Enum(data_enum) => {
  20. for variant in &mut data_enum.variants {
  21. let mut feature: Option<Ident> = None;
  22. let mut error_message: Option<String> = None;
  23. for attr in &variant.attrs {
  24. if attr.path.is_ident("cmd") {
  25. let r = attr
  26. .parse_args_with(|input: ParseStream| {
  27. if let Ok(f) = input.parse::<Ident>() {
  28. feature.replace(f);
  29. input.parse::<Token![,]>()?;
  30. let error_message_raw: LitStr = input.parse()?;
  31. error_message.replace(error_message_raw.value());
  32. }
  33. Ok(quote!())
  34. })
  35. .unwrap_or_else(syn::Error::into_compile_error);
  36. errors.extend(r);
  37. }
  38. }
  39. if let Some(f) = feature {
  40. let error_message = if let Some(e) = error_message {
  41. let e = e.to_string();
  42. quote!(#e)
  43. } else {
  44. quote!("This API is not enabled in the allowlist.")
  45. };
  46. let deserialize_function_name = quote::format_ident!("__{}_deserializer", variant.ident);
  47. deserialize_functions.extend(quote! {
  48. #[cfg(not(#f))]
  49. #[allow(non_snake_case)]
  50. fn #deserialize_function_name<'de, D, T>(deserializer: D) -> ::std::result::Result<T, D::Error>
  51. where
  52. D: ::serde::de::Deserializer<'de>,
  53. {
  54. ::std::result::Result::Err(::serde::de::Error::custom(crate::Error::ApiNotAllowlisted(#error_message.into()).to_string()))
  55. }
  56. });
  57. let deserialize_function_name = deserialize_function_name.to_string();
  58. variant
  59. .attrs
  60. .push(parse_quote!(#[cfg_attr(not(#f), serde(deserialize_with = #deserialize_function_name))]));
  61. }
  62. }
  63. }
  64. _ => {
  65. return Error::new(
  66. Span::call_site(),
  67. "`command_enum` is only implemented for enums",
  68. )
  69. .to_compile_error()
  70. .into()
  71. }
  72. };
  73. TokenStream::from(quote! {
  74. #errors
  75. #input
  76. #deserialize_functions
  77. })
  78. }
  79. pub(crate) fn generate_run_fn(input: DeriveInput) -> TokenStream {
  80. let name = &input.ident;
  81. let data = &input.data;
  82. let mut errors = TokenStream2::new();
  83. let mut is_async = false;
  84. let attrs = input.attrs;
  85. for attr in attrs {
  86. if attr.path.is_ident("cmd") {
  87. let r = attr
  88. .parse_args_with(|input: ParseStream| {
  89. if let Ok(token) = input.parse::<Ident>() {
  90. is_async = token == "async";
  91. }
  92. Ok(quote!())
  93. })
  94. .unwrap_or_else(syn::Error::into_compile_error);
  95. errors.extend(r);
  96. }
  97. }
  98. let maybe_await = if is_async { quote!(.await) } else { quote!() };
  99. let maybe_async = if is_async { quote!(async) } else { quote!() };
  100. let mut matcher;
  101. match data {
  102. Data::Enum(data_enum) => {
  103. matcher = TokenStream2::new();
  104. for variant in &data_enum.variants {
  105. let variant_name = &variant.ident;
  106. let mut feature = None;
  107. for attr in &variant.attrs {
  108. if attr.path.is_ident("cmd") {
  109. let r = attr
  110. .parse_args_with(|input: ParseStream| {
  111. if let Ok(f) = input.parse::<Ident>() {
  112. feature.replace(f);
  113. input.parse::<Token![,]>()?;
  114. let _: LitStr = input.parse()?;
  115. }
  116. Ok(quote!())
  117. })
  118. .unwrap_or_else(syn::Error::into_compile_error);
  119. errors.extend(r);
  120. }
  121. }
  122. let maybe_feature_check = if let Some(f) = feature {
  123. quote!(#[cfg(#f)])
  124. } else {
  125. quote!()
  126. };
  127. let (fields_in_variant, variables) = match &variant.fields {
  128. Fields::Unit => (quote_spanned! { variant.span() => }, quote!()),
  129. Fields::Unnamed(fields) => {
  130. let mut variables = TokenStream2::new();
  131. for i in 0..fields.unnamed.len() {
  132. let variable_name = format_ident!("value{}", i);
  133. variables.extend(quote!(#variable_name,));
  134. }
  135. (quote_spanned! { variant.span() => (#variables) }, variables)
  136. }
  137. Fields::Named(fields) => {
  138. let mut variables = TokenStream2::new();
  139. for field in &fields.named {
  140. let ident = field.ident.as_ref().unwrap();
  141. variables.extend(quote!(#ident,));
  142. }
  143. (
  144. quote_spanned! { variant.span() => { #variables } },
  145. variables,
  146. )
  147. }
  148. };
  149. let mut variant_execute_function_name = format_ident!(
  150. "{}",
  151. variant_name.to_string().to_snake_case().to_lowercase()
  152. );
  153. variant_execute_function_name.set_span(variant_name.span());
  154. matcher.extend(quote_spanned! {
  155. variant.span() => #maybe_feature_check #name::#variant_name #fields_in_variant => #name::#variant_execute_function_name(context, #variables)#maybe_await.map(Into::into),
  156. });
  157. }
  158. matcher.extend(quote! {
  159. _ => Err(crate::error::into_anyhow("API not in the allowlist (https://tauri.studio/docs/api/config#tauri.allowlist)")),
  160. });
  161. }
  162. _ => {
  163. return Error::new(
  164. Span::call_site(),
  165. "CommandModule is only implemented for enums",
  166. )
  167. .to_compile_error()
  168. .into()
  169. }
  170. };
  171. let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
  172. let expanded = quote! {
  173. #errors
  174. impl #impl_generics #name #ty_generics #where_clause {
  175. pub #maybe_async fn run<R: crate::Runtime>(self, context: crate::endpoints::InvokeContext<R>) -> super::Result<crate::endpoints::InvokeResponse> {
  176. match self {
  177. #matcher
  178. }
  179. }
  180. }
  181. };
  182. TokenStream::from(expanded)
  183. }
  184. /// Attributes for the module enum variant handler.
  185. pub struct HandlerAttributes {
  186. allowlist: Ident,
  187. }
  188. impl Parse for HandlerAttributes {
  189. fn parse(input: ParseStream) -> syn::Result<Self> {
  190. Ok(Self {
  191. allowlist: input.parse()?,
  192. })
  193. }
  194. }
  195. pub enum AllowlistCheckKind {
  196. Runtime,
  197. Serde,
  198. }
  199. pub struct HandlerTestAttributes {
  200. allowlist: Ident,
  201. error_message: String,
  202. allowlist_check_kind: AllowlistCheckKind,
  203. }
  204. impl Parse for HandlerTestAttributes {
  205. fn parse(input: ParseStream) -> syn::Result<Self> {
  206. let allowlist = input.parse()?;
  207. input.parse::<Token![,]>()?;
  208. let error_message_raw: LitStr = input.parse()?;
  209. let error_message = error_message_raw.value();
  210. let allowlist_check_kind =
  211. if let (Ok(_), Ok(i)) = (input.parse::<Token![,]>(), input.parse::<Ident>()) {
  212. if i == "runtime" {
  213. AllowlistCheckKind::Runtime
  214. } else {
  215. AllowlistCheckKind::Serde
  216. }
  217. } else {
  218. AllowlistCheckKind::Serde
  219. };
  220. Ok(Self {
  221. allowlist,
  222. error_message,
  223. allowlist_check_kind,
  224. })
  225. }
  226. }
  227. pub fn command_handler(attributes: HandlerAttributes, function: ItemFn) -> TokenStream2 {
  228. let allowlist = attributes.allowlist;
  229. quote!(
  230. #[cfg(#allowlist)]
  231. #function
  232. )
  233. }
  234. pub fn command_test(attributes: HandlerTestAttributes, function: ItemFn) -> TokenStream2 {
  235. let allowlist = attributes.allowlist;
  236. let error_message = attributes.error_message.as_str();
  237. let signature = function.sig.clone();
  238. let enum_variant_name = function.sig.ident.to_string().to_lower_camel_case();
  239. let response = match attributes.allowlist_check_kind {
  240. AllowlistCheckKind::Runtime => {
  241. let test_name = function.sig.ident.clone();
  242. quote!(super::Cmd::#test_name(crate::test::mock_invoke_context()))
  243. }
  244. AllowlistCheckKind::Serde => quote! {
  245. serde_json::from_str::<super::Cmd>(&format!(r#"{{ "cmd": "{}", "data": null }}"#, #enum_variant_name))
  246. },
  247. };
  248. quote!(
  249. #[cfg(#allowlist)]
  250. #function
  251. #[cfg(not(#allowlist))]
  252. #[allow(unused_variables)]
  253. #[quickcheck_macros::quickcheck]
  254. #signature {
  255. if let Err(e) = #response {
  256. assert!(e.to_string().contains(#error_message));
  257. } else {
  258. panic!("unexpected response");
  259. }
  260. }
  261. )
  262. }