Prechádzať zdrojové kódy

fix(acl): command scope should not error out if missing (#8675)

* fix(acl): command scope should not error out if missing

* propagate error
Lucas Fernandes Nogueira 1 rok pred
rodič
commit
2631e97e2b
1 zmenil súbory, kde vykonal 51 pridanie a 27 odobranie
  1. 51 27
      core/tauri/src/command/authority.rs

+ 51 - 27
core/tauri/src/command/authority.rs

@@ -2,8 +2,8 @@
 // SPDX-License-Identifier: Apache-2.0
 // SPDX-License-Identifier: MIT
 
-use std::collections::BTreeMap;
 use std::fmt::Debug;
+use std::{collections::BTreeMap, ops::Deref};
 
 use serde::de::DeserializeOwned;
 use state::TypeMap;
@@ -112,9 +112,27 @@ impl<T: Debug + DeserializeOwned + Send + Sync + 'static> ScopeValue<T> {
   }
 }
 
+#[derive(Debug)]
+enum OwnedOrRef<'a, T: Debug> {
+  Owned(T),
+  Ref(&'a T),
+}
+
+impl<'a, T: Debug> Deref for OwnedOrRef<'a, T> {
+  type Target = T;
+  fn deref(&self) -> &Self::Target {
+    match self {
+      Self::Owned(t) => t,
+      Self::Ref(r) => r,
+    }
+  }
+}
+
 /// Access scope for a command that can be retrieved directly in the command function.
 #[derive(Debug)]
-pub struct CommandScope<'a, T: Debug + DeserializeOwned + Send + Sync + 'static>(&'a ScopeValue<T>);
+pub struct CommandScope<'a, T: Debug + DeserializeOwned + Send + Sync + 'static>(
+  OwnedOrRef<'a, ScopeValue<T>>,
+);
 
 impl<'a, T: Debug + DeserializeOwned + Send + Sync + 'static> CommandScope<'a, T> {
   /// What this access scope allows.
@@ -133,22 +151,22 @@ impl<'a, R: Runtime, T: Debug + DeserializeOwned + Send + Sync + 'static> Comman
 {
   /// Grabs the [`ResolvedScope`] from the [`CommandItem`] and returns the associated [`CommandScope`].
   fn from_command(command: CommandItem<'a, R>) -> Result<Self, InvokeError> {
-    command
-      .acl
-      .as_ref()
-      .and_then(|resolved| resolved.scope)
-      .and_then(|scope_id| {
+    if let Some(scope_id) = command.acl.as_ref().and_then(|resolved| resolved.scope) {
+      Ok(CommandScope(OwnedOrRef::Ref(
         command
           .message
           .webview
           .manager()
           .runtime_authority
           .scope_manager
-          .get_command_scope_typed(&scope_id)
-          .unwrap_or_default()
-          .map(CommandScope)
-      })
-      .ok_or_else(|| InvokeError::from_anyhow(anyhow::anyhow!("scope not found")))
+          .get_command_scope_typed(&scope_id)?,
+      )))
+    } else {
+      Ok(CommandScope(OwnedOrRef::Owned(ScopeValue {
+        allow: Vec::new(),
+        deny: Vec::new(),
+      })))
+    }
   }
 }
 
@@ -175,6 +193,11 @@ impl<'a, R: Runtime, T: Debug + DeserializeOwned + Send + Sync + 'static> Comman
   fn from_command(command: CommandItem<'a, R>) -> Result<Self, InvokeError> {
     command
       .plugin
+      .ok_or_else(|| {
+        InvokeError::from_anyhow(anyhow::anyhow!(
+          "global scope not available for app commands"
+        ))
+      })
       .and_then(|plugin| {
         command
           .message
@@ -183,10 +206,9 @@ impl<'a, R: Runtime, T: Debug + DeserializeOwned + Send + Sync + 'static> Comman
           .runtime_authority
           .scope_manager
           .get_global_scope_typed(plugin)
-          .ok()
+          .map_err(InvokeError::from_error)
       })
       .map(GlobalScope)
-      .ok_or_else(|| InvokeError::from_anyhow(anyhow::anyhow!("global scope not found")))
   }
 }
 
@@ -228,29 +250,31 @@ impl ScopeManager {
   fn get_command_scope_typed<T: Send + Sync + DeserializeOwned + Debug + 'static>(
     &self,
     key: &ScopeKey,
-  ) -> crate::Result<Option<&ScopeValue<T>>> {
+  ) -> crate::Result<&ScopeValue<T>> {
     let cache = self.command_cache.get(key).unwrap();
     match cache.try_get() {
-      cached @ Some(_) => Ok(cached),
-      None => match self.command_scope.get(key).map(|r| {
+      Some(cached) => Ok(cached),
+      None => {
+        let resolved_scope = self
+          .command_scope
+          .get(key)
+          .unwrap_or_else(|| panic!("missing command scope for key {key}"));
+
         let mut allow: Vec<T> = Vec::new();
         let mut deny: Vec<T> = Vec::new();
 
-        for allowed in &r.allow {
+        for allowed in &resolved_scope.allow {
           allow.push(serde_json::from_value(allowed.clone().into())?);
         }
-        for denied in &r.deny {
+        for denied in &resolved_scope.deny {
           deny.push(serde_json::from_value(denied.clone().into())?);
         }
 
-        crate::Result::Ok(Some(ScopeValue { allow, deny }))
-      }) {
-        None => Ok(None),
-        Some(value) => {
-          let _ = cache.set(value);
-          Ok(cache.try_get())
-        }
-      },
+        let value = ScopeValue { allow, deny };
+
+        let _ = cache.set(value);
+        Ok(cache.get())
+      }
     }
   }
 }